In [None]:
## Mounting google drive for file handling

from google.colab import drive
drive.mount("/content/drive",force_remount=True)

In [None]:
## Copying dist utils for dice module utilities

!cp -rv "ENTER FILE PATH OF dist_utils HERE" "/content/"

In [None]:
## Importing necessary libraries

import torch as tor
import torch.nn as nn
import torchvision as tv
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import torchvision as tv
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF

import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
from dist_utils.nn_layers.cnn_utils import *
import math
import time
from PIL import Image as im

from sklearn.metrics import f1_score 
import dist_utils.postprocessing as pp
from sklearn.metrics import jaccard_score 
import random

import statistics as stats
import copy
from dist_utils.flops_compute import *

In [None]:
def calc_total_mean(datafiles,num_chn = 3,verbose = False):

    """
        Method to calculate the dataset mean for zero-centering

        Args:
            datafiles - list of data file names
            num_chn - Number of channels of the data images
            verbose - verbosity for debugging

        Returns:
            mean of dataset
    """

    img_sum = 0
    num_files = len(datafiles)

    for e,file in enumerate(datafiles):

        if(num_chn == 3):
            img = cv2.resize(cv2.imread(file),(512,512))
            img_sum += img
        elif(num_chn == 1):
            img = cv2.imread(file,0)
            img_sum += img
        else:
            assert "Incorrect number of channels"

        if(verbose):
            print(e,file)


    return np.float32(img_sum) / num_files

In [None]:
"""
NOTE: This code has been forked from https://github.com/sacmehta/EdgeNets. We thank sacmehta and team for this class definition. 
"""

