# Fill up the folder details
List the respective folder path details to the variable, example has been given below

In [None]:
train_img_folder = "/kaggle/input/stanford-thyroid-cine-clips-train-test-val-splits/ThyroidNodule_Segmentation/train_Images"
train_mask_folder = "/kaggle/input/stanford-thyroid-cine-clips-train-test-val-splits/ThyroidNodule_Segmentation/train_Masks"
test_img_folder= "/kaggle/input/stanford-thyroid-cine-clips-train-test-val-splits/ThyroidNodule_Segmentation/test_Images"
test_mask_folder= "/kaggle/input/stanford-thyroid-cine-clips-train-test-val-splits/ThyroidNodule_Segmentation/test_Masks"
val_img_folder = "/kaggle/input/stanford-thyroid-cine-clips-train-test-val-splits/ThyroidNodule_Segmentation/valid_Images"
val_mask_folder ="/kaggle/input/stanford-thyroid-cine-clips-train-test-val-splits/ThyroidNodule_Segmentation/valid_Masks"

In [1]:
from __future__ import print_function, division
import glob
import torch
import os
from skimage import io, transform, color
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as standard_transforms


In [2]:

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
        

In [3]:
model = U2NET()

In [4]:
from torchinfo import summary
# summary(model)

In [5]:
torch.cuda.is_available()

True

## Data Loading

In [6]:
class SalObjDataset(Dataset):
	def __init__(self,img_name_list,lbl_name_list,transform=None):
        # self.root_dir = root_dir
		# self.image_name_list = glob.glob(image_dir+'*.png')
		# self.label_name_list = glob.glob(label_dir+'*.png')
		self.image_name_list = img_name_list
		self.label_name_list = lbl_name_list
		self.transform = transform

	def __len__(self):
		return len(self.image_name_list)

	def __getitem__(self,idx):

		# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
		# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])

		image = io.imread(self.image_name_list[idx])
		imname = self.image_name_list[idx]
		imidx = np.array([idx])

		if(0==len(self.label_name_list)):
			label_3 = np.zeros(image.shape)
		else:
			label_3 = io.imread(self.label_name_list[idx])

		label = np.zeros(label_3.shape[0:2])
		if(3==len(label_3.shape)):
			label = label_3[:,:,0]
		elif(2==len(label_3.shape)):
			label = label_3

		if(3==len(image.shape) and 2==len(label.shape)):
			label = label[:,:,np.newaxis]
		elif(2==len(image.shape) and 2==len(label.shape)):
			image = image[:,:,np.newaxis]
			label = label[:,:,np.newaxis]

		sample = {'imidx':imidx, 'image':image, 'label':label}

		if self.transform:
			sample = self.transform(sample)

		return sample

## Loss Function

In [7]:
bce_loss = nn.BCELoss(size_average=True)

def dice_coefficient(pred, target, threshold=0.5, smooth=1e-6):
    pred = (pred > threshold).float() 
    intersection = torch.sum(pred * target)
    union = torch.sum(pred) + torch.sum(target)
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice

def accuracy(pred, target, threshold=0.5):  
    pred = (pred > threshold).float()  
    correct = torch.sum(pred == target)
    total = target.numel()  
    acc = correct / total
    return acc



In [8]:
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):

    loss0 = bce_loss(d0,labels_v)
    loss1 = bce_loss(d1,labels_v)
    loss2 = bce_loss(d2,labels_v)
    loss3 = bce_loss(d3,labels_v)
    loss4 = bce_loss(d4,labels_v)
    loss5 = bce_loss(d5,labels_v)
    loss6 = bce_loss(d6,labels_v)
    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    return loss0 ,loss1 , loss2 , loss3 , loss4 , loss5 , loss6, loss

def muti_dice_coef_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    dice0 = dice_coefficient(d0, labels_v)
    dice1 = dice_coefficient(d1, labels_v)
    dice2 = dice_coefficient(d2, labels_v)
    dice3 = dice_coefficient(d3, labels_v)
    dice4 = dice_coefficient(d4, labels_v)
    dice5 = dice_coefficient(d5, labels_v)
    dice6 = dice_coefficient(d6, labels_v)
    final_dice = dice0+dice1+dice2+dice3+dice4+dice5+dice6
    return dice0,dice1,dice2,dice3,dice4,dice5,dice6

def multi_accuracy_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    acc0 = accuracy(d0, labels_v)
    acc1 = accuracy(d1, labels_v)
    acc2 = accuracy(d2, labels_v)
    acc3 = accuracy(d3, labels_v)
    acc4 = accuracy(d4, labels_v)
    acc5 = accuracy(d5, labels_v)
    acc6 = accuracy(d6, labels_v)
    final_acc = acc0+acc1+acc2+acc3+acc4+acc5+acc6
    return acc0,acc1,acc2,acc3,acc4,acc5,acc6

