In [17]:
import os, glob, torch
from tqdm import tqdm
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, ToTensord, Compose
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import SwinUNETR
from monai.losses import DiceCELoss

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset path
data_dir = "kits19/data"
images = sorted(glob.glob(f"{data_dir}/case_*/imaging.nii.gz"))
labels = sorted(glob.glob(f"{data_dir}/case_*/segmentation.nii.gz"))
data_dicts = [{"image": i, "label": l} for i, l in zip(images, labels)]

# Transforms
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityd(keys="image"),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96),
        pos=1, neg=1, num_samples=2, image_key="image", image_threshold=0,
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=["image", "label"], prob=0.5),
    ToTensord(keys=["image", "label"]),
])

train_ds = Dataset(data=data_dicts[:50], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

# SwinUNETR model
model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=2,
    feature_size=48,
    use_checkpoint=True
).to(device)

# Loss + Optimizer
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Train
max_epochs = 5
for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader):
        inputs, labels = batch["image"].to(device), batch["label"].to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{max_epochs} - Loss: {epoch_loss / len(train_loader):.4f}")

torch.save(model.state_dict(), "swinunetr_kits19.pth")
print("✅ Model saved as swinunetr_kits19.pth")


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
import matplotlib.pyplot as plt
from monai.transforms import LoadImage, EnsureChannelFirst, Spacing, Orientation, ScaleIntensity, Compose
from monai.inferers import sliding_window_inference
import numpy as np

val_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    Spacing(pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
    Orientation(axcodes="RAS"),
    ScaleIntensity()
])

test_img_path = sorted(glob.glob("kits19_preprocessed/case_00030/imaging.nii.gz"))[0]
image = val_transforms(test_img_path).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    pred = sliding_window_inference(image, roi_size=(128, 128, 64), sw_batch_size=1, predictor=model)

# Show one slice (middle slice)
pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
slice_idx = pred.shape[-1] // 2
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.imshow(image.cpu().numpy()[0, 0, :, :, slice_idx], cmap="gray")
plt.title("CT Slice")
plt.subplot(1, 2, 2)
plt.imshow(pred[:, :, slice_idx])
plt.title("Predicted Tumor Mask")
plt.show()


In [None]:
import glob

data_dir = "kits19"
images = sorted(glob.glob(f"{data_dir}/case_*/imaging.nii.gz"))
labels = sorted(glob.glob(f"{data_dir}/case_*/segmentation.nii.gz"))

print(f"✅ Found {len(images)} images and {len(labels)} labels")
