In [1]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from nn_mri import ImageFitting_set, SineLayer, get_mgrid
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
from skimage.color import rgb2gray, gray2rgb
from torch import nn
from skimage import data, img_as_float
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage.metrics import peak_signal_noise_ratio
from skimage.util import random_noise
from skimage.transform import rescale, resize, downscale_local_mean
from sklearn.cluster import AgglomerativeClustering
import scipy.io as sio
import os
import argparse
from csv import writer

KeyboardInterrupt: 

In [28]:
# Fourier feature mapping
def input_mapping(x, B):
  if B is None:
    return x
  else:
    x_proj = (2.*np.pi*x) @ B.T
    return np.concatenate([np.sin(x_proj), np.cos(x_proj)], axis=-1)

SyntaxError: invalid syntax (3661548087.py, line 2)

In [25]:
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, 
                 hidden_layers, out_features, 
                 first_omega_0=30., 
                 hidden_omega_0=30.,
                 perturb=False):
        super().__init__()
        # self.net is the INR that calculates signal intensities for its inputs
        self.net = []
        self.tanh = nn.Tanh()  
        self.final_linear = nn.Linear(hidden_features, out_features)
        with torch.no_grad():
            self.final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0)
        self.net.append(SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0))

        self.net.append(self.final_linear)

        self.net = nn.Sequential(*self.net)
        self.perturb_linear = nn.Linear(in_features + 1, in_features + 1)
        self.perturb_linear2 = nn.Linear(in_features + 1, in_features)
        
        self.perturb = perturb
        
    def forward(self, coords, sample=0,eps=0):
        coords = coords.clone().detach().requires_grad_(False) # allows to take derivative w.r.t. input
        if self.perturb:
            acq = torch.tensor([sample], dtype=torch.float).cuda()
            acq = acq.repeat(coords.size(0),1)
            perturbation = self.perturb_linear(torch.cat((coords, acq),-1))
            perturbation = self.tanh(perturbation)
            perturbation = self.perturb_linear2(perturbation)
            pertubation = eps*self.tanh(perturbation)
            coords = coords + pertubation
        output = self.net(coords)

        return output

In [26]:
filename = '/home/gundogdu/toy2.mat'
acquisitions = sio.loadmat(filename)['pertubed_acq']
filename2 = '/home/gundogdu/nomo.mat'
nomo = sio.loadmat(filename2)['x_sub']

## Scheduled INR

In [27]:
from IPython import display
os.environ["CUDA_VISIBLE_DEVICES"]="3"

mean_img = np.mean(acquisitions,-1)
img_dataset = []
for inx in range(acquisitions.shape[-1]):
    img = acquisitions[:,:,inx]
    img_dataset.append(Image.fromarray(img))

mean_dataset = ImageFitting_set([Image.fromarray(mean_img)])
dataset = ImageFitting_set(img_dataset)

img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=128,
                      hidden_layers=3, perturb=True)

img_siren.cuda()

inr_params = list(img_siren.net.parameters())
inr_optim = torch.optim.Adam(lr=1e-4, params=inr_params)
perturb_params = list(img_siren.perturb_linear.parameters()) + list(img_siren.perturb_linear2.parameters())
perturb_optim = torch.optim.Adam(lr=1e-6, params=perturb_params)

