In [None]:
import os
import numpy as np
import pandas as pd
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import pydicom
import argparse
import warnings
from glob import glob
import nibabel as nib
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from monai.transforms import Resize
import monai.transforms as transforms

import segmentation_models_pytorch as smp

# Set matplotlib to display inline
%matplotlib inline
rcParams['figure.figsize'] = 20, 8


In [None]:
## USE THIS CELL FOR GOOGLE DRIVE COLAB SETUP
!pip install -q monai
!pip install -q segmentation-models-pytorch==0.2.1
!pip install pylibjpeg==1.4.0
!pip install python-gdcm==3.0.17.1


#
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Define the path to the "data" folder
data_folder = '/content/drive/My Drive/data'

# Check if the folder exists
if os.path.exists(data_folder):
    print(f"Successfully accessed: {data_folder}")
    print("Files in the folder:")
    print(os.listdir(data_folder))  # List files in the folder
else:
    print("The folder 'data' was not found in Google Drive.")


In [7]:
# Debug mode
DEBUG = True

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# Configuration
kernel_type = 'timm3d_resnet18d_unet4blocks_128cube_dsv2_flip12_shift3.7_gd1.5_bs4_lr3e-4_20x50ep' # maybe use different in future?
load_kernel = None
load_last = True
n_blocks = 4
n_folds = 5
backbone = 'resnet18d'

# Image sizes
image_sizes = [128, 128, 128]
resize_transform = Resize(image_sizes)

# Training hyperparameters
init_lr = 3e-3
batch_size = 4
drop_rate = 0.0
drop_path_rate = 0.0
loss_weights = [1, 1]
p_mixup = 0.1

# Data directories
data_dir = data_folder

# Other configurations
use_amp = True
num_workers = 4
out_dim = 7
n_epochs = 1000

# Directories for logs and models
log_dir = './logs'
model_dir = './models'
os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)


In [None]:
# Load training data
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))

# Prepare the mask dataframe
mask_files = os.listdir(os.path.join(data_dir, 'segmentations'))
df_mask = pd.DataFrame({
    'mask_file': mask_files,
})
df_mask['StudyInstanceUID'] = df_mask['mask_file'].apply(lambda x: x[:-4])
df_mask['mask_file'] = df_mask['mask_file'].apply(lambda x: os.path.join(data_dir, 'segmentations', x))

# Merge training data with mask data
df = df_train.merge(df_mask, on='StudyInstanceUID', how='left')
df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir, 'train_images', x))
df['mask_file'].fillna('', inplace=True)

# Filter samples with masks
df_seg = df.query('mask_file != ""').reset_index(drop=True)

# Define cv folds
kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
df_seg['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(kf.split(df_seg)):
    df_seg.loc[valid_idx, 'fold'] = fold

# Display the last few entries to see if it makes sense
df_seg.tail()


In [10]:
# List of StudyInstanceUIDs to revert masks
revert_list = [
    '1.2.826.0.1.3680043.1363',
    '1.2.826.0.1.3680043.20120',
    '1.2.826.0.1.3680043.2243',
    '1.2.826.0.1.3680043.24606',
    '1.2.826.0.1.3680043.32071'
]

def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = cv2.resize(data, (image_sizes[0], image_sizes[1]), interpolation=cv2.INTER_LINEAR)
    return data

def load_dicom_line_par(path):
    t_paths = sorted(glob(os.path.join(path, "*")),
                     key=lambda x: int(os.path.basename(x).split(".")[0]))
    
    n_scans = len(t_paths)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., image_sizes[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]
    
    images = []
    for filename in t_paths:
        images.append(load_dicom(filename))
    images = np.stack(images, -1)
    
    # Normalize images
    images = images - np.min(images)
    images = images / (np.max(images) + 1e-4)
    images = (images * 255).astype(np.uint8)

    return images

def load_sample(row, has_mask=True):
    image = load_dicom_line_par(row.image_folder)
    if image.ndim < 4:
        image = np.expand_dims(image, 0).repeat(3, 0)  # to 3 channels
    
    if has_mask:
        mask_org = nib.load(row.mask_file).get_fdata()
        shape = mask_org.shape
        mask_org = mask_org.transpose(1, 0, 2)[::-1, :, ::-1]  # (d, w, h)
        mask = np.zeros((7, shape[0], shape[1], shape[2]))
        for cid in range(7):
            mask[cid] = (mask_org == (cid+1))
        mask = (mask * 255).astype(np.uint8)
        mask = resize_transform(mask).numpy()
        
        return image, mask
    else:
        return image

class SEGDataset(Dataset):
    def __init__(self, df, mode, transform):
        self.df = df.reset_index()
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        image, mask = load_sample(row, has_mask=True)
    
        if row.StudyInstanceUID in revert_list:
            mask = mask[:, :, :, ::-1]

        if self.transform:
            res = self.transform({'image': image, 'mask': mask})
            image = res['image'] / 255.
            mask = res['mask']
            mask = (mask > 127).astype(np.float32)
    
        image, mask = torch.tensor(image).float(), torch.tensor(mask).float()
    
        return image, mask


In [11]:
# Data augmentations for training
transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=2),
    transforms.RandAffined(
        keys=["image", "mask"],
        translate_range=[int(x * y) for x, y in zip(image_sizes, [0.3, 0.3, 0.3])],
        padding_mode='zeros',
        prob=0.7
    ),
    transforms.RandGridDistortiond(
        keys=("image", "mask"),
        prob=0.5,
        distort_limit=(-0.01, 0.01),
        mode="nearest"
    ),    
])

