In [1]:
! pip install transformers

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 1.1 MB/s eta 0:00:01
Collecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.4-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 329 kB/s eta 0:00:01
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 13.4 MB/s eta 0:00:01
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.4.0 tokenizers-0.11.4 transformers-4.16.2


In [2]:
#This is so that the plots are there along with the code in the notebook rather than a popup
%matplotlib inline

import os
from os.path import isfile,join
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import torchvision.transforms as T
from torch.nn.modules import padding
import math
import sys
from transformers import ViTModel, ViTConfig, ViTFeatureExtractor

## Starting Stuff

In [3]:
# For reproducing results
seed = 48
#random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
# GPU | CPU
def get_default_device():
    
    if torch.cuda.is_available():
        return torch.device('cuda:0')
    else:
        return torch.device('cpu')

def to_device(data,device):
    
    if isinstance(data,(list,tuple)):
        return [to_device(x,device) for x in data]
    
    return data.to(device,non_blocking = True)

## Getting Data (Verified)

In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

In [6]:
''' ----------------------------------------------------------
    All the content here is intermediate data used to obtain df
    None of these will be referenced later in the notebook 
    ---------------------------------------------------------- '''

base_path = "../processed/"
#"../processed/"
#"/content/drive/MyDrive/AI Club Project - Segmentation/processed/"

file_names = {"image": "img_crp_v2.npy", 
              "esophagus": "structure/Esophagus_crp_v2.npy",
              "heart": "structure/Heart_crp_v2.npy",
              "lung_L": "structure/Lung_L_crp_v2.npy",
              "lung_R": "structure/Lung_R_crp_v2.npy",
              "spinal_cord": "structure/SpinalCord_crp_v2.npy"} 

train_dirs = [(base_path + f + "/") for f in os.listdir(base_path) if not isfile(join(base_path,f)) and 'Test' not in f]

val_dirs = [(base_path + f + "/") for f in ['LCTSC-Test-S1-104','LCTSC-Test-S1-101','LCTSC-Test-S2-103',
             'LCTSC-Test-S1-102','LCTSC-Test-S3-104','LCTSC-Test-S3-103',
             'LCTSC-Test-S1-103','LCTSC-Test-S2-102','LCTSC-Test-S3-101',
             'LCTSC-Test-S3-102','LCTSC-Test-S2-104','LCTSC-Test-S2-101']]

test_dirs = [(base_path + f + "/") for f in ['LCTSC-Test-S1-204','LCTSC-Test-S1-202','LCTSC-Test-S3-202',
             'LCTSC-Test-S2-204','LCTSC-Test-S2-202','LCTSC-Test-S2-201',
             'LCTSC-Test-S2-203','LCTSC-Test-S1-203','LCTSC-Test-S3-201',
             'LCTSC-Test-S3-203','LCTSC-Test-S3-204','LCTSC-Test-S1-201']]


data_train = [[f + file_names["image"],
         f + file_names["esophagus"],
         f + file_names["heart"],
         f + file_names["lung_L"],
         f + file_names["lung_R"],
         f + file_names["spinal_cord"]] for f in train_dirs]

data_val = [[f + file_names["image"],
         f + file_names["esophagus"],
         f + file_names["heart"],
         f + file_names["lung_L"],
         f + file_names["lung_R"],
         f + file_names["spinal_cord"]] for f in val_dirs]

# for arr in data_val:
#     data_train.append(arr)

data_test = [[f + file_names["image"],
         f + file_names["esophagus"],
         f + file_names["heart"],
         f + file_names["lung_L"],
         f + file_names["lung_R"],
         f + file_names["spinal_cord"]] for f in test_dirs]


In [7]:
#df contains all data regarding input data
df_train = pd.DataFrame(data_train, columns = ['Image','Esophagus','Heart','Lung_L','Lung_R','SpinalCord'])
df_val = pd.DataFrame(data_val, columns = ['Image','Esophagus','Heart','Lung_L','Lung_R','SpinalCord'])
df_test = pd.DataFrame(data_test, columns = ['Image','Esophagus','Heart','Lung_L','Lung_R','SpinalCord'])
df_train.head()