class dice(nn.Module):
    '''
    This class implements the volume-wise seperable convolutions
    '''
    def __init__(self, channel_in, channel_out, height, width, kernel_size=3, dilation=[1, 1, 1], shuffle=True):
        '''
        :param channel_in: # of input channels
        :param channel_out: # of output channels
        :param height: Height of the input volume
        :param width: Width of the input volume
        :param kernel_size: Kernel size. We use the same kernel size of 3 for each dimension. Larger kernel size would increase the FLOPs and Parameters
        :param dilation: It's a list with 3 elements, each element corresponding to a dilation rate for each dimension.
        :param shuffle: Shuffle the feature maps in the volume-wise separable convolutions
        '''
        super().__init__()
        assert len(dilation) == 3
        padding_1 = int((kernel_size - 1) / 2) *dilation[0] 
        padding_2 = int((kernel_size - 1) / 2) *dilation[1] 
        padding_3 = int((kernel_size - 1) / 2) *dilation[2] 
        self.conv_channel = nn.Conv2d(channel_in, channel_in, kernel_size=kernel_size, stride=1, groups=channel_in,
                                      padding=padding_1, bias=False, dilation=dilation[0])
        self.conv_width = nn.Conv2d(width, width, kernel_size=kernel_size, stride=1, groups=width,
                               padding=padding_2, bias=False, dilation=dilation[1])
        self.conv_height = nn.Conv2d(height, height, kernel_size=kernel_size, stride=1, groups=height,
                               padding=padding_3, bias=False, dilation=dilation[2])

        self.br_act = BR(3*channel_in)
        self.weight_avg_layer = CBR(3*channel_in, channel_in, kSize=1, stride=1, groups=channel_in)

        # project from channel_in to Channel_out
        groups_proj = math.gcd(channel_in, channel_out)
        self.proj_layer = CBR(channel_in, channel_out, kSize=3, stride=1, groups=groups_proj)
        self.linear_comb_layer = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Conv2d(channel_in, channel_in // 4, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel_in //4, channel_out, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

        self.vol_shuffle = Shuffle(3)

        self.width = width
        self.height = height
        self.channel_in = channel_in
        self.channel_out = channel_out
        self.shuffle = shuffle
        self.ksize=kernel_size
        self.dilation = dilation

    def forward(self, x):
        '''
        :param x: input of dimension C x H x W
        :return: output of dimension C1 x H x W
        '''
        bsz, channels, height, width = x.size()
        # process across channel. Input: C x H x W, Output: C x H x W
        out_ch_wise = self.conv_channel(x)

        # process across height. Input: H x C x W, Output: C x H x W
        x_h_wise = x.clone()
        if height != self.height:
            if height < self.height:
                x_h_wise = F.interpolate(x_h_wise, mode='bilinear', size=(self.height, width), align_corners=True)
            else:
                x_h_wise = F.adaptive_avg_pool2d(x_h_wise, output_size=(self.height, width))

        x_h_wise = x_h_wise.transpose(1, 2).contiguous()
        out_h_wise = self.conv_height(x_h_wise).transpose(1, 2).contiguous()

        h_wise_height = out_h_wise.size(2)
        if height != h_wise_height:
            if h_wise_height < height:
                out_h_wise = F.interpolate(out_h_wise, mode='bilinear', size=(height, width), align_corners=True)
            else:
                out_h_wise = F.adaptive_avg_pool2d(out_h_wise, output_size=(height, width))

        # process across width: Input: W x H x C, Output: C x H x W
        x_w_wise = x.clone()
        if width != self.width:
            if width < self.width:
                x_w_wise = F.interpolate(x_w_wise, mode='bilinear', size=(height, self.width), align_corners=True)
            else:
                x_w_wise = F.adaptive_avg_pool2d(x_w_wise, output_size=(height, self.width))

        x_w_wise = x_w_wise.transpose(1, 3).contiguous()
        out_w_wise = self.conv_width(x_w_wise).transpose(1, 3).contiguous()
        w_wise_width = out_w_wise.size(3)
        if width != w_wise_width:
            if w_wise_width < width:
                out_w_wise = F.interpolate(out_w_wise, mode='bilinear', size=(height, width), align_corners=True)
            else:
                out_w_wise = F.adaptive_avg_pool2d(out_w_wise, output_size=(height, width))

        # Merge. Output will be 3C x H X W
        outputs = torch.cat((out_ch_wise, out_h_wise, out_w_wise), 1)
        outputs = self.br_act(outputs)

        if self.shuffle:
            outputs = self.vol_shuffle(outputs)
        outputs = self.weight_avg_layer(outputs)
        linear_wts = self.linear_comb_layer(outputs)
        proj_out = self.proj_layer(outputs)
        return proj_out * linear_wts

    def __repr__(self):
        s = '{name}(in_channels={channel_in}, out_channels={channel_out}, kernel_size={ksize}, vol_shuffle={shuffle}, ' \
            'width={width}, height={height}, dilation={dilation})'
        return s.format(name=self.__class__.__name__, **self.__dict__)


In [None]:
class aspp(nn.Module):

    """
        Class definition of the ASPP (atrous spatial pyramid pooling) module 
    """

    def __init__(self,in_channels,mid_channels,prev_dim,rates):

        super().__init__()

        r1,r2,r3 = rates

        self.branch1 = dice(in_channels,mid_channels,prev_dim[0],prev_dim[1],kernel_size = 1)
        self.branch2 = dice(in_channels,mid_channels,prev_dim[0],prev_dim[1],kernel_size = 3,dilation = [r1,r1,r1])
        self.branch3 = dice(in_channels,mid_channels,prev_dim[0],prev_dim[1],kernel_size = 3,dilation = [r2,r2,r2])
        self.branch4 = dice(in_channels,mid_channels,prev_dim[0],prev_dim[1],kernel_size = 3,dilation = [r3,r3,r3])

        self.branch5 = nn.AvgPool2d(kernel_size = prev_dim)

        self.prev_dim = prev_dim

        self.upsample = nn.UpsamplingBilinear2d(size = prev_dim)

        self.final_layer = nn.Conv2d(in_channels = mid_channels * 4 + in_channels,out_channels = in_channels,kernel_size = (1,1))

    def forward(self,x):

        out1 = self.upsample(self.branch1(x))
        out2 = self.upsample(self.branch2(x))
        out3 = self.upsample(self.branch3(x))
        out4 = self.upsample(self.branch4(x))
        out5 = self.upsample(self.branch5(x))

        out = tor.cat((out1,out2,out3,out4,out5),dim = 1)

        out = self.final_layer(out)

        return out

In [None]:
class att_block(nn.Module):

    """
        Class definition of the attention block
    """

    def __init__(self,in_chnx,in_chng,mid_chn):

        super().__init__()

        self.lx = nn.Conv2d(in_chnx,mid_chn,kernel_size = (1,1),padding = 0)

        self.lg = nn.Conv2d(in_chng,mid_chn,kernel_size = (1,1),padding = 0)
        self.upconv = nn.ConvTranspose2d(in_channels = mid_chn,out_channels = mid_chn, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.lmid = nn.Conv2d(mid_chn,1,kernel_size = (1,1),padding = 0)

        self.resamp = nn.Conv2d(1,in_chnx,kernel_size = (1,1),padding = 0)

    def forward(self,x,g):

        x1 = self.lx(x)
        
        g = self.lg(g)
        g = self.upconv(g)

        res = x1 + g

        res = nn.ReLU()(res)

        res = self.lmid(res)

        res = nn.Sigmoid()(res)

        res = self.resamp(res)

        res = res * x

        return res

In [None]:
class dist_dice(nn.Module):

    """
        Kidney-SegNet model definition
    """

    def __init__(self,in_channels,nfeat,mid_channels = 256):

        self.drop = nn.Dropout

        super().__init__()
        self.conv1a = nn.Conv2d(in_channels = in_channels,out_channels = nfeat,kernel_size = (3,3),padding = 1)
        self.bn1a = nn.BatchNorm2d(nfeat)
        self.conv1b = dice(nfeat,nfeat,height = 512,width = 512)
        self.bn1b = nn.BatchNorm2d(nfeat)

        self.maxpool1 = nn.MaxPool2d(kernel_size = (2,2), stride = 2)

        self.conv2a = dice(nfeat,2 * nfeat,width = 256,height = 256)
        self.bn2a = nn.BatchNorm2d(2 * nfeat)
        self.conv2b = dice(2 * nfeat,2 * nfeat,width = 256,height = 256)
        self.bn2b = nn.BatchNorm2d(2 * nfeat)

        self.maxpool2 = nn.MaxPool2d(kernel_size = (2,2), stride = 2)

        self.conv3a = dice(2 * nfeat,4 * nfeat,width = 128,height = 128)
        self.bn3a = nn.BatchNorm2d(4 * nfeat)
        self.conv3b = dice(4 * nfeat,4 * nfeat,width = 128,height = 128)
        self.bn3b = nn.BatchNorm2d(4 * nfeat)

        self.maxpool3 = nn.MaxPool2d(kernel_size = (2,2),stride = 2)

        self.conv4a = dice(4 * nfeat,8 * nfeat,width = 64,height = 64)
        self.bn4a = nn.BatchNorm2d(8 * nfeat)
        self.conv4b = dice(8 * nfeat,8 * nfeat,width = 64,height = 64)
        self.bn4b = nn.BatchNorm2d(8 * nfeat)

        self.upconv1 = nn.ConvTranspose2d(in_channels = 8 * nfeat,out_channels = 4 * nfeat, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.conv6a = dice(8 * nfeat,4 * nfeat,width = 128,height = 128)
        self.bn6a = nn.BatchNorm2d(4 * nfeat)
        self.conv6b = dice(4 * nfeat,4 * nfeat,width = 128,height = 128)
        self.bn6b = nn.BatchNorm2d(4 * nfeat)

        self.upconv2 = nn.ConvTranspose2d(in_channels = 4 * nfeat,out_channels = 2 * nfeat, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.conv7a = dice(4 * nfeat,2 * nfeat,width = 256,height = 256)
        self.bn7a = nn.BatchNorm2d(2 * nfeat)
        self.conv7b = dice(2 * nfeat,2 * nfeat,width = 256,height = 256)
        self.bn7b = nn.BatchNorm2d(2 * nfeat)

        self.upconv3 = nn.ConvTranspose2d(in_channels = 2 * nfeat,out_channels = nfeat, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.conv8a = dice(2 * nfeat,nfeat,width = 512,height = 512)
        self.bn8a = nn.BatchNorm2d(nfeat)
        self.conv8b = dice(nfeat,nfeat,width = 512,height = 512)
        self.bn8b = nn.BatchNorm2d(nfeat)

        self.seg_map_conv = nn.Conv2d(in_channels = nfeat,out_channels = 1, kernel_size = (1,1))

        self.load_p(0)

        self.aspp = aspp(in_channels = 8 * nfeat,mid_channels = mid_channels,prev_dim = (64,64),rates = [4,6,8])

        self.att1 = att_block(in_chnx = 4 * nfeat,in_chng = 8 * nfeat,mid_chn = 4 * nfeat)
        self.att2 = att_block(in_chnx = 2 * nfeat,in_chng = 4 * nfeat,mid_chn = 2 * nfeat)
        self.att3 = att_block(in_chnx = nfeat,in_chng = 2 * nfeat,mid_chn = nfeat)
        
    def load_p(self,p):
        self.p = p

    def forward(self,inp):

        out1 = nn.ReLU()(self.bn1a(self.conv1a(inp)))
        out1 = nn.ReLU()(self.bn1b(self.conv1b(out1)))
        out2 = self.maxpool1(out1)

        out2 = nn.ReLU()(self.bn2a(self.conv2a(out2)))
        out2 = nn.ReLU()(self.bn2b(self.conv2b(out2)))
        out3 = self.maxpool2(out2)

        out3 = nn.ReLU()(self.bn3a(self.conv3a(out3)))
        out3 = nn.ReLU()(self.bn3b(self.conv3b(out3)))
        out4 = self.maxpool3(out3)

        out4 = nn.ReLU()(self.bn4a(self.conv4a(out4)))
        out4 = nn.ReLU()(self.bn4b(self.conv4b(out4)))
        
        out_aspp = self.aspp(out4)

        # out6 = self.upconv1(out_aspp)
        out6 = self.att1(out3,out_aspp)

        out6 = tor.cat((out6,out3),dim = 1)
        del out4

        out6 = nn.ReLU()(self.bn6a(self.conv6a(out6)))
        out6 = nn.ReLU()(self.bn6b(self.conv6b(out6)))
        # out7 = self.upconv2(out6)
        out7 = self.att2(out2,out6)

        out7 = tor.cat((out7,out2),dim = 1)
        del out3
        del out6

        out7 = nn.ReLU()(self.bn7a(self.conv7a(out7)))
        out7 = nn.ReLU()(self.bn7b(self.conv7b(out7)))
        # out8 = self.upconv3(out7)
        out8 = self.att3(out1,out7)

        # out1 = self.aspp2(out1)
        out8 = tor.cat((out8,out1),dim = 1)
        del out2
        del out7

        out8 = nn.ReLU()(self.bn8a(self.conv8a(out8)))
        out8 = nn.ReLU()(self.bn8b(self.conv8b(out8)))
        
        out = nn.ReLU()(self.seg_map_conv(out8))

        return out

In [None]:
def get_f1(gt,pred):

    f1 = []
    m = gt.shape[0]

    if(not isinstance(gt,np.ndarray)):
        # gt = gt.detach().cpu().numpy().squeeze()
        gt = gt.detach().cpu().numpy()

    if(not isinstance(pred,np.ndarray)):
        # pred = pred.detach().cpu().numpy().squeeze()
        pred = pred.detach().cpu().numpy()

    # print("pred.shape: ",pred.shape)
    # print("gt.shape: ",gt.shape)

    for predicted,ground_truth in zip(pred,gt):
        predicted = ((predicted - predicted.min()) / (predicted.max() - predicted.min())) * 1
        predicted = np.uint8(predicted)

        ground_truth = ((ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min())) * 1
        ground_truth = np.uint8(ground_truth)

        predicted = predicted.flatten()
        ground_truth = ground_truth.flatten()

        predicted = np.uint8(predicted)
        ground_truth = np.uint8(ground_truth)

        # print("predicted.shape: ",predicted.shape)
        # print("ground_truth.shape: ",ground_truth.shape)

        f1.append(f1_score(ground_truth,predicted))

    return np.mean(f1)

In [None]:
def get_ji(gt,pred):

    ji = []
    m = gt.shape[0]

    if(not isinstance(gt,np.ndarray)):
        # gt = gt.detach().cpu().numpy().squeeze()
        gt = gt.detach().cpu().numpy()

    if(not isinstance(pred,np.ndarray)):
        # pred = pred.detach().cpu().numpy().squeeze()
        pred = pred.detach().cpu().numpy()

    for predicted,ground_truth in zip(pred,gt):
        predicted = ((predicted - predicted.min()) / (predicted.max() - predicted.min())) * 1
        predicted = np.uint8(predicted)

        ground_truth = ((ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min())) * 1
        ground_truth = np.uint8(ground_truth)

        predicted = predicted.flatten()
        ground_truth = ground_truth.flatten()

        predicted = np.uint8(predicted)
        ground_truth = np.uint8(ground_truth)

        ji.append(jaccard_score(ground_truth,predicted))

    return np.mean(ji)

Dataset file structure:

    [TRAIN/VAL/TEST DIR]
            |
            |
            |___________________
            |         |        |
            data     labels    gts

In [None]:
class dataset(torch.utils.data.Dataset):

    """
        Dataset definition class
    """

    total_mean = 0

    def __init__(self,files_dir,data_size = -1,phase = "",apply_transforms = True):

        data_dir = os.path.join(files_dir,"data")
        label_dir = os.path.join(files_dir,"labels")
        gt_dir = os.path.join(files_dir,"gts")

        files = os.listdir(gt_dir)
        # label_files = os.listdir(label_dir)
        # gt_files = os.listdir(gt_dir)

        data_files = [os.path.join(data_dir,x) for x in files]
        label_files = [os.path.join(label_dir,x) for x in files]
        gt_files = [os.path.join(gt_dir,x) for x in files]

        if(data_size == -1):
            data_size = len(data_files)

        self.data_files = data_files
        self.label_files = label_files
        self.gt_files = gt_files
        self.data_size = data_size
        self.apply_transforms = apply_transforms

        if(phase == "train"):
          dataset.total_mean = tor.from_numpy(calc_total_mean(self.data_files))
          dataset.total_mean = dataset.total_mean.permute(2,0,1)
          
        print("shape of dataset.total_mean: ",dataset.total_mean.size())

        del data_files
        del label_files
        del gt_files

    def __len__(self):
        return self.data_size

    def transforms(self,data,label,gt):

        data = data.resize((512,512))
        label = label.resize((512,512))
        gt = gt.resize((512,512))

        # Random horizontal flip
        if(random.random() > 0.5):
            data = TF.hflip(data)
            label = TF.hflip(label)
            gt = TF.hflip(gt)

        ## Random Vertical Flip
        if(random.random() > 0.5):
            data = TF.vflip(data)
            label = TF.vflip(label)
            gt = TF.vflip(gt)

        ## Random rotate in multiples of 90
        range_of_angles = [0,90,180,270]
        angle = random.choice(range_of_angles)
        data = TF.rotate(data,angle)
        label = TF.rotate(label,angle,fill = (0,))
        gt = TF.rotate(gt,angle,fill = (0,))
        
        ## Applying color-jitter to data
        # data = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)(data)

        return data,label,gt

    def __getitem__(self,idx):

        data = self.data_files[idx]
        label = self.label_files[idx]
        gt = self.gt_files[idx]

        # print(label)

        data = im.open(data)
        label = im.open(label)
        gt = im.open(gt)
        gt = gt.convert('L')

        if(self.apply_transforms):
          data,label,gt = self.transforms(data,label,gt)
        else:
          data = data.resize((512,512))
          label = label.resize((512,512))
          gt = gt.resize((512,512))

        ## Convert PILs to tensors
        data = transforms.ToTensor()(data)[:3,:,:]
        label = transforms.ToTensor()(label)
        gt = transforms.ToTensor()(gt)

        label = label.type(tor.FloatTensor)
        gt = gt.type(tor.LongTensor)
        
        data = ((data - data.min()) / (data.max() - data.min())) * 255
        label = ((label - label.min()) / (label.max() - label.min())) * 255

        if(tor.sum(tor.isnan(label))):
            print(self.label_files[idx],"nan")
        if(tor.sum(tor.isnan(data))):
            print(self.label_files[idx],"nan")
        if(tor.sum(tor.isnan(gt))):
            print(self.label_files[idx],"nan")

        ## subtracting the mean
        data = data - dataset.total_mean

        return data,label,gt

In [None]:
## Extract training dataset

# Train directory containing the data, label and gt directories
train_dir = "ENTER TRAIN DIR NAME"

trainset = dataset(files_dir = train_dir,phase = "train",apply_transforms=True)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = 2,shuffle=True)

In [None]:
## Extract validation dataset

# Validation data directory containing the data, label and gt directories
val_dir = "ENTER VAL DIR NAME"

valset = dataset(files_dir=val_dir,apply_transforms=False)
valloader = torch.utils.data.DataLoader(valset,batch_size=2,shuffle = True)

In [None]:
def get_segments(distmaps,param = 7,thresh1 = 0.5,thresh2 = 5):

    """
        Method to extract segmentations from distance map regressions

        Args:
            distmaps - distance map regressions
            param, thresh1, thresh2 - hyper-parameters for segmentation
        
        Returns:
            segmentation outputs from distance maps
    """

    segments = []

    for res in distmaps:

        if(not isinstance(res,np.ndarray)):
            res = res.detach().cpu().numpy().squeeze()
            
        res = ((res - res.min()) / (res.max() - res.min())) * 255
        res = np.uint8(res)
        res[res<thresh2] = 0 

        res = pp.PostProcess(res,param = param,thresh = thresh1)
        res[res!=0] = 1

        segments.append(res)

    segments = np.array(segments)

    return segments

In [None]:
def train(seg,epochs,dataloaders,hyper_params,reset = True,save = False):

  """
        Method to perform training

        Args:
            seg - segmentation object
            epochs - number of epochs to be trained
            dataloaders - list of dataloaders
            hyper_params - list of various hyper-parameters
            reset - toggle to reset weights of model
            save - toggle to save model

        Returns:
            Training history  
  """

  global past

  trainloader,valloader,bridgeloader = dataloaders

  lr,reg,p,nfeat,postproc_params = hyper_params

  if(postproc_params):
      param,thresh1,thresh2 = postproc_params
  else:
      param = 7
      thresh1 = 0.5
      thresh2 = 20

  seg.load_p(p)

  if(reset):
    del seg
    seg = dist_dice(in_channels=3,nfeat = nfeat).to(device)
    print("/////////////////// Weights have been reset \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\")

  criterion = nn.MSELoss()
  optimizer = torch.optim.Adam(seg.parameters(),lr = lr,weight_decay = reg)

  epoch_losses = []
  for epoch in range(epochs):

      batch_losses = []
      batch_f1 = []
      batch_ji = []

      for batch_idx,(data,label,gt) in enumerate(trainloader):

          data,label = data.to(device),label.to(device)

          optimizer.zero_grad()

          out = seg(data) 
          loss = criterion(out,label)
          loss.backward()
          optimizer.step()

          segment_maps = get_segments(out,param = param,thresh1 = thresh1,thresh2 = thresh2)

          batch_f1.append(get_f1(gt,segment_maps))
          batch_ji.append(get_ji(gt,segment_maps))

          batch_losses.append(loss.item())

          if(math.isnan(loss.item())):
              print("batch_loss nan")
              print("data: ",tor.sum(tor.isnan(data)))
              print("out: ",tor.sum(tor.isnan(out)))
              print("label: ",tor.sum(tor.isnan(label)))
              print("gt: ",tor.sum(tor.isnan(gt)))
            #   input("")

      epoch_losses.append(np.mean(batch_losses))

      # print("Epoch: ",epoch,"\tEpoch Loss: ",epoch_losses[-1],"\tTrain acc: ",(correct*100.0/total)) 
      print("Epoch: ",epoch,"\tEpoch Loss: ",epoch_losses[-1],"Mean f1: ",np.mean(batch_f1))
      print("Mean ji: ",np.mean(batch_ji),"HM: ",stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)]))
      
      if(math.isnan(epoch_losses[-1])):
          print("batch_losses: ",batch_losses)
          break
      
      seg.eval()
      with tor.no_grad():


          batch_losses = []
          batch_f1 = []
          batch_ji = []

          for batch_idx,(valdata,vallabel,valgt) in enumerate(valloader):

              valdata,vallabel = valdata.to(device),vallabel.to(device)

              valout = seg(valdata)
              loss = criterion(valout,vallabel)

              segment_maps = get_segments(valout,param = param,thresh1 = thresh1,thresh2 = thresh2)
              batch_losses.append(loss.item())
              batch_f1.append(get_f1(valgt,segment_maps))
              batch_ji.append(get_ji(valgt,segment_maps))

          print("Val Loss: ",np.mean(batch_losses),"Mean f1: ",np.mean(batch_f1))
          print("Mean ji: ",np.mean(batch_ji),"HM: ",stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)]))
      

      if(stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)]) > past):
        state = [seg.state_dict(),hyper_params,[np.mean(batch_f1),np.mean(batch_ji),stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)])]]
        tor.save(state,"ENTER FILE PATH WHERE WEIGHTS SHOULD BE SAVED")
        print("************************best model saved*********************************")
        past = stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)])

      # save_model(seg)
      print("-------------------------------------------------------------------------")
      seg.train()

  return epoch_losses

