In [None]:
## Imports

In [3]:
import numpy as np
import time
from tqdm import tqdm
import os
import sys
import torchio as tio
import logging

# import torch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast

# Import MONAI libraries                <--- CLEAN UP THESE IMPORTS ONCE WE KNOW WHAT libraries are used
import monai
from monai.config import print_config
from monai.data import ArrayDataset, decollate_batch, DataLoader
from monai.handlers import (
    CheckpointLoader,
    IgniteMetric,
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
)
from monai.metrics import DiceMetric, LossMetric, HausdorffDistanceMetric
from monai.losses import DiceLoss, DiceFocalLoss
from monai.networks import nets as monNets
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    EnsureChannelFirst,
    AsDiscrete,
    Compose
)
from monai.inferers import sliding_window_inference
from monai.utils import first
from monai.utils.misc import set_determinism

# Other imports (unsure)
# import ignite
import nibabel

import torch
import os
import torch.utils.data as data_utils
import json
from subprocess import call
from sklearn.model_selection import train_test_split
# from utils.utils import get_main_args
import pickle
import glob

import matplotlib.pyplot as plt
from monai.data import Dataset

## Data class

In [17]:
class MRIDataset(Dataset):
    """
    Given a set of images and corresponding labels (i.e, will give it all training images + labels, and same for val and test)
    folder structure: subjectID/subjectID-stk.npy, -lbl.npy (i.e. contains 2 files)
    """

    def __init__(self, data_dir, data_folders, transform=None, SSAtransform=None):
            self.data_folders = data_folders                            # path for each data folder in the set
            self.transform = transform
            self.SSAtransform = SSAtransform
            self.imgs = []                                              # store images to load (paths)
            self.lbls = []                                              # store corresponding labels (paths)
            for img_folder in self.data_folders:                        # run through each subjectID folder
                folder_path = os.path.join(data_dir, img_folder)                                                            
                self.SSA = True if 'SSA' in img_folder else False       # check if current file is from SSA dataset
                for file in os.listdir(folder_path):                    # check folder contents
                    if os.path.isfile(os.path.join(folder_path, file)):
                        if file.endswith("-lbl.npy"):
                            self.lbls.append(os.path.join(folder_path, file))   # Save segmentation mask (file path)
                            self.mode = "labels"
                        elif file.endswith("-stk.npy"):
                            self.imgs.append(os.path.join(folder_path, file))   # Save image (file path)

    def __len__(self):
        # Return the amount of images in this set
        return len(self.imgs)
    
    def __getitem__(self, idx):
        name = os.path.dirname(self.imgs[idx])
        # Load files
        if self.mode == "labels":
            mask = np.load(self.lbls[idx])
            mask = torch.from_numpy(mask) # 240, 240, 155

        # print(self.imgs[idx] )
        # print("========================")
        # print(self.lbls[idx] )
        # print("========================")           

        if self.transform is not None: # Apply general transformations
        # transforms such as crop, flip, rotate etc will be applied to both the image and the mask
            if self.mode == "labels":
                subject = tio.Subject(
                    image=tio.ScalarImage(tensor=image),
                    mask=tio.LabelMap(tensor=mask)
                    )
                tranformed_subject = self.transform(subject)
                # Apply transformation to GLI data to reduce quality (creating fake SSA data)
                if self.SSA == False and self.SSAtransform is not None:
                    tranformed_subject = self.SSAtransform(tranformed_subject)
            
                print("Tranformed_subject: ", tranformed_subject)
                image = tranformed_subject["image"].data
                mask = tranformed_subject["mask"].data
                return image, mask, self.imgs[idx]
            else:
                subject = tio.Subject(
                    image=tio.ScalarImage(tensor=image),
                    )
                tranformed_subject = self.transform(subject)           
                print("Tranformed_subject: ", tranformed_subject)
                image = tranformed_subject["image"].data
                return image, self.imgs[idx]

        return image, mask, self.imgs[idx]
    
    def get_paths(self):
        return self.img_pth, self.seg_pth
    
    def get_subj_info(self):
        return self.subj_dir_pths, self.subj_dirs
        #, self.SSA
    
    def get_transforms(self):
        return self.transform


