**Import the Libraries**

In [1]:
#Import the libraries
#from Starscream import TransUNet

import pandas as pd
import numpy as np
import time
import os
import random
import math
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from glob import glob

from IPython import display
from matplotlib.patches import Rectangle

import time
import copy
from collections import defaultdict
import h5py

# Import the machine learning libraries
import sklearn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
from torch.cuda import amp
import sklearn
from sklearn.model_selection import train_test_split, StratifiedGroupKFold
import tensorflow as tf
import segmentation_models_pytorch as smp

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
from einops import rearrange

# For exponential moving average
from ema_pytorch import EMA

# For cool stuff
from colorama import Fore, Back, Style

# Import the monitoring libraries
import wandb
wandb.login()

# Import the warnings, suppress all the warnings messages
import warnings
warnings.filterwarnings('ignore')

# For debug
# For RAM usage
import gc # automatic garbage collector
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["WANDB_NOTEBOOK_NAME"] = "shockwave.ipynb"
# In case we have a run out of memory problem

# Import the augmemtation libraries
import albumentations as A
from albumentations.pytorch import ToTensorV2

# REMEBER TO MOVE THIS BEFORE PUBLISHING!
import sys

2023-08-21 20:33:11.673577: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myecanlee[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
#print(torch.cuda.device_count())

#if torch.cuda.device_count() > 1:
    #print("Using", torch.cuda.device_count(), "GPUs!")
    #model = nn.DataParallel(model)

**Define the Configuration**

In [3]:
class CFG:

    # special settings for wandb
    wandb = True
    comment = 'Transformer-Unet baseline, Efficient Attention'

    verbose = 1

    seed = 42
    fold = 10
    selected_folds = [0,1,2,3,4,5,6,7,8,9]
    image_size = (512,512) # 224

    P1 = 1.0
    P2 = 0.5
    P3 = 1.0

    train_batch_size = 14
    valid_batch_size = 32
    test_batch_size = 32
    drop_reminder = False
    epochs = 50
    train_num_workers = 0
    test_num_workers = 0
    
    anonymous = None
    
    loss = 'dice_loss'
    optimizer = 'AdamW'
    # learning rate scheduler would be chose during the training phase
    lr_scheduler = 'CosineAnnealingLR' # 'ReduceLROnPlateau', 'CosineAnnealingWarmRestarts', 'ExpotentialLR'
    lr_min = 1e-6 # for CosineAnnealingWarmRestarts and CosineAnnealingLR
    T_max = int(epochs*0.7)
    T_0 = 0
    patience = 5 # for ReduceLROnPlateau
    mode = 'min' # for ReduceLROnPlateau
    threshod = 0.0001 # for ReduceLROnPlateau
    gamma = 0.5 # for ExponentialLR
    weight_decay = 1e-5
    
    DEBUG = False
    # segmentation class large bowel, small bowel and stomach(Background)
    num_classes = 1

    # training device
    # this would be trained on a single GPU 4090
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    n_accumulate  = 1 # gradient accumulation steps
    comment = 'TransUNet-Tested version 512-512'
    model_name = 'TransUNet Test Version'
    MAX_PIXEL_VALUE = 255

**Reproducibility**

In [4]:
# Define the reproduviability seed
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # set the torch backends for cudnn
    torch.backends.cudnn.determinitic = True
    torch.backends.cudnn.benchmark = False

    # set the python hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

    print('Everything is set up for the reproducibility.')

In [5]:
# Set reproducibility
seed_everything(CFG.seed)

Everything is set up for the reproducibility.


**Data Augmentation**

**Augmentation Method**

In [None]:
# https://albumentations.ai/docs/getting_started/mask_augmentation/
# There are 3 probability inside the albumentation library
# They are called p1, p2, p3
# P1: the probability of the all the images will be augmented
# P2: the probability of the image will be augmented in a very specific way
# P3: the probability of the image will be augmented by 'OneOf' function

In [6]:
P1 = CFG.P1
P2 = CFG.P2
P3 = CFG.P3

# Mean and validation values from imagenet for normalization
# mean = (0.485, 0.456, 0.406)
# std = (0.229, 0.224, 0.225)
mean = (0.5,)
std = (0.5,)
class AlbumentationsTransform:
    def __init__(self, transforms):
        self.transforms = A.Compose(transforms)

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        augmented = self.transforms(image=image, mask=label)
        image, label = augmented['image'], augmented['mask']
        return {'image': image, 'label': label}

train_transform = AlbumentationsTransform([
        # Resize the image to 224x224 using nearest neighbor interpolation
        A.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
        # A.Resize(224, 224, interpolation = cv2.INTER_NEAREST)

        # Apply horizontal and vertical flips with a probability of PROB_FLIP_ROTATE
        A.HorizontalFlip(p = P2),
        A.VerticalFlip(p = P2),

        # Randomly rotate the image by 90 degrees with a probability of PROB_FLIP_ROTATE
        A.RandomRotate90(p = P2),

        # Transpose the image with a probability of PROB_FLIP_ROTATE
        A.Transpose(p = P2),

        # Apply one of the following distortions with a combined probability of PROB_DISTORT
        A.OneOf([
            A.ElasticTransform(alpha = 120, sigma = 120 * 0.05, alpha_affine = 120 * 0.03, p=0.5),
            A.GridDistortion(p = 0.5),
            A.OpticalDistortion(distort_limit = 2, shift_limit = 0.5, p = P3)
        ], p = P1),
        ToTensorV2()
    ])
        # Resize the image to 256x256 using nearest neighbor interpolation for validation
valid_transform = AlbumentationsTransform([
        A.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
        ToTensorV2()
        # A.Resize(224, 224, interpolation = cv2.INTER_NEAREST)
    ])

import os
import numpy as np
import h5py
from torch.utils.data import Dataset

class Synapse_dataset(Dataset):
    def __init__(self, base_dir, samples=None, list_dir=None, split="train", transform=None):
        self.transform = transform
        self.split = split
        if samples:
            self.sample_list = samples
        else:
            self.sample_list = open(os.path.join(list_dir, split+'.txt')).readlines()
        self.data_dir = base_dir

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        slice_name = self.sample_list[idx].strip('\n')
        if self.split != "test":
            data_path = os.path.join(self.data_dir, slice_name+'.npz')
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            filepath = self.data_dir + "/{}.npy.h5".format(slice_name)
            data = h5py.File(filepath)
            image, label = data['image'][:], data['label'][:]

        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = slice_name
        return sample


In [7]:
from sklearn.model_selection import train_test_split

def prepare_synapse_loaders(base_dir: str, list_dir: str, debug: bool = False) -> tuple:
    """
    Prepare training, validation, and test data loaders.

    Parameters:
    - base_dir: Base directory for the dataset.
    - list_dir: Directory containing the list files.
    - debug: Whether to run in debug mode (optional, default is False).

    Returns:
    - Tuple containing the training, validation, and test data loaders.
    """
    # Load the entire list of training samples
    # all_samples = open(os.path.join(list_dir, 'train.txt')).readlines()

    # Split the samples into training and validation sets (80-20 split)
    # train_samples, valid_samples = train_test_split(all_samples, test_size=0.2, random_state=42)

    # Create datasets using the split samples
    train_dataset = Synapse_dataset(base_dir=base_dir, list_dir=list_dir, split="train", transform=train_transform)
    # valid_dataset = Synapse_dataset(base_dir=base_dir, samples=valid_samples, transform=valid_transform)
    test_dataset = Synapse_dataset(base_dir=base_dir, list_dir=list_dir, split="test", transform=valid_transform)

    # Create data loaders
    train_bs = CFG.train_batch_size if not debug else 4
    # valid_bs = CFG.valid_batch_size if not debug else 4
    test_bs = CFG.valid_batch_size if not debug else 4

    train_loader = DataLoader(train_dataset, batch_size=train_bs,
                              num_workers=0, shuffle=True, pin_memory=True, drop_last=False)
    # valid_loader = DataLoader(valid_dataset, batch_size=valid_bs,
    #                          num_workers=0, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=test_bs,
                             num_workers=0, shuffle=False, pin_memory=True)

    return train_loader,  test_loader

