# 1. Requirements

In [None]:
import argparse
import os
import sys
import random

import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from einops import rearrange, repeat
from einops import rearrange
from scipy.ndimage import rotate, zoom
from tqdm import tqdm

# 2. Data

## 2.1. Synapse dataset

In [None]:
class SynapseDataset(Dataset):
  def __init__(self, root_path, transform=None):
    self.root_path = root_path
    self.transform = transform
    self.sample_list = os.listdir(root_path)

  def __len__(self):
    return len(self.sample_list)
  
  def __getitem__(self, idx):
    slice_name = self.sample_list[idx].strip()
    data_path = os.path.join(self.root_path, slice_name)
    data = np.load(data_path)
    image, mask = data['image'], data['label']
 
    sample = {'image': image, 'mask': mask}
    if self.transform: sample = self.transform(sample)
    sample['case_name'] = self.sample_list[idx].strip()
    return sample

## 2.2. Data transforms

In [None]:
class RandomGenerator:
  def __random_rot_flip(self, image, mask):
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    mask = np.rot90(mask, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    mask = np.flip(mask, axis=axis).copy()
    return image, mask
  
  def __random_rotate(self, image, mask):
    angle = np.random.randint(-20, 20)
    image = rotate(image, angle, order=0, reshape=False)
    mask = rotate(mask, angle, order=0, reshape=False)
    return image, mask

  def __call__(self, sample):
    image, mask = sample['image'], sample['mask']
    rand = random.random()
    if rand > 2/3:
      image, mask = self.__random_rot_flip(image, mask)
    elif rand > 1/3:
      image, mask = self.__random_rotate(image, mask)
    sample = {'image': image, 'mask': mask}
    return sample
  

class Zoomer:
  def __init__(self, output_size):
    self.output_size = output_size
  
  def __zoom(self, image, mask):
    x, y = image.shape
    if x != self.output_size or y != self.output_size:
      image = zoom(image, (self.output_size / x, self.output_size / y), order=3)
      mask = zoom(mask, (self.output_size / x, self.output_size / y), order=0)
    image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
    mask = torch.from_numpy(mask.astype(np.float32))
    return image, mask
  
  def __call__(self, sample):
    image, mask = sample['image'], sample['mask']
    image, mask = self.__zoom(image, mask)
    sample = {'image': image, 'mask': mask}
    return sample

In [None]:
datasets = {
  'Synapse': SynapseDataset,
}

# 3. Model

## 3.1. Criterion

In [None]:
class DiceLoss(nn.Module):
  def __init__(self, n_classes):
    super(DiceLoss, self).__init__()
    self.n_classes = n_classes

  def _one_hot_encoder(self, input_tensor):
    tensor_list = []
    for i in range(self.n_classes):
      temp_prob = input_tensor == i
      tensor_list.append(temp_prob.unsqueeze(1))
    output_tensor = torch.cat(tensor_list, dim=1)
    return output_tensor.float()

  def _dice_loss(self, pred, target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(pred * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(pred * pred)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss

  def forward(self, pred, target, weight=None, softmax=False):
    if softmax: pred = torch.softmax(pred, dim=1)
    target = self._one_hot_encoder(target)
    if weight is None: weight = [1] * self.n_classes
    class_wise_dice = []
    loss = 0
    for i in range(self.n_classes):
      dice = self._dice_loss(pred[:, i], target[:, i])
      class_wise_dice.append(1 - dice.item())
      loss += dice * weight[i]
    return loss / self.n_classes

## 3.2. Vision Transformer

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embedding_dim, head_num):
    super().__init__()
    self.head_num = head_num
    self.dk = (embedding_dim // head_num) ** 0.5
    self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
    self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)

  def forward(self, x, mask=None):
    qkv = self.qkv_layer(x)
    query, key, value = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d ', k=3, h=self.head_num))

    energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk
    if mask is not None: energy = energy.masked_fill(mask, -np.inf)

    attention = torch.softmax(energy, dim=-1)
    x = torch.einsum("... i j , ... j d -> ... i d", attention, value)

    x = rearrange(x, "b h t d -> b t (h d)")
    x = self.out_attention(x)

    return x
  

class MLP(nn.Module):
  def __init__(self, embedding_dim, mlp_dim):
    super().__init__()
    self.mlp_layers = nn.Sequential(
      nn.Linear(embedding_dim, mlp_dim),
      nn.GELU(),
      nn.Dropout(0.1),
      nn.Linear(mlp_dim, embedding_dim),
      nn.Dropout(0.1)
    )

  def forward(self, x):
    x = self.mlp_layers(x)
    return x
  

class EncoderBlock(nn.Module):
  def __init__(self, embedding_dim, head_num, mlp_dim):
    super().__init__()
    self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
    self.mlp = MLP(embedding_dim, mlp_dim)
    self.layer_norm1 = nn.LayerNorm(embedding_dim)
    self.layer_norm2 = nn.LayerNorm(embedding_dim)
    self.dropout = nn.Dropout(0.1)

  def forward(self, x):
    _x = self.multi_head_attention(x)
    _x = self.dropout(_x)
    x = x + _x
    x = self.layer_norm1(x)

    _x = self.mlp(x)
    x = x + _x
    x = self.layer_norm2(x)

    return x
  

class TransformerEncoder(nn.Module):
  def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):
    super().__init__()
    self.layer_blocks = nn.ModuleList(
      [EncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)]
    )

  def forward(self, x):
    for layer_block in self.layer_blocks:
      x = layer_block(x)
    return x
  

