## Init setup

In [4]:
from ast import mod
import pathlib
import numpy as np
from omegaconf import DictConfig, OmegaConf
from typing import Dict, Tuple

import torch
import matplotlib.pyplot as plt

import data
from data.datasets import MedicalDecathlonDataset, BrainTumourDataset
from utils.assertions import ensure

from pathlib import Path


def model_params(model_dir_str):
	model_dir = Path(model_dir_str)
	if not model_dir.is_dir():
		
		raise FileNotFoundError(f"Model directory not found: {model_dir}")
	
	model_path = f"{model_dir}/best_model.pth"
	cfg = OmegaConf.load(f"{model_dir}/config.yaml")

	try:
		path_parts = model_dir.parts
		task_name = path_parts[-2]
	except IndexError:
		raise ValueError(f"Could not parse architecture/task/name from model_dir: {model_dir}. Expected structure like 'trained_models/arch/task/name'")
	
	if not isinstance(cfg, DictConfig):
		raise TypeError("cfg must be a DictConfig.")
	
	
	return model_path, cfg

## Datasimulation

In [5]:
import nibabel as nib
from utils.metrics import dice_coefficient, dice_coefficient_classes
from utils.utils import setup_seed
from hydra.utils import instantiate
import os
from data.datasets import MedicalDecathlonDataset

model_dir_str = "trained_models/ms-unet3d/Task04_Hippocampus/2025-05-05_12-17-39"

model_path, cfg = model_params(model_dir_str)
scales = [0, 1, 2, 3, 4]

base_dir = Path(cfg.dataset.base_path)

datasets = {}

for scale in scales:
	img_dir = base_dir / "imagesTs" / f"scale{scale}"
	lbl_dir = base_dir / "labelsTs" / f"scale{scale}"
	img_dir_sort = sorted(os.listdir(img_dir))
	lbl_dir_sort = sorted(os.listdir(lbl_dir))
	dataset = MedicalDecathlonDataset(cfg, "test", img_dir_sort, lbl_dir_sort)
	datasets[scale] = dataset


setup_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = instantiate(cfg.architecture.path, cfg)

checkpoint = torch.load(model_path, map_location=device)
model_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()

results_per_scale = {scale: [] for scale in scales}
for scale in scales:
	print(f"{datasets[scale]}")

	with torch.no_grad():
		for image, label in datasets[scale]:
			print(label.shape)
			image_tensor = image.unsqueeze(0).to(device)
			
			output = model(image_tensor)
			
			label = label.squeeze(0).cpu()

			pred = torch.argmax(output, dim=1)
			d = dice_coefficient(
				pred, label, num_classes=cfg.dataset.num_classes, ignore_index=0
			)
			results_per_scale[scale] += [d]


print("Results per scale:")
for scale, results in results_per_scale.items():
    print(f"Scale {scale}: {results}")
    mean_dice = np.mean(results)
    std_dice = np.std(results)
    print(f"Mean Dice: {mean_dice:.4f}, Std Dice: {std_dice:.4f}")

<data.datasets.MedicalDecathlonDataset object at 0x7e9279949840>
torch.Size([32, 64, 32])


TypeError: argmax(): argument 'input' (position 1) must be Tensor, not tuple

## Show the segmentation

In [2]:
from ast import mod
import pathlib
import numpy as np
from omegaconf import DictConfig, OmegaConf
from typing import Dict, Tuple

import torch
import matplotlib.pyplot as plt

import data
from data.datasets import MedicalDecathlonDataset, BrainTumourDataset
from models.unet3d import UNet3D

from models.factory import create_model
from utils.assertions import ensure


# ------ Change only this for test of a trained model ------
model_dir_str = "trained_models/ms-unet3d/Task04_Hippocampus/2025-04-24_14-52-40"
# ----------------------------------------------------------


model_dir = pathlib.Path(model_dir_str)
if not model_dir.is_dir():
    raise FileNotFoundError(f"Model directory not found: {model_dir}")

model_path = f"{model_dir}/best_model.pth"
cfg = OmegaConf.load(f"{model_dir}/config.yaml")


try:
    path_parts = model_dir.parts
    # model_architechture = path_parts[-3]
    task_name = path_parts[-2]
    inference_model_name = path_parts[-1]
except IndexError:
    raise ValueError(f"Could not parse architecture/task/name from model_dir: {model_dir}. Expected structure like 'trained_models/arch/task/name'")

if not isinstance(cfg, DictConfig):
    raise TypeError("cfg must be a DictConfig.")

if task_name == "Task04_Hippocampus":
    dataset = MedicalDecathlonDataset(cfg, "test")
elif task_name == "Task01_BrainTumour":
    dataset = BrainTumourDataset(cfg, "test") 