**Get train and test dataloader**

In [8]:
train_loader, test_loader = prepare_synapse_loaders(
    base_dir=r"project_TransUNet/data/Synapse/train_npz", 
    list_dir=r"project_TransUNet/TransUNet/lists/lists_Synapse", 
    debug=False
)

**Visualize Some Images**

In [None]:
'''
batch = next(iter(train_loader))

# Get images and labels from the batch
images = batch['image']
print(images.shape)
labels = batch['label']

# Convert images to a PyTorch tensor
images_tensor = torch.tensor(images)
labels_tensor = torch.tensor(labels)

# Iterate over the images and print them one by one
for i in range(images_tensor.size(0)):
    image = images_tensor[i]
    image_np = image.permute(1, 2, 0).numpy()

    plt.imshow(image_np,cmap='gray')
    plt.axis('off')
    plt.show()
    
# Iterate over the images and print them one by one
for i in range(labels_tensor.size(0)):
    image_1 = labels_tensor[i]
    image_11 = image_1.numpy()

    plt.imshow(image_11,cmap='gray')
    plt.axis('off')
    plt.show()
    '''

In [None]:
'''
for i_batch, sampled_batch in enumerate(valid_loader):
    images = sampled_batch['image']

    # Convert images to a PyTorch tensor
    images_tensor = torch.tensor(images)

    # Iterate over the images and print them one by one
    for i in range(images_tensor.size(0)):
        image = images_tensor[i]

        # Select a specific channel (e.g., the first channel)
        channel_image = image[0]  # Select the first channel

        # Convert channel image to a NumPy array
        image_np = channel_image.numpy()

        # Display the image using Matplotlib
        plt.imshow(image_np,cmap='gray')  # Assuming it's a grayscale image
        plt.axis('off')

    # Show all images in this batch
    plt.show()
    '''

