In [None]:
# # DO 

# !pip uninstall Albumentations
# !pip install Albumentations==0.5.2

In [None]:
# Dice link: https://towardsdatascience.com/how-accurate-is-image-segmentation-dd448f896388

In [8]:
!pip uninstall ovencv-python
!pip install opencv-python-headless





In [9]:
import numpy as np
import pandas as pd

import os
import cv2
import matplotlib.pyplot as plt
# import matplotlib.pylab as plt

import numpy as np
import seaborn as sns

from tqdm.notebook import tqdm
import time
import random

plt.style.use("dark_background")
%matplotlib inline


In [None]:
# from prepareData.prepareData import get_dataset_dataframe
# from util.helper import pos_neg_diagnosis, show_aug, train_model, plot_model_history, viz_pred_output
# from prepareData import augmentData, customDatasetObject
# from model.unet3p_attention import UNet_3Plus_attn
# from model.unet3p import UNet3Plus
# from model.unet_attention import AttentionUNet
# from metrics.diceMetrics import dice_coef_metric, DiceLoss, compute_iou


# Prepare Data

In [10]:
##prepareData

#augmentData
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensor

from sklearn.model_selection import train_test_split


PATCH_SIZE = 128

transform = A.Compose([
    A.Resize(width = PATCH_SIZE, height = PATCH_SIZE, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Transpose(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.01, scale_limit=0.04, rotate_limit=0, p=0.25),
    A.Normalize(p=1.0),
    ToTensor(),
])



In [11]:
#customDatasetObject

from torch.utils.data import Dataset
import cv2 

class MRImagingDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, 1])
        mask = cv2.imread(self.df.iloc[idx, 2], 0)
        
        augmented = self.transform(image=image,
                                   mask=mask)
        
        image = augmented["image"]
        mask = augmented["mask"]
#         mask = np.expand_dims(augmented["mask"], axis=0)# Do not use this
        
        return image, mask
    

In [12]:
#prepareData

import pandas as pd
import os

def get_dataset_dataframe(base_path:str):
    data = []

    for dir_ in os.listdir(base_path):
        dir_path = os.path.join(base_path, dir_)
        if os.path.isdir(dir_path):
            for filename in os.listdir(dir_path):
                img_path = os.path.join(dir_path, filename)
                data.append([dir_, img_path])
        else:
            print(f"[INFO] This is not a dir --> {dir_path}")
            
    return pd.DataFrame(data, columns=["dir_name", "image_path"])


# Util

In [14]:
##Util

#helper
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import sys
sys.path.append('../')
# from unet3p_attention_capstone.metrics.diceMetrics import dice_coef_metric, compute_iou
from tqdm.notebook import tqdm
import time
import shutil

from torch.optim.lr_scheduler import ReduceLROnPlateau

def pos_neg_diagnosis(mask_path):
    """
    To assign 0 or 1 based on the presence of tumor.
    """
    val = np.max(cv2.imread(mask_path))
    if val > 0: return 1
    else: return 0


def show_aug(inputs, nrows=5, ncols=5, norm=False):
    plt.figure(figsize=(10, 10))
    plt.subplots_adjust(wspace=0., hspace=0.)
    i_ = 0
    
    if len(inputs) > 25:
        inputs = inputs[:25]
        
    for idx in range(len(inputs)):
    
        # normalization
        if norm:           
            img = inputs[idx].numpy().transpose(1,2,0)
            mean = [0.485, 0.456, 0.406]
            std = [0.229, 0.224, 0.225] 
            img = (img*std+mean).astype(np.float32)
            
        else:
            img = inputs[idx].numpy().astype(np.float32)
            img = img[0,:,:]
        
        plt.subplot(nrows, ncols, i_+1)
        plt.imshow(img); 
        plt.axis('off')
 
        i_ += 1
        
    return plt.show()

def save_ckp(state, is_best, checkpoint_dir, best_model_dir):
    f_path = checkpoint_dir + '\\checkpoint.pt'
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_dir + '\\best_model.pt'
        shutil.copyfile(f_path, best_fpath)

def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['model_state_dict'])
    print('Previously trained model weights state_dict loaded...')
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print('Previously trained optimizer state_dict loaded...')
    last_epoch = checkpoint['epoch']
    print(f"Previously trained for {last_epoch} number of epochs...")
    return model, optimizer, last_epoch

def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, lr_scheduler, num_epochs, device, ckp_path:str=None):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(f"total params of {model_name} model: {pytorch_total_params}")
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"trainable params of {model_name} model: {pytorch_total_params}")

    scheduler = ReduceLROnPlateau(optimizer, 'min')

    start_epoch=0

    if ckp_path is not None:
        model, optimizer, last_epoch = load_ckp(ckp_path, model, optimizer)
        start_epoch = last_epoch + 1
        print(f"Train for {num_epochs} more epochs...")

    print(f"[INFO] Model is initializing... {model_name}")

    checkpoint_dir = f"C:\\Users\\Ryan\\Documents\\checkpoints\\{model_name}"
    best_model_dir = f"{checkpoint_dir}\\{model_name}_best"
    #best_model_dir = f"{checkpoint_dir}\\best"

    loss_history = []
    train_history = []
    val_history = []
    
    mean_loss_ = 999
    
    for epoch in range(start_epoch, start_epoch+num_epochs):
        model.train()
        
        losses = []
        train_iou = []
        
        for i_step, (data, target) in enumerate(tqdm(train_loader)):
            data = data.to(device)
            target = target.to(device)
            
            outputs = model(data)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
            train_dice = dice_coef_metric(out_cut, target.data.cpu().numpy())
            
            loss = train_loss(outputs, target)
            
            losses.append(loss.item())
            train_iou.append(train_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            
            
        val_mean_iou = compute_iou(model, val_loader, device=device)
        
        mean_loss = np.array(losses).mean()
        scheduler.step(mean_loss)

        loss_history.append(mean_loss)
        train_history.append(np.array(train_iou).mean())
        val_history.append(val_mean_iou)

        checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': np.array(losses).mean(),
                    }
        save_ckp(checkpoint, False, checkpoint_dir, best_model_dir)

        if loss<mean_loss_:
            save_ckp(checkpoint, True, checkpoint_dir, best_model_dir)
            mean_loss_ = loss
                # torch.save({
                #     'epoch': epoch,
                #     'model_state_dict': model.state_dict(),
                #     'optimizer_state_dict': optimizer.state_dict(),
                #     'loss': np.array(losses).mean(),
                #     }, f"{path[:-3]}_best.pt")
        
        

#         print("losses:", np.array(losses))
#         print("iou:", np.array(train_iou))
        print("Epoch [%d]" % (epoch))
        print("Mean loss on train:", np.array(losses).mean(), 
              "\nMean DICE on train:", np.array(train_iou).mean(), 
              "\nMean DICE on validation:", val_mean_iou)
        
    return loss_history, train_history, val_history

def plot_model_history(model_name,
                    train_history, val_history, 
                    num_epochs):

    x = np.arange(num_epochs)

    fig = plt.figure(figsize=(10, 6))
    plt.plot(x, train_history, label='train dice', lw=3, c="springgreen")
    plt.plot(x, val_history, label='validation dice', lw=3, c="deeppink")

    plt.title(f"{model_name}", fontsize=15)
    plt.legend(fontsize=12)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("DICE", fontsize=15)

    fn = str(int(time.time())) + ".png"
    plt.savefig(f'{model_name}/{model_name}_dice.png', bbox_inches='tight')
    plt.show()

def viz_pred_output(model, loader, idx, test_dataset, device="mps", threshold=0.3):
    valloss = 0
    
    with torch.no_grad():

