## Contour extraction

In [None]:
import numpy as np
import cv2

In [None]:
def get_contour(mask):
    dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
    
    cv2.normalize(dist_transform, dist_transform, 0, 1, cv2.NORM_MINMAX)
    
    _, contour = cv2.threshold(dist_transform, 0.2, 1, cv2.THRESH_BINARY)

    contour = mask - contour * mask

    return contour

## Custom dataset

In [None]:
from skimage.io import imread
from torch.utils import data

In [None]:
class SegDataset(data.Dataset):
    def __init__(self, input_paths: list, target_paths: list, transform=None,
    ):
        super().__init__()
        self.input_paths = input_paths
        self.target_paths = target_paths
        self.transform = transform

    def __len__(self):
        return len(self.input_paths)

    def __getitem__(self, index: int):
        input_ID = self.input_paths[index]
        target_ID = self.target_paths[index]

        x, y = imread(input_ID), cv2.imread(target_ID, cv2.IMREAD_GRAYSCALE)
        e = get_contour(y)

        if self.transform is not None:
            augmentations = self.transform(image=x, mask=y, edge=e)
            x = augmentations["image"]
            y = augmentations["mask"]
            e = augmentations["edge"]

        x = x / 255
        y = y.unsqueeze(dim=0) / 255
        e = e.unsqueeze(dim=0) / 255
            
        return x.float(), y.float(), e.float()

## Model

In [None]:
import timm
import torch
from torch import nn
import torch.nn.functional as F

### Edge decoder

In [None]:
class EdgeDecoder(nn.Module):
    def __init__(self, dims=[96, 192, 384, 576]):
        super().__init__()
        self.con0 = nn.Conv2d(dims[0], dims[0], 1, 1, 0)
        self.con1 = nn.Conv2d(dims[1], dims[1], 1, 1, 0)
        self.con2 = nn.Conv2d(dims[2], dims[2], 1, 1, 0)
        
        self.decode0 = nn.Sequential(
            nn.BatchNorm2d(dims[0] + dims[1]),
            nn.Conv2d(dims[0] + dims[1], dims[0], 3, 1, 1),
            nn.ReLU()
        )
        self.decode1 = nn.Sequential(
            nn.BatchNorm2d(dims[1] + dims[2]),
            nn.Conv2d(dims[1] + dims[2], dims[1], 3, 1, 1),
            nn.ReLU()
        )
        self.decode2 = nn.Sequential(
            nn.BatchNorm2d(dims[2]),
            nn.Conv2d(dims[2], dims[2], 3, 1, 1),
            nn.ReLU()
        )
        
    def forward(self, x0, x1, x2):
        e2 = self.decode2(self.con2(x2))
        e1 = self.decode1(torch.cat([self.con1(x1), F.interpolate(e2, scale_factor=2, mode='bilinear')], dim = 1))
        e0 = self.decode0(torch.cat([self.con0(x0), F.interpolate(e1, scale_factor=2, mode='bilinear')], dim = 1))
        
        return e0, e1, e2

### Edge detector

In [None]:
class EdgeDetector(nn.Module):
    def __init__(self, dims=[96, 192, 384, 576]):
        super().__init__()
        self.encoder = timm.create_model('caformer_m36.sail_in22k_ft_in1k_384', features_only=True, pretrained=True)
        self.edge_decoder = EdgeDecoder()
        self.edge = nn.Sequential(
            nn.BatchNorm2d(dims[0]),
            nn.Conv2d(dims[0], 1, 1, 1, 0)
        )
        
    def forward(self, x):
        x0, x1, x2, x3 = self.encoder(x)
        e0, e1, e2 = self.edge_decoder(x0, x1, x2)
        
        return e0, e1, e2, x3

### Fuse module

