In [1]:
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
import glob

from monai.inferers import SlidingWindowInferer
from monai.data import DataLoader, Dataset, decollate_batch
from monai.transforms import *
from monai.utils import first

In [2]:
data_dir = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
patient_folders = sorted(glob.glob(os.path.join(data_dir, "BraTS20_*")))

In [4]:
infer_files = []
for patient_folder in patient_folders:
    patient_id = os.path.basename(patient_folder)
    infer_files.append({
        "image_t1": os.path.join(patient_folder, f"{patient_id}_t1.nii"),
        "image_t1ce": os.path.join(patient_folder, f"{patient_id}_t1ce.nii"),
        "image_t2": os.path.join(patient_folder, f"{patient_id}_t2.nii"),
        "image_flair": os.path.join(patient_folder, f"{patient_id}_flair.nii"),
    })

In [6]:
keys = ["image_t1", "image_t1ce", "image_t2", "image_flair"]

infer_transforms = Compose(
    [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
        Orientationd(keys=keys, axcodes="RAS"),
        Spacingd(keys=keys, pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        ScaleIntensityRanged(keys=keys, a_min=0.0, a_max=4000.0, b_min=0.0, b_max=1.0, clip=True),
        ConcatItemsd(keys=keys, name="image", dim=0),
    ]
)

In [7]:
infer_ds_transformed = Dataset(data=[infer_files[0]], transform=infer_transforms)
infer_loader = DataLoader(infer_ds_transformed, batch_size=1, num_workers=4)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [9]:
model = torch.jit.load("/kaggle/input/monai_brats_pretrained/pytorch/pretrained/1/model.ts", map_location=device)
output_dir = "./outputs_brats"
os.makedirs(output_dir, exist_ok=True)

In [10]:
inferer = SlidingWindowInferer(roi_size=(128, 128, 128), sw_batch_size=4, overlap=0.5)

post_transforms = Compose([
    AsDiscrete(argmax=True, to_onehot=3),
    SaveImaged(
        keys="pred",
        meta_keys="image_flair_meta_dict",
        output_dir=output_dir,
        output_postfix="seg",
        resample=False,
    ),
])

In [None]:
model.eval()
with torch.no_grad():
    for infer_data in infer_loader:
        inputs = infer_data["image"].to(device)
        outputs = inferer(inputs, model)

        raw_data = first(Dataset(data=[infer_files[0]]))
        outputs = decollate_batch(outputs)

        for i, output in enumerate(outputs):
            d = {"pred": output, "image_flair_meta_dict": raw_data['image_flair_meta_dict']}
            post_transforms(d)


print(f"Inference complete. Segmentation saved to: {output_dir}")

In [None]:
original_flair_path = infer_files[0]["image_flair"]
base_name = os.path.basename(original_flair_path).replace(".nii", "")
saved_seg_path = os.path.join(output_dir, f"{base_name}_seg.nii")

flair_img = LoadImage(image_only=True)(original_flair_path)
segmentation = LoadImage(image_only=True)(saved_seg_path)

slice_idx = flair_img.shape[2] // 2

plt.figure("BraTS Segmentation Overlay", (12, 6))
plt.subplot(1, 2, 1)
plt.title("FLAIR Image Slice")
plt.imshow(flair_img[:, :, slice_idx], cmap="gray", origin="lower")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Segmentation Overlay on FLAIR")
plt.imshow(flair_img[:, :, slice_idx], cmap="gray", origin="lower")

for i in range(segmentation.shape[0]):
    if i == 0:
        continue
    plt.imshow(
        np.ma.masked_where(segmentation[i, :, :, slice_idx] == 0, 1),
        cmap=plt.cm.Blues if i == 1 else plt.cm.Reds,
        alpha=0.5,
        origin="lower"
    )

plt.axis("off")
plt.tight_layout()
plt.show()