#         for i_step, (data, target) in enumerate(loader):
        target = torch.tensor(test_dataset[idx][1])
        data = torch.tensor(test_dataset[idx][0])

        data = data.to(device).unsqueeze(0)
        target = target.to(device).unsqueeze(0)

        outputs = model(data)

        out_cut = np.copy(outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < threshold)] = 0.0
        out_cut[np.nonzero(out_cut >= threshold)] = 1.0

        f, axarr = plt.subplots(1,2)
#             axarr[0,0].imshow(image_datas[0])
#             axarr[0,1].imshow(image_datas[1])

        targ = target.data.cpu().numpy()[0][0]
        target_img = cv2.merge((targ,targ,targ))
        axarr[0].imshow(target_img)

        op = out_cut[0][0]
        axarr[1].imshow(op)








# <font color='red'>Models</font>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlockWithAttention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlockWithAttention, self).__init__()

        # Convolutional layers for the main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        # Batch normalization for the main path
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),  # Adjust channels for attention
            nn.Sigmoid()  # Sigmoid activation for attention map
        )

    def forward(self, x):
        # Main path
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # Attention mechanism
        attention_map = self.attention(x)
        out = out * attention_map

        # Skip connection
        out += residual

        return out


# 1. Attention Mechanism

In [None]:
# #attention_mechanism

# import torch
# import torch.nn as nn

# class AttentionBlock(nn.Module):
#     def __init__(self, f_g, f_l, f_int):
#         super().__init__()
        
#         self.w_g = nn.Sequential(
#                                 nn.Conv2d(f_g, f_int,
#                                          kernel_size=1, stride=1,
#                                          padding=0, bias=True),
#                                 nn.BatchNorm2d(f_int)
#         )
        
#         self.w_x = nn.Sequential(
#                                 nn.Conv2d(f_l, f_int,
#                                          kernel_size=1, stride=1,
#                                          padding=0, bias=True),
#                                 nn.BatchNorm2d(f_int)
#         )
        
#         self.psi = nn.Sequential(
#                                 nn.Conv2d(f_int, 1,
#                                          kernel_size=1, stride=1,
#                                          padding=0,  bias=True),
#                                 nn.BatchNorm2d(1),
#                                 nn.Sigmoid(),
#         )
        
#         self.relu = nn.ReLU(inplace=True)
        
#     def forward(self, g, x):
#         g1 = self.w_g(g)
#         x1 = self.w_x(x)
#         psi = self.relu(g1+x1)
#         psi = self.psi(psi)
        
#         return psi*x
    




# 2. UNet with Attention Mechanism

In [None]:
# #UNet_with_Attention_Mechanism

# import torch
# import torch.nn as nn
# from torch.nn import init
# import torch.nn.functional as F
# from . import attention

# import torch
# import torch.nn as nn
# from torch.nn import init

# class ConvBlock(nn.Module):
#     def __init__(self, ch_in, ch_out):
#         super().__init__()
#         self.conv = nn.Sequential(
#                                   nn.Conv2d(ch_in, ch_out,
#                                             kernel_size=3, stride=1,
#                                             padding=1, bias=True),
#                                   nn.BatchNorm2d(ch_out),
#                                   nn.ReLU(inplace=True),
#                                   nn.Conv2d(ch_out, ch_out,
#                                             kernel_size=3, stride=1,
#                                             padding=1, bias=True),
#                                   nn.BatchNorm2d(ch_out),
#                                   nn.ReLU(inplace=True),
#         )
        
#     def forward(self, x):
#         x = self.conv(x)
#         return x

# class UpConvBlock(nn.Module):
#     def __init__(self, ch_in, ch_out):
#         super().__init__()
#         self.up = nn.Sequential(
#                                 nn.Upsample(scale_factor=2),
#                                 nn.Conv2d(ch_in, ch_out,
#                                          kernel_size=3,stride=1,
#                                          padding=1, bias=True),
#                                 nn.BatchNorm2d(ch_out),
#                                 nn.ReLU(inplace=True),
#         )
        
#     def forward(self, x):
#         x = x = self.up(x)
#         return x

# class AttentionUNet(nn.Module):
#     def __init__(self, n_classes=1, in_channel=3, out_channel=1):
#         super().__init__() 
        
#         self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
#         self.conv1 = ConvBlock(ch_in=in_channel, ch_out=64)
#         self.conv2 = ConvBlock(ch_in=64, ch_out=128)
#         self.conv3 = ConvBlock(ch_in=128, ch_out=256)
#         self.conv4 = ConvBlock(ch_in=256, ch_out=512)
#         self.conv5 = ConvBlock(ch_in=512, ch_out=1024)
        
#         self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
#         self.att5 = attention.AttentionBlock(f_g=512, f_l=512, f_int=256)
#         self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)
        
#         self.up4 = UpConvBlock(ch_in=512, ch_out=256)
#         self.att4 = attention.AttentionBlock(f_g=256, f_l=256, f_int=128)
#         self.upconv4 = ConvBlock(ch_in=512, ch_out=256)
        
#         self.up3 = UpConvBlock(ch_in=256, ch_out=128)
#         self.att3 = attention.AttentionBlock(f_g=128, f_l=128, f_int=64)
#         self.upconv3 = ConvBlock(ch_in=256, ch_out=128)
        
#         self.up2 = UpConvBlock(ch_in=128, ch_out=64)
#         self.att2 = attention.AttentionBlock(f_g=64, f_l=64, f_int=32)
#         self.upconv2 = ConvBlock(ch_in=128, ch_out=64)
        
#         self.conv_1x1 = nn.Conv2d(64, out_channel,
#                                   kernel_size=1, stride=1, padding=0)
        
#     def forward(self, x):
#         # encoder
#         x1 = self.conv1(x)
        
#         x2 = self.maxpool(x1)
#         x2 = self.conv2(x2)
        
#         x3 = self.maxpool(x2)
#         x3 = self.conv3(x3)
        
#         x4 = self.maxpool(x3)
#         x4 = self.conv4(x4)
        
#         x5 = self.maxpool(x4)
#         x5 = self.conv5(x5)
        
#         # decoder + concat
#         d5 = self.up5(x5)
#         x4 = self.att5(g=d5, x=x4)
#         d5 = torch.concat((x4, d5), dim=1)
#         d5 = self.upconv5(d5)
        
#         d4 = self.up4(d5)
#         x3 = self.att4(g=d4, x=x3)
#         d4 = torch.concat((x3, d4), dim=1)
#         d4 = self.upconv4(d4)
        
#         d3 = self.up3(d4)
#         x2 = self.att3(g=d3, x=x2)
#         d3 = torch.concat((x2, d3), dim=1)
#         d3 = self.upconv3(d3)
        
#         d2 = self.up2(d3)
#         x1 = self.att2(g=d2, x=x1)
#         d2 = torch.concat((x1, d2), dim=1)
#         d2 = self.upconv2(d2)
        
#         d1 = self.conv_1x1(d2)
        
#         return d1
    


# 3. UNet3+

In [None]:
# #UNet3+

# import torch
# import torch.nn as nn
# from torch.nn import init
# import torch.nn.functional as F

# import torch
# import torch.nn as nn
# from torch.nn import init

# def weights_init_normal(m):
#     classname = m.__class__.__name__
#     #print(classname)
#     if classname.find('Conv') != -1:
#         init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find('Linear') != -1:
#         init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find('BatchNorm') != -1:
#         init.normal_(m.weight.data, 1.0, 0.02)
#         init.constant_(m.bias.data, 0.0)


# def weights_init_xavier(m):
#     classname = m.__class__.__name__
#     #print(classname)
#     if classname.find('Conv') != -1:
#         init.xavier_normal_(m.weight.data, gain=1)
#     elif classname.find('Linear') != -1:
#         init.xavier_normal_(m.weight.data, gain=1)
#     elif classname.find('BatchNorm') != -1:
#         init.normal_(m.weight.data, 1.0, 0.02)
#         init.constant_(m.bias.data, 0.0)