**Define the Model**

In [None]:
from os.path import join as pjoin
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
from tqdm import tqdm

from einops import rearrange
from sklearn.metrics import roc_auc_score

In [None]:
class EncoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, factor=64):
        super().__init__()

        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        expansion = int(out_channels * (factor / 64))

        self.conv1 = nn.Conv2d(in_channels, expansion, kernel_size=1, stride=1, bias=False)
        self.norm1 = nn.BatchNorm2d(expansion)

        self.conv2 = nn.Conv2d(expansion, expansion, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)
        self.norm2 = nn.BatchNorm2d(expansion)

        self.conv3 = nn.Conv2d(expansion, out_channels, kernel_size=1, stride=1, bias=False)
        self.norm3 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x_down = self.downsample(x)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.norm3(x)
        x = x + x_down
        x = self.relu(x)

        return x

In [None]:
# Download the vit model from torchvision

from torchvision.models import vit_b_16
from torchvision.models import ViT_B_16_Weights

torch_vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

class PretrainedViTEncoder(nn.Module):
    def __init__(self):
        super(PretrainedViTEncoder, self).__init__()
        self.vit = torch_vit
        self.project = self.vit.conv_proj
        # Remove the last 4 encoder layers to keep only the first 8
        # self.vit.encoder.layers = self.vit.encoder.layers[:1]

    def forward(self, x):
        x = self.project(x)
        x = rearrange(x, "b c w h -> b (w h) c")
        # x = self.vit.encoder(x)
        # print("After ViT Encoder:", x.shape)
        # print(self.project)
        return x

In [None]:
class FocusedLinearAttention(nn.Module):
    def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.2, proj_drop=0.2,
                 focusing_factor=3, kernel_size=5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        #self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)


        self.focusing_factor = focusing_factor
        # Depthwise Convolution, padding = kernel_size//2 to make sure the image with the same size after convolution
        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
        # Remeber to modify the positional encoding to adjust to the new length
        self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches, dim)))
        # print('Linear Attention  f{} kernel{}'.
              #format(focusing_factor, kernel_size))

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)

        # Split the k and v matrixs here
        # kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        # k, v = kv[0], kv[1]

        # The shape of the k,v now would be (B, N, C)
        k = k + self.positional_encoding
        focusing_factor = self.focusing_factor
        kernel_function = nn.ReLU()
        scale = nn.Softplus()(self.scale)
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        q = q ** focusing_factor
        k = k ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm

        # Rearrange into multi-head dimension, each head will have C/H dimensions
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
        # print(i, j, c, d)
        # print(q.shape)
        # print(k.shape)
        # print(v.shape)
        # print(i * j * (c + d))
        # print(c * d * (i + j))
        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        # Using Linear Attention Mechanism here to get O(N) complexity
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)

        num = int(v.shape[1] ** 0.5)
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        # Expanding the attention matrix rank from d to N to expand the expressing power of the model
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        # Adding the expanding version feature map back into the input
        x = x + feature_map
        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

