In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from monai.config import print_config
from monai.transforms import SpatialPadd
from monai.transforms import Lambdad
from monai.transforms import RandCropByLabelClassesd
from monai.transforms import ScaleIntensityRanged
from monai.data import CacheDataset, DataLoader, Dataset
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd,
    Orientationd, Spacingd, 
    ScaleIntensityd, RandSpatialCropd, ToTensord,
    Activations, AsDiscrete
)
from monai.networks.nets import UNet 
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, MeanIoU
from monai.utils import set_determinism
from monai.transforms import Activations, AsDiscrete, Compose
from sklearn.model_selection import train_test_split
import glob

In [2]:
set_determinism(seed=11)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


# Data Preparation

In [3]:
data_dir = "D:/monai-project/asa/data"
image_files = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "liver_*.nii.gz")))
label_files = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))

In [4]:
len(label_files)

131

# Split data (80% train, 20% test)

In [5]:
train_images, test_images, train_labels, test_labels = train_test_split(
    image_files, label_files, test_size=0.2, random_state=42
)

In [6]:
print(len(train_images), len(train_labels), len(test_images), len(test_labels))

104 104 27 27


# Create data dictionaries

In [7]:
train_files = [{"image": img, "label": lbl} for img, lbl in zip(train_images, train_labels)]
test_files = [{"image": img, "label": lbl} for img, lbl in zip(test_images, test_labels)]

In [8]:
def binarize_label(x):
    return (x == 1).astype(np.uint8)

# Transforms & Cropping

In [9]:
crop_size = (128, 128, 64)

train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1,1,3), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(
        keys="image",
        a_min=-200, a_max=200,
        b_min=0.0, b_max=1.0,
        clip=True
    ),
    SpatialPadd(keys=["image", "label"], spatial_size=crop_size),
    # Replace this:
    # RandSpatialCropd(keys=["image", "label"], roi_size=crop_size, random_size=False),
    # With this:
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=crop_size,
        num_classes=2,
        num_samples=1,
    ),
    Lambdad(keys="label", func=binarize_label),
    ToTensord(keys=["image", "label"]),
])
test_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1,1,3), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(
        keys="image",
        a_min=-200, a_max=200,
        b_min=0.0, b_max=1.0,
        clip=True
    ),
    SpatialPadd(keys=["image", "label"], spatial_size=crop_size),  # <--- Add this line
    Lambdad(keys="label", func=binarize_label),
    ToTensord(keys=["image", "label"]),
])

# Create datasets

In [10]:
train_ds = Dataset(data=train_files, transform=train_transforms)
test_ds = Dataset(data=test_files, transform=test_transforms)

In [11]:
print(len(test_ds))

27


# Data loaders

In [12]:
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=0)

# Model Setup