torch.cuda.empty_cache()
ctr = 0
new_loss = 1000
while True:
    ctr += 1
    if ctr < 4000:
        ground_truth, model_input  = mean_dataset.pixels[0], mean_dataset.coords[0]
        ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
        ground_truth /= ground_truth.max()
        model_output = img_siren.forward(model_input)
        loss = ((model_output - ground_truth)**2).mean()
        inr_optim.zero_grad()
        loss.backward()
        inr_optim.step()
    else:
        if ctr%2:
            ground_truth, model_input  = mean_dataset.pixels[0], mean_dataset.coords[0]
            ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
            ground_truth /= ground_truth.max()
            model_output = img_siren.forward(model_input)
            loss = ((model_output - ground_truth)**2).mean()
            inr_optim.zero_grad()
            loss.backward()
            inr_optim.step()
        else:
            for sample in range(len(dataset)):
                ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
                ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
                ground_truth /= ground_truth.max()
                model_output = img_siren.forward(model_input, sample, 1/128.)
                if not sample:
                    loss = ((model_output - ground_truth)**2).mean()
                else:
                    loss += ((model_output - ground_truth)**2).mean()
        
            perturb_optim.zero_grad()
            loss.backward()
            perturb_optim.step()
        
    if not ctr%1000:
        print(new_loss)
        model_input  = get_mgrid(256, 2).cuda()
        recon = torch.clamp(img_siren.forward(model_input), min=0).cpu().view(256,256).detach().numpy()
        fig, ax = plt.subplots(1,3,figsize=(18,6))
        ax[0].imshow(recon, cmap='gray')
        ax[0].set_title('super')
        ax[1].imshow(rescale(nomo,2), cmap='gray')
        ax[1].set_title('no-motion')
        ax[2].imshow(rescale(mean_img,2), cmap='gray')
        ax[2].set_title('mean')
        for axi in range(3):
            ax[axi].axis('off')

        display.display(plt.gcf())
    if loss.item() > new_loss and ctr>10000:
        break      
    else:
        new_loss = loss.item()
print('Done')

PATH = 'toy_model.pt'
torch.save(img_siren.state_dict(), PATH)

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:

img_dataset = []
ctr = 0


dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=128,
                      hidden_layers=3, perturb=True)
img_siren.cuda()
torch.cuda.empty_cache()
PATH = 'toy_model.pt'
img_siren.load_state_dict(torch.load(PATH))
img_siren.cuda()


params1 = list(img_siren.perturb_linear.parameters()) + list(img_siren.perturb_linear2.parameters())
optim1 = torch.optim.Adam(lr=1e-4, params=params1)
params2 = list(img_siren.net.parameters())
optim2 = torch.optim.Adam(lr=1e-6, params=params2)
ctr = 0
new_loss = 1000




while True:
    for sample in range(len(dataset)):
        ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
        ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
        ground_truth /= ground_truth.max()
        model_output = img_siren.forward(model_input, sample, 4/720.)
        if not sample:
            loss = ((model_output - ground_truth)**2).mean()
        else:
            loss += ((model_output - ground_truth)**2).mean()
    #optim1.zero_grad()
    optim.zero_grad()
    loss.backward()
    #optim1.step()
    optim.step()
    if loss.item() > new_loss and ctr>100000:
        break      
    else:
        new_loss = loss.item()
    if not ctr%1000:
        print(new_loss)
        model_input  = get_mgrid(720, 2).cuda()
        recon = img_siren.forward(model_input, 0, 4/720.0).cpu().view(720,720).detach().numpy()
        fig, ax = plt.subplots(1,2,figsize=(12,6))
        ax[0].imshow(recon, cmap='gray')
        ax[0].set_title('super')
        ax[1].imshow(rescale(mean_img,4), vmin=0.6, vmax=1,cmap='gray')
        ax[1].set_title('mean')
        display.display(plt.gcf())


    ctr +=1
print('Done')

In [None]:
sio.savemat('output.mat', {'out':recon})

In [None]:
model_input.shape

In [None]:
class kiwi:
    def __init__(self, img_id):

        self.img_id = img_id
        file_address = '/home/gundogdu/matfiles'
        filename = os.path.join(file_address, img_id)
        self.dwi = sio.loadmat(filename)['img']

In [None]:
file_address = '/home/gundogdu/matfiles'
kiwi_scans = []
for f in os.listdir(file_address):
    kiwi_scans.append(kiwi(f))

In [None]:
_kiwi = kiwi_scans[7]
dwi = _kiwi.dwi[40:90,20:70]
fig, ax = plt.subplots(1,1, figsize=(6,6))
ax.imshow(dwi, cmap='gray')
ax.set_title(_kiwi.img_id)
ax.axis('off')
plt.show()