Unnamed: 0,Image,Esophagus,Heart,Lung_L,Lung_R,SpinalCord
0,../processed/LCTSC-Train-S3-005/img_crp_v2.npy,../processed/LCTSC-Train-S3-005/structure/Esop...,../processed/LCTSC-Train-S3-005/structure/Hear...,../processed/LCTSC-Train-S3-005/structure/Lung...,../processed/LCTSC-Train-S3-005/structure/Lung...,../processed/LCTSC-Train-S3-005/structure/Spin...
1,../processed/LCTSC-Train-S1-004/img_crp_v2.npy,../processed/LCTSC-Train-S1-004/structure/Esop...,../processed/LCTSC-Train-S1-004/structure/Hear...,../processed/LCTSC-Train-S1-004/structure/Lung...,../processed/LCTSC-Train-S1-004/structure/Lung...,../processed/LCTSC-Train-S1-004/structure/Spin...
2,../processed/LCTSC-Train-S2-004/img_crp_v2.npy,../processed/LCTSC-Train-S2-004/structure/Esop...,../processed/LCTSC-Train-S2-004/structure/Hear...,../processed/LCTSC-Train-S2-004/structure/Lung...,../processed/LCTSC-Train-S2-004/structure/Lung...,../processed/LCTSC-Train-S2-004/structure/Spin...
3,../processed/LCTSC-Train-S3-008/img_crp_v2.npy,../processed/LCTSC-Train-S3-008/structure/Esop...,../processed/LCTSC-Train-S3-008/structure/Hear...,../processed/LCTSC-Train-S3-008/structure/Lung...,../processed/LCTSC-Train-S3-008/structure/Lung...,../processed/LCTSC-Train-S3-008/structure/Spin...
4,../processed/LCTSC-Train-S3-012/img_crp_v2.npy,../processed/LCTSC-Train-S3-012/structure/Esop...,../processed/LCTSC-Train-S3-012/structure/Hear...,../processed/LCTSC-Train-S3-012/structure/Lung...,../processed/LCTSC-Train-S3-012/structure/Lung...,../processed/LCTSC-Train-S3-012/structure/Spin...


## Datagen

In [8]:
def Datagen_CT(df,img_size,seg_organ,batch_size):

    for _, row in df.iterrows():

        image = np.load(row["Image"])
        image = np.moveaxis(image,0,-1)
        image = cv2.resize(image, (img_size,img_size))
        image = np.swapaxes(image,0,-1)

        mask = np.load(row[seg_organ])
        mask = np.moveaxis(mask,0,-1)
        mask = cv2.resize(mask , (img_size,img_size))
        mask = np.swapaxes(mask,0,-1)
        
        slice_index = 0 #the index of the next slice to be examined

        while slice_index < image.shape[0]:#we put mask here because image has a longer size to adjust for the last slice
            images = image[slice_index: slice_index + batch_size,:,:]
            masks = mask[slice_index: slice_index + batch_size,:,:]

            slice_index = slice_index + batch_size
            images = np.expand_dims(images,axis=1).astype('float32')
            masks = np.expand_dims(masks,axis=1).astype('float32')

            yield images, masks


## Model

In [27]:
class ViT_Encoder(nn.Module):
    
    def __init__(self, init_weight_path=None, image_size=224, patch_size=16, num_channels=3, num_attention_heads = 12, num_hidden_layers=12, hidden_size=768):
        super(ViT_Encoder,self).__init__()

        
        
        self.config = ViTConfig(hidden_size = hidden_size, num_hidden_layers = num_hidden_layers, 
                                num_attention_heads = num_attention_heads, image_size = image_size, 
                                patch_size = patch_size, num_channels = num_channels)

        self.model = ViTModel(self.config)
        self.model.load_state_dict(torch.load(init_weight_path))
        
        self.linear = [nn.Linear(hidden_size,patch_size*patch_size) for _ in range(0,12)]
        self.output_size = image_size//patch_size
  
    def forward(self,x):
        x = self.model(pixel_values=x,output_hidden_states=True,output_attentions = True)
        attn_out = [self.linear[i](x.hidden_states[i][:,1:,:]).permute(0,2,1) for i in range(0,12)]
        attn_out = [out.reshape((out.shape[0],out.shape[1],self.output_size,self.output_size)) for out in attn_out]   
        return attn_out

In [28]:
model = ViT_Encoder(init_weight_path='init_weight/ViT_original_pretrained')

