In [None]:
import torch
import torch.nn.functional as F
import torchio as tio
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from model import UNet

In [None]:
# Advanced Loss Function
class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, predict, target):
        # target: [B, H, W, D], predict: [B, C, H, W, D]
        target_oh = F.one_hot(target, num_classes=3).permute(0, 4, 1, 2, 3).float()
        predict_soft = F.softmax(predict, dim=1)
        
        intersection = torch.sum(predict_soft * target_oh, dim=(2, 3, 4))
        union = torch.sum(predict_soft, dim=(2, 3, 4)) + torch.sum(target_oh, dim=(2, 3, 4))
        
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

In [None]:
# Lightning Module
class Segmenter(pl.LightningModule):
    def __init__(self, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = UNet()
        self.ce_loss = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0, 3.0]))
        self.dice_loss = DiceLoss()

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        """
        This method overrides the default Lightning behavior.
        It moves tensors to the GPU but converts TorchIO objects to plain dicts
        to prevent the 'FileNotFoundError: data' error.
        """
        if isinstance(batch, dict):
            return {k: self.transfer_batch_to_device(v, device, dataloader_idx) for k, v in batch.items()}
        elif isinstance(batch, (list, tuple)):
            return [self.transfer_batch_to_device(v, device, dataloader_idx) for v in batch]
        elif isinstance(batch, torch.Tensor):
            return batch.to(device)
        return batch
        
    def forward(self, x):
        return self.model(x)

    def _shared_step(self, batch, stage):
        img = batch["CT"]["data"]
        mask = batch["Label"]["data"][:, 0].long()
        logits = self(img)
        loss = self.ce_loss(logits, mask) + self.dice_loss(logits, mask)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True)
        return loss, img, logits, mask

    def training_step(self, batch, batch_idx):
        loss, _, _, _ = self._shared_step(batch, "train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss, img, logits, mask = self._shared_step(batch, "val")
        if batch_idx == 0:
            self.log_images(img, logits, mask)
        return loss

    def log_images(self, img, pred, mask):
        pred_class = torch.argmax(pred, dim=1)
        axial_slice = img.shape[-1] // 2
        
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(img[0, 0, :, :, axial_slice].cpu(), cmap="bone")
        ax[0].imshow(np.ma.masked_where(mask[0, :, :, axial_slice].cpu() == 0, mask[0, :, :, axial_slice].cpu()), alpha=0.5, cmap="autumn")
        ax[0].set_title("Ground Truth")
        
        ax[1].imshow(img[0, 0, :, :, axial_slice].cpu(), cmap="bone")
        ax[1].imshow(np.ma.masked_where(pred_class[0, :, :, axial_slice].cpu() == 0, pred_class[0, :, :, axial_slice].cpu()), alpha=0.5, cmap="autumn")
        ax[1].set_title("Prediction")
        
        self.logger.experiment.add_figure("Validation_Overlay", fig, self.global_step)
        plt.close()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}


In [None]:
# Data Pipeline
def get_dataloaders():
    path = Path("Task03_Liver_rs/imagesTr/")
    subjects = []
    
    for img_path in path.glob("liver_*"):
        label_path = Path(str(img_path).replace("imagesTr", "labelsTr"))
        if label_path.exists():
            subjects.append(tio.Subject(
                CT=tio.ScalarImage(img_path), 
                Label=tio.LabelMap(label_path)
            ))

    process = tio.Compose([
        tio.CropOrPad((256, 256, 200)),
        tio.RescaleIntensity(out_min_max=(0, 1))
    ])
    
    augment = tio.RandomAffine(scales=(0.9, 1.1), degrees=10)
    
    train_idx = int(len(subjects) * 0.8)
    train_ds = tio.SubjectsDataset(subjects[:train_idx], transform=tio.Compose([process, augment]))
    val_ds = tio.SubjectsDataset(subjects[train_idx:], transform=process)

    sampler = tio.data.LabelSampler(patch_size=96, label_probabilities={0: 0.1, 1: 0.4, 2: 0.5})
    
    train_loader = torch.utils.data.DataLoader(
        tio.Queue(train_ds, max_length=40, samples_per_volume=5, sampler=sampler), 
        batch_size=2, num_workers=0
    )
    val_loader = torch.utils.data.DataLoader(
        tio.Queue(val_ds, max_length=40, samples_per_volume=5, sampler=sampler), 
        batch_size=2, num_workers=0
    )
    
    return train_loader, val_loader

In [None]:
# Main Execution
if __name__ == "__main__":
    train_loader, val_loader = get_dataloaders()
    model = Segmenter(learning_rate=1e-4)
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss", 
        dirpath="checkpoints/", 
        filename="liver-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3, 
        mode="min"
    )

    trainer = pl.Trainer(
        accelerator="gpu", 
        devices=1, 
        max_epochs=100,
        precision="16-mixed",
        callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='epoch')],
        logger=TensorBoardLogger("logs/", name="Liver_Segmentation")
    )
    
    trainer.fit(model, train_loader, val_loader)