In [18]:
def define_transforms(n_channels):
    # Initialise data transforms
    data_transforms = {
        'train': tio.Compose([
            tio.CropOrPad((192, 192, 124)),
            tio.OneOf([
                tio.Compose([
                    tio.RandomFlip(axes=0, p=0.3),
                    tio.RandomFlip(axes=1, p=0.3),
                    tio.RandomFlip(axes=2, p=0.3)]),
                tio.RandomAffine(degrees=15,p=0.3)
            ], p=0.8),
            tio.EnsureShapeMultiple(2**n_channels, method='pad')
        ]),
        'fakeSSA': tio.OneOf({
            tio.OneOf({
                tio.Compose([
                    tio.Resample((1.2, 1.2, 6), scalars_only=True),
                    tio.Resample(1)
                ]):0.50,
                tio.Compose([
                    tio.RandomAnisotropy(axes=(1, 2), downsampling=(1.2), scalars_only=True),
                    tio.RandomAnisotropy(axes=0, downsampling=(6), scalars_only=True)
                ]):0.5,                
            },p=0.80),
            tio.Compose([            
                tio.OneOf({
                    tio.RandomBlur(std=(0.5, 1.5)) : 0.3,
                    tio.RandomNoise(mean=3, std=(0, 0.33)) : 0.7
                },p=0.50),
                tio.OneOf({
                    tio.RandomMotion(num_transforms=3, image_interpolation='nearest') : 0.5,
                    tio.RandomBiasField(coefficients=1) : 0.2,
                    tio.RandomGhosting(intensity=1.5) : 0.3
                }, p=0.50)
            ])
        }, p=0.8), # randomly apply ONE of these given transforms with prob 0.5 
        'val': tio.Compose([
            tio.CropOrPad((192, 192, 124)),
            tio.EnsureShapeMultiple(2**n_channels, method='pad')
        ]),
        'test' : tio.Compose([
            tio.EnsureShapeMultiple(2**n_channels, method='pad')
        ])
    }

    return data_transforms

In [25]:
def load_data(args, data_transforms):

    '''
    This function is called during training after define_transforms(n_channels)

    It takes as input
        args: argparsers from the utils script 
            args.seed
            args.data_used: 'all', 'GLI', 'SSA'
        data_transforms: a dictionary of transformations to apply to the data during training

    Returns dataloaders ready to be fed into model
    '''
    logger = logging.getLogger(__name__)

    # Set a seed for reproducibility if you want the same split - optional
    if args.seed != None:
        seed=args.seed
        logger.info(f"Seed set to {seed}.")
    else:
        seed=None
        logger.info("No seed has been set")
    
    # Locate data based on which dataset is being used
    if args.data_used == 'all':
        data_folders = glob.glob(os.path.join(args.data, "BraTS*"))
    elif args.data_used == "GLI":
        data_folders = [folder for folder in os.listdir(args.data) if 'GLI' in folder]
    elif args.data_used == 'SSA':
        data_folders = [folder for folder in os.listdir(args.data) if 'SSA' in folder]

    # Split data files
    train_files, val_files = split_data(data_folders, seed) 
    logger.info(f"Number of training files: {len(train_files)}\nNumber of validation files: {len(val_files)}")
    
    image_datasets = {
        'train': MRIDataset(args.data, train_files, transform=data_transforms['train']),
        'val': MRIDataset(args.data, val_files, transform=data_transforms['val']),
    }

    # Create dataloaders
    # can set num_workers for running sub-processes
    dataloaders = {
        'train': data_utils.DataLoader(image_datasets['train'], batch_size=args.batch_size, shuffle=True, drop_last=True),
        'val': data_utils.DataLoader(image_datasets['val'], batch_size=args.val_batch_size, shuffle=True)
        # 'test': data_utils.DataLoader(image_datasets['test'], batch_size=args.val_batch_size, shuffle=True)
    }

    # Save data split
    splitData = {
        'subjsTr' : train_files,
        'subjsVal' : val_files,
        # 'subjsTest' : test_files    
    }
    with open(args.data + str(args.data_used) + ".json", "w") as file:
        json.dump(splitData, file)

    return dataloaders

def split_data(data_folders, seed):
    '''
    Function to split dataset into train/val/test splits, given all avilable data.
    Input:
        list of paths to numpy files
    Returns:
        lists for each train and val/test sets, where each list contains the file names to be used in the set
    '''
    #-----------------------------
    # originally we split as 3: train-test-val train (70), val (15), test (15):
        # train_files, test_files = train_test_split(data_folders, test_size=0.7, random_state=seed)
        # val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=seed)

    #-----------------------------
    # training loop split is train-val (70-30)
    train_files, val_files = train_test_split(data_folders, test_size=0.7, random_state=seed)

    # ??? validation/testing???

    return train_files, val_files

## Helper Functions

