In [13]:
import numpy as np
import time
from tqdm import tqdm
import os
import sys
import torchio as tio
import logging

# import torch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast

# Import MONAI libraries                <--- CLEAN UP THESE IMPORTS ONCE WE KNOW WHAT libraries are used
import monai
from monai.config import print_config
from monai.data import ArrayDataset, decollate_batch, DataLoader
from monai.handlers import (
    CheckpointLoader,
    IgniteMetric,
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
)
from monai.metrics import DiceMetric, LossMetric, HausdorffDistanceMetric
from monai.losses import DiceLoss, DiceFocalLoss
from monai.networks import nets as monNets
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    EnsureChannelFirst,
    AsDiscrete,
    Compose
)
from monai.inferers import sliding_window_inference
from monai.utils import first
from monai.utils.misc import set_determinism

# Other imports (unsure)
# import ignite
import nibabel

import torch
import os
import torch.utils.data as data_utils
import json
from subprocess import call
from sklearn.model_selection import train_test_split
# from utils.utils import get_main_args
import pickle
import glob

import matplotlib.pyplot as plt
from monai.data import Dataset

# Dataset class

In [93]:
class MRIDataset(Dataset):
    """
    Given a set of images and corresponding labels (i.e, will give it all training images + labels, and same for val and test)
    folder structure: subjectID/subjectID-stk.npy, -lbl.npy (i.e. contains 2 files)
    """

    def __init__(self, data_dir, data_folders, transform=None, SSAtransform=None):
            self.data_folders = data_folders                            # path for each data folder in the set
            self.transform = transform
            self.SSAtransform = SSAtransform
            self.mode = 'val'
            self.imgs = []                                              # store images to load (paths)
            self.lbls = []                                              # store corresponding labels (paths)
            for img_folder in self.data_folders:                        # run through each subjectID folder
                folder_path = os.path.join(data_dir, img_folder)                                                            
                self.SSA = True if 'SSA' in img_folder else False       # check if current file is from SSA dataset
                for file in os.listdir(folder_path):                    # check folder contents
                    if os.path.isfile(os.path.join(folder_path, file)):
                        if file.endswith("-lbl.npy"):
                            self.lbls.append(os.path.join(folder_path, file))   # Save segmentation mask (file path)
                        elif file.endswith("-stk.npy"):
                            self.imgs.append(os.path.join(folder_path, file))   # Save image (file path)

    def __len__(self):
        # Return the amount of images in this set
        return len(self.imgs)
    
    def __getitem__(self, idx):
        name = os.path.dirname(self.imgs[idx])
        # Load files
        image = np.load(self.imgs[idx])
        image = torch.from_numpy(image) # 4, 240, 240, 155
        if self.mode == 'train':
            mask = np.load(self.lbls[idx])
            mask = torch.from_numpy(mask) # 240, 240, 155

        # print(self.imgs[idx] )
        # print("========================")
        # print(self.lbls[idx] )
        # print("========================")           

        if self.transform is not None: # Apply general transformations
        # transforms such as crop, flip, rotate etc will be applied to both the image and the mask
            if self.mode == 'train':
                subject = tio.Subject(
                    image=tio.ScalarImage(tensor=image),
                    mask=tio.LabelMap(tensor=mask)
                    )
                tranformed_subject = self.transform(subject)
                # Apply transformation to GLI data to reduce quality (creating fake SSA data)
                if self.SSA == False and self.SSAtransform is not None:
                    tranformed_subject = self.SSAtransform(tranformed_subject)
            
                print("Tranformed_subject: ", tranformed_subject)
                image = tranformed_subject["image"].data
                mask = tranformed_subject["mask"].data
                return image, mask, self.imgs[idx]
            else:
                subject = tio.Subject(
                    image=tio.ScalarImage(tensor=image),
                    )
                tranformed_subject = self.transform(subject)           
                print("Tranformed_subject: ", tranformed_subject)
                image = tranformed_subject["image"].data
                return image, self.imgs[idx]

        return image, mask, self.imgs[idx]
    
    def get_paths(self):
        return self.img_pth, self.seg_pth
    
    def get_subj_info(self):
        return self.subj_dir_pths, self.subj_dirs
        #, self.SSA
    
    def get_transforms(self):
        return self.transform


# Data Transforms