In [29]:
a = torch.randn(5,3,224,224)
out = model(a)

In [31]:
out[0].shape

torch.Size([5, 256, 14, 14])

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self,dim_i,dim_o,up_size):
        
        self.trans_conv1 = nn.ConvTranspose2d(dim_i,dim_o,kernel_size=up_size,stride=up_size)
        self.batchnorm1 = nn.BatchNorm2d(dim_o)
        
        self.trans_conv2 = nn.ConvTranspose2d(dim_i,dim_o,kernel_size=up_size,stride=up_size)
        self.batchnorm2 = nn.BatchNorm2d(dim_o)
    
    def forward(self,x1,x2,img):
        
        x1 = self.trans_conv1(x1)
        x1 = self.batchnorm1(x1)
        x1 = F.leaky_relu(x1)
        
        feat_map = img * x1
        
    
    

In [10]:
class DecoderBlock(nn.Module):

    def __init__(self,dim_i,dim_o,up_size):
        super(DecoderBlock,self).__init__()

        self.trans_conv = nn.ConvTranspose2d(dim_i,dim_o,kernel_size=up_size,stride=up_size)
        self.conv_2 = nn.Conv2d(dim_i,dim_o,kernel_size=3,padding="same")
        self.conv_3 = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")

        self.batchnorm_2 = nn.BatchNorm2d(dim_o)
        self.batchnorm_3 = nn.BatchNorm2d(dim_o)
  
    def forward(self,x,x_skip):
        x1 = self.trans_conv(x)
        x1 = torch.cat([x1,x_skip],dim=1)

        x2 = self.conv_2(x1)
        x2 = self.batchnorm_2(x2)
        x2 = F.leaky_relu(x2)

        x3 = self.conv_3(x2)
        x3 = self.batchnorm_3(x3)
        x3 = F.leaky_relu(x3)
        return x3
    
