## Init setup

In [2]:
from ast import mod
import pathlib
import numpy as np
from omegaconf import DictConfig, OmegaConf
from typing import Dict, Tuple

import torch
import matplotlib.pyplot as plt

import data
from data.datasets import MedicalDecathlonDataset, BrainTumourDataset
from utils.assertions import ensure

from pathlib import Path


def model_params(model_dir_str):
	model_dir = Path(model_dir_str)
	if not model_dir.is_dir():
		raise FileNotFoundError(f"Model directory not found: {model_dir}")
	
	model_path = f"{model_dir}/best_model.pth"
	cfg = OmegaConf.load(f"{model_dir}/config.yaml")

	if not isinstance(cfg, DictConfig):
		raise TypeError("cfg must be a DictConfig.")

	return model_path, cfg


## Datasimulation 
### DataSimulation -- Backbone

In [66]:
from time import time
import nibabel as nib
from torch import tensor
from utils.metrics import dice_coefficient, dice_coefficient_classes
from utils.utils import setup_seed
from hydra.utils import instantiate
import os
from data.datasets import MedicalDecathlonDataset
import torch.nn.functional as F
import torchio as tio

model_dir_str = "trained_models/unet3d/Task04_Hippocampus/2025-05-07_08-51-44_best"

model_path, cfg = model_params(model_dir_str)
scales = [0, 1, 2, 3, 4]

base_dir = Path(cfg.dataset.base_path)

datasets = {}

for scale in scales:
	img_dir = base_dir / "imagesTs" / f"scale{scale}"
	lbl_dir = base_dir / "labelsTs" / f"scale{scale}"
	img_dir_sort = sorted(os.listdir(img_dir))
	lbl_dir_sort = sorted(os.listdir(lbl_dir))
	dataset = MedicalDecathlonDataset(cfg, "test", img_dir_sort, lbl_dir_sort, str(img_dir), str(lbl_dir))
	datasets[scale] = dataset

setup_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = instantiate(cfg.architecture.path, cfg)

checkpoint = torch.load(model_path, map_location=device)
model_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()

results_per_scale = {scale: [] for scale in scales}


