In [1]:

#Importing all the relevant library
%matplotlib inline
import h5py, os
#from functions import transforms as T
#from functions.subsample import MaskFunc
from scipy.io import loadmat
from torch.utils.data import DataLoader
import numpy as np
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import glob
#from functions import transforms as T 
#from functions.subsample import MaskFunc
from PIL import Image
import random
from numpy.fft import fftshift, ifftshift, fftn, ifftn
import cmath

In [2]:
import torchvision.transforms as transforms

In [3]:
import numpy as np
from numpy.fft import fftshift, ifftshift, fftn, ifftn

def transform_kspace_to_image(k, dim=None, img_shape=None):
    """ Computes the Fourier transform from k-space to image space
    along a given or all dimensions
    :param k: k-space data
    :param dim: vector of dimensions to transform
    :param img_shape: desired shape of output image
    :returns: data in image space (along transformed dimensions)
    """
    if not dim:
        dim = range(k.ndim)

    img = fftshift(ifftn(ifftshift(k, axes=dim), s=img_shape, axes=dim), axes=dim)
    #img = fftshift(ifft2(ifftshift(k, dim=dim)), dim=dim)
    img *= np.sqrt(np.prod(np.take(img.shape, dim)))
    return img


def transform_image_to_kspace(img, dim=None, k_shape=None):
    """ Computes the Fourier transform from image space to k-space space
    along a given or all dimensions
    :param img: image space data
    :param dim: vector of dimensions to transform
    :param k_shape: desired shape of output k-space data
    :returns: data in k-space (along transformed dimensions)
    """
    if not dim:
        dim = range(img.ndim)

    k = fftshift(fftn(ifftshift(img, axes=dim), s=k_shape, axes=dim), axes=dim)
    #k = fftshift(fft2(ifftshift(img, dim=dim)), dim=dim)
    k /= np.sqrt(np.prod(np.take(img.shape, dim)))
    return k

In [4]:

def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

In [5]:
class MRIDataset(DataLoader):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        subject_id = self.data_list[idx]

        return get_epoch_batch(subject_id)

In [6]:
import random
import numpy as np
from numpy.fft import fftshift, ifftshift, fftn, ifftn
import cmath
def noise_and_kspace(image):
    #change to k-space
    img_fft = fftshift(fftn(image))
    size_img = img_fft.shape
     #np.random.uniform, np.random.normal
    std = np.random.normal(0.000, 0.005) * np.amax(img_fft)
    noise = fftshift(std * np.random.standard_normal(size_img) + std * 1j * np.random.standard_normal(size_img));     #This generates a complex noise signal.
    img_fft_noise = img_fft + noise # k-space
    img_noise = ifftn(ifftshift(img_fft_noise))# revert k-space back to noise
    return img_noise

In [7]:
def get_epoch_batch(subject_id):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name = subject_id  
    
#    with h5py.File(rawdata_name, 'r') as data:
#        rawdata = data['kspace'][slice]
   
    im_frame = Image.open(rawdata_name)
    noise_im_frame = noise_and_kspace(im_frame)

    ############################
    #img_und = to_tensor(np.array(noise_im_frame)).unsqueeze(0) # noise image tensor form    
    preprocess = T.Compose([
                       # T.Grayscale(num_output_channels=1),
                           T.Resize(64),    #128 as maximum #64
                           T.CenterCrop(64),
                           T.ToTensor() #,
                            ])
    img_gt = preprocess(Image.fromarray(np.uint8(im_frame)).convert('L'))
    img_und = preprocess(Image.fromarray(np.uint8(noise_im_frame)).convert('L'))
    
    n1 = (img_und**2).sum(dim=-1).sqrt()
    norm = n1.max() 
    if norm < 1e-6: norm = 1e-6
    
    img_gt, img_und = img_gt/norm , img_und/norm

    return img_gt.squeeze(0), img_und.squeeze(0)