# def weights_init_kaiming(m):
#     classname = m.__class__.__name__
#     #print(classname)
#     if classname.find('Conv') != -1:
#         init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
#     elif classname.find('Linear') != -1:
#         init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
#     elif classname.find('BatchNorm') != -1:
#         init.normal_(m.weight.data, 1.0, 0.02)
#         init.constant_(m.bias.data, 0.0)


# def weights_init_orthogonal(m):
#     classname = m.__class__.__name__
#     #print(classname)
#     if classname.find('Conv') != -1:
#         init.orthogonal_(m.weight.data, gain=1)
#     elif classname.find('Linear') != -1:
#         init.orthogonal_(m.weight.data, gain=1)
#     elif classname.find('BatchNorm') != -1:
#         init.normal_(m.weight.data, 1.0, 0.02)
#         init.constant_(m.bias.data, 0.0)


# def init_weights(net, init_type='normal'):
#     #print('initialization method [%s]' % init_type)
#     if init_type == 'normal':
#         net.apply(weights_init_normal)
#     elif init_type == 'xavier':
#         net.apply(weights_init_xavier)
#     elif init_type == 'kaiming':
#         net.apply(weights_init_kaiming)
#     elif init_type == 'orthogonal':
#         net.apply(weights_init_orthogonal)
#     else:
#         raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

# class unetConv2(nn.Module):
#     def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
#         super(unetConv2, self).__init__()
#         self.n = n
#         self.ks = ks
#         self.stride = stride
#         self.padding = padding
#         s = stride
#         p = padding
#         if is_batchnorm:
#             for i in range(1, n + 1):
#                 conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
#                                      nn.BatchNorm2d(out_size),
#                                      nn.ReLU(inplace=True), )
#                 setattr(self, 'conv%d' % i, conv)
#                 in_size = out_size

#         else:
#             for i in range(1, n + 1):
#                 conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
#                                      nn.ReLU(inplace=True), )
#                 setattr(self, 'conv%d' % i, conv)
#                 in_size = out_size

#         # initialise the blocks
#         for m in self.children():
#             init_weights(m, init_type='kaiming')

#     def forward(self, inputs):
#         x = inputs
#         for i in range(1, self.n + 1):
#             conv = getattr(self, 'conv%d' % i)
#             x = conv(x)

#         return x

# class UNet3Plus(nn.Module):
#     def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4,
#                  is_deconv=True, is_batchnorm=True):
#         super(UNet3Plus, self).__init__()
#         self.n_channels = n_channels
#         self.n_classes = n_classes
#         self.bilinear = bilinear
#         self.feature_scale = feature_scale
#         self.is_deconv = is_deconv
#         self.is_batchnorm = is_batchnorm
#         filters = [64, 128, 256, 512, 1024]

#         ## -------------Encoder--------------
#         self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm)
#         self.maxpool1 = nn.MaxPool2d(kernel_size=2)

#         self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
#         self.maxpool2 = nn.MaxPool2d(kernel_size=2)

#         self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
#         self.maxpool3 = nn.MaxPool2d(kernel_size=2)

#         self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
#         self.maxpool4 = nn.MaxPool2d(kernel_size=2)

#         self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)

#         ## -------------Decoder--------------
#         self.CatChannels = filters[0]
#         self.CatBlocks = 5
#         self.UpChannels = self.CatChannels * self.CatBlocks

#         '''stage 4d'''
#         # h1->320*320, hd4->40*40, Pooling 8 times
#         self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
#         self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
#         self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h1_PT_hd4_relu = nn.ReLU(inplace=True)

#         # h2->160*160, hd4->40*40, Pooling 4 times
#         self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
#         self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
#         self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h2_PT_hd4_relu = nn.ReLU(inplace=True)

#         # h3->80*80, hd4->40*40, Pooling 2 times
#         self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
#         self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
#         self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h3_PT_hd4_relu = nn.ReLU(inplace=True)

#         # h4->40*40, hd4->40*40, Concatenation
#         self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
#         self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)

#         # hd5->20*20, hd4->40*40, Upsample 2 times
#         self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
#         self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
#         self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)

#         # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
#         self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
#         self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
#         self.relu4d_1 = nn.ReLU(inplace=True)

#         '''stage 3d'''
#         # h1->320*320, hd3->80*80, Pooling 4 times
#         self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
#         self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
#         self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h1_PT_hd3_relu = nn.ReLU(inplace=True)

#         # h2->160*160, hd3->80*80, Pooling 2 times
#         self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
#         self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
#         self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h2_PT_hd3_relu = nn.ReLU(inplace=True)

#         # h3->80*80, hd3->80*80, Concatenation
#         self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
#         self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)

#         # hd4->40*40, hd4->80*80, Upsample 2 times
#         self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
#         self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
#         self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)

#         # hd5->20*20, hd4->80*80, Upsample 4 times
#         self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
#         self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
#         self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)

#         # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
#         self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
#         self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
#         self.relu3d_1 = nn.ReLU(inplace=True)

#         '''stage 2d '''
#         # h1->320*320, hd2->160*160, Pooling 2 times
#         self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
#         self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
#         self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h1_PT_hd2_relu = nn.ReLU(inplace=True)

#         # h2->160*160, hd2->160*160, Concatenation
#         self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
#         self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)

#         # hd3->80*80, hd2->160*160, Upsample 2 times
#         self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
#         self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
#         self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)

#         # hd4->40*40, hd2->160*160, Upsample 4 times
#         self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
#         self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
#         self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)

#         # hd5->20*20, hd2->160*160, Upsample 8 times
#         self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
#         self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
#         self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)

#         # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
#         self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
#         self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
#         self.relu2d_1 = nn.ReLU(inplace=True)

#         '''stage 1d'''
#         # h1->320*320, hd1->320*320, Concatenation
#         self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
#         self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
#         self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)

#         # hd2->160*160, hd1->320*320, Upsample 2 times
#         self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
#         self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
#         self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)

#         # hd3->80*80, hd1->320*320, Upsample 4 times
#         self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
#         self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
#         self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)

#         # hd4->40*40, hd1->320*320, Upsample 8 times
#         self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
#         self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
#         self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)

#         # hd5->20*20, hd1->320*320, Upsample 16 times
#         self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14
#         self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
#         self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
#         self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)

#         # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
#         self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
#         self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
#         self.relu1d_1 = nn.ReLU(inplace=True)

#         # output
#         self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)

#         # initialise weights
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 init_weights(m, init_type='kaiming')
#             elif isinstance(m, nn.BatchNorm2d):
#                 init_weights(m, init_type='kaiming')


#     def forward(self, inputs):
#         ## -------------Encoder-------------
#         h1 = self.conv1(inputs)  # h1->320*320*64

#         h2 = self.maxpool1(h1)
#         h2 = self.conv2(h2)  # h2->160*160*128

#         h3 = self.maxpool2(h2)
#         h3 = self.conv3(h3)  # h3->80*80*256

#         h4 = self.maxpool3(h3)
#         h4 = self.conv4(h4)  # h4->40*40*512

#         h5 = self.maxpool4(h4)
#         hd5 = self.conv5(h5)  # h5->20*20*1024

#         ## -------------Decoder-------------
#         h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
#         h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
#         h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
#         h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
#         hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
#         hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels

#         h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
#         h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
#         h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
#         hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
#         hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
#         hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels

#         h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
#         h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
#         hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
#         hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
#         hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
#         hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels

#         h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
#         hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
#         hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
#         hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
#         hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
#         hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels

#         d1 = self.outconv1(hd1)  # d1->320*320*n_classes
#         return F.sigmoid(d1)







# 4. UNet3+ with Residual Attention Mechanism (UNet_3Plus_ResAttn)

In [17]:
#UNet3+_with_Attention_Mechanism

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
# from . import attention

