## Setup imports 

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

import monai
from monai.data import list_data_collate, decollate_batch, partition_dataset, DatasetSummary
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    AsChannelLastd,
    AsDiscrete,
    Compose,
    ConcatItemsd,
    LoadImaged,
    RandRotated,
    EnsureTyped,
    AddChanneld,
    CropForegroundd,
    EnsureType,
    Spacingd,
    CenterSpatialCropd,
    SpatialPadd,
    ScaleIntensityRanged,
    SqueezeDimd,
    ScaleIntensityd,
    NormalizeIntensityd
)
from monai.inferers import SimpleInferer



## Define data directory and filenames 

In [None]:
root_dir ='/path/oasis/data/'

file_name_img = 'T1_oasis_reg.nii.gz' # co-registered T1 image from OASIS dataset
file_name_seg = 'brain_lesion_oasis_fu.nii.gz'# merged segmentation map

all_patients = os.listdir(root_dir)

## Read image filenames for inference

In [None]:
test_images = [os.path.join(root_dir,i,file_name_img) for i in all_patients]
test_segs = [os.path.join(root_dir,i,file_name_seg) for i in all_patients]
test_files = [{"img": img, "seg": seg} for img, seg in zip(test_images, test_segs)]

## Define the transformations to be applied on the image data

In [None]:
test_transforms = Compose(
    [   LoadImaged(keys=['img','seg'], reader='ITKReader'),
         AddChanneld(keys=['img','seg'],),
         Spacingd(keys=['img','seg'],pixdim=(0.8,1), mode=['bilinear', 'nearest']),
         CropForegroundd(keys=['img','seg'],source_key='seg', margin=10),
         CenterSpatialCropd(keys=['img','seg'],roi_size =(192,240,160)),
         SpatialPadd(keys=['img','seg'],spatial_size=(192,240,160)),
         NormalizeIntensityd(keys=['img']),
         ScaleIntensityRanged(keys='seg', a_min=0, a_max=5, b_min=0, b_max=1),
         ConcatItemsd(keys=['img','seg'],name='input'),
         EnsureTyped(keys=['img','seg']),
    ]
)

## Define data loader 

In [None]:
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate)
test_loader_iter = iter(test_loader) # create iterable object for visualization

## Visualization of images

In [None]:
batch = next(train_loader_iter)
img, seg = batch['img'], batch['seg']


plt.figure(figsize=(16,10))
for j in range(2):
        plt.subplot(2, 4, j + 1)
        plt.imshow(img[j,0,:, :, 90], cmap="gray")
        
        plt.subplot(2, 4, j + 5)
        plt.imshow(seg[j,0,:, :, 90], cmap="gray")

## Define network and load model 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = monai.networks.nets.UNet( 
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2)
).to(device)

generator.load_state_dict(torch.load("./models/generator.pth"))


## Define transforms for postprocessing

In [None]:
post_transforms = Compose([
    Invertd(
        keys="pred",  # invert the `pred` data field, also support multiple fields
        transform=test_transforms,
        orig_keys="img",  # get the previously applied pre_transforms information on the `img` data field,
                          # then invert `pred` based on this information. we can use same info
                          
        meta_keys="pred_meta_dict",  # key field to save inverted meta data, every item maps to `keys`
        orig_meta_keys="img_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
                                         
                                        
        meta_key_postfix="meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
                                       
                                       
        nearest_interp=False,  
                               
        to_tensor=True,  # convert to PyTorch Tensor after inverting
    ),
    SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="/path/oasis/predictions/", data_root_dir= '/path/oasis/data',output_postfix="mc_fu_syn", separate_folder=False, resample=False),
])


## Inference 

In [None]:
inf = SimpleInferer()
with torch.no_grad():
    for d in test_loader:
        seg = d["input"].to(device)
        # define sliding window size and batch size for windows inference
        d["pred"] = inf(inputs=seg, network=generator)
        # decollate the batch data into a list of dictionaries, then execute postprocessing transforms
        d = [post_transforms(i) for i in decollate_batch(d)]