In [14]:
import os
import torch
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from loguru import logger

DATA_DIR = os.path.join(
    Path().resolve().parent.parent,
    'data', 'formatted'
)

x_train_dir = os.path.join(DATA_DIR, "train_images")
y_train_dir = os.path.join(DATA_DIR, "train_labels")

x_val_dir = os.path.join(DATA_DIR, "val_images")
y_val_dir = os.path.join(DATA_DIR, "val_labels")

In [15]:
class MedicalImageDataset(Dataset):
    # Your Dataset class code here
    CLASSES = [
        "background",
        "gallbladder",
        "stomach",
        "esophagus",
        "right kidney",
        "right adrenal gland",
        "left adrenal gland",
        "liver",
        "left kidney",
        "aorta",
        "spleen",
        "inferior vena cava",
        "pancreas"
    ]

    def __init__(
        self,
        images_dir,
        masks_dir,
        classes=None
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # Convert class names to class values
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

    def __getitem__(self, i):
            # read image and mask
            image = sitk.GetArrayFromImage(
                sitk.ReadImage(self.images_fps[i])
            )
            mask = sitk.GetArrayFromImage(
                sitk.ReadImage(self.masks_fps[i])
            )

            label_mapping = {v: i for i, v in enumerate(self.class_values)}
            mask = np.vectorize(label_mapping.get)(mask).astype(np.int64)

            # add a channel dimension to the image - shape (1, H, W)
            image = np.expand_dims(image, axis=0)
            image = torch.from_numpy(image).float()
            mask = torch.from_numpy(mask).long()

            return image, mask

    def __len__(self):
        return len(self.ids)


---

In [16]:
class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, outputs, targets):
        """
        outputs: tensor of shape (N, C, H, W) where C = number of classes
        targets: tensor of shape (N, H, W) with class indices (0 <= target <= C-1)
        """
        
        num_classes = outputs.size(1)
        # Convert targets to one-hot encoding
        targets_one_hot = torch.nn.functional.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
        
        # Apply softmax to outputs to get probabilities
        outputs = torch.nn.functional.softmax(outputs, dim=1)
        
        # Calculate Dice coefficient per class
        intersection = (outputs * targets_one_hot).sum(dim=(2, 3))
        union = outputs.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))
        dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth)
        
        # Average over batch and classes
        dice_loss = 1.0 - dice_score.mean()
        
        return dice_loss


In [17]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf  # Initialize with infinity
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss  # We aim to minimize validation loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0  # Reset counter if validation loss improves

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [18]:
logger.add("training.log", format="{time} {level} {message}", level="INFO")

3

In [19]:
CLASSES = [
    "background",
    "gallbladder",
    "stomach",
    "esophagus",
    "right kidney",
    "right adrenal gland",
    "left adrenal gland",
    "liver",
    "left kidney",
    "aorta",
    "spleen",
    "inferior vena cava",
    "pancreas"
]
train_dataset = MedicalImageDataset(
    x_train_dir,
    y_train_dir,
    classes=CLASSES,
)
val_dataset = MedicalImageDataset(
    x_val_dir,
    y_val_dir,
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)

In [20]:
model = smp.Unet(
    encoder_name="resnet50",    
    encoder_weights=None,
    in_channels=1,
    classes=len(CLASSES),
)

loss_fn = smp.losses.DiceLoss(mode='multiclass')
optimizer = torch.optim.Adam([
    {'params': model.encoder.parameters(), 'lr': 1e-4},
    {'params': model.decoder.parameters(), 'lr': 1e-3},
    {'params': model.segmentation_head.parameters(), 'lr': 1e-3}
])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)
early_stopping = EarlyStopping(
    patience=5, 
    verbose=True, 
    delta=0.001, 
    path='best_model.pth', 
    trace_func=logger.info
)



In [21]:
# Initialize metrics storage
train_losses = []
val_losses = []
train_dice_scores = []
val_dice_scores = []

num_epochs = 20
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
model.to(device)

# Define evaluation metric
loss_fn = DiceLoss()

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_dice = 0.0

    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]"):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

        # Calculate Dice coefficient for logging
        with torch.no_grad():
            preds = torch.argmax(outputs, dim=1)
            dice = 1 - loss_fn(outputs, masks).item()
            train_dice += dice * images.size(0)

    train_loss /= len(train_dataset)
    train_dice /= len(train_dataset)
    train_losses.append(train_loss)
    train_dice_scores.append(train_dice)

    # Validation Phase
    model.eval()
    val_loss = 0.0
    val_dice = 0.0

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]"):
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = loss_fn(outputs, masks)

            val_loss += loss.item() * images.size(0)

            # Calculate Dice coefficient for logging
            preds = torch.argmax(outputs, dim=1)
            dice = 1 - loss_fn(outputs, masks).item()
            val_dice += dice * images.size(0)

    val_loss /= len(val_dataset)
    val_dice /= len(val_dataset)
    val_losses.append(val_loss)
    val_dice_scores.append(val_dice)

    # Logging
    logger.info(f"Epoch {epoch+1}/{num_epochs}")
    logger.info(f"Training Loss: {train_loss:.4f}, Training Dice Coefficient: {train_dice:.4f}")
    logger.info(f"Validation Loss: {val_loss:.4f}, Validation Dice Coefficient: {val_dice:.4f}")

    # Scheduler step
    scheduler.step(val_loss)

    # Early Stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        logger.info("Early stopping triggered. Stopping training.")
        break

# Load the best model after training
model.load_state_dict(torch.load('best_model.pth'))
logger.info("Training complete. Best model loaded.")

Epoch 1/20 [Training]:   0%|          | 2/1676 [00:10<2:23:31,  5.14s/it]


KeyboardInterrupt: 