class MLP(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()

        self.mlp_layers = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        x = self.mlp_layers(x)

        return x

In [None]:
from performer import Attention
from longformer import Long2DSCSelfAttention
from linearformer import LinearAttention
from nystroem import NystromAttention
from efficient_attention import EfficientAttention #
from focused_linear_attention import FocusedLinearAttention
from linformer import LinformerAttention #

class Linear_ViT_EncoderBlock(nn.Module):
    def __init__(self, dim, num_patches, num_heads = 8, qkv_bias = False, qk_scale = None, attn_drop = 0, proj_drop = 0,
                 focusing_factor = 3, kernel_size = 5, mlp_dim = 2048):
        super().__init__()

# Uncomment the attention that you want to use

        self.attention = Attention(dim = dim)
#        self.attention = Long2DSCSelfAttention(dim = dim, num_patches = num_patches) # contains an error
#        self.attention = LinearAttention(dim = dim, num_patches = 1024)
        '''        
            self.attention = NystromAttention(
            dim = 768,
            dim_head = 96,
            heads = 8,
            num_landmarks = 256,    # number of landmarks
            pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
            residual = True         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
        )
        '''
#        self.attention = EfficientAttention(dim = dim, patches = num_patches)
#        self.attention = FocusedLinearAttention(dim, num_patches, num_heads = 8, qkv_bias = False,
#                                                           qk_scale = None, attn_drop = 0, proj_drop = 0, focusing_factor = 3, kernel_size = 5)
#        self.attention = LinformerAttention(dim)

        # FFN bottleneck layer with expanding then shrinking dimension
        self.mlp = MLP(dim, mlp_dim)

        self.layer_norm1 = nn.LayerNorm(768)
        self.layer_norm2 = nn.LayerNorm(768)

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
      # Assuming x is of shape (batch_size, num_patches, dim)
      # and you've flattened a HxW grid of patches
        H = W = int(math.sqrt(x.shape[1]))

        #ALERT: Change
        _x = self.attention(x)
        _x = self.dropout(_x)
        x = x + _x
        x = self.layer_norm1(x)

        _x = self.mlp(x)
        x = x + _x
        x = self.layer_norm2(x)
        return x

# Encoder Block for pure focused linear attention vision transformer
# The block_num would be only 4 since the first eight encoder block would use the pretrained ViT weights based on vanilla attention
class Linear_ViT_Encoder(nn.Module):
    def __init__(self, dim, num_patches, num_heads = 8, qkv_bias = False, qk_scale = None, attn_drop = 0, proj_drop = 0,
                 focusing_factor = 3, kernel_size = 5, block_num=5):
        super().__init__()

        self.layer_blocks = nn.ModuleList(
            [Linear_ViT_EncoderBlock(dim, num_patches, num_heads = 8, qkv_bias = False, qk_scale = None, attn_drop = 0, proj_drop = 0,
                 focusing_factor = 3, kernel_size = 5) for _ in range(block_num)])

    def forward(self, x):
        for layer_block in self.layer_blocks:
            x = layer_block(x)
        #print("After Linear ViT Encoder:", x.shape)
        #!!!!
        # x = rearrange(x, "b (e f) c -> b c e f", e=2, f=2) #32, 32 for real
        # print('Ready to be put into Decoder!', x.shape)
        return x

In [None]:
class TransUNetEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, model_dim, num_patches, num_heads = 8):
        super().__init__()

        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        self.num_patches = num_patches
        #self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)
        self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)
        self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)

        self.pretrained_vit = PretrainedViTEncoder()
        self.pretrained_vit.vit.encoder.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, model_dim))
        self.pretrained_vit.project = nn.Conv2d(out_channels * 8, 768, kernel_size=(1, 1), stride=(1, 1))

        self.linear_vit = Linear_ViT_Encoder(self.model_dim, num_patches, block_num=5)

        self.conv2 = nn.Conv2d(model_dim, 512, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm2d(512)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x1 = self.relu(x)
        # print('This is the shape of x1:', x1.shape)

        x2 = self.encoder1(x1)
        # print('This is the shape of x2:',x2.shape)
        x3 = self.encoder2(x2)
        # print('This is the shape of x3:',x3.shape)
        x = self.encoder3(x3)
        # print('This is the shape of x:',x.shape)

        x = self.pretrained_vit(x)
        # print('After Pretrained layer', x.shape)
        x = self.linear_vit(x)
        e = f = int(self.num_patches ** 0.5)
        x = rearrange(x, "b (e f) c -> b c e f", e = e, f = f)
        # print('After Rearrange', x.shape)

        x = self.conv2(x)
        # print('After Conv',x.shape)
        x = self.norm2(x)
        x = self.relu(x)

        return x, x1, x2, x3

In [None]:
class DecoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, scale_factor=2, size=None):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.skip_channels = skip_channels
        self.scale_factor = scale_factor
        self.size = size

        self.layer = nn.Sequential(
            nn.Conv2d(self.in_channels + self.skip_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, x_concat=None):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)

        if x_concat is not None:
            x = torch.cat([x_concat, x], dim=1)

        x = self.layer(x)
        return x