In [None]:
class FuseModule(nn.Module):
    def __init__(self, context_dim, dim):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(dim, context_dim, 4, stride=2, padding=1)
        self.context_conv = nn.Conv2d(context_dim, 1, 1)
        
        self.conv = nn.Sequential(
            nn.Conv2d(context_dim, context_dim, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(context_dim)
        )
        
        
    def forward(self, context, x):
        batch, channel, height, width = context.shape
        
        x = self.deconv(x)
        i = x.clone().view(batch, channel, height * width)
        i = i.unsqueeze(dim=1)
        
        context = self.context_conv(context)
        pixel_vec = context.clone().view(batch, 1, height * width)
        pixel_vec = F.softmax(pixel_vec, dim=2)
        pixel_vec = pixel_vec.unsqueeze(dim=-1)
        channel_vec = torch.matmul(i, pixel_vec)
        channel_vec = channel_vec.view(batch, channel, 1, 1)
        
        x = x + x * F.sigmoid(channel_vec)
        x = x + x * F.sigmoid(context)
        x = self.conv(x)
        
        return x

In [None]:
class MaskDecoder(nn.Module):
    def __init__(self, dims=[96, 192, 384, 576]):
        super().__init__()
        self.decode3 = nn.Sequential(
            nn.Conv2d(dims[3], dims[3], 1, 1, 0),
            nn.BatchNorm2d(dims[3]),
            nn.Conv2d(dims[3], dims[3], 3, 1, 1),
            nn.ReLU()
        )
        
        self.fuse_01 = FuseModule(dims[0], dims[1])
        self.fuse_12 = FuseModule(dims[1], dims[2])
        self.fuse_23 = FuseModule(dims[2], dims[3])        
        
    def forward(self, x0, x1, x2, x3):
        x3 = self.decode3(x3)
        x2 = self.fuse_23(x2, x3)
        x1 = self.fuse_12(x1, x2)
        x0 = self.fuse_01(x0, x1) 
        
        return x0

### Contour-enhanced Segmentation (CES)

In [None]:
class CES(nn.Module):
    def __init__(self, dims=[96, 192, 384, 576]):
        super().__init__()
        self.edge_detector = EdgeDetector()
        self.mask_decoder = MaskDecoder()
        self.mask = nn.Conv2d(dims[0], 1, 1, 1, 0)
        
    def forward(self, x):
        e0, e1, e2, x3 = self.edge_detector(x)
        o = self.mask_decoder(e0, e1, e2, x3)
        mask = self.mask(F.interpolate(o, scale_factor=4, mode='bilinear'))
        
        return mask

In [None]:
PRETRAIN_EDGE_DETECTOR_PATH = 

In [None]:
import torchinfo
from torchinfo import summary

checkpoint = torch.load(PRETRAIN_EDGE_DETECTOR_PATH)
model = NovelModel()
model.edge_detector.load_state_dict(checkpoint["model_state_dict"])
for param in model.edge_detector.parameters():
    param.requires_grad = False
for param in model.edge_detector.encoder.stages_3.parameters():
    param.requires_grad = True

with torch.inference_mode():
    print(summary(model, input_size=(8, 3, 384, 384)))

## Losses & Metrics

### Losses

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        num = targets.size(0)

        probs = torch.sigmoid(logits)
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = m1 * m2

        score = (
            2.0
            * (intersection.sum(1) + self.smooth)
            / (m1.sum(1) + m2.sum(1) + self.smooth)
        )
        loss = 1 - score.sum()/num
        
        return loss

### Metrics

In [None]:
class DiceScore(torch.nn.Module):
    def __init__(self, smooth=1):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        num = targets.size(0)

        probs = torch.sigmoid(logits)
        m1 = probs.view(num, -1) > 0.5
        m2 = targets.view(num, -1) > 0.5
        intersection = m1 * m2

        score = (
            2.0
            * (intersection.sum(1) + self.smooth)
            / (m1.sum(1) + m2.sum(1) + self.smooth)
        )
        
        score = score.sum()/num
        
        return score

## Train

### Configuration

In [None]:
class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [None]:
DATASET_NAME =
TRAIN_DATASET_PATH =
VALID_DATASET_PATH =
TEST_DATASET_PATH = 

In [None]:
train_args_dict = {
    "dataset": DATASET_NAME,
    
    "train_root": TRAIN_DATASET_PATH,
    "valid_root": VALID_DATASET_PATH,
    "test_root": TEST_DATASET_PATH,
    
    "epochs": 100,
    "batch_size": 8,
    "lr": 1e-4,
    "lrs": "true",
    "lrs_min": 1e-7
}

args = Args(**train_args_dict)

### Data preparation

In [None]:
import glob
import multiprocessing

In [None]:
def get_paths(data_root):
    img_paths = data_root + "/images/*"
    input_paths = sorted(glob.glob(img_paths))
    mask_paths = data_root + "/masks/*"
    target_paths = sorted(glob.glob(mask_paths))
    return input_paths, target_paths

In [None]:
def get_train_dataloaders(input_paths, target_paths, transform_train, batch_size):

    train_dataset = SegDataset(
        input_paths=input_paths,
        target_paths=target_paths,
        transform=transform_train
    )

    train_dataloader = data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=multiprocessing.Pool()._processes,
    )

    return train_dataloader

def get_test_dataloaders(input_paths, target_paths, transform_test):

    test_dataset = SegDataset(
        input_paths=input_paths,
        target_paths=target_paths,
        transform=transform_test
    )

    test_dataloader = data.DataLoader(
        dataset=test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=multiprocessing.Pool()._processes,
    )

    return test_dataloader

### Training utils

In [None]:
import os
from tqdm import tqdm
import time
import matplotlib.pyplot as plt

In [None]:
def build_train(args, transfrom_train, transform_test):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    train_input_paths, train_target_paths = get_paths(args.train_root)
    valid_input_paths, valid_target_paths = get_paths(args.valid_root)
    test_input_paths, test_target_paths = get_paths(args.test_root)

    train_dataloader = get_train_dataloaders(
        train_input_paths, train_target_paths, transform_train, batch_size=args.batch_size
    )
    
    val_dataloader = get_test_dataloaders(
        valid_input_paths, valid_target_paths, transform_test
    )
    
    test_dataloader = get_test_dataloaders(
        test_input_paths, test_target_paths, transform_test
    )
    
    Dice_loss = DiceLoss()
    BCE_loss = nn.BCELoss()

    perf = DiceScore()

    model = EdgeDetector()
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    return (
        device,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        Dice_loss,
        BCE_loss,
        perf,
        model,
        optimizer,
    )