# class ConvBlock(nn.Module):
#     def __init__(self, ch_in, ch_out):
#         super().__init__()
#         self.conv = nn.Sequential(
#                                   nn.Conv2d(ch_in, ch_out,
#                                             kernel_size=3, stride=1,
#                                             padding=1, bias=True),
#                                   nn.BatchNorm2d(ch_out),
#                                   nn.ReLU(inplace=True),
#                                   nn.Conv2d(ch_out, ch_out,
#                                             kernel_size=3, stride=1,
#                                             padding=1, bias=True),
#                                   nn.BatchNorm2d(ch_out),
#                                   nn.ReLU(inplace=True),
#         )
        
#     def forward(self, x):
#         x = self.conv(x)
#         return x

# class UpConvBlock(nn.Module):
#     def __init__(self, ch_in, ch_out):
#         super().__init__()
#         self.up = nn.Sequential(
#                                 nn.Upsample(scale_factor=2),
#                                 nn.Conv2d(ch_in, ch_out,
#                                          kernel_size=3,stride=1,
#                                          padding=1, bias=True),
#                                 nn.BatchNorm2d(ch_out),
#                                 nn.ReLU(inplace=True),
#         )
        
#     def forward(self, x):
#         x = x = self.up(x)
#         return x

class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     SwitchNorm2d(out_size),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n + 1):
            conv = getattr(self, 'conv%d' % i)
            x = conv(x)

        return x


def weights_init_normal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('SwitchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal_(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.xavier_normal_(m.weight.data, gain=1)
    elif classname.find('SwitchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('SwitchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.orthogonal_(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.orthogonal_(m.weight.data, gain=1)
    elif classname.find('SwitchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def init_weights(net, init_type='normal'):
    #print('initialization method [%s]' % init_type)
    if init_type == 'normal':
        net.apply(weights_init_normal)
    elif init_type == 'xavier':
        net.apply(weights_init_xavier)
    elif init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)



class UNet_3Plus_ResAttn(nn.Module):

    def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True):
        super(UNet_3Plus_ResAttn, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]

        ## -------------Encoder--------------
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)

        ## -------------Decoder--------------
        self.CatChannels = filters[0]
        self.CatBlocks = 5
        self.UpChannels = self.CatChannels * self.CatBlocks

        '''stage 4d'''
        # h1->320*320, hd4->40*40, Pooling 8 times
        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
        self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd4_bn = SwitchNorm2d(self.CatChannels)
        self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
        # self.att4 = attention.AttentionBlock(f_g=256, f_l=64, f_int=256)
        self.att4 = ResidualBlockWithAttention(256,256)

        # h2->160*160, hd4->40*40, Pooling 4 times
        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
        self.h2_PT_hd4_bn = SwitchNorm2d(self.CatChannels)
        self.h2_PT_hd4_relu = nn.ReLU(inplace=True)
        

        # h3->80*80, hd4->40*40, Pooling 2 times
        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
        self.h3_PT_hd4_bn = SwitchNorm2d(self.CatChannels)
        self.h3_PT_hd4_relu = nn.ReLU(inplace=True)

        # h4->40*40, hd4->40*40, Concatenation
        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
        self.h4_Cat_hd4_bn = SwitchNorm2d(self.CatChannels)
        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd4->40*40, Upsample 2 times
        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd4_bn = SwitchNorm2d(self.CatChannels)
        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
#         self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
# modified
#         self.conv4d_1 = nn.Conv2d(128, self.UpChannels, 3, padding=1)  # 16 # if only 2 inputs concatenated(each input 64 size) 
        self.conv4d_1 = nn.Conv2d(384, self.UpChannels, 3, padding=1)  # 16 (6*64 = 384)
    
        self.bn4d_1 = SwitchNorm2d(self.UpChannels)
        self.relu4d_1 = nn.ReLU(inplace=True)

        '''stage 3d'''
        # self.att3 = attention.AttentionBlock(f_g=256, f_l=64, f_int=256)
        self.att3 = ResidualBlockWithAttention(256,256)

        # h1->320*320, hd3->80*80, Pooling 4 times
        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd3_bn = SwitchNorm2d(self.CatChannels)
        self.h1_PT_hd3_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd3->80*80, Pooling 2 times
        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
        self.h2_PT_hd3_bn = SwitchNorm2d(self.CatChannels)
        self.h2_PT_hd3_relu = nn.ReLU(inplace=True)

        # h3->80*80, hd3->80*80, Concatenation
        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
        self.h3_Cat_hd3_bn = SwitchNorm2d(self.CatChannels)
        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd4->80*80, Upsample 2 times
        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd3_bn = SwitchNorm2d(self.CatChannels)
        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd4->80*80, Upsample 4 times
        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd3_bn = SwitchNorm2d(self.CatChannels)
        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
        self.conv3d_1 = nn.Conv2d(384, self.UpChannels, 3, padding=1)  # 16
        self.bn3d_1 = SwitchNorm2d(self.UpChannels)
        self.relu3d_1 = nn.ReLU(inplace=True)

        '''stage 2d '''
        # self.att2 = attention.AttentionBlock(f_g=256, f_l=64, f_int=256)
        self.att2 = ResidualBlockWithAttention(256,256)
        
        # h1->320*320, hd2->160*160, Pooling 2 times
        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd2_bn = SwitchNorm2d(self.CatChannels)
        self.h1_PT_hd2_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd2->160*160, Concatenation
        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
        self.h2_Cat_hd2_bn = SwitchNorm2d(self.CatChannels)
        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)

        # hd3->80*80, hd2->160*160, Upsample 2 times
        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd3_UT_hd2_bn = SwitchNorm2d(self.CatChannels)
        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd2->160*160, Upsample 4 times
        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd2_bn = SwitchNorm2d(self.CatChannels)
        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd2->160*160, Upsample 8 times
        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd2_bn = SwitchNorm2d(self.CatChannels)
        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
        self.conv2d_1 = nn.Conv2d(384, self.UpChannels, 3, padding=1)  # 16
        self.bn2d_1 = SwitchNorm2d(self.UpChannels)
        self.relu2d_1 = nn.ReLU(inplace=True)

        '''stage 1d'''
        # self.att1 = attention.AttentionBlock(f_g=256, f_l=64, f_int=256)
        self.att1 = ResidualBlockWithAttention(256,256)

        # h1->320*320, hd1->320*320, Concatenation
        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_Cat_hd1_bn = SwitchNorm2d(self.CatChannels)
        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)

        # hd2->160*160, hd1->320*320, Upsample 2 times
        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd2_UT_hd1_bn = SwitchNorm2d(self.CatChannels)
        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd3->80*80, hd1->320*320, Upsample 4 times
        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd3_UT_hd1_bn = SwitchNorm2d(self.CatChannels)
        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd1->320*320, Upsample 8 times
        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd1_bn = SwitchNorm2d(self.CatChannels)
        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd1->320*320, Upsample 16 times
        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14
        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd1_bn = SwitchNorm2d(self.CatChannels)
        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)

        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
        self.conv1d_1 = nn.Conv2d(384, self.UpChannels, 3, padding=1)  # 16
        self.bn1d_1 = SwitchNorm2d(self.UpChannels)
        self.relu1d_1 = nn.ReLU(inplace=True)

        # output
        self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, SwitchNorm2d):
                init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        ## -------------Encoder-------------
        h1 = self.conv1(inputs)  # h1->320*320*64

        h2 = self.maxpool1(h1)
        h2 = self.conv2(h2)  # h2->160*160*128

        h3 = self.maxpool2(h2)
        h3 = self.conv3(h3)  # h3->80*80*256

        h4 = self.maxpool3(h3)
        h4 = self.conv4(h4)  # h4->40*40*512

        h5 = self.maxpool4(h4)
        hd5 = self.conv5(h5)  # h5->20*20*1024

        ## -------------Decoder-------------
        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
# #         modified
        temp_cat4 = torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4), 1)
        outatt4 = self.att4(g=temp_cat4, x=hd5_UT_hd4)
