In [1]:
!pip install monai

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
#This is so that the plots are there along with the code in the notebook rather than a popup
%matplotlib inline

import os
from os.path import isfile,join
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import torchvision.transforms as T
from torch.nn.modules import padding
from scipy.ndimage import distance_transform_edt
from monai.metrics import HausdorffDistanceMetric,SurfaceDistanceMetric
import math

In [3]:
# For reproducing results
seed = 58
#random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Getting Data

In [4]:
# GPU | CPU
def get_default_device():
    
    if torch.cuda.is_available():
        return torch.device('cuda:0')
    else:
        return torch.device('cpu')

def to_device(data,device):
    
    if isinstance(data,(list,tuple)):
        return [to_device(x,device) for x in data]
    
    return data.to(device,non_blocking = True)

In [5]:
''' ----------------------------------------------------------
    All the content here is intermediate data used to obtain df
    None of these will be referenced later in the notebook 
    ---------------------------------------------------------- '''

base_path = "../processed/"

file_names = {"image": "img_crp_v2.npy", 
              "esophagus": "/structure/Esophagus_crp_v2.npy",
              "heart": "/structure/Heart_crp_v2.npy",
              "lung_L": "/structure/Lung_L_crp_v2.npy",
              "lung_R": "/structure/Lung_R_crp_v2.npy",
              "spinal_cord": "/structure/SpinalCord_crp_v2.npy"} 

dirs = [(base_path + f + "/") for f in os.listdir(base_path) if not isfile(join(base_path,f))]

data = [[f + file_names["image"],
         f + file_names["esophagus"],
         f + file_names["heart"],
         f + file_names["lung_L"],
         f + file_names["lung_R"],
         f + file_names["spinal_cord"]] for f in dirs]


In [6]:
#df contains all data regarding input data
df = pd.DataFrame(data, columns = ['Image','Esophagus','Heart','Lung_L','Lung_R','SpinalCord'])
df.head()

Unnamed: 0,Image,Esophagus,Heart,Lung_L,Lung_R,SpinalCord
0,processed/LCTSC-Train-S3-005/img_crp_v2.npy,processed/LCTSC-Train-S3-005//structure/Esopha...,processed/LCTSC-Train-S3-005//structure/Heart_...,processed/LCTSC-Train-S3-005//structure/Lung_L...,processed/LCTSC-Train-S3-005//structure/Lung_R...,processed/LCTSC-Train-S3-005//structure/Spinal...
1,processed/LCTSC-Train-S1-004/img_crp_v2.npy,processed/LCTSC-Train-S1-004//structure/Esopha...,processed/LCTSC-Train-S1-004//structure/Heart_...,processed/LCTSC-Train-S1-004//structure/Lung_L...,processed/LCTSC-Train-S1-004//structure/Lung_R...,processed/LCTSC-Train-S1-004//structure/Spinal...
2,processed/LCTSC-Train-S2-004/img_crp_v2.npy,processed/LCTSC-Train-S2-004//structure/Esopha...,processed/LCTSC-Train-S2-004//structure/Heart_...,processed/LCTSC-Train-S2-004//structure/Lung_L...,processed/LCTSC-Train-S2-004//structure/Lung_R...,processed/LCTSC-Train-S2-004//structure/Spinal...
3,processed/LCTSC-Train-S3-008/img_crp_v2.npy,processed/LCTSC-Train-S3-008//structure/Esopha...,processed/LCTSC-Train-S3-008//structure/Heart_...,processed/LCTSC-Train-S3-008//structure/Lung_L...,processed/LCTSC-Train-S3-008//structure/Lung_R...,processed/LCTSC-Train-S3-008//structure/Spinal...
4,processed/LCTSC-Train-S3-012/img_crp_v2.npy,processed/LCTSC-Train-S3-012//structure/Esopha...,processed/LCTSC-Train-S3-012//structure/Heart_...,processed/LCTSC-Train-S3-012//structure/Lung_L...,processed/LCTSC-Train-S3-012//structure/Lung_R...,processed/LCTSC-Train-S3-012//structure/Spinal...


## Datagen