In [None]:
class TransUNetDecoder(nn.Module):
    def __init__(self, out_channels, class_num=1, scale_factor=2):
        super().__init__()
        self.out_channels = out_channels
        self.scale_factor = scale_factor

        # Define skip connection channels
        skip_channels_list = [256, 128, 64]

        self.decoder1 = DecoderBottleneck(out_channels*8,out_channels * 2+skip_channels_list[0],skip_channels=256)
        self.decoder2 = DecoderBottleneck(out_channels*2+skip_channels_list[0], out_channels+skip_channels_list[1],skip_channels=128)
        self.decoder3 = DecoderBottleneck(out_channels+skip_channels_list[1],int(out_channels * 1 / 2),skip_channels=64)

        # The last decoder does have a skip connection
        self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 4), skip_channels=0)

        # Segmentation layer for segmentation
        self.seg_layer = nn.Conv2d(int(out_channels * 1 /4), class_num, kernel_size=1, stride=1, padding=0)

    def forward(self, x, x1, x2, x3, deep_sup=False):
        if deep_sup:
            x3 = F.interpolate(x3, scale_factor=self.scale_factor, mode='nearest')
            x = self.decoder1(x, x3)
            x2 = F.interpolate(x2, scale_factor=self.scale_factor, mode='nearest')
            x = self.decoder2(x, x2)
            x1 = F.interpolate(x1, scale_factor=self.scale_factor, mode='nearest')
            x = self.decoder3(x, x1)
            return x
        else:
            # print("Input shape to decoder1:", x.shape)
            # print('x3 shape for skip connection',x3.shape)
            x = self.decoder1(x, x3)
            # print('d1', x.shape)
            x = self.decoder2(x, x2)
            # print('d2',x.shape)
            x = self.decoder3(x, x1)
            # print('d3',x.shape)
            x = self.decoder4(x)
            # print('d4',x.shape)
            x = self.seg_layer(x)
            return x

In [None]:
# Whole model would be defined here

