In [None]:
import os, glob, torch
from tqdm import tqdm
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, ToTensord, Compose
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import SwinUNETR
from monai.losses import DiceCELoss

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset path
data_dir = "kits19/data"
images = sorted(glob.glob(f"{data_dir}/case_*/imaging.nii.gz"))
labels = sorted(glob.glob(f"{data_dir}/case_*/segmentation.nii.gz"))
data_dicts = [{"image": i, "label": l} for i, l in zip(images, labels)]

# Transforms
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityd(keys="image"),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96),
        pos=1, neg=1, num_samples=2, image_key="image", image_threshold=0,
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=["image", "label"], prob=0.5),
    ToTensord(keys=["image", "label"]),
])

train_ds = Dataset(data=data_dicts[:50], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

# SwinUNETR model
model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=3,  # background, kidney, tumor
    feature_size=48,
    use_checkpoint=True
).to(device)


# Loss + Optimizer
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Train
max_epochs = 5
for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader):
        inputs, labels = batch["image"].to(device), batch["label"].to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{max_epochs} - Loss: {epoch_loss / len(train_loader):.4f}")

torch.save(model.state_dict(), "swinunetr_kits19.pth")
print("✅ Model saved as swinunetr_kits19.pth")


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
100%|██████████| 50/50 [3:38:41<00:00, 262.44s/it]  


Epoch 1/5 - Loss: 1.4249


  2%|▏         | 1/50 [04:56<4:02:10, 296.55s/it]

In [None]:
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import glob

from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImage, EnsureChannelFirst, Spacing, Orientation,
    ScaleIntensity, Compose
)
from monai.networks.nets import SwinUNETR
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define preprocessing for inference
val_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    Spacing(pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
    Orientation(axcodes="RAS"),
    ScaleIntensity()
])

# Load the trained model
model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=3,  # 0=background, 1=kidney, 2=tumor
    feature_size=48,
    use_checkpoint=True
).to(device)

model.load_state_dict(torch.load("swinunetr_kits19.pth", map_location=device))
model.eval()
print("✅ Model loaded and ready for inference.")
# Change this to another case to try others
image_path = "kits19/data/case_00123/imaging.nii.gz"
label_path = "kits19/data/case_00123/segmentation.nii.gz"

# Load CT image and apply transforms
image = val_transforms(image_path).unsqueeze(0).to(device)  # (1, 1, H, W, D)

# Load ground truth mask for comparison
label = nib.load(label_path).get_fdata().astype(np.uint8)   # shape: (H, W, D)
with torch.no_grad():
    output = sliding_window_inference(
        inputs=image,
        roi_size=(96, 96, 96),
        sw_batch_size=1,
        predictor=model
    )
    pred = torch.argmax(output, dim=1).cpu().numpy()[0]  # shape: (H, W, D)
print("✅ Inference completed.")
# Choose a central slice in the z-axis
slice_idx = pred.shape[-1] // 2

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(image.cpu().numpy()[0, 0, :, :, slice_idx], cmap="gray")
axs[0].set_title("CT Slice")

axs[1].imshow(label[:, :, slice_idx])
axs[1].set_title("Ground Truth Mask")

axs[2].imshow(pred[:, :, slice_idx])
axs[2].set_title("Predicted Mask")

for ax in axs:
    ax.axis("off")
plt.tight_layout()
plt.show()