In [7]:
def Datagen_CT(df,img_size,seg_organ,window_size):

    for _, row in df.iterrows():

        image = np.load(row["Image"])
        image = np.moveaxis(image,0,-1)
        image = cv2.resize(image, (img_size,img_size))
        image = np.swapaxes(image,0,-1)

        mask = np.load(row[seg_organ])
        mask = np.moveaxis(mask,0,-1)
        mask = cv2.resize(mask , (img_size,img_size))
        mask = np.swapaxes(mask,0,-1)

        #Padding to fit window_size
        image = np.concatenate((np.zeros((int(window_size/2),img_size,img_size)),image))
        image = np.concatenate((image,np.zeros((int(window_size/2),img_size,img_size))))
        mask = np.concatenate((np.zeros((int(window_size/2),img_size,img_size)),mask))
        mask = np.concatenate((mask,np.zeros((int(window_size/2),img_size,img_size))))

        half_window = window_size//2
        slice_index = half_window + 1 #the index of the next slice to be examined

        while slice_index + window_size + half_window - 1 <= image.shape[0]:
            images = image[slice_index - half_window : slice_index + window_size + half_window,:,:]
            masks = mask[slice_index: slice_index + window_size,:,:]

            slice_index = slice_index + window_size
            images = np.expand_dims(images,axis=1).astype('float32')
            masks = np.expand_dims(masks,axis=1).astype('float32')

            yield images, masks


## Model

In [8]:
class AttentionHead(nn.Module):
    def __init__(self,dim_in,dim_q,dim_k):
        super(AttentionHead,self).__init__()
        self.q = nn.Linear(dim_in,dim_q)
        self.k = nn.Linear(dim_in,dim_k)
        self.v = nn.Linear(dim_in,dim_k)
        self.linear = nn.Linear(dim_k,dim_in)

    def forward(self,query,key,value):
        query = self.q(query)
        key = self.k(key)
        value = self.v(value)

        temp = query.matmul(key.transpose(0,1))
        scale = query.size(-1) ** 0.5 #feature dimension
        softmax = F.softmax(temp/scale, dim=-1)
        attn_mat = softmax.matmul(value)
        return self.linear(attn_mat) 