In [8]:
def load_data_path(train_data_path, val_data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """

    data_list = {}
    train_and_val = ['train', 'val']
    data_path = [train_data_path, val_data_path]
      
    for i in range(len(data_path)):

        data_list[train_and_val[i]] = []
        
        which_data_path = data_path[i]
        tr = 0
        te = 0
        alfa = 0
    
        for fname in sorted(os.listdir(which_data_path)):
            if fname == '.DS_Store': continue
            
            subject_data_path = os.path.join(which_data_path, fname)
                     
            if not os.path.isfile(subject_data_path): continue 
            
     
            #get information from text file
            # this will return a tuple of root and extension
            split_tup = os.path.splitext(fname)

  
            # extract the file name and extension
            file_name = split_tup[0]
  
                
            # the first 5 slices are mostly noise so it is better to exlude them
            data_list[train_and_val[i]].append((fname, subject_data_path))
    
    return data_list

# RestNet

In [9]:
class baseBlock(torch.nn.Module):
    expansion = 1
    def __init__(self,input_planes,planes,stride=1,dim_change=None):
        super(baseBlock,self).__init__()
        #declare convolutional layers with batch norms
        self.conv1 = torch.nn.Conv2d(input_planes,planes,stride=stride,kernel_size=3,padding=1)
        self.bn1   = torch.nn.BatchNorm2d(planes)
        self.conv2 = torch.nn.Conv2d(planes,planes,stride=1,kernel_size=3,padding=1)
        self.bn2   = torch.nn.BatchNorm2d(planes)
        self.dim_change = dim_change
    def forward(self,x):
        #Save the residue
        res = x
        output = F.relu(self.bn1(self.conv1(x)))
        output = self.bn2(self.conv2(output))
        if self.dim_change is not None:
            res = self.dim_change(res)
        
        output += res
        output = F.relu(output)

        return output

class ResNet(torch.nn.Module):
    def __init__(self,block,num_layers,classes=10):
        super(ResNet,self).__init__()
        #according to research paper:
        self.input_planes = 64 #256
        self.conv1 = torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1)
        self.layer1 = self._layer(block,64,num_layers[0],stride=1)
        self.layer2 = self._layer(block,64,num_layers[1],stride=1)
        self.layer3 = self._layer(block,64,num_layers[2],stride=1)
        self.layer4 = self._layer(block,64,num_layers[3],stride=1)
        self.layer5 = self._layer(block,64,num_layers[3],stride=1)
        self.layer6 = self._layer(block,64,num_layers[3],stride=1)
        self.layer7 = self._layer(block,64,num_layers[3],stride=1)
        self.layer8 = self._layer(block,64,num_layers[2],stride=1)
        self.layer9 = self._layer(block,64,num_layers[3],stride=1)
        self.layer10 = self._layer(block,64,num_layers[3],stride=1)
       # self.layer11 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer12 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer13 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer14 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer15 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer16 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer17 = self._layer(block,64,num_layers[3],stride=1)
        #self.layer18 = self._layer(block,64,num_layers[3],stride=1)
        self.conv2 = torch.nn.Conv2d(64,1,kernel_size=3,stride=1, padding=1)
        
    
    def _layer(self,block,planes,num_layers,stride=1):
        dim_change = None
        if stride!=1 or planes != self.input_planes*block.expansion:
            dim_change = torch.nn.Sequential(torch.nn.Conv2d(self.input_planes,planes*block.expansion,kernel_size=1,stride=stride),
                                             torch.nn.BatchNorm2d(planes*block.expansion))
        netLayers =[]
        netLayers.append(block(self.input_planes,planes,stride=stride,dim_change=dim_change))
        self.input_planes = planes * block.expansion
        for i in range(1,num_layers):
            netLayers.append(block(self.input_planes,planes))
            self.input_planes = planes * block.expansion
        
        return torch.nn.Sequential(*netLayers)

    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.layer9(x)
        x = self.layer10(x)
     #   x = self.layer11(x)
      #  x = self.layer12(x)
      #  x = self.layer13(x)
     #   x = self.layer14(x)
      #  x = self.layer15(x)
     #   x = self.layer16(x)
     #   x = self.layer17(x)
     #   x = self.layer18(x)
        x = self.conv2(x)

        return x

In [10]:
from skimage.metrics import structural_similarity as cmp_ssim 
from skimage.metrics import mean_squared_error
from skimage.metrics import normalized_root_mse
def ssim(gt, pred):
    """ Compute Structural Similarity Index Metric (SSIM). """
    return cmp_ssim(
         gt, pred, multichannel=False, data_range=gt.max()
    )
#def ssim(gt, pred):
#    """ Compute Structural Similarity Index Metric (SSIM). """
#    return cmp_ssim(
 #       gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
 #   )
def mse(gt, pred):
    """ Compute mean squared error. """
    return mean_squared_error(gt, pred)

def nrmse(gt, pred):
    """ Compute normalized root mse. """
    return normalized_root_mse(gt, pred)

In [11]:

    
data_path_train = 'dataBrain'
data_path_val = 'dataBrain'
data_list = load_data_path(data_path_train, data_path_val)
    

num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    
#mae_loss = nn.L1Loss().to('cuda:0')
mae_loss = nn.L1Loss()
lr = 0.0001 # 3e-3
    #acc =8 , network_8fold
network_8fold = ResNet(baseBlock,[2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2])
#network_8fold.to('cuda:0') #move the model on the GPU

    
optimizer2 = optim.Adam(network_8fold.parameters(), lr=lr)

 
train_dataset = MRIDataset(data_list['train'])
val_dataset = MRIDataset(data_list['val'])

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, num_workers=num_workers) 
print("finish data loading- now train")
 

finish data loading- now train


In [12]:
network_8fold

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1): Sequential(
    (0): baseBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): baseBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): baseBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Ba

In [None]:
import time
import random

then = time.time() #Time before the operations start
losses2=[]
mean_loss_list = []
img_nr = 0
epoch_nums = 200 #64 # 5
for epoch in range(epoch_nums):
    for iteration, sample in enumerate(train_loader):
        img_nr += 1
        img_gt, img_und = sample
        
        img_gt = img_gt.unsqueeze(1)#.to('cuda:0') # img_gt = img_gt.unsqueeze(1).to('cuda:0')
        img_und = img_und.unsqueeze(1)#.to('cuda:0') #img_und = img_und.unsqueeze(1).to('cuda:0')
       # img_gt = img_gt.to('cuda:0')
       # img_und = img_und.to('cuda:0')
            
        output = network_8fold(img_und)      #feedforward
        #print(output.shape) #// debug
        loss = mae_loss(output, img_gt)

        optimizer2.zero_grad()       #set current gradients to 0
        loss.backward()      #backpropagate
        optimizer2.step()     #update the weights
        mean_loss_list.append(loss.item())
       # print("Loss value: ", loss.item())
            #compute and print the mean L1 lossscore for the last 20 training images.
        if img_nr%20 == 0:
            print("L1 Loss score: ", np.round(np.mean(mean_loss_list), decimals = 5), "  Image number: ", img_nr, "  Epoch: ", epoch+1)
            mean_loss_list = []
        losses2.append(loss.item() * img_gt.size(0))
        
now = time.time() #Time after it finished

print("It took: ", now-then, " seconds")



L1 Loss score:  0.55654   Image number:  20   Epoch:  1




L1 Loss score:  0.13589   Image number:  40   Epoch:  2




L1 Loss score:  0.0992   Image number:  60   Epoch:  3
L1 Loss score:  0.09001   Image number:  80   Epoch:  3




L1 Loss score:  0.08323   Image number:  100   Epoch:  4




L1 Loss score:  0.08745   Image number:  120   Epoch:  5




L1 Loss score:  0.09275   Image number:  140   Epoch:  6
L1 Loss score:  0.08978   Image number:  160   Epoch:  6




L1 Loss score:  0.08712   Image number:  180   Epoch:  7




L1 Loss score:  0.08407   Image number:  200   Epoch:  8




L1 Loss score:  0.06975   Image number:  220   Epoch:  9
L1 Loss score:  0.06127   Image number:  240   Epoch:  9




L1 Loss score:  0.07567   Image number:  260   Epoch:  10




L1 Loss score:  0.07845   Image number:  280   Epoch:  11




L1 Loss score:  0.06873   Image number:  300   Epoch:  12
L1 Loss score:  0.06888   Image number:  320   Epoch:  12




L1 Loss score:  0.07293   Image number:  340   Epoch:  13




L1 Loss score:  0.07301   Image number:  360   Epoch:  14




L1 Loss score:  0.0751   Image number:  380   Epoch:  15
L1 Loss score:  0.0787   Image number:  400   Epoch:  15




L1 Loss score:  0.06684   Image number:  420   Epoch:  16




L1 Loss score:  0.05739   Image number:  440   Epoch:  17




L1 Loss score:  0.06071   Image number:  460   Epoch:  18
L1 Loss score:  0.06056   Image number:  480   Epoch:  18




L1 Loss score:  0.0506   Image number:  500   Epoch:  19




L1 Loss score:  0.05389   Image number:  520   Epoch:  20
L1 Loss score:  0.05991   Image number:  540   Epoch:  20




L1 Loss score:  0.05581   Image number:  560   Epoch:  21




L1 Loss score:  0.04641   Image number:  580   Epoch:  22




L1 Loss score:  0.05076   Image number:  600   Epoch:  23
L1 Loss score:  0.05955   Image number:  620   Epoch:  23




L1 Loss score:  0.05997   Image number:  640   Epoch:  24




L1 Loss score:  0.05839   Image number:  660   Epoch:  25




L1 Loss score:  0.05749   Image number:  680   Epoch:  26
L1 Loss score:  0.05281   Image number:  700   Epoch:  26




L1 Loss score:  0.05438   Image number:  720   Epoch:  27




L1 Loss score:  0.04987   Image number:  740   Epoch:  28




L1 Loss score:  0.04293   Image number:  760   Epoch:  29
L1 Loss score:  0.04691   Image number:  780   Epoch:  29




L1 Loss score:  0.04937   Image number:  800   Epoch:  30




L1 Loss score:  0.05222   Image number:  820   Epoch:  31




L1 Loss score:  0.05969   Image number:  840   Epoch:  32
L1 Loss score:  0.05023   Image number:  860   Epoch:  32




L1 Loss score:  0.04313   Image number:  880   Epoch:  33




L1 Loss score:  0.04491   Image number:  900   Epoch:  34




L1 Loss score:  0.04613   Image number:  920   Epoch:  35
L1 Loss score:  0.05041   Image number:  940   Epoch:  35




L1 Loss score:  0.0484   Image number:  960   Epoch:  36




L1 Loss score:  0.04967   Image number:  980   Epoch:  37




L1 Loss score:  0.05429   Image number:  1000   Epoch:  38
L1 Loss score:  0.04009   Image number:  1020   Epoch:  38




L1 Loss score:  0.03283   Image number:  1040   Epoch:  39




L1 Loss score:  0.05354   Image number:  1060   Epoch:  40
L1 Loss score:  0.04433   Image number:  1080   Epoch:  40




L1 Loss score:  0.04205   Image number:  1100   Epoch:  41




L1 Loss score:  0.04447   Image number:  1120   Epoch:  42




L1 Loss score:  0.04027   Image number:  1140   Epoch:  43
L1 Loss score:  0.04349   Image number:  1160   Epoch:  43




L1 Loss score:  0.04183   Image number:  1180   Epoch:  44




L1 Loss score:  0.04485   Image number:  1200   Epoch:  45




L1 Loss score:  0.04645   Image number:  1220   Epoch:  46
L1 Loss score:  0.0474   Image number:  1240   Epoch:  46




L1 Loss score:  0.04596   Image number:  1260   Epoch:  47




L1 Loss score:  0.04214   Image number:  1280   Epoch:  48




L1 Loss score:  0.04402   Image number:  1300   Epoch:  49
L1 Loss score:  0.0436   Image number:  1320   Epoch:  49




L1 Loss score:  0.0488   Image number:  1340   Epoch:  50




L1 Loss score:  0.0434   Image number:  1360   Epoch:  51




L1 Loss score:  0.05182   Image number:  1380   Epoch:  52
L1 Loss score:  0.03718   Image number:  1400   Epoch:  52




L1 Loss score:  0.0356   Image number:  1420   Epoch:  53




L1 Loss score:  0.03262   Image number:  1440   Epoch:  54




L1 Loss score:  0.03486   Image number:  1460   Epoch:  55
L1 Loss score:  0.04229   Image number:  1480   Epoch:  55




L1 Loss score:  0.04396   Image number:  1500   Epoch:  56




L1 Loss score:  0.03918   Image number:  1520   Epoch:  57




L1 Loss score:  0.03711   Image number:  1540   Epoch:  58
L1 Loss score:  0.03783   Image number:  1560   Epoch:  58




L1 Loss score:  0.0414   Image number:  1580   Epoch:  59




L1 Loss score:  0.03756   Image number:  1600   Epoch:  60
L1 Loss score:  0.04155   Image number:  1620   Epoch:  60




L1 Loss score:  0.03795   Image number:  1640   Epoch:  61




L1 Loss score:  0.04132   Image number:  1660   Epoch:  62




L1 Loss score:  0.04064   Image number:  1680   Epoch:  63
L1 Loss score:  0.03811   Image number:  1700   Epoch:  63




L1 Loss score:  0.03775   Image number:  1720   Epoch:  64




L1 Loss score:  0.03557   Image number:  1740   Epoch:  65




L1 Loss score:  0.03581   Image number:  1760   Epoch:  66
L1 Loss score:  0.04047   Image number:  1780   Epoch:  66




L1 Loss score:  0.03951   Image number:  1800   Epoch:  67




L1 Loss score:  0.03556   Image number:  1820   Epoch:  68




L1 Loss score:  0.03107   Image number:  1840   Epoch:  69
L1 Loss score:  0.03584   Image number:  1860   Epoch:  69




L1 Loss score:  0.03559   Image number:  1880   Epoch:  70




L1 Loss score:  0.04104   Image number:  1900   Epoch:  71




L1 Loss score:  0.0362   Image number:  1920   Epoch:  72
L1 Loss score:  0.04009   Image number:  1940   Epoch:  72




L1 Loss score:  0.03616   Image number:  1960   Epoch:  73




L1 Loss score:  0.03249   Image number:  1980   Epoch:  74




L1 Loss score:  0.03895   Image number:  2000   Epoch:  75
L1 Loss score:  0.032   Image number:  2020   Epoch:  75




L1 Loss score:  0.03084   Image number:  2040   Epoch:  76




L1 Loss score:  0.0279   Image number:  2060   Epoch:  77




L1 Loss score:  0.03561   Image number:  2080   Epoch:  78
L1 Loss score:  0.03232   Image number:  2100   Epoch:  78




L1 Loss score:  0.0274   Image number:  2120   Epoch:  79




L1 Loss score:  0.03243   Image number:  2140   Epoch:  80
L1 Loss score:  0.02808   Image number:  2160   Epoch:  80




L1 Loss score:  0.0329   Image number:  2180   Epoch:  81




L1 Loss score:  0.0338   Image number:  2200   Epoch:  82




L1 Loss score:  0.03239   Image number:  2220   Epoch:  83
L1 Loss score:  0.03612   Image number:  2240   Epoch:  83




L1 Loss score:  0.03437   Image number:  2260   Epoch:  84




L1 Loss score:  0.03565   Image number:  2280   Epoch:  85




L1 Loss score:  0.02495   Image number:  2300   Epoch:  86
L1 Loss score:  0.02955   Image number:  2320   Epoch:  86




L1 Loss score:  0.03837   Image number:  2340   Epoch:  87




L1 Loss score:  0.03193   Image number:  2360   Epoch:  88




L1 Loss score:  0.02578   Image number:  2380   Epoch:  89
L1 Loss score:  0.03246   Image number:  2400   Epoch:  89




L1 Loss score:  0.03043   Image number:  2420   Epoch:  90




L1 Loss score:  0.02879   Image number:  2440   Epoch:  91




L1 Loss score:  0.03654   Image number:  2460   Epoch:  92
L1 Loss score:  0.03043   Image number:  2480   Epoch:  92




L1 Loss score:  0.0302   Image number:  2500   Epoch:  93




L1 Loss score:  0.02915   Image number:  2520   Epoch:  94




L1 Loss score:  0.02735   Image number:  2540   Epoch:  95
L1 Loss score:  0.03216   Image number:  2560   Epoch:  95




L1 Loss score:  0.03119   Image number:  2580   Epoch:  96




L1 Loss score:  0.02993   Image number:  2600   Epoch:  97




L1 Loss score:  0.03016   Image number:  2620   Epoch:  98
L1 Loss score:  0.03289   Image number:  2640   Epoch:  98




L1 Loss score:  0.03436   Image number:  2660   Epoch:  99




L1 Loss score:  0.03239   Image number:  2680   Epoch:  100
L1 Loss score:  0.02983   Image number:  2700   Epoch:  100




L1 Loss score:  0.03197   Image number:  2720   Epoch:  101




L1 Loss score:  0.02704   Image number:  2740   Epoch:  102




L1 Loss score:  0.02797   Image number:  2760   Epoch:  103
L1 Loss score:  0.03071   Image number:  2780   Epoch:  103




L1 Loss score:  0.03115   Image number:  2800   Epoch:  104




L1 Loss score:  0.02503   Image number:  2820   Epoch:  105




L1 Loss score:  0.0251   Image number:  2840   Epoch:  106
L1 Loss score:  0.02193   Image number:  2860   Epoch:  106




L1 Loss score:  0.02354   Image number:  2880   Epoch:  107




L1 Loss score:  0.02821   Image number:  2900   Epoch:  108




L1 Loss score:  0.02925   Image number:  2920   Epoch:  109
L1 Loss score:  0.02545   Image number:  2940   Epoch:  109




L1 Loss score:  0.02864   Image number:  2960   Epoch:  110




L1 Loss score:  0.02565   Image number:  2980   Epoch:  111




L1 Loss score:  0.02681   Image number:  3000   Epoch:  112
L1 Loss score:  0.02313   Image number:  3020   Epoch:  112




L1 Loss score:  0.03879   Image number:  3040   Epoch:  113




L1 Loss score:  0.03202   Image number:  3060   Epoch:  114




L1 Loss score:  0.0284   Image number:  3080   Epoch:  115
L1 Loss score:  0.027   Image number:  3100   Epoch:  115




L1 Loss score:  0.02658   Image number:  3120   Epoch:  116




L1 Loss score:  0.02687   Image number:  3140   Epoch:  117




L1 Loss score:  0.02792   Image number:  3160   Epoch:  118
L1 Loss score:  0.02731   Image number:  3180   Epoch:  118




L1 Loss score:  0.02575   Image number:  3200   Epoch:  119




L1 Loss score:  0.02551   Image number:  3220   Epoch:  120
L1 Loss score:  0.02344   Image number:  3240   Epoch:  120




L1 Loss score:  0.02701   Image number:  3260   Epoch:  121




L1 Loss score:  0.02524   Image number:  3280   Epoch:  122


In [None]:
plt.plot(losses2)
plt.show()

In [None]:
#compute the SSIM score for every image after a feedforward propagation through 
#the network.
#Subtract the image SSIM score before the feedforward prop to obtain the net improvement for every image.
#Print the average improvement and the average SSIM score after the reconstruction.
SSIM_improvement = []
SSIM_score = []
MSE_improvement = []
MSE_score = []
NRMSE_improvement = []
NRMSE_score = []
MIE_improvement = []
MIE_score = []
for i in range(0,len(val_dataset)):
    gt, image = val_dataset[i]
    #image = image.unsqueeze(0).to('cuda:0')
    image = image.unsqueeze(0)
    image = image.unsqueeze(0)
    gt = gt.unsqueeze(0).numpy()
    output = network_8fold(image)
  #  output = output.squeeze(1).cpu().detach().numpy()
    output = output.squeeze(1).detach().numpy()
    image = image.squeeze(1).numpy()
    gt =  np.squeeze(gt)
    output =  np.squeeze(output)
    image =  np.squeeze(image)


    output_loss1 = torch.tensor(ssim(gt, output))
    output_loss2 = torch.tensor(mse(gt, output))
    output_loss3 = torch.tensor(nrmse(gt, output))
    output_loss4 = np.mean(np.abs(gt - output))
  #  image_loss = torch.tensor(ssim(gt, image.squeeze(1).cpu().numpy()))
    image_loss1 = torch.tensor(ssim(gt, image))
    image_loss2 = torch.tensor(mse(gt, image))
    image_loss3 = torch.tensor(nrmse(gt, image))
    image_loss4 = np.mean(np.abs(gt - image))
    SSIM_improvement.append(output_loss1.item()-image_loss1.item())
    SSIM_score.append(output_loss1.item())
    MSE_improvement.append(output_loss2.item()-image_loss2.item())
    MSE_score.append(output_loss2.item())
    NRMSE_improvement.append(output_loss3.item()-image_loss3.item())
    NRMSE_score.append(output_loss3.item())
    MIE_improvement.append(output_loss4 -image_loss4)
    MIE_score.append(output_loss4)

print(np.nanmean(SSIM_improvement))
print(np.nanmean(SSIM_score))
print(np.nanmean(MSE_improvement))
print(np.nanmean(MSE_score))
print(np.nanmean(NRMSE_improvement))
print(np.nanmean(NRMSE_score))
print(np.nanmean(MIE_improvement))
print(np.nanmean(MIE_score))

In [None]:
SSIM_improvement.sort()
plt.plot(SSIM_improvement)

In [None]:
MSE_improvement.sort()
plt.plot(MSE_improvement)

In [None]:
NRMSE_improvement.sort()
plt.plot(NRMSE_improvement)

In [None]:
MIE_improvement.sort()
plt.plot(MSE_improvement)

## save Model

In [None]:
index = 11

In [None]:
#output_dir = f"s3://savemodels/network_8fold/restnet-model{index}.pt"
output_dir = f"./network_8fold/restnet-oasis-model{index}.pt"

In [None]:
#save model to S3 bucket or data
torch.save(network_8fold.state_dict(), output_dir)
#torch.save(network_8fold.state_dict(), './models/resnet18-model.pt')

## Load Model from saved model

In [None]:
index = 10

In [None]:
#output_dir = f"s3://savemodels/network_8fold/restnet-model{index}.pt"
output_dir = f"./network_8fold/restnet-oasis-model{index}.pt"

In [None]:
#load model on CPU: laptop
device = torch.device('cpu')
#model = TheModelClass(*args, **kwargs)
model = ResNet(baseBlock,[2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2])
#model.load_state_dict(torch.load(PATH, map_location=device))
model.load_state_dict(torch.load(output_dir, map_location=device))
model.eval()

## Predict a single image

In [None]:
index = 9

In [None]:
#model_dir = f"s3://savemodels/network_8fold/restnet-model{index}.pt"
model_dir = f"./network_8fold/restnet-oasis-model{index}.pt"

In [None]:
#load model on CPU: laptop
device = torch.device('cpu')
#model = TheModelClass(*args, **kwargs)
model = ResNet(baseBlock,[2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2])
#model.load_state_dict(torch.load(PATH, map_location=device))
model.load_state_dict(torch.load(model_dir, map_location=device))
model.eval()

In [None]:
#compute the SSIM score for every image after a feedforward propagation through 
#the network.
#Subtract the image SSIM score before the feedforward prop to obtain the net improvement for every image.
#Print the average improvement and the average SSIM score after the reconstruction.
SSIM_improvement = []
SSIM_score = []
MSE_improvement = []
MSE_score = []
MIE_improvement = []
MIE_score = []
for i in range(0,len(val_dataset)):
    gt, image = val_dataset[i]
    #image = image.unsqueeze(0).to('cuda:0')
    image = image.unsqueeze(0)
    image = image.unsqueeze(0)
    gt = gt.unsqueeze(0).numpy()
    output = model(image)
  #  output = output.squeeze(1).cpu().detach().numpy()
    output = output.squeeze(1).detach().numpy()
    image = image.squeeze(1).numpy()
    gt =  np.squeeze(gt)
    output =  np.squeeze(output)
    image =  np.squeeze(image)


    output_loss1 = torch.tensor(ssim(gt, output))
    output_loss2 = torch.tensor(mse(gt, output))
    output_loss3 = np.mean(np.abs(gt - output))
  #  image_loss = torch.tensor(ssim(gt, image.squeeze(1).cpu().numpy()))
    image_loss1 = torch.tensor(ssim(gt, image))
    image_loss2 = torch.tensor(mse(gt, image))
    image_loss3 = np.mean(np.abs(gt - image))
    SSIM_improvement.append(output_loss1.item()-image_loss1.item())
    SSIM_score.append(output_loss1.item())
    MSE_improvement.append(output_loss2.item()-image_loss2.item())
    MSE_score.append(output_loss2.item())
    MIE_improvement.append(output_loss3 -image_loss3)
    MIE_score.append(output_loss3)

print(np.nanmean(SSIM_improvement))
print(np.nanmean(SSIM_score))
print(np.nanmean(MSE_improvement))
print(np.nanmean(MSE_score))
print(np.nanmean(MIE_improvement))
print(np.nanmean(MIE_score))

In [None]:
SSIM_improvement.sort()
plt.plot(SSIM_improvement)

In [None]:
MSE_improvement.sort()
plt.plot(MSE_improvement)

In [None]:
MIE_improvement.sort()
plt.plot(MSE_improvement)

In [42]:
from PIL import Image

In [8]:
file_dir = "dataBrain/OAS1_0365_MR1_mpr-1_anon_sag_66.png"

In [None]:
im_frame = Image.open(file_dir)
   
noise_im_frame = noise_and_kspace(im_frame)

preprocess = T.Compose([
                       # T.Grayscale(num_output_channels=1),
                           T.Resize(64),    #128 as maximum
                           T.CenterCrop(64),
                           T.ToTensor() #,
                           #T.Normalize(
                            #        mean=[0.485, 0.456, 0.406],
                               #        std=[0.229, 0.224, 0.225]
                             ##         )
                            ])
img_gt = preprocess(Image.fromarray(np.uint8(im_frame)).convert('L'))
img_und = preprocess(Image.fromarray(np.uint8(noise_im_frame)).convert('L'))
    
n1 = (img_und**2).sum(dim=-1).sqrt()
norm = n1.max() 
if norm < 1e-6: norm = 1e-6
    
img_gt, img_und = img_gt/norm , img_und/norm  
    




In [12]:
from skimage.metrics import structural_similarity as cmp_ssim 
def ssim(gt, pred):
    """ Compute Structural Similarity Index Metric (SSIM). """
    return cmp_ssim(
        gt, pred, multichannel=True, data_range=gt.max()
    )

In [None]:

img_und = img_und.unsqueeze(0)
output = model(img_und)
   # output = output.squeeze(1).cpu().detach().numpy()
output = output.squeeze(1).detach() #.numpy()   #image under numpy form


In [None]:
output.shape

In [None]:
np_rescontruct_image =  output # np.reshape(output, (64, 64))# image noise numpy array
im_reconstruct = T.ToPILImage()(np_rescontruct_image)#Image.fromarray(np_rescontruct_image).convert('L')
im_reconstruct.save("testing/test.png") #for prediction values
im_reconstruct.save("pred1.png")

In [None]:
display(im_reconstruct)

In [None]:
display(im_frame)

In [None]:
display(Image.fromarray(np.uint8(noise_im_frame)))