class TransUNet(nn.Module):
    def __init__(self, in_channels = 1, encoder_out_channels = 512, decoder_out_channels = 64, head_num = 8, model_dim = 768, mlp_dim = 1536, block_num = 5, num_patches = 1024, class_num = 1):
        super().__init__()
        self.decoder_out_channels = decoder_out_channels

        self.encoder = TransUNetEncoder(in_channels=1, out_channels=64, num_patches=1024, model_dim=768)
        # self.encoder = TransUNetEncoder(in_channels, encoder_out_channels, model_dim, block_num, num_patches, num_heads = 8)

        # self.decoder = TransUNetDecoder(out_channels = self.decoder_out_channels, class_num = 3)
        self.decoder = TransUNetDecoder(out_channels=64)

    def forward(self, x):
        x, x1, x2, x3 = self.encoder(x)
        # print('x.shape:', x.shape)
        # print('x1.shape:', x1.shape)
        # print('x2.shape:', x2.shape)
        # print('x3.shape:', x3.shape)
        x = self.decoder(x, x1, x2, x3)

        return x

In [None]:
# Initialize loss functions
# Those special loss functions are defined in segmentation_models_pytorch, just use them directly
JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss = smp.losses.DiceLoss(mode='multilabel')
BCELoss = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

def dice_coef(y_true: torch.Tensor, y_pred: torch.Tensor, thr: float = 0.5, dim: tuple = (1,2), epsilon: float = 0.001) -> torch.Tensor:
    """Compute the Dice coefficient."""
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred > thr).to(torch.float32)
    inter = (y_true * y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2 * inter + epsilon) / (den + epsilon)).mean(dim=(0))
    return dice

def iou_coef(y_true: torch.Tensor, y_pred: torch.Tensor, thr: float = 0.5, dim: tuple = (1,2), epsilon: float = 0.001) -> torch.Tensor:
    """Compute the Intersection over Union (IoU) coefficient."""
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred > thr).to(torch.float32)
    inter = (y_true * y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true * y_pred).sum(dim=dim)
    iou = ((inter + epsilon) / (union + epsilon)).mean(dim=(0))
    return iou

def criterion(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """Compute the combined loss as a weighted sum of BCE and Tversky losses."""
    return 0.5 * BCELoss(y_pred, y_true) + 0.5 * TverskyLoss(y_pred, y_true)

def fetch_scheduler(optimizer: torch.optim.Optimizer) -> lr_scheduler._LRScheduler:
    """Fetch the learning rate scheduler based on the configuration."""
    if CFG.lr_scheduler == 'CosineAnnealingLR':
        return lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.lr_min)
    elif CFG.lr_scheduler == 'CosineAnnealingWarmRestarts':
        return lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, eta_min=CFG.lr_min)
    elif CFG.lr_scheduler == 'ReduceLROnPlateau':
        return lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7, threshold=0.0001, min_lr=CFG.lr_min)
    elif CFG.lr_scheduler == 'ExponentialLR':
        return lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif CFG.lr_scheduler is None:
        return None

**Define the model training phase**

In [None]:
def train_one_epoch(model: torch.nn.Module, 
                    optimizer: torch.optim.Optimizer, 
                    scheduler, 
                    dataloader: torch.utils.data.DataLoader, 
                    device: torch.device, 
                    epoch: int, 
                    criterion) -> float:
    """
    Train the model for one epoch.

    Parameters:
    - model: The model to train.
    - optimizer: The optimizer to use.
    - scheduler: The learning rate scheduler.
    - dataloader: The data loader for training data.
    - device: The device to train on (e.g., 'cuda').
    - epoch: The current epoch number.

    Returns:
    - The average loss for the epoch.
    """
    model.train()

    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, data in pbar:
        images = data['image'] 
        masks = data['label']
        images = images.to(device, dtype=torch.float)
        masks = masks.to(device, dtype=torch.float)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        y_pred = model(images)
        y_pred = y_pred.squeeze(1)
        loss = criterion(y_pred, masks)
        
        # Backward pass and optimization
        loss.backward()
        
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        batch_size = images.size(0)
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size

        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss=f'{epoch_loss:0.4f}',
                         lr=f'{current_lr:0.5f}',
                         gpu_mem=f'{mem:0.2f} GB')

    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss

**Define the model validation phase**

In [None]:
# Test time augmentation would be used here