In [None]:
## Function to count the number of parameters in a pytorch model

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad),sum(p.numel() for p in model.parameters())

In [None]:
## setting up training

device = tor.device("cuda:0" if tor.cuda.is_available() else "cpu")
print("using: ",device)

nfeat = 32
seg = dist_dice(in_channels=3,nfeat = nfeat).to(device)
seg = add_flops_counting_methods(seg)

In [None]:
## Calculate the number of parameters

train_params,total_params = count_parameters(seg)
print("Number of trainable params: ",train_params)
print("total params: ",total_params)

## The next 3 cells are used to load an already existing previous model. 
___________________________________________________

In [None]:
## Load previous past best model score

# file path to model weights
state = tor.load("ENTER FILE PATH TO MODEL WEIGHTS HERE")
state3 = state[2]
past = state3[2]
print("past best = ",past,state3)

In [None]:
## make segmentation model weights the same as previous best model weights

seg.load_state_dict(state[0])

In [None]:
## training mode

_ = seg.train()

________________

# The training process

### The following cell performs the training for the Kidney-Segnet framework. In order to train the network effectively, it is important to tune the hyper-parameters.
___

In [None]:
epochs = 1000
nfeat = 32
lr = 7 * 1e-4
reg = 0.7 * 1e-3
reset = True
p = 0.0
past = 0.0
dataloaders = [trainloader,testloader,0.0]
postproc_params = [7,0.5,24]
# postproc_params = [65,0.5,18]
hyper_params = [lr,reg,p,nfeat,postproc_params]