class SC1(nn.Module):

    def __init__(self,dim_i,dim_o):
        super(SC1,self).__init__()

        self.trans_conv1 = nn.ConvTranspose2d(dim_i,dim_i//2,kernel_size=4,stride=4)
        self.conv_2 = nn.Conv2d(dim_i//2,dim_i//2,kernel_size=5,padding="same")
        
        self.trans_conv2 = nn.ConvTranspose2d(dim_i//2,dim_o,kernel_size=4,stride=4)
        self.conv_3 = nn.Conv2d(dim_o,dim_o,kernel_size=5,padding="same")

        self.batchnorm_2 = nn.BatchNorm2d(dim_i//2)
        self.batchnorm_3 = nn.BatchNorm2d(dim_o)
        
        self.trans_conv_dec = nn.ConvTranspose2d(dim_o*2,dim_o,kernel_size=2,stride=2)
        self.batchnorm_dec = nn.BatchNorm2d(dim_o)
        
        self.convf = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")
        self.batchnorm_f = nn.BatchNorm2d(dim_o)
  
    def forward(self,x,x_dec):
        x1 = self.trans_conv1(x)

        x2 = self.conv_2(x1)
        x2 = self.batchnorm_2(x2)
        x2 = F.leaky_relu(x2)
        
        x2 = self.trans_conv2(x2)

        x3 = self.conv_3(x2)
        x3 = self.batchnorm_3(x3)
        x3 = F.leaky_relu(x3)
        
        x_dec = self.trans_conv_dec(x_dec)
        x_dec = self.batchnorm_dec(x_dec)
        x_dec = F.leaky_relu(x_dec)
        
        xf = x3 * x_dec
        xf = self.convf(xf)
        xf = self.batchnorm_f(xf)
        xf = F.leaky_relu(xf)
        return xf
    
class SC2(nn.Module):

    def __init__(self,dim_i,dim_o):
        super(SC2,self).__init__()

        self.trans_conv1 = nn.ConvTranspose2d(dim_i,dim_i//2,kernel_size=4,stride=4)
        self.conv_2 = nn.Conv2d(dim_i//2,dim_i//2,kernel_size=5,padding="same")
        
        self.trans_conv2 = nn.ConvTranspose2d(dim_i//2,dim_o,kernel_size=2,stride=2)
        self.conv_3 = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")

        self.batchnorm_2 = nn.BatchNorm2d(dim_i//2)
        self.batchnorm_3 = nn.BatchNorm2d(dim_o)
        
        self.trans_conv_dec = nn.ConvTranspose2d(dim_o*2,dim_o,kernel_size=2,stride=2)
        self.batchnorm_dec = nn.BatchNorm2d(dim_o)
  
        self.convf = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")
        self.batchnorm_f = nn.BatchNorm2d(dim_o)
  
    def forward(self,x,x_dec):
        x1 = self.trans_conv1(x)

        x2 = self.conv_2(x1)
        x2 = self.batchnorm_2(x2)
        x2 = F.leaky_relu(x2)
        
        x2 = self.trans_conv2(x2)

        x3 = self.conv_3(x2)
        x3 = self.batchnorm_3(x3)
        x3 = F.leaky_relu(x3)
        
        x_dec = self.trans_conv_dec(x_dec)
        x_dec = self.batchnorm_dec(x_dec)
        x_dec = F.leaky_relu(x_dec)
        
        xf = x3 * x_dec
        xf = self.convf(xf)
        xf = self.batchnorm_f(xf)
        xf = F.leaky_relu(xf)
        return xf

class SC3(nn.Module):

    def __init__(self,dim_i,dim_o):
        super(SC3,self).__init__()

        self.trans_conv1 = nn.ConvTranspose2d(dim_i,dim_i//2,kernel_size=2,stride=2)
        self.conv_2 = nn.Conv2d(dim_i//2,dim_i//2,kernel_size=3,padding="same")
        
        self.trans_conv2 = nn.ConvTranspose2d(dim_i//2,dim_o,kernel_size=2,stride=2)
        self.conv_3 = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")

        self.batchnorm_2 = nn.BatchNorm2d(dim_i//2)
        self.batchnorm_3 = nn.BatchNorm2d(dim_o)
        
        self.trans_conv_dec = nn.ConvTranspose2d(dim_o*2,dim_o,kernel_size=2,stride=2)
        self.batchnorm_dec = nn.BatchNorm2d(dim_o)
        
        self.convf = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")
        self.batchnorm_f = nn.BatchNorm2d(dim_o)
  
    def forward(self,x,x_dec):
        x1 = self.trans_conv1(x)

        x2 = self.conv_2(x1)
        x2 = self.batchnorm_2(x2)
        x2 = F.leaky_relu(x2)
        
        x2 = self.trans_conv2(x2)

        x3 = self.conv_3(x2)
        x3 = self.batchnorm_3(x3)
        x3 = F.leaky_relu(x3)
        
        x_dec = self.trans_conv_dec(x_dec)
        x_dec = self.batchnorm_dec(x_dec)
        x_dec = F.leaky_relu(x_dec)
        
        xf = x3 * x_dec
        xf = self.convf(xf)
        xf = self.batchnorm_f(xf)
        xf = F.leaky_relu(xf)
        return xf
    
class SC4(nn.Module):

    def __init__(self,dim_i,dim_o):
        super(SC4,self).__init__()

        self.trans_conv1 = nn.ConvTranspose2d(dim_i,dim_o,kernel_size=2,stride=2)
        self.conv_2 = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")
        self.conv_3 = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")

        self.batchnorm_2 = nn.BatchNorm2d(dim_o)
        self.batchnorm_3 = nn.BatchNorm2d(dim_o)
        
        self.trans_conv_dec = nn.ConvTranspose2d(dim_o*2,dim_o,kernel_size=2,stride=2)
        self.batchnorm_dec = nn.BatchNorm2d(dim_o)
        
        self.convf = nn.Conv2d(dim_o,dim_o,kernel_size=3,padding="same")
        self.batchnorm_f = nn.BatchNorm2d(dim_o)
  
    def forward(self,x,x_dec):
        x1 = self.trans_conv1(x)

        x2 = self.conv_2(x1)
        x2 = self.batchnorm_2(x2)
        x2 = F.leaky_relu(x2)

        x3 = self.conv_3(x2)
        x3 = self.batchnorm_3(x3)
        x3 = F.leaky_relu(x3)
        
        x_dec = self.trans_conv_dec(x_dec)
        x_dec = self.batchnorm_dec(x_dec)
        x_dec = F.leaky_relu(x_dec)
        
        xf = x3 * x_dec
        xf = self.convf(xf)
        xf = self.batchnorm_f(xf)
        xf = F.leaky_relu(xf)
        return xf

class Decoder(nn.Module):
    
    def __init__(self):
        super(Decoder,self).__init__()

        #since last two maxpools are x4, the filter and strides of first two transpose convs are 4 each
        self.dec1 = DecoderBlock(256,128,2) #14->28
        self.dec2 = DecoderBlock(128,64,2) #28 -> 56
        self.dec3 = DecoderBlock(64,32,2) #56 -> 112
        self.dec4 = DecoderBlock(32,16,2) #112 -> 224
        
        self.sc1 = SC1(256,16)
        self.sc2 = SC2(256,32)
        self.sc3 = SC3(256,64)
        self.sc4 = SC4(256,128)
  
    def forward(self,x,x4,x3,x2,x1): 
        x4 = self.sc4(x4,x)
        xi = self.dec1(x,x4)
        
        x3 = self.sc3(x3,xi)
        xi = self.dec2(xi,x3)
        
        x2 = self.sc2(x2,xi)
        xi = self.dec3(xi,x2)
        
        x1 = self.sc1(x1,xi)
        xi = self.dec4(xi,x1)
        return xi

In [11]:
class SegModel(nn.Module):

    def __init__(self,encoder,decoder):
        super(SegModel,self).__init__()

        self.encoder = encoder
        self.decoder = decoder

        self.conv1 = nn.Conv2d(256,256,kernel_size=3,stride=1,padding="same")
        self.conv2 = nn.Conv2d(256,256,kernel_size=3,stride=1,padding="same")

        self.batchnorm_1 = nn.BatchNorm2d(256)
        self.batchnorm_2 = nn.BatchNorm2d(256)

        self.convf = nn.Conv2d(16,1,kernel_size=1,stride=1)

    def forward(self,x):

        x1,x2,x3,xf = self.encoder(x)

        xm = self.conv1(xf)
        xm = self.batchnorm_1(xm)
        xm = F.leaky_relu(xm)

        xm = self.conv2(xm)
        xm = self.batchnorm_2(xm)
        xm = F.leaky_relu(xm)

        x_dec = self.decoder(xm,xf,x3,x2,x1)

        xf = self.convf(x_dec)
        xf = torch.sigmoid(xf)
        return xf

## Loss Function

In [12]:
def dice_loss(y_pred,y_true,smooth=1):
    y_true_f = torch.flatten(y_true,start_dim=1)
    y_pred_f = torch.flatten(y_pred,start_dim=1)
    
    intersection = torch.sum(y_true_f * y_pred_f,1)
    result = (2. * intersection + smooth) / (torch.sum(y_true_f,1) + torch.sum(y_pred_f,1) + smooth)
    dice_loss = torch.sum(1-result)/y_pred.shape[0]
    return  dice_loss

## Eval Metrics

In [13]:
#IOU [Between 0-1: Higher value => Better results]
def iou_metric(y_pred,y_true):
    y_true_f = torch.flatten(y_true,start_dim=1).detach().cpu()
    y_pred_f = torch.flatten(y_pred,start_dim=1).detach().cpu()
    
    y_pred_f[y_pred_f >= 0.5] = 1
    y_pred_f[y_pred_f < 0.5] = 0

    intersection = torch.sum(y_true_f * y_pred_f,1)
    union = torch.sum(y_true_f,1) + torch.sum(y_pred_f,1) - torch.sum(y_true_f * y_pred_f,1)
    
    iou_score = 0
    for i in range(union.shape[0]):
        if union[i] == 0:
            iou_score += 1
        else:
            iou_score += intersection[i].item()/union[i].item()
    iou_score = iou_score/y_pred.shape[0]
    return iou_score

#Dice [Between 0-1: Higher value => Better results]
def dice_metric(y_pred,y_true):
    y_true_f = torch.flatten(y_true,start_dim=1).detach().cpu()
    y_pred_f = torch.flatten(y_pred,start_dim=1).detach().cpu()
    
    y_pred_f[y_pred_f >= 0.5] = 1
    y_pred_f[y_pred_f < 0.5] = 0

    intersection = torch.sum(y_true_f * y_pred_f,1)
    area_sum = torch.sum(y_true_f,1) + torch.sum(y_pred_f,1)
    
    dice_score = 0
    for i in range(area_sum.shape[0]):
        if area_sum[i] == 0:
            dice_score += 1
        else:
            dice_score += 2. * intersection[i].item() / area_sum[i].item()
    dice_score = dice_score/y_pred.shape[0]
    return  dice_score

#Precision [Between 0-1: Larger value => Better results]
def precision_metric(y_pred,y_true):
    y_true = torch.flatten(y_true,start_dim=1).detach().cpu()
    y_pred = torch.flatten(y_pred,start_dim=1).detach().cpu()

    y_pred[y_pred >= 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    tp = torch.sum(y_pred * y_true,1)
    denom = torch.sum(y_pred,1)
    true_sum = torch.sum(y_true,1)
    
    pre = 0
    for i in range(denom.shape[0]):
        if denom[i] == 0 and true_sum[i] == 0:
            pre += 1
        elif denom[i] == 0:
            continue
        else:
            pre += tp[i].item()/denom[i].item()
    pre_score = pre/y_pred.shape[0]
    return pre_score

#Recall [Between 0-1: Larger value => Better results]
def recall_metric(y_pred,y_true):
    y_true = torch.flatten(y_true,start_dim=1).detach().cpu()
    y_pred = torch.flatten(y_pred,start_dim=1).detach().cpu()

    y_pred[y_pred >= 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    tp = torch.sum(y_pred * y_true,1)
    denom = torch.sum(y_true,1)
    
    re = 0
    for i in range(denom.shape[0]):
        if denom[i] == 0:
            re += 1
        else:
            re += tp[i].item()/denom[i].item()
    re_score = re/y_pred.shape[0]
    return re_score

## Eval Function

In [14]:
def eval(model,val_df,img_size,seg_organ,batch_size,epoch=0,epochs=0):
    #epoch and epoch size is passed if this function is called during training in between epochs
    model.eval()
    test_dataloader = Datagen_CT(df=val_df,img_size=img_size,seg_organ=seg_organ,batch_size=batch_size)

    iou_val = 0
    dice_val = 0
    pre_val = 0
    re_val = 0

    count = 0
    for batch_idx, (X,y) in enumerate(test_dataloader):
        X = to_device(torch.tensor(np.repeat(X,repeats=3,axis=1), requires_grad=True),get_default_device())
        y = to_device(torch.tensor(y, requires_grad=True),get_default_device())

        output = model(X)
        iou_val += iou_metric(output,y)
        dice_val += dice_metric(output,y)
        pre_val += precision_metric(output,y)
        re_val += recall_metric(output,y)

        count = count + 1
        
    dice_score = dice_val/count
    iou_score = iou_val/count
    pre_score = pre_val/count
    re_score = re_val/count
    
    
    
    print('End of Epoch [{}/{}]   Dice Metric: {:.4f}  IoU Metric: {:.4f}  Precision Metric: {:.4f}  Recall Metric: {:.4f}'.format(epoch+1,epochs,dice_score,iou_score,pre_score,re_score))

## Fit Function

In [15]:
def fit(model,img_size,seg_organ,batch_size,loss_fn,optimizer,scheduler,epochs,train_df,val_df,validate=False,print_every=1,print_epoch=1):


    for epoch in range(epochs):

        model.train()
        train_dataloader = Datagen_CT(df=train_df,img_size=img_size,seg_organ=seg_organ,batch_size=batch_size)
        
        total_loss = 0
        max_out = 0
        for batch_idx, (X,y) in enumerate(train_dataloader):
            X = to_device(torch.tensor(np.repeat(X,repeats=3,axis=1), requires_grad=True),get_default_device())
            y = to_device(torch.tensor(y, requires_grad=True),get_default_device())

            optimizer.zero_grad()
            output = model(X)
            loss = loss_fn(output,y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            max_out = max(max_out,torch.max(output))

            if batch_idx != 0 and batch_idx % print_every == 0:
                print('Epoch [{}/{}]   Batch {}   Loss: {:.4f}    Max: {:.4f}'.format(epoch+1,epochs,batch_idx,total_loss/print_every,max_out))
                total_loss = 0
                max_out = 0

                if epoch % print_epoch == 0 :
                    plt.subplot(1,3,1)
                    plt.imshow(X.cpu().detach().numpy()[0][0])

                    plt.subplot(1,3,2)
                    plt.imshow(y.cpu().detach().numpy()[0][0])

                    plt.subplot(1,3,3)
                    plt.imshow(output.cpu().detach().numpy()[0][0])

                    plt.show()

        scheduler.step()

        if validate:
            eval(model=model,val_df=val_df,img_size=img_size,seg_organ=seg_organ,batch_size=batch_size,epoch=epoch,epochs=epochs)

## SingleFit

In [16]:
def single_fit(model,img_size,seg_organ,batch_size,loss_fn,optimizer,scheduler,epochs,img,mask):

    for epoch in range(epochs):

        model.train()


        total_loss = 0
        max_out = 0

        X = to_device(torch.tensor(np.repeat(img,repeats=3,axis=1), requires_grad=True),get_default_device())
        y = to_device(torch.tensor(mask, requires_grad=True),get_default_device())

        optimizer.zero_grad()
        output = model(X)
        loss = loss_fn(output,y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        max_out = max(max_out,torch.max(output))

    
        print('Epoch [{}/{}]   Loss: {:.4f}    Max: {:.4f}'.format(epoch+1,epochs,total_loss,max_out))

        plt.subplot(1,3,1)
        plt.imshow(X.cpu().detach().numpy()[0][0])

        plt.subplot(1,3,2)
        plt.imshow(y.cpu().detach().numpy()[0][0])

        plt.subplot(1,3,3)
        plt.imshow(output.cpu().detach().numpy()[0][0])

        plt.show()
        scheduler.step()

## Training

In [21]:
#Parameters
img_size = 224
epochs = 100
seg_organ = 'Esophagus'
batch_size = 5

In [22]:
#Model Specific
enc = ViT_Encoder(init_weight_path='init_weight/ViT_original_pretrained')
dec = Decoder()
model = SegModel(enc,dec)
model = to_device(model,get_default_device())

In [23]:
#Training Specific
loss_fn = dice_loss

optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.5)

In [None]:
fit(model=model, img_size=img_size, seg_organ=seg_organ, batch_size=batch_size,
    loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
    epochs=epochs, train_df=df_train, val_df=df_val, 
    validate=True, print_every=2000, print_epoch=3)

End of Epoch [1/100]   Dice Metric: 0.0588  IoU Metric: 0.0311  Precision Metric: 0.0311  Recall Metric: 0.9973
End of Epoch [2/100]   Dice Metric: 0.3330  IoU Metric: 0.2901  Precision Metric: 0.3133  Recall Metric: 0.4707
End of Epoch [3/100]   Dice Metric: 0.3332  IoU Metric: 0.2992  Precision Metric: 0.3099  Recall Metric: 0.4963
End of Epoch [4/100]   Dice Metric: 0.3109  IoU Metric: 0.2927  Precision Metric: 0.3089  Recall Metric: 0.3411
End of Epoch [5/100]   Dice Metric: 0.4817  IoU Metric: 0.4088  Precision Metric: 0.4950  Recall Metric: 0.5675
End of Epoch [6/100]   Dice Metric: 0.4386  IoU Metric: 0.3714  Precision Metric: 0.4430  Recall Metric: 0.5576
End of Epoch [7/100]   Dice Metric: 0.4769  IoU Metric: 0.3951  Precision Metric: 0.4746  Recall Metric: 0.6285
End of Epoch [8/100]   Dice Metric: 0.4798  IoU Metric: 0.4058  Precision Metric: 0.4849  Recall Metric: 0.5840
End of Epoch [9/100]   Dice Metric: 0.4321  IoU Metric: 0.3728  Precision Metric: 0.4590  Recall Metric:

In [None]:
# optimizer = torch.optim.Adam(model.parameters(),lr=1e-5 * 5)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=0.1)

# fit(model=model, img_size=img_size, seg_organ=seg_organ, batch_size=batch_size,
#     loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
#     epochs=epochs, train_df=df_train, val_df=df_val, 
#     validate=True, print_every=2000, print_epoch=3)

In [None]:
# d = Datagen_CT(df=df_train,img_size=img_size,seg_organ=seg_organ,batch_size=1)
# for i in range(0,100):
#     img, mask = next(d)

# single_fit(model=model, img_size=img_size, seg_organ=seg_organ, batch_size=batch_size,
#           loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
#           epochs=500, img=img, mask=mask)