def tta_inference(model, images):
    """Apply Test-Time Augmentation and return averaged predictions."""
    # Original
    preds = model(images)

    # Horizontal Flip
    preds_hflip = model(torch.flip(images, [3]))

    # Vertical Flip
    preds_vflip = model(torch.flip(images, [2]))

    # Average predictions
    preds = (preds + preds_hflip + preds_vflip) / 3.0
    return preds

In [None]:
def valid_one_epoch(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, device: torch.device, optimizer) -> tuple:
    model.eval()

    dataset_size = 0
    running_loss = 0.0
    val_scores = []
    
    # pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')

    with torch.no_grad():
        for _, data in pbar:
            masks = data['label'].to(device, dtype=torch.float)

            for channel in range(data['image'].shape[1]):
                image = data['image'][:, channel:channel+1].to(device, dtype=torch.float)  # Extract a single channel

                batch_size = image.size(0)

                y_pred = tta_inference(model, image)
                y_pred = y_pred.squeeze(1)
                loss = criterion(y_pred, masks)

                running_loss += (loss.item() * batch_size)
                dataset_size += batch_size

                epoch_loss = running_loss / dataset_size

                y_pred = nn.Sigmoid()(y_pred)
                val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
                val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
                val_scores.append([val_dice, val_jaccard])

                mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
                current_lr = optimizer.param_groups[0]['lr']
                pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}', lr=f'{current_lr:0.5f}', gpu_memory=f'{mem:0.2f} GB')

    val_scores = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss, val_scores

**Save the best parameter**

In [None]:
def save_model(model, filename, is_best=False):
    """Save the model's state dict to a file."""
    torch.save(model.state_dict(), filename)
    if is_best:
        wandb.save(filename)

**Define Run Training function here**

In [None]:
def run_training(model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler, device: torch.device, num_epochs: int) -> tuple:
    """
    Train the model for a specified number of epochs and log metrics to wandb.

    Parameters:
    - model: The model to train.
    - optimizer: The optimizer to use.
    - scheduler: The learning rate scheduler.
    - device: The device to train on (e.g., 'cuda').
    - num_epochs: The number of epochs to train.

    Returns:
    - Tuple containing the trained model and training history.
    """
    wandb.watch(model, log_freq=100)
    
    # ema = EMA(model)
    
    # Test if CUDA is available
    if torch.cuda.is_available():
        print(f"cuda: {torch.cuda.get_device_name()}\n")
     
    start_time = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice = -np.inf
    best_epoch = -1
    history = defaultdict(list)
    
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    criterion = criterion.to(device)

    run = wandb.init(
        project='uw-maddison-gi-tract',
        config={k: v for k, v in dict(vars(CFG)).items() if '__' not in k},
        anonymous=CFG.anonymous,
        name=f"dim-{CFG.image_size[0]}x{CFG.image_size[1]}|model-{CFG.model_name}",
        group=CFG.comment
    )
    
    train_loader, valid_loader, test_loader = prepare_synapse_loaders(base_dir=r"project_TransUNet/data/Synapse/train_npz",
                                                                      list_dir=r"project_TransUNet/TransUNet/lists/lists_Synapse", debug=False)

    for epoch in range(1, num_epochs + 1):
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')

        train_loss = train_one_epoch(model, optimizer, scheduler, criterion = criterion, dataloader=train_loader, device=device, epoch=epoch)
        val_loss, val_scores = valid_one_epoch(model, dataloader= valid_loader, device=device, optimizer = optimizer)
        val_dice, val_jaccard = val_scores

        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)

        # Log the metrics
        wandb.log({"Train Loss": train_loss,
                   "Valid Loss": val_loss,
                   "Valid Dice": val_dice,
                   "Valid Jaccard": val_jaccard,
                   "LR":scheduler.get_last_lr()[0]})

        print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')

        # deep copy the model
        if val_dice >= best_dice:
            print(f"Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            run.summary["Best Dice"]    = best_dice
            run.summary["Best Jaccard"] = best_jaccard
            run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            wandb.save(PATH)
            print(f"Model Saved")

        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch.bin"
        torch.save(model.state_dict(), PATH)

        print(); print()


    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f'Training complete in {hours:.0f}h {minutes:.0f}m {seconds:.0f}s')
    print(f"Best Score: {best_jaccard:.4f}")

    wandb.log({"Training Time": elapsed_time})

    model.load_state_dict(best_model_wts)

    return model, history