In [None]:
img_dataset = []
mean_img = np.zeros(_kiwi.dwi[40:90,20:70].shape)
ctr = 0
for inx in range(len(kiwi_scans)):
    _kiwi = kiwi_scans[inx]
    if not 'high' in _kiwi.img_id and not 'motion' in _kiwi.img_id:
        img = _kiwi.dwi[40:90,20:70]
        mean_img += _kiwi.dwi[40:90,20:70]
        ctr +=1
        img_dataset.append(Image.fromarray(img))
mean_img /= ctr
dataset = ImageFitting_set(img_dataset)

In [None]:
dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=64,
                      hidden_layers=3, perturb=False)
img_siren.cuda()
torch.cuda.empty_cache()
optim = torch.optim.Adam(lr=3e-4, params=img_siren.parameters())

In [None]:
img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=64,
                      hidden_layers=3, perturb=True)
PATH = 'model.pt'
img_siren.load_state_dict(torch.load(PATH))
img_siren.cuda()

In [None]:
from IPython import display
params1 = list(img_siren.perturb_linear.parameters()) + list(img_siren.perturb_linear2.parameters())
params2 = list(img_siren.net.parameters())
perturb = True
if perturb:
    optim1 = torch.optim.Adam(lr=0.00001, params=params1)
    optim2 = torch.optim.Adam(lr=0, params=params2)
    perturb_degree = 1/100.
else:
    optim1 = torch.optim.Adam(lr=0, params=params1)
    optim2 = torch.optim.Adam(lr=3e-4, params=params2)
    perturb_degree = 0
    
ctr = 0
new_loss = 1000


while True:
    for sample in range(len(dataset)):
        ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
        ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
        ground_truth /= ground_truth.max()
        model_output = img_siren.forward(model_input, sample, perturb_degree)
        if not sample:
            loss = ((model_output - ground_truth)**2).mean()
        else:
            loss += ((model_output - ground_truth)**2).mean()
    if perturb:
        optim1.zero_grad()
    optim2.zero_grad()
    loss.backward()
    if perturb:
        optim1.step()
    optim2.step()
    if loss.item() > new_loss and ctr>1000:
        break      
    else:
        new_loss = loss.item()
    if not ctr%500:
        print(new_loss)
        model_input  = get_mgrid(100, 2).cuda()
        recon = img_siren.forward(model_input, 0, perturb_degree).cpu().view(100,100).detach().numpy()
        fig, ax = plt.subplots(1,2,figsize=(12,6))
        ax[0].imshow(recon, cmap='gray')
        ax[0].set_title('super')
        ax[1].imshow(rescale(mean_img,2),cmap='gray')
        ax[1].set_title('mean')
        display.display(plt.gcf())


    ctr +=1
print('Done')
PATH = 'model.pt'
torch.save(img_siren.state_dict(), PATH)

In [None]:
img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=128,
                      hidden_layers=3, perturb=True)
PATH = 'model.pt'
img_siren.load_state_dict(torch.load(PATH))
img_siren.cuda()

#### model_input  = get_mgrid(128, 2).cuda()
recon = img_siren(model_input).cpu().view(128,128).detach().numpy()


fig, ax = plt.subplots(1,1,figsize=(6,6))
ax.imshow(recon, cmap='gray')
ax.set_title('super')


In [None]:
mean_recon = np.zeros((128,128))
model_input  = get_mgrid(128, 2).cuda()
for i in range(9):
    model_output = img_siren(model_input, i, 1/64.0)
    mean_recon += model_output.cpu().view(128,128).detach().numpy()