In [9]:
image_ext = '.jpg'
label_ext = '.jpg'
model_name = 'u2net'
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
batch_size_train = 32
batch_size_val = 1
train_num = 0
val_num = 0

In [11]:
def names_from_folder(folder_path):
    list = []
    for image in os.listdir(folder_path):
        list.append(os.path.join(folder_path,image))
    return list

tra_img_name_list = names_from_folder(train_img_folder)
tra_lbl_name_list = names_from_folder(train_mask_folder)

test_img_name_list = names_from_folder(test_img_folder)
test_lbl_name_list = names_from_folder(test_mask_folder)

val_img_name_list = names_from_folder(val_img_folder)
val_lbl_name_list = names_from_folder(val_mask_folder)

In [12]:
print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
print("test images: ", len(test_img_name_list))
print("test labels: ", len(test_lbl_name_list))
print("---")
print("valid images: ", len(val_img_name_list))
print("valid labels: ", len(val_lbl_name_list))
print("---")

train_num = len(tra_img_name_list)

---
train images:  13746
train labels:  13746
---
test images:  1726
test labels:  1726
---
valid images:  1940
valid labels:  1940
---


In [13]:
class Rescale2(object):

    def __init__(self, output_size):
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        elif isinstance(output_size, tuple) and len(output_size) == 2:
            self.output_size = output_size
        else:
            raise ValueError("output_size should be an int or a tuple of two integers")

    def __call__(self, sample):
        imidx, image, label = sample['imidx'], sample['image'], sample['label']

        if random.random() >= 0.5:
            image = image[:, ::-1]  # Flip image horizontally
            label = label[:, ::-1]  # Flip label horizontally

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        # Resize the image and label
        img = transform.resize(image, (new_h, new_w), mode='reflect', anti_aliasing=True)
        lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True)

        return {'imidx': imidx, 'image': img, 'label': lbl}

In [14]:
import numpy as np
import torch
import torchvision.transforms as transforms
from skimage import color

class ToTensorLab2(object):
    """Convert ndarrays in sample to Tensors and apply color space transformations."""

    def __init__(self, flag=0):
        self.flag = flag

    def __call__(self, sample):
        imidx, image, label = sample['imidx'], sample['image'], sample['label']

        # Normalize label to range [0,1] if label has any non-zero values
        # if np.max(label) >= 1e-6:
        #     label = label / np.max(label)

        # Apply transformations based on the flag
        if self.flag == 2:  # RGB + Lab colors
            tmpImg = self.apply_rgb_lab_transform(image)
        elif self.flag == 1:  # Lab color only
            tmpImg = self.apply_lab_transform(image)
        else:  # RGB color only
            tmpImg = self.apply_rgb_transform(image)
        
        label[label <= 128] = 0
        label[label > 128] = 1
        tmpLbl = label.transpose((2, 0, 1))  # Reformat label to (C, H, W)

        return {'imidx': torch.from_numpy(imidx), 
                'image': torch.from_numpy(tmpImg), 
                'label': torch.from_numpy(tmpLbl)}

    def apply_rgb_transform(self, image):
        """Apply RGB normalization using torchvision.transforms."""
        if image.shape[2] == 1:  # Handle grayscale image
            image = np.repeat(image, 3, axis=2)

        image = image / np.max(image)  # Normalize image to range [0, 1]

        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        return transform(image).numpy()

    def apply_lab_transform(self, image):
        """Convert the image to Lab color space and normalize."""
        if image.shape[2] == 1:  # Handle grayscale image
            image = np.repeat(image, 3, axis=2)

        image[image <= 128] = 0
        image[image > 128] = 1

        # Convert image to Lab color space
        # lab_image = color.rgb2lab(image)
        
        # Normalize each channel to range [0, 1] and then standardize
        # for i in range(3):
        #     lab_image[:, :, i] = (lab_image[:, :, i] - np.min(lab_image[:, :, i])) / (np.max(lab_image[:, :, i]) - np.min(lab_image[:, :, i]))
        #     lab_image[:, :, i] = (lab_image[:, :, i] - np.mean(lab_image[:, :, i])) / np.std(lab_image[:, :, i])
        return image.transpose((2, 0, 1))

    def apply_rgb_lab_transform(self, image):
        """Convert the image to both RGB and Lab color spaces, then concatenate."""
        if image.shape[2] == 1:  # Handle grayscale image
            image = np.repeat(image, 3, axis=2)

        # RGB normalization
        rgb_normalized = self.apply_rgb_transform(image).transpose(1, 2, 0)  

        # Lab transformation
        lab_transformed = self.apply_lab_transform(image).transpose(1, 2, 0)  

        # Concatenate RGB and Lab
        combined = np.concatenate((rgb_normalized, lab_transformed), axis=2)

        return combined.transpose((2, 0, 1))  # Return in (C, H, W) format