class ViT(nn.Module):
  def __init__(
    self, image_dim, in_channels, embedding_dim, head_num, mlp_dim,
    block_num, patch_dim, classification=True, num_classes=1
  ):
    super().__init__()

    self.patch_dim = patch_dim
    self.classification = classification
    self.num_tokens = (image_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)

    self.projection = nn.Linear(self.token_dim, embedding_dim)
    self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))
    self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
    self.dropout = nn.Dropout(0.1)
    self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)

    if self.classification:
      self.mlp_head = nn.Linear(embedding_dim, num_classes)

  def forward(self, x):
    img_patches = rearrange(
      x, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
      patch_x=self.patch_dim, patch_y=self.patch_dim
    )
    batch_size, tokens, _ = img_patches.shape

    project = self.projection(img_patches)
    token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', batch_size=batch_size)

    patches = torch.cat([token, project], dim=1)
    patches += self.embedding[:tokens + 1, :]

    x = self.dropout(patches)
    x = self.transformer(x)
    x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]

    return x

## 3.3. TransUNet

In [None]:
class EncoderBottleneck(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1, base_width=64):
    super().__init__()
    self.downsample = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
      nn.BatchNorm2d(out_channels)
    )
    width = int(out_channels * (base_width / 64))
    self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False)
    self.norm1 = nn.BatchNorm2d(width)
    self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)
    self.norm2 = nn.BatchNorm2d(width)
    self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False)
    self.norm3 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x_down = self.downsample(x)
    x = self.conv1(x)
    x = self.norm1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.norm2(x)
    x = self.relu(x)
    x = self.conv3(x)
    x = self.norm3(x)
    x = x + x_down
    x = self.relu(x)
    return x


class Encoder(nn.Module):
  def __init__(self, image_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)
    self.norm1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)
    self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)
    self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)
    self.vit_image_dim = image_dim // patch_dim
    self.vit = ViT(
      self.vit_image_dim, out_channels * 8, out_channels * 8,
      head_num, mlp_dim, block_num, patch_dim=1, classification=False
    )
    self.conv2 = nn.Conv2d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1)
    self.norm2 = nn.BatchNorm2d(512)

  def forward(self, x):
    x = self.conv1(x)
    x = self.norm1(x)
    x1 = self.relu(x)
    x2 = self.encoder1(x1)
    x3 = self.encoder2(x2)
    x = self.encoder3(x3)
    x = self.vit(x)
    x = rearrange(x, "b (x y) c -> b c x y", x=self.vit_image_dim, y=self.vit_image_dim)
    x = self.conv2(x)
    x = self.norm2(x)
    x = self.relu(x)
    return x, x1, x2, x3
  

class DecoderBottleneck(nn.Module):
  def __init__(self, in_channels, out_channels, scale_factor=2):
    super().__init__()
    self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)
    self.layer = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
    )

  def forward(self, x, x_concat=None):
    x = self.upsample(x)
    if x_concat is not None:
      x = torch.cat([x_concat, x], dim=1)
    x = self.layer(x)
    return x
  

class Decoder(nn.Module):
  def __init__(self, out_channels, class_num):
    super().__init__()
    self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 2)
    self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels)
    self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))
    self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))
    self.conv1 = nn.Conv2d(int(out_channels * 1 / 8), class_num, kernel_size=1)

  def forward(self, x, x1, x2, x3):
    x = self.decoder1(x, x3)
    x = self.decoder2(x, x2)
    x = self.decoder3(x, x1)
    x = self.decoder4(x)
    x = self.conv1(x)
    return x
  