#         print("outatt4", outatt4.shape)
#         print("hd5_UT_hd4", hd5_UT_hd4.shape)
        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
            torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, outatt4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels
#         hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
#             torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels
#         print("h4_Cat_hd4", h4_Cat_hd4.shape)
#         print("hd5_UT_hd4", hd5_UT_hd4.shape)
#         print("cat shape", torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1).shape)
#         print("hd4 shape", hd4.shape)
        
# h4_Cat_hd4 torch.Size([26, 64, 16, 16])
# hd5_UT_hd4 torch.Size([26, 64, 16, 16])
# cat shape torch.Size([26, 320, 16, 16])
# hd4 shape torch.Size([26, 320, 16, 16])

        
        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
        temp_cat3 = torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd5_UT_hd3), 1)
        outatt3 = self.att3(g=temp_cat3, x=hd4_UT_hd3)
#         print("outatt4", outatt4.shape)
#         print("hd5_UT_hd4", hd5_UT_hd4.shape)
        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
            torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, outatt3, hd5_UT_hd3), 1))))
#         hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
#             torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels

        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
        temp_cat2 = torch.cat((h1_PT_hd2, h2_Cat_hd2, hd5_UT_hd2, hd4_UT_hd2), 1)
        outatt2 = self.att2(g=temp_cat2, x=hd3_UT_hd2)
#         hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(
#             torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels
        hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, outatt2, hd5_UT_hd2), 1))))

        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
#         hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
#             torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels
        temp_cat1 = torch.cat((h1_Cat_hd1, hd5_UT_hd1, hd3_UT_hd1, hd4_UT_hd1), 1)
        outatt1 = self.att1(g=temp_cat1, x=hd2_UT_hd1)
        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
            torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, outatt1, hd5_UT_hd1), 1))))
        
        d1 = self.outconv1(hd1)  # d1->320*320*n_classes
        return F.sigmoid(d1)

# 5. Residual Attention Mechanism Network

In [None]:
# ##Residual_Attention_Mechanism_Network

# import torch
# import torch.nn as nn
# from torch.nn import init
# import functools
# from torch.autograd import Variable
# import numpy as np

# from .basic_layers import ResidualBlock


# class AttentionModule_pre(nn.Module):

#     def __init__(self, in_channels, out_channels, size1, size2, size3):
#         super(AttentionModule_pre, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax2_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax3_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)

#         self.softmax4_blocks = ResidualBlock(in_channels, out_channels)

#         self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)

#         self.softmax5_blocks = ResidualBlock(in_channels, out_channels)

#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

#         self.softmax6_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         out_softmax1 = self.softmax1_blocks(out_mpool1)
#         out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
#         out_mpool2 = self.mpool2(out_softmax1)
#         out_softmax2 = self.softmax2_blocks(out_mpool2)
#         out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
#         out_mpool3 = self.mpool3(out_softmax2)
#         out_softmax3 = self.softmax3_blocks(out_mpool3)
#         #
#         out_interp3 = self.interpolation3(out_softmax3)
#         # print(out_skip2_connection.data)
#         # print(out_interp3.data)
#         out = out_interp3 + out_skip2_connection
#         out_softmax4 = self.softmax4_blocks(out)
#         out_interp2 = self.interpolation2(out_softmax4)
#         out = out_interp2 + out_skip1_connection
#         out_softmax5 = self.softmax5_blocks(out)
#         out_interp1 = self.interpolation1(out_softmax5)
#         out_softmax6 = self.softmax6_blocks(out_interp1)
#         out = (1 + out_softmax6) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage0(nn.Module):
#     # input size is 112*112
#     def __init__(self, in_channels, out_channels, size1=(112, 112), size2=(56, 56), size3=(28, 28), size4=(14, 14)):
#         super(AttentionModule_stage0, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         # 56*56
#         self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         # 28*28
#         self.softmax2_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         # 14*14
#         self.softmax3_blocks = ResidualBlock(in_channels, out_channels)
#         self.skip3_connection_residual_block = ResidualBlock(in_channels, out_channels)
#         self.mpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         # 7*7
#         self.softmax4_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )
#         self.interpolation4 = nn.UpsamplingBilinear2d(size=size4)
#         self.softmax5_blocks = ResidualBlock(in_channels, out_channels)
#         self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
#         self.softmax6_blocks = ResidualBlock(in_channels, out_channels)
#         self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
#         self.softmax7_blocks = ResidualBlock(in_channels, out_channels)
#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

#         self.softmax8_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias = False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels , kernel_size=1, stride=1, bias = False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         # 112*112
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         # 56*56
#         out_softmax1 = self.softmax1_blocks(out_mpool1)
#         out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
#         out_mpool2 = self.mpool2(out_softmax1)
#         # 28*28
#         out_softmax2 = self.softmax2_blocks(out_mpool2)
#         out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
#         out_mpool3 = self.mpool3(out_softmax2)
#         # 14*14
#         out_softmax3 = self.softmax3_blocks(out_mpool3)
#         out_skip3_connection = self.skip3_connection_residual_block(out_softmax3)
#         out_mpool4 = self.mpool4(out_softmax3)
#         # 7*7
#         out_softmax4 = self.softmax4_blocks(out_mpool4)
#         out_interp4 = self.interpolation4(out_softmax4) + out_softmax3
#         out = out_interp4 + out_skip3_connection
#         out_softmax5 = self.softmax5_blocks(out)
#         out_interp3 = self.interpolation3(out_softmax5) + out_softmax2
#         # print(out_skip2_connection.data)
#         # print(out_interp3.data)
#         out = out_interp3 + out_skip2_connection
#         out_softmax6 = self.softmax6_blocks(out)
#         out_interp2 = self.interpolation2(out_softmax6) + out_softmax1
#         out = out_interp2 + out_skip1_connection
#         out_softmax7 = self.softmax7_blocks(out)
#         out_interp1 = self.interpolation1(out_softmax7) + out_trunk
#         out_softmax8 = self.softmax8_blocks(out_interp1)
#         out = (1 + out_softmax8) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage1(nn.Module):
#     # input size is 56*56
#     def __init__(self, in_channels, out_channels, size1=(56, 56), size2=(28, 28), size3=(14, 14)):
#         super(AttentionModule_stage1, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax2_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax3_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)

#         self.softmax4_blocks = ResidualBlock(in_channels, out_channels)

#         self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)

#         self.softmax5_blocks = ResidualBlock(in_channels, out_channels)

#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

#         self.softmax6_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         out_softmax1 = self.softmax1_blocks(out_mpool1)
#         out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
#         out_mpool2 = self.mpool2(out_softmax1)
#         out_softmax2 = self.softmax2_blocks(out_mpool2)
#         out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
#         out_mpool3 = self.mpool3(out_softmax2)
#         out_softmax3 = self.softmax3_blocks(out_mpool3)
#         #
#         out_interp3 = self.interpolation3(out_softmax3) + out_softmax2
#         # print(out_skip2_connection.data)
#         # print(out_interp3.data)
#         out = out_interp3 + out_skip2_connection
#         out_softmax4 = self.softmax4_blocks(out)
#         out_interp2 = self.interpolation2(out_softmax4) + out_softmax1
#         out = out_interp2 + out_skip1_connection
#         out_softmax5 = self.softmax5_blocks(out)
#         out_interp1 = self.interpolation1(out_softmax5) + out_trunk
#         out_softmax6 = self.softmax6_blocks(out_interp1)
#         out = (1 + out_softmax6) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage2(nn.Module):
#     # input image size is 28*28
#     def __init__(self, in_channels, out_channels, size1=(28, 28), size2=(14, 14)):
#         super(AttentionModule_stage2, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

#         self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.softmax2_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)

#         self.softmax3_blocks = ResidualBlock(in_channels, out_channels)

#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

