In [None]:
import torch 
import torch.nn as nn 

import h5py, os
import numpy as np
from matplotlib import pyplot as plt
from functions import transforms as T
from torch.nn import functional as F
from functions.subsample import MaskFunc
from scipy.io import loadmat
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from math import exp
import torch.optim as optim
from skimage.measure import compare_ssim

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'  # check whether a GPU is available
torch.cuda.empty_cache()

In [None]:
def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

In [None]:
class MRIDataset(DataLoader):
    def __init__(self, data_list, acceleration, center_fraction):
        self.data_list = data_list
        self.acceleration = acceleration
        self.center_fraction = center_fraction

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        subject_id = self.data_list[idx]
        return get_epoch_batch(subject_id, self.acceleration, self.center_fraction)

In [None]:
def get_epoch_batch(subject_id, acc, center_fract, use_seed=False):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id  
    
    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]
                      
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=center_fract, accelerations=acc)
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)

    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6
    
    # normalized data
    img_gt, img_und = img_gt/norm, img_und/norm
        
    img_gt, img_und = img_gt.squeeze(0), img_und.squeeze(0)
    
    img_gt, img_und = T.complex_abs(img_gt), T.complex_abs(img_und)
    
    return T.center_crop(img_gt, (320, 320)), T.center_crop(img_und, (320, 320))