_ = train(seg,epochs,dataloaders,hyper_params = hyper_params,reset = reset)

___

# Testing the network
___

In [None]:
## Extract test dataset

# test data directory containing the data, label and gt directories
test_dir = "ENTER TEST DIR HERE"

testset = dataset(files_dir=test_dir,apply_transforms=False)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle = True)

In [None]:
batch_losses = []
batch_f1 = []
batch_ji = []

criterion = nn.MSELoss()
param = 7
thresh1 = 0.5
thresh2 = 28

imgs = []
gts = []

seg.eval()
with torch.no_grad():
  for batch_idx,(testdata,testlabel,testgt) in enumerate(testloader):

      testdata,testlabel = testdata.to(device),testlabel.to(device)

      testout = seg(testdata)
      loss = criterion(testout,testlabel)

      segment_maps = get_segments(testout,param = param,thresh1 = thresh1,thresh2 = thresh2)
      batch_losses.append(loss.item())
      batch_f1.append(get_f1(testgt,segment_maps))
      batch_ji.append(get_ji(testgt,segment_maps))

      imgs.append(segment_maps)
      gts.append(testgt)

  print("Test Loss: ",np.mean(batch_losses),"Mean f1: ",np.mean(batch_f1))
  print("Mean ji: ",np.mean(batch_ji),"HM: ",stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)]))

___