mean_recon /= 9
#model_input  = get_mgrid(128, 2).cuda()
#model_output = img_siren(model_input).cpu().view((128,128)).detach().numpy()
fig, ax = plt.subplots(13,1,figsize=(6,72))
ax[0].imshow(mean_recon, cmap='gray')
ax[0].set_title('super')
ax[1].imshow(mean_img,cmap='gray')
ax[1].set_title('mean')
i=4
for inx in range(len(kiwi_scans)):
    _kiwi = kiwi_scans[inx]
    if 'low_high' in _kiwi.img_id:
        img = _kiwi.dwi[40:90,20:70]   
        ax[2].imshow(img,cmap='gray')
        ax[2].set_title(_kiwi.img_id)
    elif 'high.mat' in _kiwi.img_id:
        img = _kiwi.dwi[70:160,40:120]   
        ax[3].imshow(img,cmap='gray')
        ax[3].set_title(_kiwi.img_id)
    else:
        img = _kiwi.dwi[40:90,20:70]   
        ax[i].imshow(img,cmap='gray')
        ax[i].set_title(_kiwi.img_id)
        i+=1


In [None]:
mean_recon = np.zeros((128,128))
model_input  = get_mgrid(128, 2).cuda()
for i in range(9):
    model_output = img_siren(model_input, i, 1/60.0)
    mean_recon += model_output.cpu().view(128,128).detach().numpy()
mean_recon /= 9
#model_input  = get_mgrid(128, 2).cuda()
#model_output = img_siren(model_input).cpu().view((128,128)).detach().numpy()
fig, ax = plt.subplots(1,4, figsize=(24,6))
ax[1].imshow(mean_recon, cmap='gray')
ax[1].set_title('super resolution')
ax[0].imshow(rescale(mean_img,2),cmap='gray')
ax[0].set_title('mean image')

for inx in range(len(kiwi_scans)):
    _kiwi = kiwi_scans[inx]
    if 'low_high' in _kiwi.img_id:
        img = _kiwi.dwi[40:90,20:70]   
        ax[2].imshow(rescale(img,2),cmap='gray')
        ax[2].set_title('no motion')
    elif 'high.mat' in _kiwi.img_id:
        img = _kiwi.dwi[70:150,34:115]   
        ax[3].imshow(img,cmap='gray')
        ax[3].set_title('high resolution')
for i in range(4):
    ax[i].axis('off')

In [None]:
mean_recon = np.zeros((128,128))
model_input  = get_mgrid(128, 2).cuda()
for i in range(9):
    model_output = img_siren(model_input, i, 1/64.0)
    mean_recon += model_output.cpu().view(128,128).detach().numpy()
mean_recon /= 9
#model_input  = get_mgrid(128, 2).cuda()
#model_output = img_siren(model_input).cpu().view((128,128)).detach().numpy()
fig, ax = plt.subplots(13,1,figsize=(6,72))
ax[0].imshow(mean_recon, cmap='gray')
ax[0].set_title('super')
ax[1].imshow(mean_img,cmap='gray')
ax[1].set_title('mean')
i=4
for inx in range(len(kiwi_scans)):
    _kiwi = kiwi_scans[inx]
    if 'low_high' in _kiwi.img_id:
        img = _kiwi.dwi[40:90,20:70]   
        ax[2].imshow(img,cmap='gray')
        ax[2].set_title(_kiwi.img_id)
    elif 'high.mat' in _kiwi.img_id:
        img = _kiwi.dwi[70:160,40:120]   
        ax[3].imshow(img,cmap='gray')
        ax[3].set_title(_kiwi.img_id)
    else:
        img = _kiwi.dwi[40:90,20:70]   
        ax[i].imshow(img,cmap='gray')
        ax[i].set_title(_kiwi.img_id)
        i+=1


In [None]:
dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=128,
                      hidden_layers=3, perturb=False)
img_siren.cuda()
torch.cuda.empty_cache()
optim = torch.optim.Adam(lr=3e-4, params=img_siren.parameters())

In [None]:
#params1 = list(img_siren.perturb_linear.parameters()) + list(img_siren.perturb_linear2.parameters())
#optim1 = torch.optim.Adam(lr=3e-4, params=params1)
#params2 = list(img_siren.net.parameters()) + list(img_siren.final_linear.parameters())
#optim2 = torch.optim.Adam(lr=0.000001, params=params2)
ctr = 0
new_loss = 1000


