In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torchio as tio
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F
import os
import shutil
import csv

from segmenter import Segmenter

In [3]:
# --------------------- CMD arguments
# TODO Use the latest checkpoint
lightning_checkpoint = "C:/Users/denni/Documents/MedicalDecathlon/Logs/lightning_logs/version_1/checkpoints/epoch=4-step=2570.ckpt"
device = "cuda"
# This is some path for the preprocessed data
input_path = "C:/Users/denni/Documents/fallstudie-ss2024/data/image-repository/1-test2-uni/16-Mr_Hirn-9346A15227666_01/preprocessed/301_314_313_319/"
# This is some path for the outputs.
# TODO Find out which file extension to use
output_path = "C:/Users/denni/Documents/MedicalDecathlon/Outputs/version_1/"
output_channel_mapping = {"background": 0, "edema": 1, "non_enhancing_and_necrosis": 2, "enhancing_tumor": 3}
# --------------------- CMD arguments

In [4]:
tio.ScalarImage("C:/Users/denni/Documents/fallstudie-ss2024/data/image-repository/1-test2-uni/16-Mr_Hirn-9346A15227666_01/preprocessed/301_314_313_319/nifti_t1_norm_register.nii.gz").data.shape

torch.Size([1, 256, 256, 180])

In [43]:
# Iterate all saved checkpoints, i.e., the model with the metadata and find out which one performs best on the
# validation data. That best model will be used to perform more in-depth analysis on the loss and to visualize
# the segmented tumors of the model.
loss_per_model = []
all_losses = []

# Switch the model to eval mode
# Get the model from the Pytorch Lightning checkpoint
# TODO Remove the arguments learning_rate_decay and dropout_probability
model = Segmenter.load_from_checkpoint(lightning_checkpoint, learning_rate_decay=1, dropout_probability=0)
# Switch to eval mode
model.eval()
# Switch to the given device
model.to(device)

# Some minor preprocessing of the images before using our model on it to ensure that they are the same size
# as the training data and have the same pixel range.
process = tio.Compose([
    tio.CropOrPad((240, 240, 155)),
    tio.RescaleIntensity((-1, 1))
])

# Make a tio.Subject out of the images. The other tio utilities are for extracting patches
# from the image and then reassembling them to create a valid image. Our model only takes
# 96 by 96 by 96 input images, so we have to use this sampling strategy. There is an overlap
# of 8 x 8 x 8 pixels. Since we get the images by sequence, we separately apply the process
# function per sequence and then concatenate the rescaled tensors into one multi-channel tensor.
paths = []
# This is the order of sequences as written in dataset.json.
for index, seq in enumerate(["flair", "t1_norm", "t1c", "t2"]):
    paths.append(os.path.join(input_path, f'nifti_{seq}_register.nii.gz'))

tensors = [tio.ScalarImage(path).data for path in paths]
full_tensor = torch.cat(tensors)
raw_subject = tio.Subject({"MRI": tio.ScalarImage(tensor=full_tensor)})
subject = tio.SubjectsDataset([raw_subject], transform=process)[0]
print("Subject dimension:", subject["MRI"].shape)

grid_sampler = tio.inference.GridSampler(subject, 96, (8, 8, 8))
aggregator = tio.inference.GridAggregator(grid_sampler)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)

# This is the actual prediction of the segmentation
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch["MRI"]["data"].to(device)
        print("Input tensor dimension:", input_tensor.shape)
        locations = patches_batch[tio.LOCATION]
        pred = model(input_tensor)
        print("Output dimension:", pred.shape)
        print(pred.min(), pred.max(), end="\n\n")
        # We keep adding batches to the aggregator to later collect all the data.
        aggregator.add_batch(pred, locations)

# The prediction is composed of the patches we have generated before
pred = torch.swapaxes(F.one_hot(aggregator.get_output_tensor().argmax(0)).unsqueeze(dim=0), 0, 4).squeeze()

print("Prediction shape:", pred.shape)
print(torch.bincount(pred.flatten()))

# If the first dimension is not a channel dimension, add that dimension. We assume that everything larger than 6
# is not a channel dimension.
if pred.shape[0] > 6:
    pred = pred.unsqueeze(dim=0)
    
# Extend the number of channels of the prediction to 4 if necessary.
for it in range(pred.shape[0], label.shape[0]):
    pred = torch.cat((pred, torch.zeros_like(pred[0]).unsqueeze(0)), 0)


C:\Users\denni\AppData\Roaming\Python\Python39\site-packages\pytorch_lightning\utilities\parsing.py:199: Attribute 'activation_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['activation_fn'])`.


Subject dimension: (4, 240, 240, 155)
Input tensor dimension: torch.Size([4, 4, 96, 96, 96])
Output dimension: torch.Size([4, 4, 96, 96, 96])
tensor(-5.2684, device='cuda:0') tensor(11.4493, device='cuda:0')

Input tensor dimension: torch.Size([4, 4, 96, 96, 96])
Output dimension: torch.Size([4, 4, 96, 96, 96])
tensor(-5.2685, device='cuda:0') tensor(11.4629, device='cuda:0')

Input tensor dimension: torch.Size([4, 4, 96, 96, 96])
Output dimension: torch.Size([4, 4, 96, 96, 96])
tensor(-5.2685, device='cuda:0') tensor(11.4510, device='cuda:0')

Input tensor dimension: torch.Size([4, 4, 96, 96, 96])
Output dimension: torch.Size([4, 4, 96, 96, 96])
tensor(-5.2685, device='cuda:0') tensor(11.4493, device='cuda:0')

Input tensor dimension: torch.Size([2, 4, 96, 96, 96])
Output dimension: torch.Size([2, 4, 96, 96, 96])
tensor(-5.2685, device='cuda:0') tensor(11.4493, device='cuda:0')

Prediction shape: torch.Size([3, 240, 240, 155])
tensor([17856000,  8928000])


NameError: name 'label' is not defined