In [26]:
"""Setup transforms, dataset"""
def define_dataloaders(n_channels):
    # Define transforms
    data_transform = define_transforms(n_channels)
    # Load data
    dataloaders = load_data(args, data_transform)                      # this also saves a json splitData
    # train_loader, val_loader = dataloaders['train'], dataloaders['val']
    return dataloaders

In [23]:
"""Define model architecture:
        Done before data loader so that transforms has n_channels for EnsureShapeMultiple
"""
def define_model(checkpoint=None):
    model=UNet(
        spatial_dims=3,
        in_channels=4,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        # channels=(32, 64, 128, 256, 320, 320), #nnunet channels, depth 6
        # channels=(64, 96, 128, 192, 256, 384, 512), # optinet, depth 7
        strides=(2, 2, 2, 2), # length should = len(channels) - 1
        # kernel_size=,
        # num_res_units=,
        # dropout=0.0,
        )
    n_channels = len(model.channels)
    print(f"Number of channels: {n_channels}")

    if checkpoint != None:
        model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu')))

    return model, n_channels

"""
Setup validation stuff
    metrics
    post trans ???????
    define inference
"""
def val_params():
    VAL_AMP = True
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=True, num_classes=4)
    dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch", get_not_nans=True, num_classes=4)
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    return VAL_AMP, dice_metric, dice_metric_batch, post_trans

# define inference method
def inference(VAL_AMP, model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=None, ## best to leave as None, computes itself (else NEEDS TO BE BASED ON LAYER IN, OUT E.G., Conv3d(4, 16; Conv3d(16, 32; etc)
            # roi_size=(128, 128, 64),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
            mode='gaussian'
        )
    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

## Define model