while True:
    for sample in range(9):
        ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
        ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
        ground_truth /= ground_truth.max()
        model_output = img_siren(model_input)
        if not sample:
            loss = ((model_output - ground_truth)**2).mean()
        else:
            loss += ((model_output - ground_truth)**2).mean()
    optim.zero_grad()
    #optim2.zero_grad()
    loss.backward()
    optim.step()
    #optim2.step()
    if loss.item() > new_loss and ctr>100:
        break      
    else:
        new_loss = loss.item()
    if not ctr%500:
        print(new_loss)

    ctr +=1

In [None]:
model_output.max()

In [None]:
ground_truth.max()

In [None]:
_case.pt_id

In [None]:
_case = cases[0]
seed = 0
img_siren = Siren(in_features=2, out_features=1, 
                      hidden_features=128,
                      hidden_layers=3, perturb=True)



PATH = os.path.join('models', _case.pt_id + '_' + str(seed) + '.pt')
img_siren.load_state_dict(torch.load(PATH))
img_siren.cuda()

_slice = _case.cancer_slice
b = _case.b[3]
b0 = _case.b0[:, :, _slice]
dwi = _case.b3[:, :, _slice, :]
img = np.mean(dwi,-1)
mean_recon = np.zeros((128,128))
model_input  = get_mgrid(128, 2).cuda()
for i in range(_case.b3.shape[3]):
    model_output = img_siren(model_input, i, 1.0/128.0)
    mean_recon += model_output.cpu().view(128,128).detach().numpy()
mean_recon /= _case.b3.shape[3]
fig, axes = plt.subplots(1,2, figsize=(12,6))
axes[1].imshow(mean_recon)
axes[0].imshow(img)
axes[0].axis('off')
axes[1].axis('off')
plt.show()
print([round(x,3) for x in calculate_CNR_SNR(_case, img)])
print([round(x,3) for x in calculate_CNR_SNR(_case, mean_recon)])

In [None]:
adc_in = calc_adc(img, _case.b0[:,:,_slice], _case.b[3])
adc_out = calc_adc(mean_recon, _case.b0[:,:,_slice], _case.b[3])

In [None]:
print([round(x,2) for x in calculate_CNR_SNR(_case, adc_in)])
print([round(x,2) for x in calculate_CNR_SNR(_case, adc_out)])

In [None]:
fig, ax = plt.subplots(1,2,figsize=(18,18))
ax[0].imshow(adc_in[35:95, 35:95], cmap='gray',vmin=0.0, vmax = 3)
ax[1].imshow(adc_out[35:95, 35:95], cmap='gray',vmin=0.0,vmax=3)
adc_out.max()

In [None]:
big_size = 512
big_mean = np.zeros((big_size,big_size))
model_input  = get_mgrid(big_size, 2).cuda()
for i in range(_case.b3.shape[3]):
    big_mean += img_siren(model_input, i, 1.0/128).cpu().view(big_size,big_size).detach().numpy()
big_mean /= _case.b3.shape[3]

fig, axes = plt.subplots(1,2, figsize=(25,25))
axes[1].imshow(big_mean, cmap='gray')
axes[0].imshow(rescale(img,4), cmap = 'gray')
axes[0].axis('off')
axes[1].axis('off')
plt.show()


fig, axes = plt.subplots(1,2, figsize=(25,25))
axes[1].imshow(calc_adc(big_mean, rescale(b0,4), _case.b[3])[35*4:95*4,35*4:95*4], cmap='gray',vmin=0,vmax=3)
axes[0].imshow(rescale(calc_adc(img, b0, _case.b[3]),4)[35*4:95*4,35*4:95*4], cmap = 'gray',vmin=0,vmax=3)
axes[0].axis('off')
axes[1].axis('off')
plt.show()