#         self.softmax4_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         out_softmax1 = self.softmax1_blocks(out_mpool1)
#         out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
#         out_mpool2 = self.mpool2(out_softmax1)
#         out_softmax2 = self.softmax2_blocks(out_mpool2)

#         out_interp2 = self.interpolation2(out_softmax2) + out_softmax1
#         # print(out_skip2_connection.data)
#         # print(out_interp3.data)
#         out = out_interp2 + out_skip1_connection
#         out_softmax3 = self.softmax3_blocks(out)
#         out_interp1 = self.interpolation1(out_softmax3) + out_trunk
#         out_softmax4 = self.softmax4_blocks(out_interp1)
#         out = (1 + out_softmax4) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage3(nn.Module):
#     # input image size is 14*14
#     def __init__(self, in_channels, out_channels, size1=(14, 14)):
#         super(AttentionModule_stage3, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         self.softmax1_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

#         self.softmax2_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         out_softmax1 = self.softmax1_blocks(out_mpool1)

#         out_interp1 = self.interpolation1(out_softmax1) + out_trunk
#         out_softmax2 = self.softmax2_blocks(out_interp1)
#         out = (1 + out_softmax2) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage1_cifar(nn.Module):
#     # input size is 16*16
#     def __init__(self, in_channels, out_channels, size1=(16, 16), size2=(8, 8)):
#         super(AttentionModule_stage1_cifar, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 8*8

#         self.down_residual_blocks1 = ResidualBlock(in_channels, out_channels)

#         self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

#         self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 4*4

#         self.middle_2r_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size2)  # 8*8

#         self.up_residual_blocks1 = ResidualBlock(in_channels, out_channels)

#         self.interpolation2 = nn.UpsamplingBilinear2d(size=size1)  # 16*16

#         self.conv1_1_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias = False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         out_down_residual_blocks1 = self.down_residual_blocks1(out_mpool1)
#         out_skip1_connection = self.skip1_connection_residual_block(out_down_residual_blocks1)
#         out_mpool2 = self.mpool2(out_down_residual_blocks1)
#         out_middle_2r_blocks = self.middle_2r_blocks(out_mpool2)
#         #
#         out_interp = self.interpolation1(out_middle_2r_blocks) + out_down_residual_blocks1
#         # print(out_skip2_connection.data)
#         # print(out_interp3.data)
#         out = out_interp + out_skip1_connection
#         out_up_residual_blocks1 = self.up_residual_blocks1(out)
#         out_interp2 = self.interpolation2(out_up_residual_blocks1) + out_trunk
#         out_conv1_1_blocks = self.conv1_1_blocks(out_interp2)
#         out = (1 + out_conv1_1_blocks) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage2_cifar(nn.Module):
#     # input size is 8*8
#     def __init__(self, in_channels, out_channels, size=(8, 8)):
#         super(AttentionModule_stage2_cifar, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 4*4

#         self.middle_2r_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.interpolation1 = nn.UpsamplingBilinear2d(size=size)  # 8*8

#         self.conv1_1_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias = False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_mpool1 = self.mpool1(x)
#         out_middle_2r_blocks = self.middle_2r_blocks(out_mpool1)
#         #
#         out_interp = self.interpolation1(out_middle_2r_blocks) + out_trunk
#         # print(out_skip2_connection.data)
#         # print(out_interp3.data)
#         out_conv1_1_blocks = self.conv1_1_blocks(out_interp)
#         out = (1 + out_conv1_1_blocks) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last


# class AttentionModule_stage3_cifar(nn.Module):
#     # input size is 4*4
#     def __init__(self, in_channels, out_channels, size=(8, 8)):
#         super(AttentionModule_stage3_cifar, self).__init__()
#         self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

#         self.trunk_branches = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#          )

#         self.middle_2r_blocks = nn.Sequential(
#             ResidualBlock(in_channels, out_channels),
#             ResidualBlock(in_channels, out_channels)
#         )

#         self.conv1_1_blocks = nn.Sequential(
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias = False),
#             nn.Sigmoid()
#         )

#         self.last_blocks = ResidualBlock(in_channels, out_channels)

#     def forward(self, x):
#         x = self.first_residual_blocks(x)
#         out_trunk = self.trunk_branches(x)
#         out_middle_2r_blocks = self.middle_2r_blocks(x)
#         #
#         out_conv1_1_blocks = self.conv1_1_blocks(out_middle_2r_blocks)
#         out = (1 + out_conv1_1_blocks) * out_trunk
#         out_last = self.last_blocks(out)

#         return out_last




# 6. Switch Normalization

In [16]:
##Switch_Normalization

#Switchable Normalization is a normalization technique that is able to learn different normalization operations 
#for different normalization layers in a deep neural network in an end-to-end manner.

import torch
import torch.nn as nn


class SwitchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
        super(SwitchNorm1d, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.using_moving_average = using_moving_average
        self.weight = nn.Parameter(torch.ones(1, num_features))
        self.bias = nn.Parameter(torch.zeros(1, num_features))
        self.mean_weight = nn.Parameter(torch.ones(2))
        self.var_weight = nn.Parameter(torch.ones(2))
        self.register_buffer('running_mean', torch.zeros(1, num_features))
        self.register_buffer('running_var', torch.zeros(1, num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.zero_()
        self.weight.data.fill_(1)
        self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 2:
            raise ValueError('expected 2D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x):
        self._check_input_dim(x)
        mean_ln = x.mean(1, keepdim=True)
        var_ln = x.var(1, keepdim=True)

        if self.training:
            mean_bn = x.mean(0, keepdim=True)
            var_bn = x.var(0, keepdim=True)
            if self.using_moving_average:
                self.running_mean.mul_(self.momentum)
                self.running_mean.add_((1 - self.momentum) * mean_bn.data)
                self.running_var.mul_(self.momentum)
                self.running_var.add_((1 - self.momentum) * var_bn.data)
            else:
                self.running_mean.add_(mean_bn.data)
                self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        softmax = nn.Softmax(0)
        mean_weight = softmax(self.mean_weight)
        var_weight = softmax(self.var_weight)

        mean = mean_weight[0] * mean_ln + mean_weight[1] * mean_bn
        var = var_weight[0] * var_ln + var_weight[1] * var_bn

        x = (x - mean) / (var + self.eps).sqrt()
        return x * self.weight + self.bias

class SwitchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.9, using_moving_average=True, using_bn=True,
                 last_gamma=False):
        super(SwitchNorm2d, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.using_moving_average = using_moving_average
        self.using_bn = using_bn
        self.last_gamma = last_gamma
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        if self.using_bn:
            self.mean_weight = nn.Parameter(torch.ones(3))
            self.var_weight = nn.Parameter(torch.ones(3))
        else:
            self.mean_weight = nn.Parameter(torch.ones(2))
            self.var_weight = nn.Parameter(torch.ones(2))
        if self.using_bn:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
            self.register_buffer('running_var', torch.zeros(1, num_features, 1))

        self.reset_parameters()

    def reset_parameters(self):
        if self.using_bn:
            self.running_mean.zero_()
            self.running_var.zero_()
        if self.last_gamma:
            self.weight.data.fill_(0)
        else:
            self.weight.data.fill_(1)
        self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x):
        self._check_input_dim(x)
        N, C, H, W = x.size()
        x = x.view(N, C, -1)
        mean_in = x.mean(-1, keepdim=True)
        var_in = x.var(-1, keepdim=True)

        mean_ln = mean_in.mean(1, keepdim=True)
        temp = var_in + mean_in ** 2
        var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2

        if self.using_bn:
            if self.training:
                mean_bn = mean_in.mean(0, keepdim=True)
                var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
                if self.using_moving_average:
                    self.running_mean.mul_(self.momentum)
                    self.running_mean.add_((1 - self.momentum) * mean_bn.data)
                    self.running_var.mul_(self.momentum)
                    self.running_var.add_((1 - self.momentum) * var_bn.data)
                else:
                    self.running_mean.add_(mean_bn.data)
                    self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
            else:
                mean_bn = torch.autograd.Variable(self.running_mean)
                var_bn = torch.autograd.Variable(self.running_var)

        softmax = nn.Softmax(0)
        mean_weight = softmax(self.mean_weight)
        var_weight = softmax(self.var_weight)

        if self.using_bn:
            mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
            var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
        else:
            mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
            var = var_weight[0] * var_in + var_weight[1] * var_ln

        x = (x-mean) / (var+self.eps).sqrt()
        x = x.view(N, C, H, W)
        return x * self.weight + self.bias


class SwitchNorm3d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True, using_bn=True,
                 last_gamma=False):
        super(SwitchNorm3d, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.using_moving_average = using_moving_average
        self.using_bn = using_bn
        self.last_gamma = last_gamma
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1, 1))
        if self.using_bn:
            self.mean_weight = nn.Parameter(torch.ones(3))
            self.var_weight = nn.Parameter(torch.ones(3))
        else:
            self.mean_weight = nn.Parameter(torch.ones(2))
            self.var_weight = nn.Parameter(torch.ones(2))
        if self.using_bn:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
            self.register_buffer('running_var', torch.zeros(1, num_features, 1))

        self.reset_parameters()

    def reset_parameters(self):
        if self.using_bn:
            self.running_mean.zero_()
            self.running_var.zero_()
        if self.last_gamma:
            self.weight.data.fill_(0)
        else:
            self.weight.data.fill_(1)
        self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x):
        self._check_input_dim(x)
        N, C, D, H, W = x.size()
        x = x.view(N, C, -1)
        mean_in = x.mean(-1, keepdim=True)
        var_in = x.var(-1, keepdim=True)

        mean_ln = mean_in.mean(1, keepdim=True)
        temp = var_in + mean_in ** 2
        var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2

        if self.using_bn:
            if self.training:
                mean_bn = mean_in.mean(0, keepdim=True)
                var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
                if self.using_moving_average:
                    self.running_mean.mul_(self.momentum)
                    self.running_mean.add_((1 - self.momentum) * mean_bn.data)
                    self.running_var.mul_(self.momentum)
                    self.running_var.add_((1 - self.momentum) * var_bn.data)
                else:
                    self.running_mean.add_(mean_bn.data)
                    self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
            else:
                mean_bn = torch.autograd.Variable(self.running_mean)
                var_bn = torch.autograd.Variable(self.running_var)

        softmax = nn.Softmax(0)
        mean_weight = softmax(self.mean_weight)
        var_weight = softmax(self.var_weight)

        if self.using_bn:
            mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
            var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
        else:
            mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
            var = var_weight[0] * var_in + var_weight[1] * var_ln

        x = (x - mean) / (var + self.eps).sqrt()
        x = x.view(N, C, D, H, W)
        return x * self.weight + self.bias