In [15]:
train_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        Rescale2(256),
        ToTensorLab2()]))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)


test_dataset = SalObjDataset(
    img_name_list=test_img_name_list,
    lbl_name_list=test_lbl_name_list,
    transform=transforms.Compose([
        Rescale2(256),
        ToTensorLab2()]))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)


val_dataset = SalObjDataset(
    img_name_list=val_img_name_list,
    lbl_name_list=val_lbl_name_list,
    transform=transforms.Compose([
        Rescale2(256),
        ToTensorLab2()]))
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)


In [16]:
if torch.cuda.is_available():
    model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)


In [None]:
import torch
from torch.autograd import Variable

metrics = {
    'loss':None, 'loss0': None, 'loss1': None, 'loss2': None, 'loss3': None, 'loss4': None, 'loss5': None, 'loss6': None,
    'dice0': None, 'dice1': None, 'dice2': None, 'dice3': None, 'dice4': None, 'dice5': None, 'dice6': None,
    'acc0': None, 'acc1': None, 'acc2': None, 'acc3': None, 'acc4': None, 'acc5': None, 'acc6': None
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize variables
ite_num = 0
epoch_num = 1

for epoch in range(epoch_num):
    model.train()

    # Initialize running sums for each metric
    running_loss = 0.0
    running_loss0 = 0.0
    running_loss1 = 0.0
    running_loss2 = 0.0
    running_loss3 = 0.0
    running_loss4 = 0.0
    running_loss5 = 0.0
    running_loss6 = 0.0
    running_dice0 = 0.0
    running_dice1 = 0.0
    running_dice2 = 0.0
    running_dice3 = 0.0
    running_dice4 = 0.0
    running_dice5 = 0.0
    running_dice6 = 0.0
    running_acc0 = 0.0
    running_acc1 = 0.0
    running_acc2 = 0.0
    running_acc3 = 0.0
    running_acc4 = 0.0
    running_acc5 = 0.0
    running_acc6 = 0.0
    
    
   
    total_batches = 0  # Counter for the number of batches

    for i, data in enumerate(train_dataloader):
        ite_num += 1
        total_batches += 1

        inputs, labels = data['image'], data['label']
        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        # Move to GPU if available
        if torch.cuda.is_available():
            inputs_v = Variable(inputs.to(device), requires_grad=False)
            labels_v = Variable(labels.to(device), requires_grad=False)
        else:
            inputs_v = Variable(inputs, requires_grad=False)
            labels_v = Variable(labels, requires_grad=False)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        d0, d1, d2, d3, d4, d5, d6 = model(inputs_v)
        loss0 ,loss1 , loss2 , loss3 , loss4 , loss5 , loss6, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
        dice0,dice1,dice2,dice3,dice4,dice5,dice6 = muti_dice_coef_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
        acc0,acc1,acc2,acc3,acc4,acc5,acc6 = multi_accuracy_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate the metrics
        running_loss += loss.data.item()
        running_loss0 += loss0.data.item()
        running_loss1 += loss1.data.item()
        running_loss2 += loss2.data.item()
        running_loss3 += loss3.data.item()
        running_loss4 += loss4.data.item()
        running_loss5 += loss5.data.item()
        running_loss6 += loss6.data.item()
        running_dice0 += dice0.data.item()
        running_dice1 += dice1.data.item()
        running_dice2 += dice2.data.item()
        running_dice3 += dice3.data.item()
        running_dice4 += dice4.data.item()
        running_dice5 += dice5.data.item()
        running_dice6 += dice6.data.item()
        running_acc0 += acc0.data.item()
        running_acc1 += acc1.data.item()
        running_acc2 += acc2.data.item()
        running_acc3 += acc3.data.item()
        running_acc4 += acc4.data.item()
        running_acc5 += acc5.data.item()
        running_acc6 += acc6.data.item()
        
        del d0, d1, d2, d3, d4, d5, d6, loss0 ,loss1 , loss2 , loss3 , loss4 , loss5 , loss6, loss,dice0,dice1,dice2,dice3,dice4,dice5,dice6,acc0,acc1,acc2,acc3,acc4,acc5,acc6

    # Calculate averages for the epoch
    
    

    avg_loss = running_loss / total_batches
    avg_loss0 = running_loss0 / total_batches
    avg_loss1 = running_loss1 / total_batches
    avg_loss2 = running_loss2 / total_batches
    avg_loss3 = running_loss3 / total_batches
    avg_loss4 = running_loss4 / total_batches
    avg_loss5 = running_loss5 / total_batches
    avg_loss6 = running_loss6 / total_batches
    avg_dice0 = running_dice0 / total_batches
    avg_dice1 = running_dice1 / total_batches
    avg_dice2 = running_dice2 / total_batches
    avg_dice3 = running_dice3 / total_batches
    avg_dice4 = running_dice4 / total_batches
    avg_dice5 = running_dice5 / total_batches
    avg_dice6 = running_dice6 / total_batches
    avg_acc0 = running_acc0 / total_batches
    avg_acc1 = running_acc1 / total_batches
    avg_acc2 = running_acc2 / total_batches
    avg_acc3 = running_acc3 / total_batches
    avg_acc4 = running_acc4 / total_batches
    avg_acc5 = running_acc5 / total_batches
    avg_acc6 = running_acc6 / total_batches
    
    
    for i,j in zip(list(metrics.keys()),[avg_loss, avg_loss0, avg_loss1, avg_loss2, avg_loss3, avg_loss4, avg_loss5, avg_loss6,avg_dice0,avg_dice1,avg_dice2,avg_dice3,avg_dice4,avg_dice5,avg_dice6,avg_acc0,avg_acc1,avg_acc2,avg_acc3,avg_acc4,avg_acc5,avg_acc6 ]):
        metrics[i]=j
    
    
    # Print metrics for the epoch
    print(f"[Epoch: {epoch + 1}/{epoch_num}] "
          f"Train Loss: {avg_loss:.4f}, "
          f"Final Loss: {avg_loss0:.4f}, "
          f"Final Dice Coefficient: {avg_dice0:.4f}, "
          f"Final Accuracy: {avg_acc0:.4f}")
    
    torch.save(model.state_dict(), f"u2net_bce_train_epoch_{epoch + 1}_loss_{avg_loss:.4f}.pth")


    # Reset running sums for the next epoch
    running_loss = 0.0
    running_loss0 = 0.0
    running_loss1 = 0.0
    running_loss2 = 0.0
    running_loss3 = 0.0
    running_loss4 = 0.0
    running_loss5 = 0.0
    running_loss6 = 0.0
    running_dice0 = 0.0
    running_dice1 = 0.0
    running_dice2 = 0.0
    running_dice3 = 0.0
    running_dice4 = 0.0
    running_dice5 = 0.0
    running_dice6 = 0.0
    running_acc0 = 0.0
    running_acc1 = 0.0
    running_acc2 = 0.0
    running_acc3 = 0.0
    running_acc4 = 0.0
    running_acc5 = 0.0
    running_acc6 = 0.0
import pandas as pd
df = pd.DataFrame(metrics)
df.to_csv("Training_Data.csv")

## Testing

In [None]:
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')


test_metrics = {
    'loss':None, 'loss0': None, 'loss1': None, 'loss2': None, 'loss3': None, 'loss4': None, 'loss5': None, 'loss6': None,
    'dice0': None, 'dice1': None, 'dice2': None, 'dice3': None, 'dice4': None, 'dice5': None, 'dice6': None,
    'acc0': None, 'acc1': None, 'acc2': None, 'acc3': None, 'acc4': None, 'acc5': None, 'acc6': None
}


for i_test, data_test in enumerate(test_dataloader):

    inputs_test = data_test['image']
    inputs_test = inputs_test.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
    else:
        inputs_test = Variable(inputs_test)

    d1,d2,d3,d4,d5,d6,d7= model(inputs_test)
    loss0 ,loss1 , loss2 , loss3 , loss4 , loss5 , loss6, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
    dice0,dice1,dice2,dice3,dice4,dice5,dice6 = muti_dice_coef_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
    acc0,acc1,acc2,acc3,acc4,acc5,acc6 = multi_accuracy_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

    for i,j in zip(list(metrics.keys()),[loss, loss0 ,loss1 , loss2 , loss3 , loss4 , loss5 , loss6, dice0,dice1,dice2,dice3,dice4,dice5,dice6 ,acc0,acc1,acc2,acc3,acc4,acc5,acc6 ]):
        metrics[i]=j
    
    # normalization
    pred = d1[:,0,:,:]
    pred = normPRED(pred)

    # save results to test_results folder
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir, exist_ok=True)
    save_output(test_img_name_list[i_test],pred,prediction_dir)

    del d1,d2,d3,d4,d5,d6,d7
import pandas as pd
df2 = pd.DataFrame(test_metrics)
test_metrics.to_csv("Testing_Data.csv")