In [13]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,  # 2 classes: background and liver
    channels=(8, 16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

loss_function = DiceLoss(softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dice_metric = DiceMetric(include_background=True, reduction="mean")
iou_metric = MeanIoU(include_background=True, reduction="mean")


# Training Loop

In [14]:
post_trans = Compose([Activations(softmax=True), AsDiscrete(argmax=True)])


In [15]:
# for i in range(10):
#     sample = train_ds[44][0]
#     print(np.unique(sample["label"]))

In [16]:
# max_epochs = 75
# train_losses = []
# val_dice_scores = []
# val_iou_scores = []

# for epoch in range(max_epochs):
#     model.train()
#     epoch_loss = 0
#     for batch_data in train_loader:
#         inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
#         labels = labels.squeeze(1).long()   # shape (B, H, W, D)
#         optimizer.zero_grad()
#         outputs = model(inputs)  # Raw logits, shape (B, 2, H, W, D)
#         # --- One-hot encode for loss ---
#         labels_onehot = torch.nn.functional.one_hot(labels, num_classes=2)  # [B, H, W, D, 2]
#         labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()        # [B, 2, H, W, D]
#         loss = loss_function(outputs, labels_onehot)
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()
    
#     epoch_loss /= len(train_loader)
#     train_losses.append(epoch_loss)
    
#     # Validation
#     model.eval()
#     dice_vals, iou_vals = [], []
    
#     with torch.no_grad():
#         for test_data in test_loader:
#             inputs, labels = test_data["image"].to(device), test_data["label"].to(device)
#             labels = labels.squeeze(1).long()
#             outputs = sliding_window_inference(
#                 inputs, 
#                 roi_size=crop_size, 
#                 sw_batch_size=4,
#                 predictor=model,
#             )
#             # For metrics, use class indices (not one-hot)
#             dice_metric(y_pred=outputs, y=labels)
#             iou_metric(y_pred=outputs, y=labels)
#             dice_vals.append(dice_metric.aggregate().item())
#             iou_vals.append(iou_metric.aggregate().item())
#             dice_metric.reset()
#             iou_metric.reset()
    
#     avg_dice = np.mean(dice_vals)
#     avg_iou = np.mean(iou_vals)
#     val_dice_scores.append(avg_dice)
#     val_iou_scores.append(avg_iou)
    
#     print(f"Epoch {epoch+1}/{max_epochs}")
#     print(f"Train Loss: {epoch_loss:.4f}")
#     print(f"Val Dice: {avg_dice:.4f}, Jaccard: {avg_iou:.4f}")


In [None]:
max_epochs = 50
train_losses = []
val_dice_scores = []
val_iou_scores = []

for epoch in range(max_epochs):
    print(f"\n--- Epoch {epoch+1}/{max_epochs} ---")  # Start of epoch
    model.train()
    epoch_loss = 0
    for batch_idx, batch_data in enumerate(train_loader):
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        labels = labels.squeeze(1).long()   # shape (B, H, W, D)
        optimizer.zero_grad()
        outputs = model(inputs)  # Raw logits, shape (B, 2, H, W, D)
        
        # --- One-hot encode for loss ---
        labels_onehot = torch.nn.functional.one_hot(labels, num_classes=2)  # [B, H, W, D, 2]
        labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()        # [B, 2, H, W, D]
        loss = loss_function(outputs, labels_onehot)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if (batch_idx + 1) % 5 == 0 or (batch_idx + 1) == len(train_loader):
            print(f"  Batch {batch_idx+1}/{len(train_loader)} - Loss: {loss.item():.4f}")

    epoch_loss /= len(train_loader)
    train_losses.append(epoch_loss)
    
    # Validation
    model.eval()
    dice_vals, iou_vals = [], []
    
    with torch.no_grad():
        for test_data in test_loader:
            inputs, labels = test_data["image"].to(device), test_data["label"].to(device)
            labels = labels.squeeze(1).long()
            outputs = sliding_window_inference(
                inputs, 
                roi_size=crop_size, 
                sw_batch_size=4,
                predictor=model,
            )
            # For metrics, use class indices (not one-hot)
            dice_metric(y_pred=outputs, y=labels)
            iou_metric(y_pred=outputs, y=labels)
            dice_vals.append(dice_metric.aggregate().item())
            iou_vals.append(iou_metric.aggregate().item())
            dice_metric.reset()
            iou_metric.reset()
    
    avg_dice = np.mean(dice_vals)
    avg_iou = np.mean(iou_vals)
    val_dice_scores.append(avg_dice)
    val_iou_scores.append(avg_iou)
    
    print(f"Epoch {epoch+1} Results:")
    print(f"  Train Loss: {epoch_loss:.4f}")
    print(f"  Val Dice:  {avg_dice:.4f}")
    print(f"  Jaccard:   {avg_iou:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"checkpoint_epoch{epoch+1}.pth")
        print(f"Checkpoint saved at epoch {epoch+1}")


--- Epoch 1/50 ---
  Batch 5/52 - Loss: 0.4011
  Batch 10/52 - Loss: 0.4894
  Batch 15/52 - Loss: 0.5334
  Batch 20/52 - Loss: 0.6202
  Batch 25/52 - Loss: 0.4319
  Batch 30/52 - Loss: 0.2848
  Batch 35/52 - Loss: 0.4460


# Visualization of loss and metrics


In [None]:
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.subplot(1,2,2)
plt.plot(val_dice_scores, label="Dice")
plt.plot(val_iou_scores, label="Jaccard")
plt.legend()
plt.savefig("training_metrics.png")
plt.show()

# Visualize test results

In [None]:
model.eval()
with torch.no_grad():
    test_data = next(iter(test_loader))
    input_img = test_data["image"].to(device)
    label_img = test_data["label"].cpu().numpy()[0, 0]
    
    # Get prediction
    pred = sliding_window_inference(
        input_img, 
        roi_size=crop_size, 
        sw_batch_size=4,
        predictor=model,
    )
    pred = post_trans(pred)[0].cpu().numpy()  # shape: (H, W, D)
    
    # Display middle slices
    slice_idx = pred.shape[2] // 2
    input_slice = input_img[0, 0, :, :, slice_idx].cpu().numpy()
    pred_slice = pred[:, :, slice_idx]
    label_slice = label_img[:, :, slice_idx]
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(input_slice, cmap="gray")
    plt.title("Input Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.imshow(pred_slice, cmap="jet", vmin=0, vmax=1)
    plt.title("Prediction (0=bg, 1=liver)")
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.imshow(label_slice, cmap="jet", vmin=0, vmax=1)
    plt.title("Ground Truth")
    plt.axis("off")
    
    plt.savefig("segmentation_comparison.png")
    plt.show()


# Save model


torch.save(model.state_dict(), "liver_seg_unet_2class.pth")