In [1]:
import torch
import numpy as np

# dataset
from twaidata.torchdatasets.in_ram_ds import MRISegmentation2DDataset, MRISegmentation3DDataset
from torch.utils.data import DataLoader, random_split, ConcatDataset

# model
from trustworthai.models.base_models.source_kinet import kiunet, reskiunet, densekiunet, kiunet3d
from trustworthai.models.base_models.torchUNet import UNet, UNet3D
from trustworthai.models.base_models.deepmedic import DeepMedic

# augmentation and pretrain processing
from trustworthai.utils.augmentation.standard_transforms import RandomFlip, GaussianBlur, GaussianNoise, \
                                                            RandomResizeCrop, RandomAffine, \
                                                            NormalizeImg, PairedCompose, LabelSelect, \
                                                            PairedCentreCrop, CropZDim
# loss function
from trustworthai.utils.losses_and_metrics.tversky_loss import TverskyLoss
from trustworthai.utils.losses_and_metrics.misc_metrics import IOU
from trustworthai.utils.losses_and_metrics.dice import dice, DiceMetric
from trustworthai.utils.losses_and_metrics.dice_losses import DiceLoss, GeneralizedDiceLoss
from trustworthai.utils.losses_and_metrics.power_jaccard_loss import PowerJaccardLoss
from torch.nn import BCELoss, MSELoss

# fitter
from trustworthai.utils.fitting_and_inference.fitters.basic_lightning_fitter import StandardLitModelWrapper
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import pytorch_lightning as pl

# misc
import os
import torch
import matplotlib.pyplot as plt
import torch
from torchinfo import summary

In [2]:
seed = 3407
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
root_dir = "/disk/scratch/s2208943/ipdis/preprep/out_data/collated/"
wmh_dir = root_dir + "WMH_challenge_dataset/"
ed_dir = root_dir + "EdData/"

In [43]:
domains = [
            wmh_dir + d for d in ["Singapore", "Utrecht", "GE3T"]
          ]
# domains = [
#             wmh_dir + d for d in ["Singapore", "Utrecht", "GE3T"]
#           ] + [
#             ed_dir + d for d in ["domainA", "domainB", "domainC", "domainD"]
#           ]

In [44]:
# function to do train validate test split
test_proportion = 0.1
validation_proportion = 0.2

def train_val_test_split(dataset, val_prop, test_prop, seed):
    # I think the sklearn version might be prefereable for determinism and things
    # but that involves fiddling with the dataset implementation I think....
    size = len(dataset)
    test_size = int(test_prop*size) 
    val_size = int(val_prop*size)
    train_size = size - val_size - test_size
    train, val, test = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(seed))
    return train, val, test

In [71]:
is3D = False

In [72]:
# augmentation definintion
def get_transforms(is3D):
    transforms = [
        LabelSelect(label_id=1),
        RandomFlip(p=0.5, orientation="horizontal"),
        # GaussianBlur(p=0.5, kernel_size=7, sigma=(.1, 1.5)),
        # GaussianNoise(p=0.2, mean=0, sigma=0.2),
        # RandomAffine(p=0.2, shear=(.1,3.)),
        # RandomAffine(p=0.2, degrees=5),
        # RandomResizeCrop(p=1., scale=(0.6, 1.), ratio=(3./4., 4./3.))
    ]
    if not is3D:
        return PairedCompose(transforms)
    else:
        transforms.append(CropZDim(size=32, minimum=0, maximum=-1))
        return PairedCompose(transforms)

In [73]:
if is3D:
    transforms = CropZDim(size=32, minimum=0, maximum=-1)
    datasets_domains = [MRISegmentation3DDataset(root_dir, domain, transforms=get_transforms(is3D)) for domain in domains]
else:
    datasets_domains = [MRISegmentation2DDataset(root_dir, domain, transforms=get_transforms(is3D)) for domain in domains]

# split into train, val test datasets
datasets = [train_val_test_split(dataset, validation_proportion, test_proportion, seed) for dataset in datasets_domains]

# concat the train val test datsets
train_dataset = ConcatDataset([ds[0] for ds in datasets])
val_dataset = ConcatDataset([ds[1] for ds in datasets])
test_dataset = ConcatDataset([ds[2] for ds in datasets])

In [74]:
len(train_dataset), len(val_dataset), len(test_dataset)

(2506, 716, 358)

In [75]:
# define dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size = 16, shuffle=False, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

### setup model

In [76]:
in_channels = 3
out_channels = 1
# TODO: put in the path to the assets on your machine here
checkpoint_dir = "/home/s2208943/ipdis/Trustworthai-MRI-WMH/assets/"

if is3D:
    model = UNet3D(in_channels, out_channels, init_features=16, dropout_p=0., softmax=False)
    checkpoint = checkpoint_dir + "WMHCHAL_unet3d_16_no_softmax_dropout0.ckpt"
else:
    model = UNet(in_channels, out_channels, init_features=32, dropout_p=0., softmax=False)
    checkpoint = checkpoint_dir + "WMHCHAL_unet2d_32_no_softmax_dropout0.ckpt" 
    
loss = GeneralizedDiceLoss(normalization='sigmoid')

In [77]:
model = StandardLitModelWrapper.load_from_checkpoint(checkpoint, model=model, loss=loss)

In [78]:
accelerator="gpu"
devices=1
precision = 16

trainer = pl.Trainer(
    accelerator=accelerator,
    devices=devices,
    precision=precision,
)


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### evaluate

In [79]:
trainer.validate(model, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            0.22884611785411835
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.22884611785411835}]

In [80]:
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.24728116393089294
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.24728116393089294}]

In [None]:
trainer.test(model, datasets[0][0])

In [None]:
# to get some predictions use trainer.predict(model, dataloader)\s
# remeber to use relevant normalization (e.g sigmoid or softmax) if it is not included in the model itself.

# also note that the crop z dim transform by defualt takes a random range of slices from the 3D image.

In [None]:
# get some predictions
x, y = next(iter(val_dataloader))

In [None]:
preds = trainer.predict(model, val_dataloader)

In [None]:
y_hat = preds[0] # select first batch from predictions

In [None]:
dice(y_hat, y)