class TransUNet(nn.Module):
  def __init__(self, image_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):
    super().__init__()
    self.encoder = Encoder(image_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim)
    self.decoder = Decoder(out_channels, class_num)

  def forward(self, x):
    x, x1, x2, x3 = self.encoder(x)
    x = self.decoder(x, x1, x2, x3)
    return x

## 3.4. Model manager

In [None]:
class ModelManager:
  def __init__(self, args):
    self.args = args
    self.model = TransUNet(
      args.image_dim, args.in_channels, args.out_channels, args.head_num,
      args.mlp_dim, args.block_num, args.patch_dim, args.class_num
    )
    self.model = nn.DataParallel(self.model)
    self.model.to(args.device)

    self.dice_loss = DiceLoss(self.args.class_num)
    self.ce_loss = CrossEntropyLoss()
    self.optimizer = SGD(
      self.model.parameters(), lr=args.learning_rate,
      momentum=args.momentum, weight_decay=args.weight_decay
    )

  def load_model(self):
    ckpt = torch.load(self.args.pretrain_path, map_location=torch.device(self.args.device), weights_only=True)
    self.model.load_state_dict(ckpt['model_state_dict'])
    self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    print(f'Checkpoint is loaded - epoc: {ckpt["epoch"]} loss: {ckpt["loss"]}')
    return ckpt['epoch'], ckpt['loss']

  def train_step(self, image, mask):
    self.model.train()
    self.optimizer.zero_grad()
    pred_mask = self.model(image)
    loss_ce = self.ce_loss(pred_mask, mask[:].long())
    loss_dice = self.dice_loss(pred_mask, mask, softmax=True)
    loss = (loss_ce + loss_dice) / 2
    loss.backward()
    self.optimizer.step()
    return loss.item(), pred_mask

  def test_step(self, image, mask):
    self.model.eval()
    pred_mask = self.model(image)
    loss_ce = self.ce_loss(pred_mask, mask[:].long())
    loss_dice = self.dice_loss(pred_mask, mask, softmax=True)
    loss = (loss_ce + loss_dice) / 2
    return loss.item(), pred_mask

# 4. Training

In [None]:
class EpochCallback:
  end_training = False
  not_improved_epoch = 0

  def __init__(self, save_path, epochs, model, optimizer, monitor=None, patience=None, init_loss=np.inf):
    self.save_path = save_path
    self.epochs = epochs
    self.monitor = monitor
    self.patience = patience
    self.model = model
    self.optimizer = optimizer
    self.monitor_value = init_loss

  def __save_model(self, epoch, loss):
    torch.save({
      'epoch': epoch,
      'loss': loss,
      'model_state_dict': self.model.state_dict(),
      'optimizer_state_dict': self.optimizer.state_dict()
    }, self.save_path)
    print(f'Model saved to {self.save_path}')

  def epoch_end(self, epoch_num, hash):
    epoch_end_str = f'Epoch {epoch_num}/{self.epochs} - '
    for name, value in hash.items():
      epoch_end_str += f'{name}: {round(value, 4)} '
    print(epoch_end_str)

    if self.monitor is None:
      self.__save_model(epoch_num, hash[self.monitor])
    elif hash[self.monitor] < self.monitor_value:
      print(f'{self.monitor} decreased from {round(self.monitor_value, 4)} to {round(hash[self.monitor], 4)}')
      self.not_improved_epoch = 0
      self.monitor_value = hash[self.monitor]
      self.__save_model(epoch_num, hash[self.monitor])
    else:
      print(f'{self.monitor} did not decrease from {round(self.monitor_value, 4)}, model did not save!')
      self.not_improved_epoch += 1
      if self.patience is not None and self.not_improved_epoch >= self.patience:
        print("Training was stopped by callback!")
        self.end_training = True

