In [None]:
# =============================================================== #
#                    3D Spleen Segmenatation                      #
#                      Mustansir Verdawala                        #
# =============================================================== #

In [None]:
#%% Install necessary libraries

!pip install "monai[all]" nibabel SimpleITK torch torchvision torchaudio matplotlib numpy scikit-learn tqdm huggingface_hub

In [None]:
#%% Downloading dataset

from datasets import load_dataset

# Load the dataset
ds = load_dataset("Angelou0516/msd-spleen")

from huggingface_hub import snapshot_download
import nibabel as nib
import os

# Download the full dataset
local_path = snapshot_download(
    repo_id="Angelou0516/msd-spleen",
    local_dir="/content/msd-spleen",
    repo_type="dataset"
)

In [None]:
#%% Image Analysis

import nibabel as nib
import os
import random
import shutil
from glob import glob

root_dir = "/content/msd-spleen"
img_dir = os.path.join(root_dir, "imagesTr")
lbl_dir = os.path.join(root_dir, "labelsTr")

split_root = "/content/msd-spleen_split"
splits = ["train", "val", "test"]

for s in splits:
    os.makedirs(os.path.join(split_root, s, "images"), exist_ok=True)
    os.makedirs(os.path.join(split_root, s, "labels"), exist_ok=True)

images = sorted(glob(os.path.join(img_dir, "*.nii.gz")))
labels = sorted(glob(os.path.join(lbl_dir, "*.nii.gz")))

random.seed(42)  # set for reproducibility
indices = list(range(len(images)))
random.shuffle(indices)

train_idx = indices[:20]
val_idx = indices[20:30]
test_idx = indices[30:]

def copy_split(indices, split_name):
    for i in indices:
        img_file = images[i]
        lbl_file = labels[i]
        shutil.copy(img_file, os.path.join(split_root, split_name, "images"))
        shutil.copy(lbl_file, os.path.join(split_root, split_name, "labels"))

copy_split(train_idx, "train")
copy_split(val_idx, "val")
copy_split(test_idx, "test")

print(f"Split complete. Files saved under: {split_root}")
for s in splits:
    n_imgs = len(os.listdir(os.path.join(split_root, s, "images")))
    print(f"{s.capitalize()}: {n_imgs} images")

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import os
from glob import glob

local_path = "/content/msd-spleen"

image_files = sorted(glob(os.path.join(local_path, "imagesTr", "*.nii.gz")))
label_files = sorted(glob(os.path.join(local_path, "labelsTr", "*.nii.gz")))

idx = 18
image_path = image_files[idx]
mask_path = label_files[idx]

image = nib.load(image_path)
mask = nib.load(mask_path)

image_data = image.get_fdata()
mask_data = mask.get_fdata()

print(f"Loaded: {os.path.basename(image_path)}")
print(f"Image shape: {image_data.shape}")
print(f"Mask shape: {mask_data.shape}")

slice_idx = image_data.shape[2] // 2

plt.figure(figsize=(10,5))

# CT image
plt.subplot(1,2,1)
plt.imshow(image_data[:, :, slice_idx], cmap='gray')
plt.title(f"CT Slice {slice_idx}")
plt.axis('off')

# Mask overlay
plt.subplot(1,2,2)
plt.imshow(mask_data[:, :, slice_idx], cmap='gray', alpha=1)
plt.title("Mask")
plt.axis('off')

plt.tight_layout()
plt.show()

img_flat = image_data.flatten()
mask_flat = mask_data.flatten()

spleen_voxels = image_data[mask_data > 0]

plt.figure(figsize=(10,5))
plt.hist(img_flat, bins=100, color='gray', alpha=0.6, label='All Voxels')
plt.xlabel("Intensity")
plt.ylabel("Voxel Count")
plt.legend()
plt.title("Voxel Intensity Distribution")
plt.show()

plt.figure(figsize=(10,5))
plt.hist(spleen_voxels, bins=100, color='red', alpha=0.6, label='Spleen Region')
plt.xlabel("Intensity")
plt.ylabel("Voxel Count")
plt.legend()
plt.title("Voxel Intensity Distribution")
plt.show()

print(np.unique(mask_data))

