In [15]:
import os
from time import time
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from monai.networks.nets import AttentionUnet
from monai.inferers import sliding_window_inference
from monai.data import Dataset
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    EnsureType,
    Orientation,
    Spacing,
    NormalizeIntensity,
    Activations,
    AsDiscrete
)
import torch

DATA_ROOT = 'D:\\github\\MICCAI_BraTS2020_ValidationData'
MRI_TYPE = 't1ce'
EXTENSION = '.nii.gz'
DEVICE = torch.device("cuda:0")
VAL_AMP = True
test_image_paths = glob(os.path.join(DATA_ROOT,'**')+os.sep+'*_'+MRI_TYPE+EXTENSION)
print(f"Found {len(test_image_paths)} test images")
CHOSEN_IDX = 25

Found 125 test images


In [16]:
test_transform = Compose(
    [
        LoadImage(),
        EnsureChannelFirst(),
        EnsureType(),
        Orientation(axcodes="RAS"),
        Spacing(pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

class BrainTumorDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image and label
        image = self.image_paths[idx]
        
        # Apply the transform
        if self.transform:
            image = self.transform(image)

        return image

def prepare_data(image_path: str):
    test_ds = BrainTumorDataset([image_path], transform=test_transform)
    test_image = test_ds[0].unsqueeze(0).to(DEVICE)
    return test_image

In [17]:
# define inference method
def inference(input, model):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.amp.autocast(device_type='cuda'):
            return _compute(input)
    else:
        return _compute(input)

def test(output_dir: str):
    model = AttentionUnet(spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        dropout=0.2,
    ).to(DEVICE)
    model.load_state_dict(torch.load(os.path.join('Checkpoints', "best_metric_model.pth"), weights_only=True))
    model.eval()
    with torch.no_grad():
        # select one image to evaluate and visualize the model output
        image_name = os.path.basename(test_image_paths[CHOSEN_IDX]).split('.')[0]
        image = prepare_data(test_image_paths[CHOSEN_IDX])
        output = inference(image, model)
        output = post_trans(output[0])
        image_array = image.detach().cpu().numpy()
        output_array = output.detach().cpu().numpy()
        image_array = np.squeeze(np.squeeze(image_array))
        output_array = np.squeeze(np.squeeze(output_array))
        print(image_array.shape)
        print(output_array.shape)
        img_sitk = sitk.GetImageFromArray(np.moveaxis(image_array, [0,1,2], [1,2,0]))
        output_sitk = sitk.GetImageFromArray(np.moveaxis(output_array, [0,1,2], [1,2,0]))
        seg_output_path = os.path.join(output_dir, image_name+"_seg.nii.gz")
        img_output_path = os.path.join(output_dir, image_name+".nii.gz")
        # output_sitk.SetSpacing((1.0,1.0,1.0))
        sitk.WriteImage(output_sitk, seg_output_path)
        sitk.WriteImage(img_sitk, img_output_path)
        
        # plt.figure("image", (24, 6))
        # for i in range(1):
        #     plt.subplot(1, 4, i + 1)
        #     plt.title(f"image channel {i}")
        #     plt.imshow(val_ds[6]["image"][i, :, :, 70].detach().cpu(), cmap="gray")
        # plt.show()
        # # visualize the 3 channels model output corresponding to this image
        # plt.figure("output", (18, 6))
        # for i in range(1):
        #     plt.subplot(1, 3, i + 1)
        #     plt.title(f"output channel {i}")
        #     plt.imshow(val_output[i, :, :, 70].detach().cpu())
        # plt.show()
        

In [18]:
test("./")

(240, 240, 155)
(240, 240, 155)