In [None]:
class Trainer:
  def __init__(self, args):
    self.args = args
    self.train_loader = self.__load_dataset(self.args.train_path, 'train')
    self.test_loader = self.__load_dataset(self.args.test_path, 'test')
    self.model_manager = ModelManager(args)
    if (self.args.pretrain_path):
      self.init_epoch, self.init_loss = self.model_manager.load_model()
    else:
      self.init_epoch, self.init_loss = 0, np.inf

  def __load_dataset(self, path, split):
    shuffle = split == 'train'
    transform = [RandomGenerator(), Zoomer(self.args.image_dim)] if split == 'train' else [Zoomer(self.args.image_dim)]
    transform = transforms.Compose(transform)
    dataset = datasets[self.args.dataset_name](path, transform)
    loader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=shuffle, num_workers=4)
    return loader
  
  def __loop(self, loader, step_function, t):
    total_loss = 0
    for _, data in enumerate(loader):
      image = data['image'].to(self.args.device)
      mask = data['mask'].to(self.args.device)
      loss, _ = step_function(image=image, mask=mask)
      total_loss += loss
      t.update()
    return total_loss

  def train(self):
    callback = EpochCallback(
      save_path=self.args.save_path, epochs=self.args.epochs,
      model=self.model_manager.model, optimizer=self.model_manager.optimizer,
      monitor='test_loss', patience=self.args.patience, init_loss=self.init_loss
    )

    for epoch in range(self.init_epoch, self.args.epochs):
      with tqdm(total=len(self.train_loader) + len(self.test_loader)) as t:
        train_loss = self.__loop(self.train_loader, self.model_manager.train_step, t)
        test_loss = self.__loop(self.test_loader, self.model_manager.test_step, t)

      callback.epoch_end(epoch + 1, {
        'loss': train_loss / len(self.train_loader),
        'test_loss': test_loss / len(self.test_loader)
      })
      if callback.end_training: break

# 5. Inference

In [None]:
class Inference:
  def __init__(self, args):
    self.args = args
    self.model_manager = ModelManager(args)
    self.model_manager.load_model()

  def __read_and_preprocess(self):
    image = cv2.imread(self.args.image_path)
    image_torch = cv2.resize(image, (self.args.image_dim, self.args.image_dim))
    image_torch = image_torch / 255.
    image_torch = image_torch.transpose((2, 0, 1))
    image_torch = np.expand_dims(image_torch, axis=0)
    image_torch = torch.from_numpy(image_torch.astype('float32')).to(self.device)
    return image, image_torch
  
  def __save(self, mask):
    cv2.imshow("Mask", mask)

  def __threshold(self, mask, thresh=0.5):
    mask[mask >= thresh] = 1
    mask[mask < thresh] = 0
    return mask

  def infer(self):
    image, image_torch = self.__read_and_preprocess(self)
    with torch.no_grad():
      pred_mask = self.transunet.model(image_torch)
      pred_mask = torch.sigmoid(pred_mask)
      pred_mask = pred_mask.detach().cpu().numpy().transpose((0, 2, 3, 1))
    
    orig_h, orig_w = image.shape[:2]
    pred_mask = cv2.resize(pred_mask[0, ...], (orig_w, orig_h))
    pred_mask = self.__threshold(pred_mask, thresh=self.args.infer_threshold)
    pred_mask *= 255

    if self.args.merge_infer:
      pred_mask = cv2.bitwise_and(image, image, mask=pred_mask.astype('uint8'))
    if self.args.save_infer:
      self.__save(pred_mask)
    return pred_mask

# 6. Main entry

In [None]:
sys.argv = [
  'kaggle.ipynb',
  '--mode', 'train',
  '--dataset_name', 'Synapse',
  '--class_num', '10',
  '--train_path', '/kaggle/input/synapse/train',
  '--test_path', '/kaggle/input/synapse/val',
  '--save_path', '/kaggle/working/checkpoint.pth',
  '--epochs', '100',
]

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, required=True, choices=['train', 'infer'])
parser.add_argument('--dataset_name', required='train' in sys.argv, type=str, choices=['Synapse'])
parser.add_argument('--train_path', required='train' in sys.argv,  type=str, default=None)
parser.add_argument('--test_path', required='train' in sys.argv, type=str, default=None)
parser.add_argument('--save_path', required='train' in sys.argv, type=str, default=None)
parser.add_argument('--pretrain_path', required='infer' in sys.argv, type=str, default=None)
parser.add_argument('--image_path', required='infer' in sys.argv, type=str, default=None)
parser.add_argument('--merge_infer', type=bool, default=False)
parser.add_argument('--save_infer', type=bool, default=False)

parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--learning_rate', type=float, default=1e-2)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--patience', type=int, default=25)
parser.add_argument('--inference_threshold', type=float, default=0.75)

parser.add_argument('--image_dim', type=int, default=512)
parser.add_argument('--in_channels', type=int, default=1)
parser.add_argument('--out_channels', type=int, default=128)
parser.add_argument('--head_num', type=int, default=4)
parser.add_argument('--mlp_dim', type=int, default=512)
parser.add_argument('--block_num', type=int, default=12)
parser.add_argument('--patch_dim', type=int, default=16)
parser.add_argument('--class_num', type=int, default=1)

args = parser.parse_args()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

if args.mode == 'train':
  trainer = Trainer(args)
  trainer.train()
elif args.mode == 'infer':
  inference = Inference(args)
  inference.infer()