In [None]:
def load_data_path(data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """

    data_list = {}
    train_and_val = ['train', 'val']
    limit = 60
    
        
    l = sorted(os.listdir(data_path))
    
    for i in range(len(train_and_val)):

        data_list[train_and_val[i]] = []
        
        if i == 0 : la = l[:limit]
        else : la = l[limit:]
    
        for fname in la:
                
            subject_data_path = os.path.join(data_path, fname)
                     
            if not os.path.isfile(subject_data_path): continue
            
            with h5py.File(subject_data_path, 'r') as data:
                num_slice = data['kspace'].shape[0]
                
            # the first 5 slices are mostly noise so it is better to exlude them
            if i == 1:
                data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(0, num_slice)]
            else:
                data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(14, 25)]
    
    return data_list 

In [None]:
#PREPARE THE DATA 
data_list = load_data_path('/data/local/NC2019MRI/train')

acc = [4,8]
cen_fract = [0.08, 0.04]
num_workers = 10 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    
# create data loader for training set. It applies same to validation set as well
train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=14, num_workers=num_workers)

In [8]:
class ConvolutionalBlock(nn.Module):
    """
    2 lots of:
        3x3 convolutional blocks
        Instance Normalisation 
        ReLu
        Dropout 
    """

    def __init__(self, in_chans, out_chans, drop_prob):
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_prob = drop_prob

        self.layers = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_chans),
            nn.ReLU(),
            nn.Dropout2d(drop_prob),
            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_chans),
            nn.ReLU(),
            nn.Dropout2d(drop_prob)
        )
        

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        #print('input', input.shape)
        return self.layers(input)



In [9]:
class NeuralNetworkModel(nn.Module):
    """
        Unet model
    """

    def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
            chans (int): Number of output channels of the first convolution layer.
            num_pool_layers (int): Number of down-sampling and up-sampling layers.
            drop_prob (float): Dropout probability.
        """
        super().__init__()
        

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob

        self.layers_list_downsample = nn.ModuleList()
        convblock_1 = ConvolutionalBlock(in_chans, chans, drop_prob)
        self.layers_list_downsample += [convblock_1]
        
        #print(in_chans, chans)
        
        ch = chans
        #create a new convolutionalm block for each layer, doubling the number of channels to downsample 
        for i in range(num_pool_layers - 1):
            new_convBlock = ConvolutionalBlock(ch, ch * 2, drop_prob)
            self.layers_list_downsample += [new_convBlock]
            #print(ch, ch*2)
            ch *= 2
        #one for convolution block with the same number of channels as the previous    
        self.conv = ConvolutionalBlock(ch, ch, drop_prob)
        #print(ch, ch)

        #the same thing but decreasing the chanells to upsample 
        self.layers_list_upsample = nn.ModuleList()
        for i in range(num_pool_layers - 1):
            new_convBlock = ConvolutionalBlock(ch * 2, ch // 2, drop_prob)
            self.layers_list_upsample += [new_convBlock]
            #print(ch * 2, ch // 2)
            ch //= 2
        self.layers_list_upsample += [ConvolutionalBlock(ch * 2, ch, drop_prob)]
        #print(ch*2, ch)
        
        #2 convolution layers to build the data up to the same size at the input 
        self.conv2 = nn.Sequential(
            nn.Conv2d(ch, ch // 2, kernel_size=1),
            nn.Conv2d(ch // 2, out_chans, kernel_size=1),
            nn.Conv2d(out_chans, out_chans, kernel_size=1),
        )
        #print(ch, ch // 2 )
        #print(ch // 2, out_chans)
        #print(out_chans, out_chans)

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = []
        output = input
        # Apply down-sampling layers
        for layer in self.layers_list_downsample:
            output = layer(output)
            #print('output', output.shape)
            stack.append(output)
            #apply 2x2 max pooling 
            output = F.max_pool2d(output, kernel_size=2)

        output = self.conv(output)
        #print('output', output.shape)
        
        #print('up-sampling')

        # Apply up-sampling layers
        for layer in self.layers_list_upsample:
            #print(output.shape, stack[-1].shape)
            output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False)
            #print(output.shape, stack[-1].shape)
            output = torch.cat([output, stack.pop()], dim=1)
            output = layer(output)
        return self.conv2(output)

In [10]:
import torch.optim as optim

#create a model
model = NeuralNetworkModel(
    in_chans=1,
    out_chans=1,
    chans=32,
    num_pool_layers=4,
    drop_prob=0.0
).to(device)

#inspect parameters 
# print("Before training: \n", model.state_dict())

In [11]:
#loss function
#start point: L1 loss |output - gold standard|

#ssim loss
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


# Classes to re-use window
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()
        
        if channel == self.channel and self.window.type == img1.type:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return 1 - ssim(img2, img1, window=window, window_size=self.window_size, size_average=self.size_average, val_range=img2.max())
        
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True):
    device = img1.device
    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
    levels = weights.size()[0]
    mssim = []
    mcs = []
    for _ in range(levels):
        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
        mssim.append(sim)
        mcs.append(cs)

        img1 = F.avg_pool2d(img1, (2, 2))
        img2 = F.avg_pool2d(img2, (2, 2))


    mssim = torch.stack(mssim)
    mcs = torch.stack(mcs)

    # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
    if normalize:
        mssim = (mssim + 1) / 2
        mcs = (mcs + 1) / 2

    pow1 = mcs ** weights
    pow2 = mssim ** weights
    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
    output = torch.prod(pow1[:-1] * pow2[-1])
    return output


class MSSSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, channel=3):
        super(MSSSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = channel

    def forward(self, img1, img2):
        # TODO: store window between calls if possible
        return 1 - msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, val_range=img1.max())
    
ssim_loss = SSIM()
msssim_loss = MSSSIM()

#mean square error (MSELoss)
loss_fn = nn.MSELoss(reduction='mean')

In [12]:
# set learning rate
lr = 1e-4
wd = 0.0
#optimiser
#stochastic gradient descent (SGD)
# optimiser = optim.SGD(model.parameters(), lr=lr)
optimiser = optim.Adam(model.parameters(), lr=lr)
# optimiser = torch.optim.RMSprop(model.parameters(), lr)

In [13]:
epochs = 10


for epoch in range(epochs):
    model.train() 
    mean = []
    for iter, data in enumerate(train_loader):
        target_img, input_img = data
        input_img = input_img.to(device).unsqueeze(1)
        target_img = target_img.to(device).unsqueeze(1)

        output_img = model(input_img)
        
        loss = ssim_loss(output_img, target_img)
        # loss = msssim_loss(output_img, target_img)
        # loss = F.l1_loss(output_img, target_img)
        # loss = loss_fn(output_img, target_img)
        # loss = 0.84 * (1 - msssim_loss(output_img, target_img)) - 0.16 * F.l1_loss(output_img, target_img)
        mean.append(loss)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
    l = sum(mean)/len(mean)
    print("Epoch {}'s loss: {}".format(epoch, l))


Epoch 0's loss: 0.7352574467658997
Epoch 1's loss: 0.49669113755226135
Epoch 2's loss: 0.41145452857017517
Epoch 3's loss: 0.38900166749954224
Epoch 4's loss: 0.37703055143356323


In [None]:
torch.cuda.empty_cache()
#model.load_state_dict(torch.load('model_final.h5'))
model.eval()

In [None]:
# create data loader for training set. It applies same to validation set as well
val_dataset = MRIDataset(data_list['val'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=38, num_workers=num_workers)

In [14]:
def ssim_numpy(gt, pred):
    return compare_ssim(
        gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
    )

In [15]:
ssim_scores = []

with torch.no_grad():
    for iteration, sample in enumerate(val_loader):
    
        img_gt, img_und = sample
    
        output_img = model(img_und.to(device).unsqueeze(1)).cpu().numpy().squeeze()
        ssim_scores.append(ssim_numpy(img_gt.squeeze(1).numpy(), output_img))
            
numpy_ssims = np.array(ssim_scores)
print("Mean:", numpy_ssims.mean())

  """


0.5501071896680699 tensor(0.5507, device='cuda:0')
0.5897546048952125 tensor(0.5954, device='cuda:0')
0.5961351589454073 tensor(0.6018, device='cuda:0')
0.6938788695043581 tensor(0.7011, device='cuda:0')
0.598302807235098 tensor(0.6033, device='cuda:0')
0.5970857088021189 tensor(0.6030, device='cuda:0')
0.5072443215799641 tensor(0.5113, device='cuda:0')
0.5724955347306269 tensor(0.5768, device='cuda:0')
0.6578282980619549 tensor(0.6652, device='cuda:0')
0.5908600671938099 tensor(0.5959, device='cuda:0')
0.542855633543178 tensor(0.5476, device='cuda:0')


In [None]:
with torch.no_grad():
    for iteration, sample in enumerate(val_loader):
    
        img_gt, img_und = sample
    
        output_img = model(img_und.to(device).unsqueeze(1)).cpu().numpy().squeeze()
        show_slices([img_gt.squeeze(1).numpy()[6], output_img[6]], [0, 1], cmap='gray')
        
        if iteration >= 3: break