In [24]:
"""General Setup: 
    logging,
    utils.args 
    seed,
    cuda, 
    root dir"""
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# args = get_main_args()
seed = 42
set_determinism(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# root_dir = args.data
# results_dir = args.results

model, n_channels = define_model()
model.to(device)

Number of channels: 5


UNet(
  (model): Sequential(
    (0): Convolution(
      (conv): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): Convolution(
          (conv): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (1): SkipConnection(
          (submodule): Sequential(
            (0): Convolution(
              (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (adn): ADN(
                (N): Inst

DONT RUN THIS FOR TESTING INFERENCE (TRAIN STUFF)

In [11]:
"""
Potentially useful functions for model tracking and checkpoint loading
"""
# def save_checkpoint(model, epoch, best_acc=0, dir_add=results_dir, args=args):
#     filename=f"chkpt_{args.run_name}_{epoch}_{best_acc}.pt"
#     state_dict = model.state_dict()
#     save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
#     filename = os.path.join(dir_add, filename)
#     torch.save(save_dict, filename)
#     print("\nSaving checkpoint", filename)

"""Setup transforms, dataset"""
def define_dataloaders(n_channels):
    # Define transforms
    data_transform = define_transforms(n_channels)
    # Load data
    dataloaders = load_data(args, data_transform)                      # this also saves a json splitData
    # train_loader, val_loader = dataloaders['train'], dataloaders['val']
    return dataloaders

"""Create Model Params:
    optimiser
    loss fn
    lr
"""
def model_params(args, model):
    # Define optimiser
    if args.optimiser == "adam":
        optimiser = torch.optim.Adam(params=model.parameters(), lr=args.learning_rate)
        print("Adam optimizer set")
    elif args.optimiser == "sgd":
        optimiser = torch.optim.SGD(params=model.parameters())
        print("SGD optimizer set")
    elif args.optimiser == "novo":
        optimiser = monai.optimizers.Novograd(params=model.parameters(), lr=args.learning_rate)
    else:
        print("Error, no optimiser provided")

    # Define loss function
    if args.criterion == "ce":
        criterion = nn.CrossEntropyLoss()
        print("Cross Entropy Loss set")
    elif args.criterion == "dice":
        criterion = DiceFocalLoss(squared_pred=True, to_onehot_y=False, sigmoid=True)
        print("Focal-Dice Loss set")
    else:
        print("Error, no loss fn provided")

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=args.epochs)
    
    return optimiser, criterion, lr_scheduler

"""
Define training loop
    initialise empty lists for val
    Add GradScalar which uses automatic mixed precision to accelerate training
    forward and backward passes
    validate training epoch

"""
def train(args, model, device, train_loader, val_loader, optimiser, criterion, lr_scheduler):

    VAL_AMP, dice_metric, dice_metric_batch, post_trans = val_params()

    # Train model --> see MONAI notebook examples
    val_interval = 1
    epoch_loss_list = []
    val_epoch_loss_list = []

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]

    metric_values = []
    metric_values_0 = []
    metric_values_1 = []
    metric_values_2 = []
    metric_values_3 = []

    scaler = GradScaler()

    total_start = time.time()

    for epoch in range(args.epochs):
        epoch_start = time.time()
        # print("-" * 10)
        # print(f"epoch {epoch + 1}/{args.epochs}")
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), dynamic_ncols=True)
        progress_bar.set_description(f"Training Epoch {epoch}")

        # for step, batch in progress_bar:
        for step, batch_data in progress_bar:
            step_start = time.time()
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimiser.zero_grad()
        
            with autocast(): # cast tensor to smaller memory footprint to avoid OOM
                """ FOR USE WITH A DIFFUSION MODEL ONLY
                # Generate random noise
                noise = torch.randn_like(images).to(device)
                Create timesteps
                timesteps = torch.randint(
                    0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
                ).long()
                # Get model prediction
                noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
                loss = F.mse_loss(noise_pred.float(), noise.float())
                """

                # print(inputs.shape)
                outputs = model(inputs)
                # print(outputs.shape)
                loss = criterion.forward(outputs, labels)
            
            # Calculate Loss and Update optimiser using scalar
            scaler.scale(loss).backward()
            scaler.step(optimiser)
            scaler.update()
            epoch_loss += loss.item()
            progress_bar.set_postfix({"bat_train_loss" : loss.item(), "Ave_train_loss" : epoch_loss/(step + 1)})
            
            print(
                f"\n{step}/{len(train_loader.dataset)//train_loader.batch_size}"
                f",     Batch train_loss: {loss.item():.4f}"
                f",     Step time: {(time.time() - step_start):.4f}"
            )
            epoch_loss2 = epoch_loss/(step+1)
            lr_scheduler.step()
        epoch_loss_list.append(epoch_loss2)
        print(f"\nEpoch {epoch} average loss: {epoch_loss2:.4f}")
        
        if (epoch + 1) % val_interval == 0:
            model.eval()
            val_epoch_loss = 0
            progress_bar = tqdm(enumerate(val_loader), total=len(val_loader),dynamic_ncols=True)
            progress_bar.set_description(f"Val_train Epoch {epoch}")

            for step, batch in enumerate(val_loader):
                val_inputs, val_labels = batch[0].to(device), batch[1].to(device)
                """ FOR USE WITH A DIFFUSION MODEL ONLY
                timesteps = torch.randint(
                    0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
                ).long()

                # Get model prediction
                noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
                val_loss = F.mse_loss(noise_pred.float(), noise.float())
                """
                with torch.no_grad():
                    val_outputs = inference(VAL_AMP, model, val_inputs)
                    val_loss = criterion.forward(val_outputs, val_labels)
                    
                    val_labels_list = decollate_batch(val_labels)
                    val_outputs_convert = [post_trans(i) for i in decollate_batch(val_outputs)]
                    dice_metric(y_pred=val_outputs_convert, y=val_labels_list)
                    dice_metric_batch(y_pred=val_outputs_convert, y=val_labels_list)

                val_epoch_loss += val_loss.item()
                progress_bar.set_postfix({"Val_loss": val_epoch_loss / (step + 1)})
            val_epoch_loss_list.append(val_epoch_loss / (step + 1))
            
            metric = dice_metric.aggregate()[0].item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            # print(metric)
            # print(metric_batch)

            metric_0 = metric_batch[0][0].item()
            metric_values_0.append(metric_0)

            metric_1 = metric_batch[0][1].item()
            metric_values_1.append(metric_1)

            metric_2 = metric_batch[0][2].item()
            metric_values_2.append(metric_2)

            metric_3 = metric_batch[0][3].item()
            metric_values_3.append(metric_3)

            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                save_checkpoint(
                        model,
                        epoch,
                        best_acc=best_metric,
                    )
                torch.save(
                    model.state_dict(),
                    os.path.join(args.result, f"best_metric_model_{args.run_name}.pth"),
                )
                print("\nsaved new best metric model")
            print(
                f"\ncurrent epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nMean Dice per Region is: label 1: {metric_1:.4f};  label 2: {metric_2:.4f} label 3: {metric_3:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
        print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
    total_time = time.time() - total_start

## Data

In [None]:
dataloaders = define_dataloaders(n_channels)
train_loader, val_loader = dataloaders['train'], dataloaders['val']

## Train

In [None]:
optimiser, criterion, lr_scheduler = model_params(args, model)

# TRAIN MODEL
train(args, model, device, train_loader, val_loader, optimiser, criterion, lr_scheduler)

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Val Mean Dice TC")
x = [val_interval * (i + 1) for i in range(len(metric_values_tc))]
y = metric_values_tc
plt.xlabel("epoch")
plt.plot(x, y, color="blue")
plt.subplot(1, 3, 2)
plt.title("Val Mean Dice WT")
x = [val_interval * (i + 1) for i in range(len(metric_values_wt))]
y = metric_values_wt
plt.xlabel("epoch")
plt.plot(x, y, color="brown")
plt.subplot(1, 3, 3)
plt.title("Val Mean Dice ET")
x = [val_interval * (i + 1) for i in range(len(metric_values_et))]
y = metric_values_et
plt.xlabel("epoch")
plt.plot(x, y, color="purple")
plt.show()

# RUN FROM HERE

## Inference

File paths for Alex's local system

In [20]:
validation_dir= '/Users/alexandrasmith/Desktop/Workspace/Projects/UNN_BraTS23/data/val_SSA'
validation_files = [os.path.join(validation_dir, file) for file in os.listdir(validation_dir)]

# checkpoint to test: /scratch/guest187/Data/train_all/results/test_run/best_metric_model.pth 
checkpoint = '/Users/alexandrasmith/Desktop/Workspace/Projects/UNN_BraTS23/data/best_metric_model_fullTest.pth'

# Load validation data to dataloader
data_transforms = define_transforms(n_channels)
validation_dataset = MRIDataset(validation_dir, validation_files, transform=data_transforms['val'])

# Validation parameters
VAL_AMP, dice_metric, dice_metric_batch, post_transforms = val_params()

['/scratch/guest187/Data/val_SSA/BraTS-SSA-00139-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00227-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00126-000', '/scratch/guest187/Data/val_SSA/dataset.json', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00192-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00143-000', '/scratch/guest187/Data/val_SSA/images', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00169-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00129-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00180-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00158-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00132-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00198-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00210-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00188-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00218-000', '/scratch/guest187/Data/val_SSA/BraTS-SSA-00148-000']


NotADirectoryError: [Errno 20] Not a directory: '/scratch/guest187/Data/val_SSA/dataset.json'

## Check best model output with the input image and label

In [71]:
print(len(validation_dataset))
img, image_path = validation_dataset[0]
print(img.shape)

15
tranformed_subject:  Subject(Keys: ('image',); images: 1)
torch.Size([4, 192, 192, 128])


In [27]:
model, n_channels = define_model(checkpoint)
model.to(device)
model.eval()

with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_input, _ = validation_dataset[0]
    print(val_input.shape)
    # roi_size = (128, 128, 64)
    # sw_batch_size = 4
    val_output = inference(VAL_AMP, model, val_input)
    val_output = post_transforms(val_output[0])
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(validation_dataset[6]["image"][i, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 3 channels label corresponding to this image
    plt.figure("label", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"label channel {i}")
        plt.imshow(validation_dataset[6]["label"][i, :, :, 70].detach().cpu())
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[i, :, :, 70].detach().cpu())
    plt.show()

Number of channels: 5


RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict: "model.0.conv.weight", "model.0.conv.bias", "model.0.adn.A.weight", "model.1.submodule.0.conv.weight", "model.1.submodule.0.conv.bias", "model.1.submodule.0.adn.A.weight", "model.1.submodule.1.submodule.0.conv.weight", "model.1.submodule.1.submodule.0.conv.bias", "model.1.submodule.1.submodule.0.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.weight", "model.1.submodule.1.submodule.1.submodule.0.conv.bias", "model.1.submodule.1.submodule.1.submodule.0.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.weight", "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.bias", "model.1.submodule.1.submodule.1.submodule.1.submodule.adn.A.weight", "model.1.submodule.1.submodule.1.submodule.2.conv.weight", "model.1.submodule.1.submodule.1.submodule.2.conv.bias", "model.1.submodule.1.submodule.1.submodule.2.adn.A.weight", "model.1.submodule.1.submodule.2.conv.weight", "model.1.submodule.1.submodule.2.conv.bias", "model.1.submodule.1.submodule.2.adn.A.weight", "model.1.submodule.2.conv.weight", "model.1.submodule.2.conv.bias", "model.1.submodule.2.adn.A.weight", "model.2.conv.weight", "model.2.conv.bias". 
	Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "loops", "callbacks", "optimizer_states", "lr_schedulers", "NativeMixedPrecisionPlugin", "hparams_name", "hyper_parameters". 