# Import & config

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

In [2]:
import torch
from scripts.neweval import load_trained_model
from src.utils.config import load_config
import csv
from tqdm.notebook import tqdm
from src.data.dataset2d import PancreasDataset2D
from src.metrics.sma import SegmentationMetricsAccumulator as SMA
from src.training.setup.transforms_factory import get_transforms
from src.training.setup.dataset_factory import get_dataset
from torch.utils.data import DataLoader

# __Load trained model__

In [3]:
config = load_config('configs/experiments/deep_aug_5.yaml')
model_path = 'experiments/deep_aug/deep_aug_20250415_215856/checkpoints/best_model_epoch60.pth'
model = load_trained_model(config, model_path)

In [4]:
config_device = config['training']['device']
device = torch.device(config_device if torch.cuda.is_available() else "cpu")

In [5]:
model.to(device);

# __Function__

In [27]:
def check_2d_3d_scores(model, config, device):
    model.to(device)
    model.eval()

    # Get patient IDs
    data_dir = os.path.join(config["data"]["processed_dir"], "train")
    patient_ids = [
        "rtum79"
    ]

    # Get the transforms
    transform = get_transforms(config)

    # Create metrics accumulator
    sma = SMA(include_background=False)
    sma_3d = SMA(include_background=False)
    sma_3d_recon = SMA(include_background=False)

    loop = tqdm(
        patient_ids,
        colour="red",
        leave=True
    )
    loop.set_description(f"Evaluating patients")

    for pid in loop:
        p_dataset = PancreasDataset2D(
            data_dir=data_dir,
            transform=transform,
            load_into_memory=False,
            patient_ids=[pid]
        );

        # Create DataLoader
        p_dl = DataLoader(
            p_dataset,
            batch_size=config['data']['batch_size'],
            shuffle=False,
            num_workers=config['data']['num_workers'],
            pin_memory=True
        )

        # Evaluate
        patient_loop = tqdm(
            p_dl,
            leave=True,
            colour="blue"
        )
        patient_loop.set_description(f"Patient {pid}")

        all_preds = []
        all_gts = []

        with torch.no_grad():
            for images, masks, _ in patient_loop:
                images, masks = images.to(device), masks.to(device)

                outputs = model(images)
                
                if isinstance(outputs, dict):
                    outputs = outputs["out"]

                all_preds.append(outputs)
                all_gts.append(masks)
                
                # Update metrics
                _ = sma.update(outputs, masks)

            # Get aggregated scores and confusion matrix
            p_metrics = sma.aggregate()
            p_cm = sma.aggregate_global_cm()

            # Stacking
            all_preds = torch.cat(all_preds, dim=0).permute(1, 0, 2, 3).unsqueeze(0)
            print("all_preds:", all_preds.shape)
            all_gts = torch.cat(all_gts, dim=0).unsqueeze(0)
            print("all_gts:", all_gts.shape)
            _ = sma_3d.update(all_preds, all_gts)
            ps_metrics = sma_3d.aggregate()
            ps_cm = sma_3d.aggregate_global_cm()

            # Reconstruction
            recon_vol, recon_mask = p_dataset.get_patient_volume(pid)
            print("recon mask:", recon_mask.shape)
            _ = sma_3d_recon.update(all_preds, recon_mask)
            pr_metrics = sma_3d_recon.aggregate()
            pr_cm = sma_3d_recon.aggregate_global_cm()
            
            sma.reset() # Reset accumulator

            tqdm.write(f"Metrics:\n{p_metrics}")
            tqdm.write(f"CM:\n{p_cm}")
            tqdm.write('-'*65)
            tqdm.write(f"Metrics stacking:\n{ps_metrics}")
            tqdm.write(f"CM stacking:\n{ps_cm}")
            tqdm.write('-'*65)
            tqdm.write(f"Metrics recon:\n{pr_metrics}")
            tqdm.write(f"CM recon:\n{pr_cm}")
            tqdm.write('-'*65)

In [28]:
check_2d_3d_scores(model, config, device)

  0%|          | 0/1 [00:00<?, ?it/s]

📊 Loading dataset... 103 slices found.


  0%|          | 0/26 [00:00<?, ?it/s]

all_preds: torch.Size([1, 5, 103, 256, 256])
all_gts: torch.Size([1, 103, 256, 256])
recon mask: torch.Size([1, 103, 256, 256])
Metrics:
{'dice_class_1': 0.5324310660362244, 'dice_class_2': 0.2273699939250946, 'dice_class_3': 0.8012505769729614, 'dice_class_4': 0.6751064658164978, 'iou_class_1': 0.45153653621673584, 'iou_class_2': 0.156370609998703, 'iou_class_3': 0.7064030170440674, 'iou_class_4': 0.5607652068138123, 'precision_class_1': 0.6438092589378357, 'precision_class_2': 0.168614000082016, 'precision_class_3': 0.8756927847862244, 'precision_class_4': 0.7729638814926147, 'recall_class_1': 0.4964034855365753, 'recall_class_2': 0.379624605178833, 'recall_class_3': 0.7697128653526306, 'recall_class_4': 0.6122469902038574, 'dice': 0.5590395331382751, 'iou': 0.468768835067749, 'precision': 0.6152700185775757, 'recall': 0.5644969940185547}
CM:
{'tp_class_1': 43973.0, 'tp_class_2': 2534.0, 'tp_class_3': 30782.0, 'tp_class_4': 38452.0, 'fp_class_1': 5177.0, 'fp_class_2': 6220.0, 'fp_cla