In [84]:
def define_transforms(n_channels):
    # Initialise data transforms
    data_transforms = {
        'train': tio.Compose([
            tio.CropOrPad((192, 192, 124)),
            tio.OneOf([
                tio.Compose([
                    tio.RandomFlip(axes=0, p=0.3),
                    tio.RandomFlip(axes=1, p=0.3),
                    tio.RandomFlip(axes=2, p=0.3)]),
                tio.RandomAffine(degrees=15,p=0.3)
            ], p=0.8),
            tio.EnsureShapeMultiple(2**n_channels, method='pad')
        ]),
        'fakeSSA': tio.OneOf({
            tio.OneOf({
                tio.Compose([
                    tio.Resample((1.2, 1.2, 6), scalars_only=True),
                    tio.Resample(1)
                ]):0.50,
                tio.Compose([
                    tio.RandomAnisotropy(axes=(1, 2), downsampling=(1.2), scalars_only=True),
                    tio.RandomAnisotropy(axes=0, downsampling=(6), scalars_only=True)
                ]):0.5,                
            },p=0.80),
            tio.Compose([            
                tio.OneOf({
                    tio.RandomBlur(std=(0.5, 1.5)) : 0.3,
                    tio.RandomNoise(mean=3, std=(0, 0.33)) : 0.7
                },p=0.50),
                tio.OneOf({
                    tio.RandomMotion(num_transforms=3, image_interpolation='nearest') : 0.5,
                    tio.RandomBiasField(coefficients=1) : 0.2,
                    tio.RandomGhosting(intensity=1.5) : 0.3
                }, p=0.50)
            ])
        }, p=0.8), # randomly apply ONE of these given transforms with prob 0.5 
        'val': tio.Compose([
            tio.CropOrPad((192, 192, 124)),
            tio.EnsureShapeMultiple(2**n_channels, method='pad')
        ]),
        'test' : tio.Compose([
            tio.EnsureShapeMultiple(2**n_channels, method='pad')
        ])
    }

    return data_transforms

In [85]:
# pth = '/scratch/guest187/Data/train_gli_hack/results_hack_finetune_ssa/checkpoints/f0_finetune/epoch=39-dice=91.38.ckpt'
pth = '/scratch/guest187/Data/train_all/results/test_fullRunThrough/best_metric_model_fullTest.pth'
checkpoint = torch.load(pth, map_location=torch.device('cpu'))
keys = checkpoint.keys()

for key in checkpoint.keys():
    print(key)

model.0.conv.weight
model.0.conv.bias
model.0.adn.A.weight
model.1.submodule.0.conv.weight
model.1.submodule.0.conv.bias
model.1.submodule.0.adn.A.weight
model.1.submodule.1.submodule.0.conv.weight
model.1.submodule.1.submodule.0.conv.bias
model.1.submodule.1.submodule.0.adn.A.weight
model.1.submodule.1.submodule.1.submodule.0.conv.weight
model.1.submodule.1.submodule.1.submodule.0.conv.bias
model.1.submodule.1.submodule.1.submodule.0.adn.A.weight
model.1.submodule.1.submodule.1.submodule.1.submodule.conv.weight
model.1.submodule.1.submodule.1.submodule.1.submodule.conv.bias
model.1.submodule.1.submodule.1.submodule.1.submodule.adn.A.weight
model.1.submodule.1.submodule.1.submodule.2.conv.weight
model.1.submodule.1.submodule.1.submodule.2.conv.bias
model.1.submodule.1.submodule.1.submodule.2.adn.A.weight
model.1.submodule.1.submodule.2.conv.weight
model.1.submodule.1.submodule.2.conv.bias
model.1.submodule.1.submodule.2.adn.A.weight
model.1.submodule.2.conv.weight
model.1.submodule.2.c

In [86]:
print(checkpoint["state_dict"])

KeyError: 'state_dict'

In [87]:
import pickle
cnfgP = '/scratch/guest187/Data/train_all_nnUNET/train_all_data/results/11_3d/config.pkl'
config = pickle.load(open(cnfgP, "rb"))
print(config)

{'patch_size': [128, 128, 128], 'spacings': [1.0, 1.0, 1.0], 'n_class': 4, 'in_channels': 5}


In [88]:
patch_size, spacings = config["patch_size"], config["spacings"]
strides, kernels, sizes = [], [], patch_size[:]
while True:
    spacing_ratio = [spacing / min(spacings) for spacing in spacings]
    stride = [
        2 if ratio <= 2 and size >= 2 * 2 else 1 for (ratio, size) in zip(spacing_ratio, sizes)
    ]
    kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
    if all(s == 1 for s in stride):
        break
    sizes = [i / j for i, j in zip(sizes, stride)]
    spacings = [i * j for i, j in zip(spacings, stride)]
    kernels.append(kernel)
    strides.append(stride)
    if len(strides) == 6:
        break
strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])

print(strides)
print(kernels)

[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]


# Setup