**Define Loss Function**

**Helper function for loading the model**

In [None]:
def build_model():
    model = TransUNet(in_channels=1, encoder_out_channels=512, decoder_out_channels = 64, head_num=8, model_dim=768, mlp_dim=2048, block_num=5, num_patches=1024, class_num=1)
    model.to(CFG.device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

**Prepare to train our model!**

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
model = TransUNet(in_channels=1, encoder_out_channels=512, decoder_out_channels = 64, head_num=8, model_dim=768, mlp_dim=2048, block_num=5, num_patches=1024, class_num=1)

In [None]:
# model = build_model()
# Initialize EMA

# ema = EMA(model)

optimizer = optim.AdamW(model.parameters(), lr=1e-6, weight_decay=CFG.weight_decay)
scheduler = fetch_scheduler(optimizer)

**Train the model!**

In [None]:
def train_model(CFG, model, optimizer, scheduler, device=CFG.device, num_epochs=CFG.epochs):
    print(f"\n{'#' * 15}\n### Training Model\n{'#' * 15}\n")

    # Initialize Weights & Biases run
    run = wandb.init(
        project='uw-maddison-gi-tract',
        config={k: v for k, v in dict(vars(CFG)).items() if '__' not in k},
        anonymous=CFG.anonymous,
        name=f"dim-{CFG.image_size[0]}x{CFG.image_size[1]}|model-{CFG.model_name}",
        group=CFG.comment
    )

    # Prepare data loaders
    # train_loader, valid_loader, test_loader = prepare_synapse_loaders(base_dir=r"C:\Users\ra78lof\Desktop\TransUNet\project_TransUNet\data\Synapse\train_npz",
                                                                      #list_dir=r"C:\Users\ra78lof\Desktop\TransUNet\project_TransUNet\TransUNet\lists\lists_Synapse", debug=False)
    
    # Train the model
    model, history = run_training(model, optimizer, scheduler, device=CFG.device, num_epochs=CFG.epochs)

    run.finish()

    # Display the Weights & Biases dashboard
    # display(display.IFrame(run.url, width=1000, height=720))

    return model, history

**Test the model!**

In [None]:
def evaluate_model(model, config):
    # Prepare test dataset and loader
    _, _, test_loader = prepare_synapse_loaders(base_dir=r"project_TransUNet/data/Synapse/train_npz", 
                                                list_dir=r"project_TransUNet/TransUNet/lists/lists_Synapse", debug=False)
    imgs = next(iter(test_loader)).to(CFG.device, dtype=torch.float)

    preds = []
    with torch.no_grad():
        pred = model(imgs)
        pred = (nn.Sigmoid()(pred) > 0.5).double()
        preds.append(pred)

    return torch.mean(torch.stack(preds, dim=0), dim=0).cpu().detach()

In [None]:
trained_model, training_history = train_model(CFG, model, optimizer, scheduler)

In [None]:
from torchvision.datasets import VOCDetection

voc_dataset = VOCDetection(root="data", year="2012", image_set="trainval", download=True)

In [None]:
class CustomVOCDetection(Dataset):
    def __init__(self, voc_dataset, transforms=None):
        self.voc_dataset = voc_dataset
        self.transforms = transforms

    def __len__(self):
        return len(self.voc_dataset)

    def __getitem__(self, idx):
        img, target = self.voc_dataset[idx]
        boxes = target["annotation"]["object"]
        labels = []
        bboxes = []
        for box in boxes:
            xmin = float(box["bndbox"]["xmin"])
            ymin = float(box["bndbox"]["ymin"])
            xmax = float(box["bndbox"]["xmax"])
            ymax = float(box["bndbox"]["ymax"])
            bboxes.append([xmin, ymin, xmax, ymax])
            labels.append(1)  # assuming all objects are of the same class

        if self.transforms:
            transformed = self.transforms(image=img, bboxes=bboxes, labels=labels)
            img = transformed["image"]
            bboxes = transformed["bboxes"]

        return img, bboxes

Let's Go!