# Importing required packages

In [None]:
import torch
import torch.nn as nn
import scipy.io
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms
from sklearn.metrics import f1_score
import torch.nn.functional as F
import numpy as np
import scipy
from skimage.transform import radon,iradon,rescale

# Reading mask and CT_data

In [None]:
mask_data = scipy.io.loadmat("infmsk_hw1.mat")
mask = mask_data['infmsk']
ct_data = scipy.io.loadmat("ctscan_hw1.mat")
ct = ct_data['ctscan']
mask = np.transpose(mask, (2,0,1))
ct = np.transpose(ct, (2,0,1))
ct_scans=ct

# Q1 
Unet implementation from the link below:

Link: https://arxiv.org/abs/1505.04597

In [None]:
device = 'cuda'

class UNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=3, depth=3, wf=6, padding=True,batch_norm=True, up_mode='upconv', residual=True):

        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            if i == 0 and residual:
                self.down_path.append(UNetConvBlock(prev_channels, 2 ** (wf + i),
                                                    padding, batch_norm, residual, first=True))
            else:
                self.down_path.append(UNetConvBlock(prev_channels, 2 ** (wf + i),
                                                    padding, batch_norm, residual))
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode,
                                            padding, batch_norm, residual))
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.avg_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        res = self.last(x)
        return self.softmax(res)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm, residual=False, first=False):
        super(UNetConvBlock, self).__init__()
        self.residual = residual
        self.out_size = out_size
        self.in_size = in_size
        self.batch_norm = batch_norm
        self.first = first
        self.residual_input_conv = nn.Conv2d(self.in_size, self.out_size, kernel_size=1)
        self.residual_batchnorm = nn.BatchNorm2d(self.out_size)

        if residual:
            padding = 1
        block = []

        if residual and not first:
            block.append(nn.ReLU())
            if batch_norm:
                block.append(nn.BatchNorm2d(in_size))

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
                               padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
                               padding=int(padding)))

        if not residual:
            block.append(nn.ReLU())
            if batch_norm:
                block.append(nn.BatchNorm2d(out_size))
        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        if self.residual:
            if self.in_size != self.out_size:
                x = self.residual_input_conv(x)
                x = self.residual_batchnorm(x)
            out = out + x

        return out

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm, residual=False):
        super(UNetUpBlock, self).__init__()
        self.residual = residual
        self.in_size = in_size
        self.out_size = out_size
        self.residual_input_conv = nn.Conv2d(self.in_size, self.out_size, kernel_size=1)
        self.residual_batchnorm = nn.BatchNorm2d(self.out_size)

        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
                                         stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
                                    nn.Conv2d(in_size, out_size, kernel_size=1))

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    @staticmethod
    def center_crop(layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out_orig = torch.cat([up, crop1], 1)
        out = self.conv_block(out_orig)
        if self.residual:
            if self.in_size != self.out_size:
                out_orig = self.residual_input_conv(out_orig)
                out_orig = self.residual_batchnorm(out_orig)
            out = out + out_orig

        return out

# DataLoader for test, eval etc

In [None]:
from torch.utils.data import Dataset, DataLoader

class myDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.ct = ct
        self.mask = mask
            
    def __getitem__(self, index):
        return torch.from_numpy(self.ct[index]).unsqueeze(0), torch.from_numpy(self.mask[index]).unsqueeze(0)

    def __len__(self):                
        return self.ct.shape[0]

train_ds = myDataset()

# Test-train-val Split

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler

n = mask.shape[0]
index=np.array(range(0,int(n)))
index=np.random.permutation(n)
a=index[0,int(0.7*n)]
b=index[int(0.7*n),int(0.9*n)]
c=index[int(0.9*n),int(1.0*n)]
train_sampler = SubsetRandomSampler(a)
train_loader = DataLoader(train_ds, batch_size=8, sampler=train_sampler,shuffle=True)

val_sampler = SubsetRandomSampler(b)
val_loader = DataLoader(train_ds, batch_size=1, sampler=val_sampler,shuffle=True)

test_sampler = SubsetRandomSampler(c)
test_loader = DataLoader(train_ds, batch_size=1, sampler=test_sampler,shuffle=True)

# Training model

In [None]:
from tqdm import tqdm
unet = UNet(1, 3)
# unet.cuda(1)
optim = torch.optim.Adam(unet.parameters(), lr=1e-3)
loss_fn2 = nn.CrossEntropyLoss()
epochs = 1
for ep in range(epochs):
    unet.train()
    for x, y in tqdm(train_dl):
#         x = x.cuda(1)
#         y = y.cuda(1)
        optim.zero_grad()
        yt = unet(x.float())
        loss = loss_fn2(yt, y.long().squeeze(1))
        loss.backward()
        optim.step()

    unet.eval()
    with torch.no_grad():
        v_loss = 0
        for x, y in tqdm(val_dl):
#             x = x.cuda(1)
#             y = y.cuda(1)
            yt = unet(x.float())
            loss = loss_fn2(yt, y.long().squeeze(1))
            v_loss += loss.item()
    print(v_loss)

# For calculating metrics

In [None]:
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings('ignore')
def give_metrics(y, y_pred):

  N = y_pred.shape[0]
  total_infection_sensitivity = 0
  total_infection_specificity = 0
  total_infection_accuracy = 0
  total_infection_dice_score = 0
  
  total_normal_sensitivity = 0
  total_normal_specificity = 0
  total_normal_accuracy = 0
  total_normal_dice_score = 0

  count_infection_sensitivity = 0               

  for i in range(y_pred.shape[0]):
    infection_sensitivity = 0
    #Getting confusion matrix
    confusion_metric = (confusion_matrix(y[i].flatten(), y_pred[i].flatten(),labels=[0,1,2])).T

    #Getting TP/FP/FN/TN
    TP_infection = confusion_metric[1][1]
    TP_normal = confusion_metric[2][2]

    TN_infection = confusion_metric[0][0] + confusion_metric[2][0] + confusion_metric[0][2] + confusion_metric[2][2]
    TN_normal = confusion_metric[0][0] + confusion_metric[0][1] + confusion_metric[1][0] + confusion_metric[1][1]
    
    FP_infection = confusion_metric[1][0] + confusion_metric[1][2]
    FP_normal = confusion_metric[2][0] + confusion_metric[2][1] 

    FN_infection = confusion_metric[0][1] + confusion_metric[2][1]
    FN_normal = confusion_metric[0][2] + confusion_metric[1][2]
        
    #calculating metrics    
    infection_specificity = (TN_infection)/(TN_infection+FP_infection)
    infection_accuracy = (TP_infection+TN_infection)/(TP_infection+TN_infection+FP_infection+FN_infection)
    infection_dice_score = (2*TP_infection)/(2*TP_infection + FP_infection + FN_infection)

    normal_sensitivity = (TP_normal)/(TP_normal+FN_normal)
    normal_specificity = (TN_normal)/(TN_normal+FP_normal)
    normal_accuracy = (TP_normal+TN_normal)/(TP_normal+TN_normal+FP_normal+FN_normal)
    normal_dice_score = (2*TP_normal)/(2*TP_normal + FP_normal + FN_normal)
    
    if((TP_infection+FN_infection)!=0):
      count_infection_sensitivity += 1
      infection_sensitivity = (TP_infection)/(TP_infection+FN_infection)
    
    # normal_sensitivity = 0
    # if((normal_TP+infectio_FN)!=0):
    #   count_infection_sensitivity += 1
    #   infection_sensitivity = (infection_TP)/(infection_TP+infection_FN)

    total_infection_sensitivity += infection_sensitivity
    total_normal_sensitivity += normal_sensitivity
    total_infection_specificity += infection_specificity
    total_normal_specificity += normal_specificity
    total_infection_accuracy += infection_accuracy
    total_normal_accuracy += normal_accuracy
    total_infection_dice_score += infection_dice_score
    total_normal_dice_score += normal_dice_score
  #finding avg and printing 
  print(f"Averaged dice score = {total_normal_dice_score/N}, sensitivity = {total_normal_sensitivity/N}, \n specificity = {total_normal_specificity/N}, and accuracy = {total_normal_accuracy/N} for the normal")
  print(f"\nAveraged dice score = {total_infection_dice_score/N}, sensitivity = {total_infection_sensitivity/(count_infection_sensitivity+0.0001)}, \n specificity = {total_infection_specificity/N}, and accuracy = {total_infection_accuracy/N} for the infection")

In [None]:
from sklearn.metrics import multilabel_confusion_matrix

def eval(unet1):
    unet1.eval()

    y_all=[]
    y_pred_all=[]
    for x, y in val_dl:
#         x = x.cuda(1)
        with torch.no_grad():
            y_pred = unet1(x.float())
        y_all.append(y)
        y_pred_all.append(yt)

    y_all=np.array(y_all)
    y_pred_all=np.array(y_pred_all)
    y_all=y_all.squeeze()
    y_pred_all =  yt_pred_all.argmax(dim=1).squeeze(0).detach().cpu().numpy()

    give_metrics(y_all, y_pred_all)
    
    #plotting 
    index_1=5
    index_2=9
    f, plot = plt.subplots(1,3)
    plot[0].set_title("CT Scans")
    plot[0].imshow(ct_scans[index_1], cmap='gray')
    plot[1].set_title("Expert infection mask")
    plot[1].imshow(mask[index_1], cmap='gray')
    plot[2].set_title("Predicted Masks")
    plot[2].imshow(y_pred_all[index_1], cmap='gray')
    plt.show()

eval(unet)

# Q2
Reconstruction limited angle sinogram

In [None]:
def eval_reconstruction(unet1):

    y_4x_all=[]
    mask_4x_all=[]
    y_8x_all=[]
    mask_8x_all=[]
    
    for scan, mask_ in tqdm(val_dl):
        scan = scan.squeeze(0).squeeze(0).cpu().numpy()
        organ_area = (mask_.squeeze().cpu().numpy()>0).astype(int)
        sinogram = radon(scan, circle=False, preserve_range=True)
        
        #Reconstruct the CT scans from limited angle sinograms 4x
        sinogram_4x = np.array([sinogram[:, i] for i in range(0, 180, 4)])
        reconstruction4x = iradon(sinogram_4x.T, circle=False, preserve_range=True)
        with torch.no_grad():
            y_4x = unet1(torch.from_numpy(reconstruction4x*organ_area).unsqueeze(0).unsqueeze(0).cuda(1).float())
        y_4x_all.append(y_4x.argmax(dim=1).squeeze(0).detach().cpu().numpy())
        mask_4x_all.append(mask_)
        
        #Reconstruct the CT scans from limited angle sinograms 4x
        sinogram_8x = np.array([sinogram[:, i] for i in range(0, 180, 8)])
        reconstruction8x = iradon(sinogram_8x.T, circle=False, preserve_range=True)
        with torch.no_grad():
            y_8x = unet1(torch.from_numpy(reconstruction8x*organ_area).unsqueeze(0).unsqueeze(0).cuda(1).float())
        y_8x_all.append(y_8x.argmax(dim=1).squeeze(0).detach().cpu().numpy())
        mask_8x_all.append(mask_)

    psnr_4x = peak_signal_noise_ratio(scan, reconstruction4x)
    ssim_4x = structural_similarity(scan, reconstruction4x)
    give_metrics(np.array(mask_4x), np.array(y_4x_all))
    print(psnr_4x,ssim_4x)
    
    psnr_8x = peak_signal_noise_ratio(scan, reconstruction4x)
    ssim_8x = structural_similarity(scan, reconstruction4x)
    give_metrics(np.array(mask_8x), np.array(y_8x_all))
    print(psnr_8x,ssim_8x)
    
    #PLOTTING THE RECONSTRUCTED SINOGRAMS
    plt.rcParams["figure.figsize"] = (12,12)
    i = 5
    f, axarr = plt.subplots(1,3)
    axarr[0].set_title("4x Reconstruction")
    axarr[0].imshow(reconstructed_ct_scans_4x[i], cmap='gray')
    axarr[1].set_title("8x Reconstruction")
    axarr[1].imshow(reconstructed_ct_scans_8x[i], cmap='gray')
    axarr[2].set_title("CT Scans")
    axarr[2].imshow(ct_scans[i], cmap='gray')
    f.tight_layout()
    plt.show()
    
eval_reconstruction(unet)

# Qn3

In [None]:
import random

x = random.randint(0, 3554)
scan = torch.from_numpy(ct[x]).cuda(1)
mask_ = torch.from_numpy(mask[x]).cuda(1)

for eta in [-0.01, -0.001, 0.001, 0.01]:
    new_model = UNet_model(1, 3)
    new_model.load_state_dict(torch.load('unet_modelxapoo9.756179238078403.pt'))
#     new_model.cuda(1)
    with torch.no_grad():
        for name, param in new_model.named_parameters():
            modified_param = (1+eta)*param
            param.copy_(modified_param)

    eval(new_model)

    # display sample
    with torch.no_grad():
        yt = new_model(scan.unsqueeze(0).unsqueeze(0).float())
    yt = yt.argmax(dim=1).squeeze(0).detach().cpu().numpy()
    
    #plotting 
    index_1=20
    f, plot = plt.subplots(1,3)
    plot[0].set_title("CT Scans")
    plot[0].imshow(ct_scans[index_1], cmap='gray')
    plot[1].set_title("Expert infection mask")
    plot[1].imshow(mask[index_1], cmap='gray')
    plot[2].set_title("Predicted Masks")
    plot[2].imshow(y_pred_all[index_1], cmap='gray')
    plt.show()