# Metrics

In [15]:
##Metrics
#Dice_Metrics
import torch.nn as nn
import torch
import numpy as np

def dice_coef_metric(inputs, target): # ORIGINAL
    intersection = 2.0 * (target*inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0 
    return intersection/union


class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
#         inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

def compute_iou(model, loader, device:str, threshold=0.3):
    valloss = 0
    
    with torch.no_grad():

        for i_step, (data, target) in enumerate(loader):
            
            data = data.to(device)
            target = target.to(device)
            
            outputs = model(data)

            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
            picloss = dice_coef_metric(out_cut, target.data.cpu().numpy())
            valloss += picloss

    return valloss / i_step


In [None]:
# BASE_PATH= "C:\\Users\\jervi\\Downloads\\kaggle_3m"

In [18]:
data_location= "C:\\Users\\jervi\\Downloads\\kaggle_3m"

In [None]:
# bucket = "heartnet-sagemaker-project/dataset/1/"
# data_key = "train.csv"
# data_location = "s3://{}/{}".format(bucket,data_key)

# pd.read_csv(data_location)

In [None]:
# bucket = "heartnet-sagemaker-project/kaggle_3m/"
# data_key = "train.csv"
# data_location = "s3://{}/{}".format(bucket,data_key)

# pd.read_csv(data_location)   

In [19]:
BASE_LEN = len(data_location) + len("\\TCGA_CS_4941_19960909/TCGA_CS_4941_19960909_")
END_LEN = len(".tif")
END_MASK_LEN = len("_mask.tif")

IMG_SIZE = 512

In [None]:
# BASE_LEN = len(BASE_PATH) + len("\\TCGA_CS_4941_19960909/TCGA_CS_4941_19960909_")
# END_LEN = len(".tif")
# END_MASK_LEN = len("_mask.tif")

# IMG_SIZE = 512

# Prepare data

In [20]:
df = get_dataset_dataframe(data_location)

[INFO] This is not a dir --> C:\Users\jervi\Downloads\kaggle_3m\data.csv
[INFO] This is not a dir --> C:\Users\jervi\Downloads\kaggle_3m\README.md


In [None]:
# df = get_dataset_dataframe(BASE_PATH)

In [21]:
df

Unnamed: 0,dir_name,image_path
0,TCGA_CS_4942_19970222,C:\Users\jervi\Downloads\kaggle_3m\TCGA_CS_494...
1,TCGA_CS_4942_19970222,C:\Users\jervi\Downloads\kaggle_3m\TCGA_CS_494...
2,TCGA_CS_4942_19970222,C:\Users\jervi\Downloads\kaggle_3m\TCGA_CS_494...
3,TCGA_CS_4942_19970222,C:\Users\jervi\Downloads\kaggle_3m\TCGA_CS_494...
4,TCGA_CS_4942_19970222,C:\Users\jervi\Downloads\kaggle_3m\TCGA_CS_494...
...,...,...
3013,TCGA_HT_A61A_20000127,C:\Users\jervi\Downloads\kaggle_3m\TCGA_HT_A61...
3014,TCGA_HT_A61A_20000127,C:\Users\jervi\Downloads\kaggle_3m\TCGA_HT_A61...
3015,TCGA_HT_A61A_20000127,C:\Users\jervi\Downloads\kaggle_3m\TCGA_HT_A61...
3016,TCGA_HT_A61A_20000127,C:\Users\jervi\Downloads\kaggle_3m\TCGA_HT_A61...


In [22]:
df.isna().sum()

dir_name      0
image_path    0
dtype: int64

In [23]:
df_imgs = df[~df["image_path"].str.contains("mask")]
df_masks = df[df["image_path"].str.contains("mask")]

In [24]:
df_imgs.iloc[0,1][BASE_LEN: -END_LEN]

'1'

In [25]:
imgs = sorted(df_imgs["image_path"].values, key= lambda x: int((x[BASE_LEN: -END_LEN])))
masks = sorted(df_masks["image_path"].values, key=lambda x: int((x[BASE_LEN: -END_MASK_LEN])))

In [26]:
df_masks.shape

(1526, 2)

In [27]:
# sanity check
idx = random.randint(0, len(imgs)-1)
print(f"This image *{imgs[idx]}*\n Belongs to the mask *{masks[idx]}*")

This image *C:\Users\jervi\Downloads\kaggle_3m\TCGA_FG_7634_20000128\TCGA_FG_7634_20000128_12.tif*
 Belongs to the mask *C:\Users\jervi\Downloads\kaggle_3m\TCGA_DU_8162_19961029\TCGA_DU_8162_19961029_12_mask.tif*


In [28]:
df_imgs.columns

Index(['dir_name', 'image_path'], dtype='object')

In [29]:
df_imgs.dir_name

0       TCGA_CS_4942_19970222
1       TCGA_CS_4942_19970222
3       TCGA_CS_4942_19970222
5       TCGA_CS_4942_19970222
7       TCGA_CS_4942_19970222
                ...          
3006    TCGA_HT_A61A_20000127
3008    TCGA_HT_A61A_20000127
3009    TCGA_HT_A61A_20000127
3011    TCGA_HT_A61A_20000127
3016    TCGA_HT_A61A_20000127
Name: dir_name, Length: 1492, dtype: object

In [30]:
len(imgs),len(masks)

(1492, 1526)

In [36]:
# final dataframe
dff = pd.DataFrame({"patient": df_imgs.dir_name.values,
                   "image_path": imgs,
                   "mask_path": masks})

dff.head()

ValueError: All arrays must be of the same length

In [None]:
dff.iloc[0,1]

In [None]:
dff.iloc[0,2]


In [None]:
dff["diagnosis"] = dff["mask_path"].apply(lambda x: pos_neg_diagnosis(x))

dff.head()

In [None]:
dff.shape


In [None]:
dff.diagnosis.value_counts()

In [None]:
print("Amount of patients: ", len(set(dff.patient)))
print("Amount of records: ", len(dff))

# Data Augmentation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensor

from sklearn.model_selection import train_test_split

In [None]:
import platform
import sklearn as sk
import sys
has_gpu = torch.cuda.is_available()
has_mps = getattr(torch,'has_mps',False)
device = "mps" if getattr(torch,'has_mps',False) \
    else "gpu" if torch.cuda.is_available() else "cpu"

print(f"Python Platform: {platform.platform()}")
print(f"PyTorch Version: {torch.__version__}")
print()
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print(f"Scikit-Learn {sk.__version__}")
print("GPU is", "available" if has_gpu else "NOT AVAILABLE")
print("MPS (Apple Metal) is", "AVAILABLE" if has_mps else "NOT AVAILABLE")
print(f"Target device is {device}")

# Split Data and DataLoaders

In [None]:
train_df, val_df = train_test_split(dff, stratify=dff.diagnosis, test_size=0.1)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

train_df, test_df = train_test_split(train_df, stratify=train_df.diagnosis, test_size=0.12)
train_df = train_df.reset_index(drop=True)

print(f"Train: {train_df.shape} \nVal: {val_df.shape} \nTest: {test_df.shape}")

In [None]:
train_dataset = MRImagingDataset(train_df, transform=augmentData.transform)
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)