In [None]:
import os
import nibabel as nib
from glob import glob

image_dir = "/content/msd-spleen/imagesTr"
label_dir = "/content/msd-spleen/labelsTr"

image_files = sorted(glob(os.path.join(image_dir, "*.nii.gz")))
label_files = sorted(glob(os.path.join(label_dir, "*.nii.gz")))

print("Image Shapes")
for f in image_files:
    img = nib.load(f)
    print(f"{os.path.basename(f)}: {img.get_fdata().shape}")

In [None]:
import os
import torch
import numpy as np
import nibabel as nib
from glob import glob
from scipy.ndimage import binary_closing, generate_binary_structure
from monai.data import Dataset, DataLoader
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from skimage.transform import resize
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device set: {DEVICE}")

root_dir = "/content/msd-spleen_split"
train_images = sorted(glob(os.path.join(root_dir, "train/images/*.nii.gz")))
train_labels = sorted(glob(os.path.join(root_dir, "train/labels/*.nii.gz")))
val_images = sorted(glob(os.path.join(root_dir, "val/images/*.nii.gz")))
val_labels = sorted(glob(os.path.join(root_dir, "val/labels/*.nii.gz")))

PATCH_DEPTH = 128
IMG_SIZE = 256

# Preprocessing
def preprocess(img_path, lbl_path, hu_min=-190, hu_max=300, closing_radius=3):
    img = nib.load(img_path).get_fdata()
    lbl = nib.load(lbl_path).get_fdata()
    lbl[img < hu_min] = 0
    lbl[img > hu_max] = 0
    struct = generate_binary_structure(3, 1)
    struct = np.pad(struct, closing_radius-1, mode='constant', constant_values=0)
    lbl_closed = binary_closing(lbl, structure=struct).astype(np.uint8)
    img = np.clip(img, -1024, 1023)
    img = (img + 1024) / 2047
    return img.astype(np.float32), lbl_closed

def generate_patches(img, lbl, patch_depth=PATCH_DEPTH):
    patches = []
    z_slices = img.shape[2]
    start = 0
    while start < z_slices:
        end = start + patch_depth
        if end <= z_slices:
            patch_img = img[:, :, start:end]
            patch_lbl = lbl[:, :, start:end]
        else:
            patch_img = img[:, :, -patch_depth:]
            patch_lbl = lbl[:, :, -patch_depth:]
        # Resize to 256x256x64
        patch_img = resize(
            patch_img, (IMG_SIZE, IMG_SIZE, patch_depth),
            order=1, preserve_range=True, anti_aliasing=True
        ).astype(np.float32)
        patch_lbl = resize(
            patch_lbl, (IMG_SIZE, IMG_SIZE, patch_depth),
            order=0, preserve_range=True, anti_aliasing=False
        ).astype(np.uint8)
        patches.append((patch_img, patch_lbl))
        start += patch_depth
    return patches

# Dataset
class SpleenPatchDataset(Dataset):
    def __init__(self, img_files, lbl_files):
        self.patches = []
        for i, l in zip(img_files, lbl_files):
            img, lbl = preprocess(i, l)
            self.patches.extend(generate_patches(img, lbl))

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        img, lbl = self.patches[idx]
        img = img[np.newaxis, ...]
        lbl = lbl[np.newaxis, ...]
        # One-hot labels
        lbl_onehot = np.zeros((2, *lbl.shape[1:]), dtype=np.float32)
        lbl_onehot[0] = 1 - lbl
        lbl_onehot[1] = lbl
        return torch.tensor(img, dtype=torch.float32), torch.tensor(lbl_onehot, dtype=torch.float32)

train_ds = SpleenPatchDataset(train_images, train_labels)
val_ds = SpleenPatchDataset(val_images, val_labels)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=True)

# Model
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2,2,2),
    num_res_units=2,
    dropout=0.5,
    act="leakyrelu",
    kernel_size=3
).to(DEVICE)

loss_fn = DiceLoss(softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200, 350], gamma=0.1)

# Training loop with plotting
EPOCHS = 1000
best_val_loss = 1.0
patience = 30
trigger_times = 0