# No augmentations for validation
transforms_valid = transforms.Compose([])


In [None]:
# Initialize dataset for visualization
df_show = df_seg
dataset_show = SEGDataset(df_show, 'train', transform=transforms_train)

# Plot samples
rcParams['figure.figsize'] = 20, 8
for i in range(2):
    fig, axarr = plt.subplots(1, 4)
    for p in range(4):
        idx = i * 4 + p
        img, mask = dataset_show[idx]
        img_slice = img[:, :, :, 60]
        mask_slice = mask[:, :, :, 60]
        
        # Combine masks
        mask_slice[0] = mask_slice[0] + mask_slice[3] + mask_slice[6]
        mask_slice[1] = mask_slice[1] + mask_slice[4]
        mask_slice[2] = mask_slice[2] + mask_slice[5]
        mask_slice = mask_slice[:3]
        
        # Overlay mask on image
        img_overlay = img_slice * 0.7 + mask_slice * 0.3
        axarr[p].imshow(img_overlay.transpose(0, 1).transpose(1, 2).squeeze(), cmap='gray')
    plt.show()


In [1]:
# Define the segmentation model using Timm
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv3d(
            decoder_channels[n_blocks-1],
            out_dim,
            kernel_size=(3, 3, 3),
            stride=(1, 1, 1),
            padding=(1, 1, 1)
        )

    def forward(self, x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame  # Ensure this module is available or defined

def convert_3d(module):
    """Recursively convert 2D modules to 3D."""
    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig

    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

# Initialize and convert the model
model = TimmSegModel(backbone, pretrained=True)
model = convert_3d(model)
model.to(device)

# Test the model with a random input
test_input = torch.rand(1, 3, *image_sizes).to(device)
test_output = model(test_input)
print(test_output.shape)


NameError: name 'nn' is not defined

In [None]:
#LOSS FUNCTION AND METRICS
from typing import Any, Dict, Optional

def binary_dice_score(
    y_pred: torch.Tensor,
    y_true: torch.Tensor,
    threshold: Optional[float] = None,
    nan_score_on_empty=False,
    eps: float = 1e-7,
) -> float:
    if threshold is not None:
        y_pred = (y_pred > threshold).to(y_true.dtype)

    intersection = torch.sum(y_pred * y_true).item()
    cardinality = (torch.sum(y_pred) + torch.sum(y_true)).item()

    score = (2.0 * intersection) / (cardinality + eps)

    has_targets = torch.sum(y_true) > 0
    has_predicted = torch.sum(y_pred) > 0

    if not has_targets:
        if nan_score_on_empty:
            score = np.nan
        else:
            score = float(not has_predicted)
    return score

def multilabel_dice_score(
    y_true: torch.Tensor,
    y_pred: torch.Tensor,
    threshold=None,
    eps=1e-7,
    nan_score_on_empty=False,
):
    ious = []
    num_classes = y_pred.size(0)
    for class_index in range(num_classes):
        iou = binary_dice_score(
            y_pred=y_pred[class_index],
            y_true=y_true[class_index],
            threshold=threshold,
            nan_score_on_empty=nan_score_on_empty,
            eps=eps,
        )
        ious.append(iou)

    return ious

def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return 1 - ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

def bce_dice(input, target, loss_weights=loss_weights):
    loss1 = loss_weights[0] * nn.BCEWithLogitsLoss()(input, target)
    loss2 = loss_weights[1] * dice_loss(input, target)
    return (loss1 + loss2) / sum(loss_weights)

# Set the criterion
criterion = bce_dice


In [None]:
# TRAINING and VALIDATION FUNCTIONS
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam

def train_func(model, loader_train, optimizer, scaler=None):
    model.train()
    train_loss = []
    bar = tqdm(loader_train, desc='Training')
    for images, gt_masks in bar:
        optimizer.zero_grad()
        images = images.to(device)
        gt_masks = gt_masks.to(device)

        do_mixup = False
        if random.random() < p_mixup:
            do_mixup = True
            images, gt_masks, gt_masks_sfl, lam = mixup(images, gt_masks)

        with amp.autocast(enabled=use_amp):
            logits = model(images)
            loss = criterion(logits, gt_masks)
            if do_mixup:
                loss2 = criterion(logits, gt_masks_sfl)
                loss = loss * lam + loss2 * (1 - lam)

        train_loss.append(loss.item())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bar.set_description(f'Train Loss: {np.mean(train_loss[-30:]):.4f}')

    return np.mean(train_loss)

def valid_func(model, loader_valid):
    model.eval()
    valid_loss = []
    ths = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    batch_metrics = [[] for _ in ths]
    bar = tqdm(loader_valid, desc='Validation')
    with torch.no_grad():
        for images, gt_masks in bar:
            images = images.to(device)
            gt_masks = gt_masks.to(device)

            logits = model(images)
            loss = criterion(logits, gt_masks)
            valid_loss.append(loss.item())

            for thi, th in enumerate(ths):
                pred = (torch.sigmoid(logits) > th).float()
                for i in range(logits.shape[0]):
                    tmp = multilabel_dice_score(
                        y_pred=pred[i].cpu(),
                        y_true=gt_masks[i].cpu(),
                        threshold=0.5,
                    )
                    batch_metrics[thi].extend(tmp)
            bar.set_description(f'Valid Loss: {np.mean(valid_loss[-30:]):.4f}')

    metrics = [np.mean(this_metric) for this_metric in batch_metrics]
    best_th = ths[np.argmax(metrics)]
    best_dc = np.max(metrics)
    print(f'Best Threshold: {best_th}, Best Dice Coefficient: {best_dc:.4f}')

    return np.mean(valid_loss), best_dc

In [None]:
# Initialize optimizer and scheduler for visualization
optimizer = optim.AdamW(model.parameters(), lr=init_lr)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
lrs = []

for epoch in range(n_epochs):
    scheduler_cosine.step()
    lrs.append(optimizer.param_groups[0]["lr"])

# Plot the learning rate schedule
plt.figure(figsize=(20, 5))
plt.plot(range(len(lrs)), lrs)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Cosine LR scheduler')
plt.show()


In [14]:
def run_fold(fold):
    log_file = os.path.join(log_dir, f'{kernel_type}.txt')
    model_file = os.path.join(model_dir, f'{kernel_type}_fold{fold}_best.pth')

    # Split data into training and validation sets
    train_df = df_seg[df_seg['fold'] != fold].reset_index(drop=True)
    valid_df = df_seg[df_seg['fold'] == fold].reset_index(drop=True)

    # Initialize datasets and dataloaders
    dataset_train = SEGDataset(train_df, 'train', transform=transforms_train)
    dataset_valid = SEGDataset(valid_df, 'valid', transform=transforms_valid)
    loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    loader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # Model
    model = TimmSegModel(backbone, pretrained=True)
    model = convert_3d(model)
    model.to(device)

    # Initialize optimizer and scaler
    optimizer = optim.AdamW(model.parameters(), lr=init_lr)
    scaler = amp.GradScaler(enabled=use_amp)

    # Initialize scheduler
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=n_epochs)

    # Training state variables
    metric_best = 0.0
    loss_min = np.inf

    print(f'Starting Fold {fold}')
    print(f'Training samples: {len(dataset_train)}, Validation samples: {len(dataset_valid)}')

    for epoch in range(1, n_epochs + 1):
        scheduler_cosine.step(epoch - 1)

        print(f'\nEpoch {epoch}/{n_epochs}')
        print('-' * 10)

        # Training
        train_loss = train_func(model, loader_train, optimizer, scaler)

        # Validation 
        valid_loss, metric = valid_func(model, loader_valid)

        # Logging
        content = f'{time.ctime()} Fold {fold}, Epoch {epoch}, LR: {optimizer.param_groups[0]["lr"]:.7f}, Train Loss: {train_loss:.5f}, Valid Loss: {valid_loss:.5f}, Metric: {metric:.6f}.'
        print(content)
        with open(log_file, 'a') as appender:
            appender.write(content + '\n')

        # Save best model
        if metric > metric_best:
            print(f'Improvement from {metric_best:.6f} to {metric:.6f}. Saving model...')
            torch.save(model.state_dict(), model_file)
            metric_best = metric

        # Save the last model
        if not DEBUG:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'score_best': metric_best,
                },
                model_file.replace('_best', '_last')
            )

    # Cleanup
    del model
    torch.cuda.empty_cache()
    gc.collect()


In [None]:
for fold in range(n_folds):
    run_fold(fold)