In [131]:
seed = 42
set_determinism(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def define_model(checkpoint=None):
    """Define model architecture:
        Done before data loader so that transforms has n_channels for EnsureShapeMultiple
    """
    logger = logging.getLogger(__name__)
    
    model=UNet(
        spatial_dims=3,
        in_channels=4,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        # channels=(32, 64, 128, 256, 320, 320), #nnunet channels, depth 6
        # channels=(64, 96, 128, 192, 256, 384, 512), # optinet, depth 7
        strides=(2, 2, 2, 2), # length should = len(channels) - 1
        # kernel_size=,
        # num_res_units=,
        # dropout=0.0,
        )
#     model= monNets.DynUNet(
#                 spatial_dims=3,
#                 in_channels=5,
#                 out_channels=4,
#                 kernel_size=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
#                 strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
#                 upsample_kernel_size=[[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
#                 filters=[64, 96, 128, 192, 256, 384, 512],
#                 norm_name=("instance", {"affine": True}),
#                 act_name=("leakyrelu", {"inplace": False, "negative_slope": 0.01}),
#                 deep_supervision=True,
#                 res_block=False,
#                 trans_bias=True)
    
    model.to(device)
    n_channels = len(model.channels)
    logger.info(f"Number of channels: {n_channels}")

    if checkpoint != None:
        ckpt = torch.load(checkpoint, map_location=torch.device('cpu'))
#         state_dict = (ckpt["state_dict"])    
        model.load_state_dict(ckpt)

    return model, n_channels

"""
Setup transforms, dataset
"""
def define_dataloaders(n_channels):
    data_transform = define_transforms(n_channels)       # Define transforms
    dataloaders = load_data(args, data_transform)        # Load data; also saves json splitData
    # train_loader, val_loader = dataloaders['train'], dataloaders['val']
    return dataloaders

"""
Setup validation stuff
    metrics
    post trans ???????
    define inference
"""
def val_params():
    VAL_AMP = True
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=True, num_classes=4)
    dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch", get_not_nans=True, num_classes=4)
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    return VAL_AMP, dice_metric, dice_metric_batch, post_trans

# define inference method
def inference(VAL_AMP, model, input):
    def _compute(input):
        '''
        roi_size – the spatial window size for inferences. 
        When its components have None or non-positives, the corresponding inputs dimension will be used. 
        if the components of the roi_size are non-positive values, the transform will use the corresponding components of img size.
        For example, roi_size=(32, -1) will be adapted to (32, 64) if the second spatial dimension size of img is 64
        '''
        return sliding_window_inference(
            inputs=input,
            roi_size=None,
            sw_batch_size=4, #sw_batch_size denotes the max number of windows per network inference iteration, not the batch size of inputs.
            predictor=model,
            overlap=0.2,
            mode='constant'
        )
    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

In [132]:
validation_dir='/scratch/guest187/Data/val_SSA'
validation_files = [os.path.join(validation_dir, file) for file in os.listdir(validation_dir) if not (file.endswith(".json") or file == 'images' or file == 'labels' or file == 'ATr_prepoc')]

print(validation_files)
# checkpoint to test: finetuned hackathon model brats21
# checkpoint = '/scratch/guest187/Data/train_gli_hack/results_hack_finetune_ssa/checkpoints/f0_finetune/epoch=39-dice=91.38.ckpt'
checkpoint2 = '/scratch/guest187/Data/train_all/results/test_fullRunThrough/best_metric_model_fullTest.pth'

model, n_channels = define_model(checkpoint2)


# Load validation data to dataloader
data_transforms = define_transforms(n_channels)
validation_dataset = MRIDataset(validation_dir, validation_files, transform=data_transforms['val'])

print(len(validation_dataset))
img, image_path = validation_dataset[0]
print(img.shape)

# Validation parameters
VAL_AMP, dice_metric, dice_metric_batch, post_transforms = val_params()

['/scratch/guest187/Data/val_SSA/BraTS-SSA-00139-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00227-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00126-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00192-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00143-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00169-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00129-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00180-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00158-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00132-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00198-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00210-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00188-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00218-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00148-000']
15
Tranformed_subject:  Subject(Keys: ('image',); images: 1)
torch.Size([4, 192, 192, 128])


In [133]:
model, n_channels = define_model(checkpoint2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
model.eval()

with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_input, _ = validation_dataset[0]
    print(val_input.shape)
    val_output = inference(VAL_AMP, model, val_input)
    val_output = post_transforms(val_output[0])
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(validation_dataset[6]["image"][i, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 3 channels label corresponding to this image
    plt.figure("label", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"label channel {i}")
        plt.imshow(validation_dataset[6]["label"][i, :, :, 70].detach().cpu())
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[i, :, :, 70].detach().cpu())
    plt.show()

Tranformed_subject:  Subject(Keys: ('image',); images: 1)
torch.Size([4, 192, 192, 128])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 128 but got size 256 for tensor number 1 in the list.