else:
    raise ValueError(f"Unknown task name: {task_name}. Expected 'Task04_Hippocampus' or 'Task01_BrainTumour'.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = cfg.active_architecture

model = create_model(cfg, model_name).to(device)

checkpoint = torch.load(model_path, map_location=device)
model_state_dict = checkpoint['model_state_dict']
model.load_state_dict(model_state_dict)
model.to(device)  # Move the model to the appropriate device
model.eval()


#samples idx list to check: (6,)
sample_idx = 0

# ... after dataset creation ...
print(f"Dataset size (test phase): {len(dataset)}")
ensure(len(dataset) > 0, Exception, "Dataset is empty!")

sample_idx = min(len(dataset)-1, 392) 
image, gt = dataset[sample_idx]
# ... rest of the code ...

image, gt = dataset[sample_idx]  # image: (C, D, H, W), gt: (D, H, W)

# Add batch dimension and move to device
image_batch = image.unsqueeze(0).to(device)  # shape: (1, C, D, H, W)
print(f"Image shape: {image_batch.shape}")
# Run inference
with torch.no_grad():
    output = model(image_batch)
    # If model returns deep supervision outputs, take the final prediction
    if isinstance(output, (tuple, list)):
        output = output[0]
    # Get predicted labels: (B, D, H, W)
    pred = torch.argmax(output, dim=1).squeeze(0).cpu()  

# Convert tensors to numpy arrays for visualization
# Remove channel dimension from image for visualization: (D, H, W)
image_np = image.squeeze(0).cpu().numpy()
gt_np = gt.cpu().numpy()
pred_np = pred.numpy()

# Choose 3 slices evenly spaced along the depth dimension
depth = image_np.shape[0]
num_slices = min(depth, 66)
slice_indices = np.linspace(0, depth-1, num=num_slices, dtype=int)


# Create subplots: one row per slice and 3 columns for image, ground truth, and prediction
fig, axes = plt.subplots(nrows=num_slices, ncols=3, figsize=(12, 4 * num_slices))
for i, slice_idx in enumerate(slice_indices):
    slice_2d = image_np[slice_idx]
    # slice_2d = np.rot90(slice_2d)
    axes[i, 0].imshow(slice_2d, cmap="gray")
    axes[i, 0].set_title(f"Image Slice {slice_idx}")
    
    gt_2d = gt_np[slice_idx]
    # gt_2d = np.rot90(gt_2d)
    axes[i, 1].imshow(gt_2d, cmap="grey", vmin=0, vmax=cfg.dataset.num_classes - 1)

    axes[i, 1].set_title(f"Ground Truth Slice {slice_idx}")
    pred_2d = pred_np[slice_idx]
    # pred_2d = np.rot90(pred_2d)
    axes[i, 2].imshow(pred_2d, cmap="grey", vmin=0, vmax=cfg.dataset.num_classes - 1)
    axes[i, 2].set_title(f"Prediction Slice {slice_idx}")

    # for ax in axes[i]:
    #     ax.axis("off")



plt.tight_layout()
plt.show()

ModuleNotFoundError: No module named 'models.unet3d'

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.__config__ import show

import matplotlib.pyplot as plt
import numpy as np

def show_prediction(image, gt_mask, pred_mask, slice_idx=None, class_id=1):
    gt_binary = (gt_mask == class_id)
    pred_binary = (pred_mask == class_id)

    if slice_idx is None:
        z_coords = np.where(gt_binary)[0]
        slice_idx = z_coords[len(z_coords)//2] if len(z_coords) > 0 else gt_mask.shape[0] // 2

    plt.figure(figsize=(16, 4))

    plt.subplot(1, 4, 1)
    plt.imshow(image[slice_idx], cmap='gray')
    plt.title("MRI Slice")
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.imshow(gt_binary[slice_idx], cmap='gray')
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.imshow(pred_binary[slice_idx], cmap='gray')
    plt.title("Prediction")
    plt.axis('off')

    plt.subplot(1, 4, 4)
    error_map = np.logical_xor(gt_binary[slice_idx], pred_binary[slice_idx])
    plt.imshow(error_map, cmap='Reds')
    plt.title("Error Map")
    plt.axis('off')

    plt.suptitle(f"Class {class_id} - Slice {slice_idx}", fontsize=14)
    plt.tight_layout()
    plt.show()


show_prediction(image_np, gt_np, pred_np, class_id=1)


## Visualization of images!

### Setup

In [None]:
from ast import mod
import numpy as np
from omegaconf import DictConfig, OmegaConf
from typing import Dict, Tuple

import torch
import matplotlib.pyplot as plt

import data
from data.datasets import MedicalDecathlonDataset, BrainTumourDataset
from models.ms_unet3d import MSUNet3D

# ras+ ORIENTATION
# class imbalance

# task_name = "Task04_Hippocampus"
task_name = "Task01_BrainTumour"
inference_model_name = "2025-03-31_13-44-04"

model_dir = f"trained_models/unet3d/{task_name}/{inference_model_name}"
model_path = f"{model_dir}/best_model.pth"

cfg = OmegaConf.load(f"{model_dir}/config.yaml")

if not isinstance(cfg, DictConfig):
    raise TypeError("cfg must be a DictConfig.")

# dataset = MedicalDecathlonDataset(cfg, phase="test")
dataset = BrainTumourDataset(cfg, phase='train')

model = MSUNet3D(
    in_channels=1,
    num_classes=cfg.dataset.num_classes,
    n_filters=cfg.model.n_filters,
    dropout=cfg.training.dropout,
    batch_norm=True,
    inference_fusion_mode=cfg.model.deep_supervision.inference_fusion_mode,
    depth=cfg.model.depth,
    deep_supervision_levels=cfg.model.deep_supervision.levels
)



### Visualize predetermined sample and random slices

In [None]:
sample_idx = 300

image, gt = dataset[sample_idx]  # image: (C, D, H, W), gt: (D, H, W)
slice = 104 # or 126

# three random slices
# Pick three random slices along the depth dimension
depth = image.shape[3]
slices = np.random.choice(depth, size=3, replace=False)

# Plot the slices
for i, slice_idx in enumerate(slices):
    plt.subplot(1, 3, i + 1)
    plt.imshow(image[0, :, :, slice_idx], cmap='gray')
    plt.title(f"Slice {slice_idx}")
    plt.axis('off')
plt.show()

### Visualize predetermined sample and slice

In [None]:
sample_idx = 300

image, gt = dataset[sample_idx]  # image: (C, D, H, W), gt: (D, H, W)
slice = 104 # or 126

#Skew image


plt.imshow(image[0, :, :, slice_idx], cmap='gray')
plt.axis('off')
plt.show()

In [None]:

def run_inference(self, x: torch.Tensor):
    D, H, W = x.shape[2:]
    input_shape = (D, H, W)

    def div_shape(shape, factor):
        return tuple(s // factor for s in shape)

    # build mapping shape -> entry string
    shape_to_entry = {self.target_shape: "enc1"}
    for d in range(1, self.depth):
        shape_to_entry[div_shape(self.target_shape, 2**d)] = f"msb{d}"

    rounded = tuple(2 ** round(np.log2(s)) for s in input_shape)
    if rounded not in shape_to_entry:
        raise ValueError(
            f"Unsupported input shape {input_shape} (rounded {rounded}). "
            f"Expected one of: {list(shape_to_entry.keys())}"
        )
    entry_gateway = shape_to_entry[rounded]

    if entry_gateway == "enc1":
        # full resolution
        out = x
        encoder_feats = []
        for enc, pool, drop in zip(self.encoders, self.pools, self.enc_dropouts):
            out = enc(out)
            encoder_feats.append(out)
            out = drop(pool(out))

        # bottleneck
        out = self.bn(out)

        # Decoder pathway
        for up_conv, decoder, drop in zip(
            self.up_convs, self.decoders, self.dec_dropouts
        ):
            out = up_conv(out)
            skip = encoder_feats.pop()
            out = torch.cat([out, skip], dim=1)
            out = decoder(out)
            out = drop(out)

        final_out = self.final_conv(out)
        return final_out
    elif entry_gateway.startswith("msb"):
        # lower resolution image
        level = int(entry_gateway.replace("msb", ""))
        msb = self.msb_blocks[level - 1]
        out = msb(x)
        ms_feats = []
        ms_feats.append(out)
        out = self.pools[level](out)
        out = self.enc_dropouts[level](out)

        for enc, pool, drop in zip(
            list(self.encoders)[level + 1 :],
            list(self.pools)[level + 1 :],
            list(self.enc_dropouts)[level + 1 :],
        ):
            out = enc(out)
            ms_feats.append(out)
            out = drop(pool(out))

        # bottleneck
        out = self.bn(out)

        num_ups = self.depth - level
        # decoder up to match MS scale
        for up_conv, dec, drop in zip(
            list(self.up_convs)[:num_ups],
            list(self.decoders)[:num_ups],
            list(self.dec_dropouts)[:num_ups],
        ):
            out = up_conv(out)
            skip = ms_feats.pop()
            out = torch.cat([out, skip], dim=1)
            out = dec(out)
            out = drop(out)

        final_out = self.ms_heads[level - 1](out)  # ms_heads not final_conv
        return final_out
    else:
        raise ValueError(f"Unknown entry point in Multiscale UNet: {entry_gateway}")




#### DataSimulation 2 - backbone vs our