val_dataset = MRImagingDataset(val_df, transform=augmentData.transform)
val_dataloader = DataLoader(val_dataset, batch_size=10,  shuffle=True)

test_dataset = MRImagingDataset(test_df, transform=augmentData.transform)
test_dataloader = DataLoader(test_dataset, batch_size=10,shuffle=True)

In [None]:
len(train_dataset)

In [None]:
len(val_dataset)

In [None]:
len(test_dataset)

In [None]:
images, masks = next(iter(train_dataloader))
print(images.shape, masks.shape)

show_aug(images)
show_aug(masks, norm=False)

In [None]:
# check sanity
output = torch.randn(1,3,256,256).to(device)
output.shape

In [None]:
# sanity check
DiceLoss()(F.sigmoid(torch.tensor([0.7, 1., 1.])), 
              torch.tensor([1.,1.,1.]))

# U-Net 3+

In [None]:
unet3p = UNet3Plus(n_classes=1).to(device)


In [None]:
# PATH = "model_unet_3p.pt"

In [None]:
# opt_unet_3p = torch.optim.Adamax(unet3p.parameters(), lr=1e-3)

In [None]:
opt_unet_3p = torch.optim.ASGD(unet3p.parameters(), lr=1e-3, lambd= 0.04, alpha=0.05)

In [None]:
# %%time
num_ep = 30
# try until 30

# aun_lh, aun_th, aun_vh = train_model("Attention UNet", attention_unet, train_dataloader, val_dataloader, DiceLoss(), opt, False, num_ep)
aun_lh, aun_th, aun_vh = train_model("UNet3p", unet3p, train_dataloader, val_dataloader, DiceLoss(), opt_unet_3p, False, num_ep, device=device)

In [None]:
plot_model_history("U-Net 3+", aun_th, aun_vh, num_ep)

In [None]:
plt.plot(range(num_ep), aun_lh)

In [None]:
test_iou = compute_iou(unet3p, test_dataloader,device)
print(f"""U-Net 3+\nMean IoU of the test images - {np.around(test_iou, 2)*100}%""")

In [None]:
test_dataset.df[test_dataset.df.diagnosis==1]

In [None]:
test_dataset[1][1].sum()

In [None]:
idx = 1669
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx = 2032
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx = 1819
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =691
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =2243
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =737
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:

targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =2194
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =459
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =1440
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
idx =2616
plt.imshow(test_dataset.get_image_and_mask(idx)[0].T)

In [None]:
targ, op = viz_pred_output(model,test_dataloader, idx, test_dataset, device)

In [None]:
f"{np.around(dice_coef_metric(op, targ), 2)*100}%"

In [None]:
#==============================

## U-NET 3+ with ResidualBlock With Attention (UNet_3Plus_ResAttn)

In [None]:

unet_3plus_resattn = UNet_3Plus_ResAttn(n_classes=1).to(device)

In [None]:
# Training
# PATH = "model_unet_3p_attn.pt"



In [None]:
opt_unet_3plus_resattn = torch.optim.Adamax(unet_3plus_resattn.parameters(), lr=1e-4)


In [None]:
# ckp_path = "path/to/checkpoint/checkpoint.pt"


In [None]:
# %%time
num_ep = 3
# try until 30

# aun_lh, aun_th, aun_vh = train_model("Attention UNet", attention_unet, train_dataloader, val_dataloader, DiceLoss(), opt, False, num_ep)
aun_lh, aun_th, aun_vh = train_model("UNet_3p_attn", unet3p_attn, train_dataloader, val_dataloader, DiceLoss(), opt_unet_3plus_resattn, False, num_ep, device=device)

In [None]:
plot_model_history(" U-Net 3+ ResAttn", aun_th, aun_vh, num_ep)

In [None]:
plt.plot(range(num_ep), aun_lh)

In [None]:
test_iou = compute_iou(unet_3plus_resattn, test_dataloader, device)
print(f"""U-Net 3+ Residual Attention\nMean IoU of the test images - {np.around(test_iou, 2)*100}%""")

In [None]:
aun_lh_prev, aun_th_prev, aun_vh_prev = aun_lh, aun_th, aun_vh

In [None]:
last_num_ep = num_ep

In [None]:
def save_metadata():
    #Saving initial history
    np.savetxt('metadata/unet_3plus_resattn/aun_lh_prev.txt', aun_lh_prev)
    np.savetxt('metadata/unet_3plus_resattn/aun_th_prev.txt', aun_th_prev)
    np.savetxt('metadata/unet_3plus_resattn/aun_vh_prev.txt', aun_vh_prev)


In [None]:
save_metadata()

In [None]:
# CONTINUE WHERE LEFT OFF / TRAIN MORE:

In [None]:
aun_lh_prev = np.loadtxt('metadata/unet_3plus_resattn/aun_lh_prev.txt')
aun_th_prev = np.loadtxt('metadata/unet_3plus_resattn/aun_th_prev.txt')
aun_vh_prev = np.loadtxt('metadata/unet_3plus_resattn/aun_vh_prev.txt')
last_num_ep = len(aun_lh_prev)

In [None]:
for_x_more_epochs = 10

In [None]:
save_metadata()

In [None]:
plot_model_history("U-Net 3+ with Attention Gates", aun_th_prev, aun_vh_prev, last_num_ep)
# plt.savefig(f'UNet_3p_attn_dice.png', bbox_inches='tight')

In [None]:
plt.plot(range(last_num_ep), aun_lh_prev)
plt.savefig(f'UNet_3p_attn_loss.png', bbox_inches='tight')

In [None]:
test_iou = compute_iou(unet3p_attn, test_dataloader, device)
print(f"""U-Net 3+ Attention\nMean IoU of the test images - {np.around(test_iou, 2)*100}%""")