In [9]:
#Input size is (window_size, 1, 512, 512)
class UNet(nn.Module):

    def __init__(self):
        super(UNet,self).__init__()

        self.conv1 = nn.Conv2d(1,32,3,padding="same")
        self.conv2 = nn.Conv2d(32,32,3,padding="same")

        self.conv3 = nn.Conv2d(32,64,3,padding="same")
        self.conv4 = nn.Conv2d(64,64,3,padding="same")

        self.conv5 = nn.Conv2d(64,128,3,padding="same")
        self.conv6 = nn.Conv2d(128,128,3,padding="same")

        self.conv7 = nn.Conv2d(128,256,3,padding="same")
        self.conv8 = nn.Conv2d(256,256,3,padding="same")

        self.conv9 = nn.Conv2d(256,512,3,padding="same")
        self.conv10 = nn.Conv2d(512,512,3,padding="same")

        self.conv11 = nn.Conv2d(512,1024,3,padding="same")
        self.conv12 = nn.Conv2d(1024,1024,3,padding="same")

        #since last two maxpools are x4, the filter and strides of first two transpose convs are 4 each
        self.trans_conv1 = nn.ConvTranspose2d(1024,512,kernel_size=4,stride=4)
        self.conv_r10 = nn.Conv2d(1024,512,kernel_size=3,padding="same")
        self.conv_r9 = nn.Conv2d(512,512,kernel_size=3,padding="same")

        self.trans_conv2 = nn.ConvTranspose2d(512,256,kernel_size=4,stride=4)
        self.conv_r8 = nn.Conv2d(512,256,kernel_size=3,padding="same")
        self.conv_r7 = nn.Conv2d(256,256,kernel_size=3,padding="same")

        self.trans_conv3 = nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)
        self.conv_r6 = nn.Conv2d(256,128,kernel_size=3,padding="same")
        self.conv_r5 = nn.Conv2d(128,128,kernel_size=3,padding="same")

        self.trans_conv4 = nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
        self.conv_r4 = nn.Conv2d(128,64,kernel_size=3,padding="same")
        self.conv_r3 = nn.Conv2d(64,64,kernel_size=3,padding="same")

        self.trans_conv5 = nn.ConvTranspose2d(64,32,kernel_size=2,stride=2)
        self.conv_r2 = nn.Conv2d(64,32,kernel_size=3,padding="same")
        self.conv_r1 = nn.Conv2d(32,32,kernel_size=3,padding="same")

        self.conv_f = nn.Conv2d(32,1,kernel_size=1)

        self.maxpool = nn.MaxPool2d((2,2))
        self.maxpool_x4 = nn.MaxPool2d((4,4))

        self.attn_head1 = AttentionHead(dim_in=16384,dim_q=8192,dim_k=8192)
        # self.attn_head2 = AttentionHead(dim_in=16384,dim_q=8192,dim_k=8192)
        # self.attn_head3 = AttentionHead(dim_in=16384,dim_q=8192,dim_k=8192)

  
    def attention_layer_calc(self,x_flatten,window_size):

        i = 0
        x_attn = None

        while i + window_size <= x_flatten.shape[0]:
            img_enc = x_flatten[i:i+window_size]
            pos_enc = np.zeros(img_enc.shape)
            emb_length = pos_enc.shape[1]

            #calculating the value of positional embedding
            for ind1 in range(pos_enc.shape[0]):
                for ind2 in range(pos_enc.shape[1]):

                    pos_enc_val = ind1/(10000**(2*ind2/emb_length))
                    if ind1%2 == 0:
                        pos_enc[ind1][ind2] = np.sin(pos_enc_val)
                    else:
                        pos_enc[ind1][ind2] = np.cos(pos_enc_val)
      
            pos_enc = to_device(torch.from_numpy(pos_enc.astype('float32')), get_default_device())
            img_enc = img_enc + pos_enc

            temp1 = self.attn_head1(img_enc,img_enc,img_enc)[window_size//2][None,:]
            # temp2 = self.attn_head2(temp1,temp1,temp1)
            # temp3 = self.attn_head3(temp2,temp2,temp2)[window_size//2][None,:]

            i = i + 1
            if x_attn is None:
                print("TEMP1: ",temp1.shape)
                x_attn = temp1
            else:
                x_attn = torch.cat((x_attn,temp1))
                print("x_attn: ",x_attn.shape)

        x_attn = x_attn.reshape([x_attn.shape[0],1024,4,4])
        return x_attn


    def forward(self,x,window_size): #note: the entire batch is passed in 

        x1 = self.conv1(x)
        x1 = F.relu(x1)
        x1 = self.conv2(x1)
        x1 = F.relu(x1)

        x2 = self.maxpool(x1)
        x2 = self.conv3(x2)
        x2 = F.relu(x2)
        x2 = self.conv4(x2)
        x2 = F.relu(x2)

        x3 = self.maxpool(x2)
        x3 = self.conv5(x3)
        x3 = F.relu(x3)
        x3 = self.conv6(x3)
        x3 = F.relu(x3)

        x4 = self.maxpool(x3)
        x4 = self.conv7(x4)
        x4 = F.relu(x4)
        x4 = self.conv8(x4)
        x4 = F.relu(x4)

        x5 = self.maxpool_x4(x4)
        x5 = self.conv9(x5)
        x5 = F.relu(x5)
        x5 = self.conv10(x5)
        x5 = F.relu(x5)

        x6 = self.maxpool_x4(x5)
        x6 = self.conv11(x6)
        x6 = F.relu(x6)
        x6 = self.conv12(x6)
        x6 = F.relu(x6) #Size: 4x4x1024

        #Attention
        x_flatten = torch.flatten(x6,start_dim=1)
        x_attn = self.attention_layer_calc(x_flatten,window_size)

        #expansion
        xi_5 = self.trans_conv1(x_attn)
        xi_5 = torch.cat((xi_5,x5[(window_size//2):-(window_size//2)]),dim=1) #concatenating to the channel dimension
        xi_5 = self.conv_r10(xi_5)
        xi_5 = F.relu(xi_5)
        xi_5 = self.conv_r9(xi_5)
        xi_5 = F.relu(xi_5)

        xi_4 = self.trans_conv2(xi_5)
        xi_4 = torch.cat((xi_4,x4[(window_size//2):-(window_size//2)]),dim=1) 
        xi_4 = self.conv_r8(xi_4)
        xi_4 = F.relu(xi_4)
        xi_4 = self.conv_r7(xi_4)
        xi_4 = F.relu(xi_4)

        xi_3 = self.trans_conv3(xi_4)
        xi_3 = torch.cat((xi_3,x3[(window_size//2):-(window_size//2)]), dim=1)
        xi_3 = self.conv_r6(xi_3)
        xi_3 = F.relu(xi_3)
        xi_3 = self.conv_r5(xi_3)
        xi_3 = F.relu(xi_3)

        xi_2 = self.trans_conv4(xi_3)
        xi_2 = torch.cat((xi_2,x2[(window_size//2):-(window_size//2)]), dim=1)
        xi_2 = self.conv_r4(xi_2)
        xi_2 = F.relu(xi_2)
        xi_2 = self.conv_r3(xi_2)
        xi_2 = F.relu(xi_2)

        xi_1 = self.trans_conv5(xi_2)
        xi_1 = torch.cat((xi_1,x1[(window_size//2):-(window_size//2)]), dim=1)
        xi_1 = self.conv_r2(xi_1)
        xi_1 = F.relu(xi_1)
        xi_1 = self.conv_r1(xi_1)
        xi_1 = F.relu(xi_1)

        xf = self.conv_f(xi_1)
        xf = torch.sigmoid(xf)


        return xf

## Loss Function

In [10]:
def dice_loss(y_pred,y_true,smooth=1):
    
    print("YPRED: ",y_pred.shape)
    print("YTRUE: ",y_true.shape)
    
    y_true_f = torch.flatten(y_true,start_dim=1)
    y_pred_f = torch.flatten(y_pred,start_dim=1)

    intersection = torch.sum(y_true_f * y_pred_f,1)
    result = torch.log((2. * intersection + smooth) / (torch.sum(y_true_f,1) + torch.sum(y_pred_f,1) + smooth))
    dice_loss = -1 * torch.sum(result)/y_pred.shape[0]
    return  dice_loss

## Eval Metrics

In [11]:
#IOU [Between 0-1: Higher value => Better results]
def iou_metric(y_pred,y_true,smooth=1):
    y_true_f = torch.flatten(y_true,start_dim=1)
    y_pred_f = torch.flatten(y_pred,start_dim=1)

    intersection = torch.sum(y_true_f * y_pred_f,1)
    union = torch.sum(y_true_f,1) + torch.sum(y_pred_f,1) - torch.sum(y_true_f * y_pred_f,1)
    iou_score = torch.sum((intersection + smooth)/(union + smooth)) / y_pred.shape[0]
    return  iou_score

#Dice [Between 0-1: Higher value => Better results]
def dice_metric(y_pred,y_true,smooth=1):
    y_true_f = torch.flatten(y_true,start_dim=1)
    y_pred_f = torch.flatten(y_pred,start_dim=1)

    intersection = torch.sum(y_true_f * y_pred_f,1)
    result = torch.sum((2. * intersection + smooth) / (torch.sum(y_true_f,1) + torch.sum(y_pred_f,1) + smooth))
    dice_score = result/y_pred.shape[0]
    return  dice_score

#MSD [Between 0-infinity: Smaller value => Better results]
def msd_metric(y_pred,y_true):
    y_pred[y_pred >= 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    msd = SurfaceDistanceMetric(include_background=True)
    val = torch.sum(msd(y_pred,y_true))

    if val.isnan():
    return to_device(torch.tensor(0),get_default_device())

    return val

#HD95 [Between 0-infinity: Smaller value => Better results]
def hd95_metric(y_pred,y_true):
    y_pred[y_pred >= 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    hd = HausdorffDistanceMetric(percentile=95,reduction="none",include_background=False)
    val = torch.sum(hd(y_pred,y_true))

    if val.isnan():
        return to_device(torch.tensor(0),get_default_device())

    return val

#Precision [Between 0-1: Larger value => Better results]
def precision_metric(y_pred,y_true):
    y_true = torch.flatten(y_true,start_dim=1)
    y_pred = torch.flatten(y_pred,start_dim=1)

    y_pred[y_pred > 0] = 1
    y_pred[y_pred <= 0] = 0

    tp = torch.sum(y_pred * y_true,1)
    fp = torch.sum(y_pred * (1 - y_true),1)

    if 0 in tp+fp:
    return to_device(torch.tensor(0),get_default_device())

    precision = torch.sum(tp/(tp+fp))/y_pred.shape[0]
    return precision

#Recall [Between 0-1: Larger value => Better results]
def recall_metric(y_pred,y_true):
    y_true = torch.flatten(y_true,start_dim=1)
    y_pred = torch.flatten(y_pred,start_dim=1)

    y_pred[y_pred > 0] = 1
    y_pred[y_pred <= 0] = 0

    tp = torch.sum(y_pred * y_true,1)
    fn = torch.sum((1-y_pred) * y_true,1)

    if 0 in tp+fn:
    return to_device(torch.tensor(0),get_default_device())

    recall = torch.sum(tp/(tp+fn))/y_pred.shape[0]
    return recall

## Eval Function

In [12]:
def eval(model,test_df,img_size,seg_organ,window_size,epoch=0,epochs=0):
    #epoch and epoch size is passed if this function is called during training in between epochs
    model.eval()
    test_dataloader = Datagen_CT(df=test_df,img_size=img_size,seg_organ=seg_organ,window_size=window_size)

    iou_val = 0
    dice_val = 0
    msd_val = 0
    hd95_val = 0
    pre_val = 0
    re_val = 0

    count = 0
    for batch_idx, (X,y) in enumerate(test_dataloader):
        X = to_device(torch.tensor(X, requires_grad=True),get_default_device())
        y = to_device(torch.tensor(y, requires_grad=True),get_default_device())

        output = model(X,window_size)
        iou_val += iou_metric(output,y).item()
        dice_val += dice_metric(output,y).item()
        # msd_val += msd_metric(output,y).item()
        # hd95_val += hd95_metric(output,y).item()
        pre_val += precision_metric(output,y).item()
        re_val += recall_metric(output,y).item()

        count = count + 1
        break #<<<<<-------------REMOVE THIS

    print('-------------------------')
    print('Epoch [{}/{}]   Dice Metric: {:.4f}  IoU Metric: {:.4f}  \n  Precision Metric: {:.4f}  Recall Metric: {:.4f}'.format(epoch+1,epochs,dice_val/count,iou_val/count,pre_val/count,re_val/count))
    print('\n')

## Fit Function

In [13]:
def fit(model,img_size,seg_organ,window_size,loss_fn,optimizer,scheduler,epochs,train_df,test_df,validate=False,print_every=1):


    for epoch in range(epochs):

        model.train()
        train_dataloader = Datagen_CT(df=train_df,img_size=img_size,seg_organ=seg_organ,window_size=window_size)
    
        for batch_idx, (X,y) in enumerate(train_dataloader):
            X = to_device(torch.tensor(X, requires_grad=True),get_default_device())
            y = to_device(torch.tensor(y, requires_grad=True),get_default_device())

            optimizer.zero_grad()
            output = model(X,window_size)
            loss = loss_fn(output,y)
            loss.backward()
            optimizer.step()

            if batch_idx % print_every == 0:
                print('Epoch [{}/{}]   Batch {}   Loss: {:.4f}    Max: {:.4f}'.format(epoch+1,epochs,batch_idx,loss.item(),torch.max(output)))

        

                plt.subplot(1,3,1)
                plt.imshow(X.cpu().detach().numpy()[window_size//2][0])

                plt.subplot(1,3,2)
                plt.imshow(y.cpu().detach().numpy()[0][0])

                plt.subplot(1,3,3)
                plt.imshow(output.cpu().detach().numpy()[0][0])

                plt.show()
            
            break #<<<<<-------------REMOVE THIS


        print('Epoch [{}/{}]   Batch {}   Loss: {:.4f}    lr: {:.10f}'.format(epoch+1,epochs,batch_idx,loss.item(),optimizer.param_groups[0]['lr']))
        scheduler.step()

        if validate:
            eval(model=model,test_df=test_df,img_size=img_size,seg_organ=seg_organ,window_size=window_size,epoch=epoch,epochs=epochs)


## Training

In [14]:
#Parameters
img_size = 512
epochs = 10
seg_organ = 'Lung_L'
train_test_split = 0.8
window_size = 5

#Model Specific
train_df = df[:int(train_test_split*df.shape[0])].reset_index(drop=True)
test_df = df[int(train_test_split*df.shape[0]):].reset_index(drop=True)

model = UNet()
model = to_device(model,get_default_device())
loss_fn = dice_loss
optimizer = torch.optim.Adam(model.parameters(),lr=0.000005)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,6,10], gamma=0.1)

In [None]:
fit(model=model, img_size=img_size, seg_organ=seg_organ, window_size=window_size,
    loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
    epochs=epochs, train_df=train_df, test_df=test_df, 
    validate=True, print_every=1)

In [None]:
torch.save(model.state_dict(), 'lung_l')