In [None]:
def train_epoch(model, device, train_loader, optimizer, epoch, Dice_loss, BCE_loss, dataset):
    t = time.time()
    model.train()
    loss_accumulator = []
    for batch_idx, (data, target_mask, target_edge) in enumerate(tqdm(train_loader)):
        data, target_mask, target_edge = data.to(device), target_mask.to(device), target_edge.to(device)
        optimizer.zero_grad()
        mask = model(data)
        
        if batch_idx == len(train_loader) - 1:
            plt.subplot(3, 3, 1)
            plt.imshow(data[0].permute(1, 2, 0).cpu().numpy())

            plt.subplot(3, 3, 2)
            plt.imshow(target_mask[0][0].cpu().numpy())

            plt.subplot(3, 3, 3)
            plt.imshow(target_edge[0][0].cpu().numpy())
            
            plt.subplot(3, 3, 5)
            plt.imshow(mask[0][0].detach().cpu().numpy())
            
            plt.subplot(3, 3, 8)
            plt.imshow(((mask[0][0].sigmoid() > 0.5) * 1).detach().cpu().numpy())
            
            plt.show()
        
        loss = Dice_loss(mask, target_mask) + BCE_loss(torch.sigmoid(mask), target_mask)
        
        loss.backward()
        optimizer.step()
        loss_accumulator.append(loss.item())
    print(
        "\r[Train {} Epoch: {}]\nAverage loss: {:.6f}\nTime: {:.6f}".format(
            dataset,
            epoch,
            np.mean(loss_accumulator),
            time.time() - t,
        )
    )

    return np.mean(loss_accumulator)

def test(model, device, test_loader, epoch, perf_measure, dataset):
    t = time.time()
    model.eval()
    perf_accumulator = []
    perf = []
    with torch.inference_mode():
        for batch_idx, (data, target_mask, target_edge) in enumerate(tqdm(test_loader)):
            data, target_mask, target_edge = data.to(device), target_mask.to(device), target_edge.to(device)
            mask = model(data)
            perf_accumulator.append(perf_measure(mask, target_mask).item())
            
            if batch_idx == len(test_loader) - 1:
                plt.subplot(3, 3, 1)
                plt.imshow(data[0].permute(1, 2, 0).cpu().numpy())

                plt.subplot(3, 3, 2)
                plt.imshow(target_mask[0][0].cpu().numpy())

                plt.subplot(3, 3, 3)
                plt.imshow(target_edge[0][0].cpu().numpy())

                plt.subplot(3, 3, 5)
                plt.imshow(mask[0][0].detach().cpu().numpy())

                plt.subplot(3, 3, 8)
                plt.imshow(((mask[0][0].sigmoid() > 0.5) * 1).detach().cpu().numpy())

                plt.show()
    print(
        "\r[{} Epoch: {}]\nAverage performance: {:.6f}\nTime: {:.6f}".format(
            dataset,
            epoch,
            np.mean(perf_accumulator),
            time.time() - t,
        )
    )

    return np.mean(perf_accumulator), np.std(perf_accumulator)

In [None]:
def train(args, transfrom_train, transform_test):
    (
        device,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        Dice_loss,
        BCE_loss,
        perf,
        model,
        optimizer,
    ) = build_train(args, transfrom_train, transform_test)

    if not os.path.exists("./Trained models"):
        os.makedirs("./Trained models")

    prev_best_valid = None
    
    if args.lrs == "true":
        if args.lrs_min > 0:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.7, patience=5, min_lr=args.lrs_min, verbose=True
            )
        else:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.7, patience=5, verbose=True
            )

    for epoch in range(1, args.epochs + 1):
        try:
            loss = train_epoch(
                model, device, train_dataloader, optimizer, epoch, Dice_loss, BCE_loss, args.dataset
            )
            valid_measure_mean, valid_measure_std = test(
                model, device, val_dataloader, epoch, perf, f"Valid {args.dataset}"
            )
            test_measure_mean_kvasir, test_measure_std_kvasir = test(
                model, device, test_dataloader, epoch, perf, f"Valid {args.dataset}
        except KeyboardInterrupt:
            print("Training interrupted by user")
            sys.exit(0)

        if args.lrs == "true":
            scheduler.step(valid_measure_mean)
        if prev_best_valid == None or valid_measure_mean > prev_best_valid:
            print("Saving...")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": loss,
                },
                f"Trained models/model_{args.dataset}.pt",
            )
            prev_best_valid = valid_measure_mean
        print("====================================================================================================")

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
IMAGE_SIZE = 384

transform_train = A.Compose([
    A.RandomCrop(height=256, width=256, p=0.5),
    A.Rotate(limit=(-90, 90), p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    ToTensorV2()
], additional_targets={'mask': 'mask', 'edge': 'mask'})

transform_test = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    ToTensorV2()
], additional_targets={'mask': 'mask', 'edge': 'mask'})

In [None]:
train(args, transform_train, transform_test)