print("******************* TEST 2 *******************")
# Case 1) padding 
for scale in scales:
	with torch.no_grad():
		print(f"{'-'*20} Scale {scale} {'-'*20}")
		for image, label in datasets[scale]:
			target_shape = cfg.dataset.target_shape
			output_shape = tuple(v // (2**scale)  for v in target_shape)

			# pad to target shape
			tio_img = tio.ScalarImage(tensor=image)
			cp = tio.CropOrPad(target_shape, padding_mode="constant")
			tio_img = cp(tio_img)
			tensor_img = tio_img.data.to(device)

			x = tensor_img.unsqueeze(0).to(device)

			# print(f"\n\n[Scale {scale} | Image Shape {image.shape} | Image Dtype: {image.dtype}") # ADD THIS
			# print(f"[Scale {scale} | Label Shape nig {label.shape}")
			output = model.forward(x)

			pred = torch.argmax(output, dim=1)

			# crop to output shape
			tio_img = tio.ScalarImage(tensor=pred)
			cp = tio.CropOrPad(output_shape, padding_mode="constant")
			pred = cp(tio_img).data.to(device)

			d = dice_coefficient(
				pred, label.to(device), num_classes=cfg.dataset.num_classes, ignore_index=0
			)
			results_per_scale[scale] += [d.item()]
	mean_dice = np.mean(results_per_scale[scale])
	std_dice = np.std(results_per_scale[scale])
	print(f"Scale {scale} | Mean Dice: {mean_dice:.4f}, Std Dice: {std_dice:.4f}")


print("Results per scale:")
for scale, results in results_per_scale.items():
    print(f"Scale {scale}: {results}")
    mean_dice = np.mean(results)
    std_dice = np.std(results)
    print(f"Mean Dice: {mean_dice:.4f}, Std Dice: {std_dice:.4f}")

******************* TEST 2 *******************
-------------------- Scale 0 --------------------
Scale 0 | Mean Dice: 0.9005, Std Dice: 0.0255
-------------------- Scale 1 --------------------
Scale 1 | Mean Dice: 0.1433, Std Dice: 0.0626
-------------------- Scale 2 --------------------
Scale 2 | Mean Dice: 0.0000, Std Dice: 0.0000
-------------------- Scale 3 --------------------
Scale 3 | Mean Dice: 0.0000, Std Dice: 0.0000
-------------------- Scale 4 --------------------
Scale 4 | Mean Dice: 0.8846, Std Dice: 0.2324
Results per scale:
Scale 0: [0.8814674615859985, 0.890953779220581, 0.9086867570877075, 0.922738790512085, 0.8895153403282166, 0.9227004051208496, 0.9229305982589722, 0.9363595247268677, 0.9318417310714722, 0.9297301769256592, 0.9086819887161255, 0.9163600206375122, 0.9016448855400085, 0.9296976327896118, 0.8816892504692078, 0.9333398342132568, 0.9232256412506104, 0.8857079148292542, 0.8980642557144165, 0.9187278747558594, 0.911818265914917, 0.8596487045288086, 0.89219

In [65]:
from time import time
import nibabel as nib
from scipy import interpolate
from torch import tensor
from utils.metrics import dice_coefficient, dice_coefficient_classes
from utils.utils import setup_seed
from hydra.utils import instantiate
import os
from data.datasets import MedicalDecathlonDataset
import torch.nn.functional as F
import torchio as tio

model_dir_str = "trained_models/unet3d/Task04_Hippocampus/2025-05-07_08-51-44_best"

model_path, cfg = model_params(model_dir_str)
scales = [0, 1, 2, 3, 4]

base_dir = Path(cfg.dataset.base_path)

datasets = {}

for scale in scales:
	img_dir = base_dir / "imagesTs" / f"scale{scale}"
	lbl_dir = base_dir / "labelsTs" / f"scale{scale}"
	img_dir_sort = sorted(os.listdir(img_dir))
	lbl_dir_sort = sorted(os.listdir(lbl_dir))
	dataset = MedicalDecathlonDataset(cfg, "test", img_dir_sort, lbl_dir_sort, str(img_dir), str(lbl_dir))
	datasets[scale] = dataset

setup_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = instantiate(cfg.architecture.path, cfg)

checkpoint = torch.load(model_path, map_location=device)
model_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()

results_per_scale = {scale: [] for scale in scales}


print("******************* TEST 2 *******************")
# Case 2) interpolation 
for scale in scales:
	with torch.no_grad():
		print(f"{'-'*20} Scale {scale} {'-'*20}")
		for image, label in datasets[scale]:
			target_shape = tuple(cfg.dataset.target_shape)
			output_shape = tuple(v // (2**scale)  for v in target_shape)
			image = image.unsqueeze(0).float().to(device)
			# interpolate to target shape
			x = interpolatex = F.interpolate(
				image, size=target_shape, mode="trilinear", align_corners=False
			)

			output = model.forward(x)

			pred = torch.argmax(output, dim=1)

			# interpolate to output shape
			pred = F.interpolate(
                pred.unsqueeze(0).float(), size=output_shape, mode="trilinear", align_corners=False
            ).squeeze(0).squeeze(0).to(device)

			d = dice_coefficient(
				pred, label.to(device), num_classes=cfg.dataset.num_classes, ignore_index=0
			)
			results_per_scale[scale] += [d.item()]
	mean_dice = np.mean(results_per_scale[scale])
	std_dice = np.std(results_per_scale[scale])
	print(f"Scale {scale} | Mean Dice: {mean_dice:.4f}, Std Dice: {std_dice:.4f}")


print("Results per scale:")
for scale, results in results_per_scale.items():
    print(f"Scale {scale}: {results}")
    mean_dice = np.mean(results)
    std_dice = np.std(results)
    print(f"Mean Dice: {mean_dice:.4f}, Std Dice: {std_dice:.4f}")

******************* TEST 2 *******************
-------------------- Scale 0 --------------------
Scale 0 | Mean Dice: 0.9005, Std Dice: 0.0255
-------------------- Scale 1 --------------------
Scale 1 | Mean Dice: 0.7399, Std Dice: 0.0368
-------------------- Scale 2 --------------------
Scale 2 | Mean Dice: 0.4701, Std Dice: 0.0968
-------------------- Scale 3 --------------------
Scale 3 | Mean Dice: 0.0567, Std Dice: 0.0866
-------------------- Scale 4 --------------------
Scale 4 | Mean Dice: 0.8654, Std Dice: 0.2425
Results per scale:
Scale 0: [0.8814674615859985, 0.890953779220581, 0.9086867570877075, 0.922738790512085, 0.8895153403282166, 0.9227004051208496, 0.9229305982589722, 0.9363595247268677, 0.9318417310714722, 0.9297301769256592, 0.9086819887161255, 0.9163600206375122, 0.9016448855400085, 0.9296976327896118, 0.8816892504692078, 0.9333398342132568, 0.9232256412506104, 0.8857079148292542, 0.8980642557144165, 0.9187278747558594, 0.911818265914917, 0.8596487045288086, 0.89219