for epoch in range(EPOCHS):
    # Train
    model.train()
    train_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss_total = 0
    valid_patches = 0
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(val_loader):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            if masks[0, 1].sum() == 0:
                continue  # skip empty patch
            outputs = model(imgs)

            loss = loss_fn(outputs, masks)
            val_loss_total += loss.item()
            valid_patches += 1

            # Plot first patch every 10 epochs
            if epoch % 10 == 0 and i == 0:
                pred_mask = torch.argmax(torch.softmax(outputs, dim=1), dim=1).cpu().numpy()[0]
                true_mask = torch.argmax(masks, dim=1).cpu().numpy()[0]
                slice_idx = PATCH_DEPTH // 2
                plt.figure(figsize=(10,5))
                plt.subplot(1,2,1)
                plt.imshow(true_mask[:, :, slice_idx], cmap='gray')
                plt.title("Ground Truth")
                plt.subplot(1,2,2)
                plt.imshow(pred_mask[:, :, slice_idx], cmap='gray')
                plt.title(f"Prediction Epoch {epoch+1}")
                plt.show()

    val_loss = val_loss_total / valid_patches

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        trigger_times = 0
        torch.save(model.state_dict(), "best_patch_3dunet_spleen.pth")
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    scheduler.step()
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.6e}")


In [None]:
import torch
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
THRESH = 0.5

def dice_coefficient_manual(pred, target, eps=1e-8):
    intersection = torch.sum(pred * target).item()
    union = torch.sum(pred).item() + torch.sum(target).item()
    return (2. * intersection + eps) / (union + eps)

def jaccard_index_manual(pred, target, eps=1e-8):
    intersection = torch.sum(pred * target).item()
    union = torch.sum(pred).item() + torch.sum(target).item() - intersection
    return (intersection + eps) / (union + eps)

@torch.no_grad()
def evaluate_full_scan_torch(loader, model, name="Dataset"):
    model.eval()
    scan_dices = []
    scan_jaccs = []

    for scan_idx, (imgs, masks) in enumerate(tqdm(loader, desc=f"Evaluating {name}")):
        imgs, masks = imgs.to(device), masks.to(device)
        B, C, H, W, D = imgs.shape
        assert B == 1, "Loader should load one scan at a time for full evaluation"

        prob_accum = torch.zeros((2, H, W, D), device=device)
        count_accum = torch.zeros((H, W, D), device=device)

        outputs = model(imgs)  # shape: 1x2xHxWxD
        probs = torch.softmax(outputs, dim=1)[0]  # 2xHxWxD
        prob_accum += probs
        count_accum += 1.0

        prob_avg = prob_accum / torch.clamp(count_accum, min=1.0)

        pred_bin = (prob_avg[1] >= THRESH).float()
        true_bin = masks[0, 1].float()

        dice_val = dice_coefficient_manual(pred_bin, true_bin)
        jacc_val = jaccard_index_manual(pred_bin, true_bin)

        scan_dices.append(dice_val)
        scan_jaccs.append(jacc_val)
        print(f"\nScan {scan_idx} | Dice: {dice_val:.4f} | Jaccard: {jacc_val:.4f}")

    mean_dice = np.mean(scan_dices)
    mean_jacc = np.mean(scan_jaccs)
    print(f"\n{name} Overall | Mean Dice: {mean_dice:.4f} | Mean Jaccard: {mean_jacc:.4f}")
    print("-" * 40)
    return mean_dice, mean_jacc

test_images = sorted(glob(os.path.join(root_dir, "test/images/*.nii.gz")))
test_labels = sorted(glob(os.path.join(root_dir, "test/labels/*.nii.gz")))

train_loader_full = DataLoader(train_ds, batch_size=1, shuffle=False)
val_loader_full = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader_full = DataLoader(SpleenPatchDataset(test_images, test_labels), batch_size=1, shuffle=False)

train_dice, train_jacc = evaluate_full_scan_torch(train_loader_full, model, name="Train Set")
val_dice, val_jacc = evaluate_full_scan_torch(val_loader_full, model, name="Validation Set")
test_dice, test_jacc = evaluate_full_scan_torch(test_loader_full, model, name="Test Set")