## Environment Setup <a name="environment-setup"></a>

In [None]:
# Install Kaggle API for dataset access
!pip install kaggle --quiet

In [2]:
# Install required libraries for medical image processing and deep learning
try:
    import pydicom, nibabel, monai, SimpleITK, torchio
except ImportError:
    !pip install -q pydicom nibabel monai SimpleITK torchio
    
# Import standard libraries
import os
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from pathlib import Path
from collections import defaultdict, Counter
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix

# Import medical image processing libraries
import pydicom
import nibabel as nib
import SimpleITK as sitk

# Import PyTorch and MONAI for deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from monai.networks.nets import DenseNet121
from monai.transforms import (
    Compose, EnsureChannelFirst, EnsureType,
    Orientation, Spacing, RandAffine, RandFlip,
    NormalizeIntensity, RandScaleIntensity, RandShiftIntensity,
    RandGaussianNoise, RandGaussianSmooth, RandAdjustContrast,
    Resize, RandBiasField, ToTensor
)
from monai.data import MetaTensor

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
set_seed()

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Environment setup complete.")

# Additional imports for file operations
import shutil
import zipfile
from datetime import datetime
from kaggle.api.kaggle_api_extended import KaggleApi

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.2/193.2 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m93.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m36.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5

2025-08-04 10:22:19.163699: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754302939.537587      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754302939.642771      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


cuda
Environment setup complete.


## Data Preparation <a name="data-preparation"></a>

In [3]:
# Set up Kaggle authentication
os.makedirs("/root/.config/kaggle", exist_ok=True)
shutil.copy("/kaggle/input/kaggle-json/kaggle.json", "/root/.config/kaggle/kaggle.json")
os.chmod("/root/.config/kaggle/kaggle.json", 0o600)

# Initialize Kaggle API
api = KaggleApi()
api.authenticate()

# Define dataset paths
DATASET_ROOTS = [
'/kaggle/input/abdominal-nifti-0-100',
    '/kaggle/input/abdominal-nifti-100-200',
    '/kaggle/input/abdominal-trauma-nifti-200-300',
    '/kaggle/input/abdominal-trauma-nifti-300-400',
    '/kaggle/input/abdominal-trauma-nifti-400-500',
    '/kaggle/input/abdominal-trauma-nifti-500-600',
    '/kaggle/input/abdominal-trauma-nifti-600-700',
    '/kaggle/input/abdominal-trauma-nifti-700-800',
    '/kaggle/input/abdominal-trauma-nifti-800-900',
    '/kaggle/input/abdominal-trauma-nifti-900-1000',
    
    '/kaggle/input/abdominal-nifti-1000-1100',
    '/kaggle/input/abdominal-nifti-1100-1200',
    '/kaggle/input/abdominal-nifti-1200-1300',
    '/kaggle/input/abdominal-nifti-1300-1400',
    '/kaggle/input/abdominal-nifti-1400-1500',
    '/kaggle/input/abdominal-nifti-1500-1600',
    '/kaggle/input/abdominal-nifti-1600-1700',
    '/kaggle/input/abdominal-nifti-1700-1800',
    '/kaggle/input/abdominal-nifti-1800-1900',
    '/kaggle/input/abdominal-nifti-1900-2000',
    
    '/kaggle/input/abdominal-trauma-nifti-2000-above',
    '/kaggle/input/abdominal-nifti-2100-2150',
    '/kaggle/input/abdominal-trauma-nifti-2230-2360',
    '/kaggle/input/abdominal-trauma-nifti-2300-2400',
    '/kaggle/input/abdominal-trauma-nifti-2400-2500',
    '/kaggle/input/abdominal-trauma-nifti-2500-2600',
    '/kaggle/input/abdominal-trauma-nifti-2600-2700',
    '/kaggle/input/abdominal-trauma-nifti-2700-2800',
    '/kaggle/input/abdominal-trauma-nifti-2800-2900',
    '/kaggle/input/abdominal-trauma-nifti-2900-3000',

    '/kaggle/input/abdominal-trauma-nifti-3000-3100',
    '/kaggle/input/abdominal-trauma-nifti-3100-3200',
    '/kaggle/input/abdominal-trauma-nifti-3200-3300',
    '/kaggle/input/abdominal-trauma-nifti-3300-3400',
    '/kaggle/input/abdominal-trauma-nifti-3400-3500',
    '/kaggle/input/abdominal-trauma-nifti-3500-3600',
    '/kaggle/input/abdominal-trauma-nifti-3600-3700',
    '/kaggle/input/abdominal-trauma-nifti-3700-3800',
    '/kaggle/input/abdominal-trauma-nifti-3800-3900',
    '/kaggle/input/abdominal-trauma-nifti-3900-4000',

    '/kaggle/input/abdominal-trauma-nifti-4000-4100',
    '/kaggle/input/abdominal-trauma-nifti-4100-4200',
    '/kaggle/input/abdominal-trauma-nifti-4200-4300',
    '/kaggle/input/abdominal-nifti-4290-4400',
    '/kaggle/input/abdominal-nifti-4380-4470',
    '/kaggle/input/abdominal-nifti-4470-4560',
    '/kaggle/input/abdominal-nifti-4560-4650',
    '/kaggle/input/abdominal-nifti-4650-4710',  
]

LABELS_CSV_PATH = '/kaggle/input/rsna-2023-abdominal-trauma-detection/train_2024.csv'
OUTPUT_JSON_PATH = '/kaggle/working/train_metadata.json'

# Load and process labels
labels_df = pd.read_csv(LABELS_CSV_PATH)
labels_df['patient_id'] = labels_df['patient_id'].astype(str)
labels_dict_map = labels_df.set_index('patient_id').to_dict(orient='index')
label_cols = [col for col in labels_df.columns if col != 'patient_id']

# Create metadata dictionary for all NIfTI files
metadata_list = []

for dataset_root in DATASET_ROOTS:
    nifti_files = sorted(Path(dataset_root).rglob("*.nii*"))  # .nii or .nii.gz
    
    print(f"Found {len(nifti_files)} NIfTI files in {dataset_root}")
    
    for nii_path in nifti_files:
        stem = nii_path.stem  # e.g. "12345_67890"
        try:
            patient_id, study_id = stem.split("_")
        except ValueError:
            print(f"Skipping malformed filename: {stem}")
            continue
            
        if patient_id not in labels_dict_map:
            print(f"No label for patient {patient_id}, skipping...")
            continue
            
        labels = {col: int(labels_dict_map[patient_id][col]) for col in label_cols}
        
        metadata_list.append({
            "patient_id": patient_id,
            "study_id": study_id,
            "nifti_path": str(nii_path),
            "labels": labels
        })

print(f"Total metadata entries: {len(metadata_list)}")

# Save metadata to JSON
with open(OUTPUT_JSON_PATH, 'w') as f:
    json.dump(metadata_list, f, indent=2)
    
print(f"Metadata saved to {OUTPUT_JSON_PATH}")

## Data Exploration <a name="data-exploration"></a>

In [None]:
# Check for duplicate entries
with open(OUTPUT_JSON_PATH, "r") as f:
    data = json.load(f)
    
id_pairs = [(entry["patient_id"], entry["study_id"]) for entry in data]
pair_counts = Counter(id_pairs)
duplicates = [pair for pair, count in pair_counts.items() if count > 1]

if duplicates:
    print(f"Found {len(duplicates)} duplicate entries:")
    for pair in duplicates:
        print(f" - patient_id: {pair[0]}, study_id: {pair[1]}")
else:
    print("No duplicate (patient_id, study_id) entries found.")

# Analyze class distribution
label_rows = []
for entry in metadata_list:
    row = entry["labels"]
    label_rows.append(row)
    
labels_df = pd.DataFrame(label_rows)

# Aggregate counts per label
labels_agg = pd.DataFrame({
    'bowel_healthy': [labels_df['bowel_healthy'].sum()],
    'bowel_injury': [labels_df['bowel_injury'].sum()],
    'extravasation_healthy': [labels_df['extravasation_healthy'].sum()],
    'extravasation_injury': [labels_df['extravasation_injury'].sum()],
    
    'kidney_healthy': [labels_df['kidney_healthy'].sum()],
    'kidney_low': [labels_df['kidney_low'].sum()],
    'kidney_high': [labels_df['kidney_high'].sum()],
    
    'liver_healthy': [labels_df['liver_healthy'].sum()],
    'liver_low': [labels_df['liver_low'].sum()],
    'liver_high': [labels_df['liver_high'].sum()],
    
    'spleen_healthy': [labels_df['spleen_healthy'].sum()],
    'spleen_low': [labels_df['spleen_low'].sum()],
    'spleen_high': [labels_df['spleen_high'].sum()]
})

# Prepare for plotting
labels_agg = labels_agg.T.reset_index()
labels_agg.columns = ['label', 'count']
labels_agg[['organ', 'status']] = labels_agg['label'].str.rsplit('_', n=1, expand=True)

# Print counts
print("Counts per class:")
print(labels_agg[['label', 'count']].to_string(index=False))

# Create pivot table for visualization
pivot_df = labels_agg.pivot(index='organ', columns='status', values='count').fillna(0)
status_order = ['healthy', 'injury', 'low', 'high']
for status in status_order:
    if status not in pivot_df.columns:
        pivot_df[status] = 0
pivot_df = pivot_df[status_order]

# Plot class distribution
color_map = {
    'healthy': 'skyblue',
    'low': 'orange',
    'high': 'salmon',
    'injury': 'red'
}
colors = [color_map[status] for status in pivot_df.columns]

pivot_df.plot(kind='bar', figsize=(10, 6), color=colors)
plt.title("Organ Injury Severity Distribution")
plt.ylabel("Number of Samples")
plt.xlabel("Organ")
plt.xticks(rotation=0)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.legend(title='Status')
plt.tight_layout()
plt.show()

# Analyze injury vs healthy distribution
labels_df['injury_status'] = labels_df['any_injury'].apply(lambda x: 'injured' if x == 1 else 'healthy')
status_counts = labels_df['injury_status'].value_counts()

print("Counts by injury status:")
print(status_counts)

# Plot pie chart
status_counts.plot(
    kind='pie',
    colors=['skyblue', 'salmon'],
    autopct='%1.1f%%',
    startangle=90,
    ylabel='',
    title='Proportion of Healthy vs Injured Samples'
)
plt.tight_layout()
plt.show()

## Configuration Setup

In [None]:
# Configuration Class
# This class stores all important constants and settings in one centralized location
class Config:
    SEED = 42  # Random seed for reproducibility
    IMAGE_SIZE = (128, 128, 128)  # Dimensions for resizing 3D volumes (depth, height, width)
    BATCH_SIZE = 16  # Number of samples per batch
    EPOCHS = 100  # Total number of training epochs
    LR = 3e-4  # Initial learning rate
    
    # Target columns (all labels to predict)
    TARGET_COLS = [
        "bowel_healthy", "extravasation_healthy",
        "bowel_injury", "extravasation_injury",
        "kidney_healthy", "kidney_low", "kidney_high",
        "liver_healthy", "liver_low", "liver_high",
        "spleen_healthy", "spleen_low", "spleen_high",
    ]

    NUM_CLASSES = len(TARGET_COLS)  # Total number of target classes
    VOXEL_SPACING = (1.0, 1.0, 1.0)  # Used to normalize spacing in 3D CT scans
    SPLIT_MODE = "group"  # 'group' for stratified grouping, 'random' for simple random split

# Create an instance of the Config class to use throughout the code
config = Config()

## Training Data Transforms

In [None]:
# These transformations are applied to training data for augmentation and normalization
train_transforms = Compose([
    EnsureChannelFirst(),  # Ensure input has channel dimension
    EnsureType(),  # Convert to MetaTensor
    Orientation(axcodes="RAS"),  # Standardize orientation
    Spacing(pixdim=(1.0, 1.0, 1.0), mode="bilinear"),  # Normalize voxel spacing
    Resize(spatial_size=config.IMAGE_SIZE),  # Resize to target dimensions

    # Random augmentations - only one will be applied per sample (OneOf)
    OneOf([
        RandAffine(
            rotate_range=(0, 0, np.pi/12),  # Limited Z-axis rotation
            shear_range=(0.1, 0.1, 0.1),
            translate_range=(10, 10, 5),
            scale_range=(0.1, 0.1, 0.1),
            prob=0.5,
            mode="bilinear"
        ),
        RandFlip(prob=0.5, spatial_axis=0),  # Random flip along depth
        RandFlip(prob=0.5, spatial_axis=1),  # Random flip along height
        RandFlip(prob=0.5, spatial_axis=2),  # Random flip along width
    ]),
    
    # Intensity normalization and augmentation
    NormalizeIntensity(nonzero=True, channel_wise=True),
    RandScaleIntensity(factors=0.1, prob=1.0),
    RandShiftIntensity(offsets=0.1, prob=1.0),
    RandGaussianNoise(prob=0.3, mean=0.0, std=0.1),
    RandAdjustContrast(prob=0.3, gamma=(0.7, 1.5)),

    # Additional safe augmentations
    RandGaussianSmooth(
        prob=0.2,
        sigma_x=(0.25, 0.5),  # Mild smoothing
        sigma_y=(0.25, 0.5),
        sigma_z=(0.25, 0.5)
    ),
    RandBiasField(
        prob=0.2,
        coeff_range=(0.1, 0.3)  # Subtle intensity variations
    ),

    ToTensor()  # Convert to PyTorch tensor
])

## Validation Transforms

In [None]:
# Only basic preprocessing - no random augmentations
val_transforms = Compose([
    EnsureChannelFirst(),  # Add channel dimension
    EnsureType(),  # Convert to MetaTensor
    Orientation(axcodes="RAS"),  # Standard orientation
    Spacing(pixdim=(1.0, 1.0, 1.0), mode="bilinear"),  # Uniform voxel spacing
    Resize(spatial_size=config.IMAGE_SIZE),  # Resize to target dimensions
    NormalizeIntensity(nonzero=True, channel_wise=True),  # Normalize intensities
    ToTensor()  # Convert to tensor
])

## Test Transforms

In [None]:
# Same as validation transforms - no randomness for consistent evaluation
test_transforms = Compose([
    EnsureChannelFirst(),
    EnsureType(),
    Orientation(axcodes="RAS"),
    Spacing(pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
    Resize(spatial_size=config.IMAGE_SIZE),
    NormalizeIntensity(nonzero=True, channel_wise=True),
    ToTensor()
])

## Dataset Preparation

### Label Conversion and Target Formatting

In [None]:
def convert_labels_to_targets(label_dict):
    """
    Convert the multi-label dictionary into a simplified target format.
    For binary targets (bowel, extravasation): 0=healthy, 1=injury
    For multi-class targets (kidney, liver, spleen): 0=healthy, 1=low, 2=high
    """
    # Binary targets
    bowel = label_dict['bowel_injury']
    extra = label_dict['extravasation_injury']

    # Multi-class targets
    def get_class(keys):
        for i, key in enumerate(keys):
            if label_dict[key] == 1:
                return i
        return 0  # Default to healthy if no injury found

    kidney = get_class(['kidney_healthy', 'kidney_low', 'kidney_high'])
    liver = get_class(['liver_healthy', 'liver_low', 'liver_high'])
    spleen = get_class(['spleen_healthy', 'spleen_low', 'spleen_high'])

    return {
        "bowel": float(bowel),
        "extra": float(extra),
        "kidney": kidney,
        "liver": liver,
        "spleen": spleen,
    }

### Custom Dataset Class Implementation

In [None]:
class RSNADataset(Dataset):
    """
    Custom dataset class for loading and processing 3D CT scans and their labels.
    """
    def __init__(self, metadata_list, transforms=None, has_labels=True):
        """
        Args:
            metadata_list: List of dictionaries with 'nifti_path' and 'labels'
            transforms: MONAI transforms to apply
            has_labels: Whether to return labels (True for train/val, False for test)
        """
        self.metadata_list = metadata_list
        self.transforms = transforms
        self.has_labels = has_labels

    def __len__(self):
        return len(self.metadata_list)

    def __getitem__(self, idx):
        entry = self.metadata_list[idx]

        # Load NIfTI file
        nifti_path = entry["nifti_path"]
        nifti_img = nib.load(nifti_path)
        volume = nifti_img.get_fdata().astype(np.float32)

        # Rearrange dimensions from (X, Y, Z) to (Z, Y, X)
        volume = np.transpose(volume, (2, 1, 0))

        # Add channel dimension: (1, Z, Y, X)
        volume = np.expand_dims(volume, axis=0)

        # Wrap in MetaTensor for MONAI compatibility
        meta = {"original_channel_dim": 0}
        sample = MetaTensor(volume, meta=meta)

        # Apply transforms if provided
        if self.transforms:
            sample = self.transforms(sample)

        # Package label if available
        if self.has_labels:
            targets = convert_labels_to_targets(entry['labels'])
            label_array = np.array([
                targets["bowel"],
                targets["extra"],
                targets["kidney"],
                targets["liver"],
                targets["spleen"],
            ], dtype=np.float32)
            
            sample = {
                "image": sample,
                "label": label_array,
            }
        else:
            sample = {"image": sample}

        return sample

### Balanced DataLoader Creation

In [None]:
def prepare_balanced_dataloaders(metadata_list, train_transforms, val_transforms, test_transforms, config):
    """
    Create balanced dataloaders by oversampling minority classes.
    Returns train, validation, and test datasets and dataloaders.
    """
    def filter_by_condition(label_key, value=1):
        return [m for m in metadata_list if m["labels"].get(label_key, 0) == value]

    def filter_multi_organ_injury(min_injuries=2):
        return [m for m in metadata_list if sum([m["labels"].get(k, 0) for k in config.TARGET_COLS]) >= min_injuries]

    # Use a dict to avoid duplicates (keyed by nifti_path)
    balanced_dict = {}

    def add_to_balanced(samples):
        for sample in samples:
            balanced_dict[sample["nifti_path"]] = sample

    # Balance bowel (binary)
    bowel_injury = filter_by_condition("bowel_injury")
    bowel_healthy = np.random.choice(filter_by_condition("bowel_healthy"), size=len(bowel_injury), replace=False)
    add_to_balanced(bowel_injury)
    add_to_balanced(bowel_healthy)

    # Balance extravasation (binary)
    extrav_injury = filter_by_condition("extravasation_injury")
    extrav_healthy = np.random.choice(filter_by_condition("extravasation_healthy"), size=len(extrav_injury), replace=False)
    add_to_balanced(extrav_injury)
    add_to_balanced(extrav_healthy)

    # Balance kidney (3-class)
    kidney_low = filter_by_condition("kidney_low")
    kidney_high = filter_by_condition("kidney_high")
    kidney_healthy = np.random.choice(filter_by_condition("kidney_healthy"), size=len(kidney_low) + len(kidney_high), replace=False)
    add_to_balanced(kidney_low)
    add_to_balanced(kidney_high)
    add_to_balanced(kidney_healthy)

    # Balance liver (3-class)
    liver_low = filter_by_condition("liver_low")
    liver_high = filter_by_condition("liver_high")
    liver_healthy = np.random.choice(filter_by_condition("liver_healthy"), size=len(liver_low) + len(liver_high), replace=False)
    add_to_balanced(liver_low)
    add_to_balanced(liver_high)
    add_to_balanced(liver_healthy)

    # Balance spleen (3-class)
    spleen_low = filter_by_condition("spleen_low")
    spleen_high = filter_by_condition("spleen_high")
    spleen_healthy = np.random.choice(filter_by_condition("spleen_healthy"), size=len(spleen_low) + len(spleen_high), replace=False)
    add_to_balanced(spleen_low)
    add_to_balanced(spleen_high)
    add_to_balanced(spleen_healthy)

    # Add multi-organ injury samples
    multi_organ_samples = filter_multi_organ_injury(min_injuries=2)
    np.random.shuffle(multi_organ_samples)
    multi_sample_limit = min(100, len(multi_organ_samples))  # Limit to 100 or available
    add_to_balanced(multi_organ_samples[:multi_sample_limit])

    # Convert to unique list
    balanced_metadata = list(balanced_dict.values())

    # Split the data
    if config.SPLIT_MODE == "random":
        def simplified_strat_key(m):
            # Total number of positive labels across all target columns
            return sum([m["labels"].get(k, 0) for k in config.TARGET_COLS])

        strat_labels = [simplified_strat_key(m) for m in balanced_metadata]

        train_meta, temp_meta = train_test_split(
            balanced_metadata,
            test_size=0.3,
            random_state=42,
            stratify=strat_labels
        )

        val_labels = [simplified_strat_key(m) for m in temp_meta]
        val_meta, test_meta = train_test_split(
            temp_meta,
            test_size=0.5,
            random_state=42,
            stratify=val_labels
        )

    elif config.SPLIT_MODE == "group":
        train_meta, val_meta, test_meta = split_metadata_train_val_test(
            balanced_metadata,
            target_cols=config.TARGET_COLS,
            val_size=0.15,
            test_size=0.15,
            seed=42
        )

    # Create Dataset and DataLoaders
    train_ds = RSNADataset(train_meta, transforms=train_transforms)
    val_ds = RSNADataset(val_meta, transforms=val_transforms)
    test_ds = RSNADataset(test_meta, transforms=test_transforms)

    return (
        train_ds,
        val_ds,
        test_ds,
        DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True),
        DataLoader(val_ds, batch_size=config.BATCH_SIZE, shuffle=False),
        DataLoader(test_ds, batch_size=1, shuffle=False),
        train_meta
    )

### Stratified Data Splitting

In [None]:
def split_metadata_train_val_test(metadata_list, target_cols, val_size=0.1, test_size=0.1, seed=42):
    """
    Custom stratified split to maintain label distribution across train, validation, and test sets.
    """
    # Convert metadata into DataFrame
    df = pd.DataFrame(metadata_list)

    # Extract individual labels into separate columns
    label_df = pd.json_normalize(df['labels'])
    df = pd.concat([df.drop(columns='labels'), label_df], axis=1)

    # Group rows by all target label combinations
    grouped = df.groupby(target_cols)

    # Initialize empty splits
    train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    val_test_size = val_size + test_size

    for _, group in grouped:
        n = len(group)
        if n == 1:
            # For single-sample groups, randomly assign to one set
            r = np.random.rand()
            if r < test_size:
                test_df = pd.concat([test_df, group], ignore_index=True)
            elif r < val_test_size:
                val_df = pd.concat([val_df, group], ignore_index=True)
            else:
                train_df = pd.concat([train_df, group], ignore_index=True)
        else:
            # For larger groups, do proper stratified splits
            train_split, val_test_split = train_test_split(group, test_size=val_test_size, random_state=seed)

            if len(val_test_split) < 2:
                val_split = val_test_split
                test_split = pd.DataFrame()
            else:
                relative_test_size = test_size / val_test_size if val_test_size > 0 else 0
                val_split, test_split = train_test_split(val_test_split, test_size=relative_test_size, random_state=seed)

            train_df = pd.concat([train_df, train_split], ignore_index=True)
            val_df = pd.concat([val_df, val_split], ignore_index=True)
            test_df = pd.concat([test_df, test_split], ignore_index=True)

    # Convert DataFrame rows back to metadata format
    def row_to_metadata(row):
        return {
            "nifti_path": row["nifti_path"],
            "labels": {col: row[col] for col in target_cols}
        }

    train_list = [row_to_metadata(row) for _, row in train_df.iterrows()]
    val_list = [row_to_metadata(row) for _, row in val_df.iterrows()]
    test_list = [row_to_metadata(row) for _, row in test_df.iterrows()]

    return train_list, val_list, test_list


In [None]:
# Prepare Data Using the Loader Function (with subset for tuning)
train_ds, val_ds, test_ds, train_loader, val_loader, test_loader, train_meta = prepare_balanced_dataloaders(
    metadata_list=metadata_list,
    train_transforms=train_transforms,
    val_transforms=val_transforms,
    test_transforms=test_transforms,
    config=config
)

print(f"Train size: {len(train_loader.dataset)}")
print(f"Val size:   {len(val_loader.dataset)}")
print(f"Test size:  {len(test_loader.dataset)}")

## Class Balancing and Weighted Sampling

In [None]:
def compute_sample_weights(metadata_list, target_cols, power=1.0):
    """
    Compute sample weights based on inverse class frequency to address class imbalance.
    
    Args:
        metadata_list: List of dictionaries containing label information
        target_cols: List of target columns to consider for weight calculation
        power: Exponent for weighting (higher values increase emphasis on rare classes)
    
    Returns:
        List of sample weights for each entry in metadata_list
    """
    # Extract label values for each sample
    label_rows = []
    for entry in metadata_list:
        row = [entry["labels"].get(col, 0) for col in target_cols]
        label_rows.append(row)
    
    # Convert to numpy array for vectorized operations
    labels_np = np.array(label_rows)
    
    # Calculate class frequencies with small epsilon to avoid division by zero
    class_freq = labels_np.mean(axis=0) + 1e-6
    
    # Compute class weights using inverse frequency weighting
    class_weights = 1.0 / (class_freq ** power)
    
    # Calculate final sample weights by combining class weights
    sample_weights = (labels_np * class_weights).sum(axis=1)
    return sample_weights.tolist()

# Compute sample weights with power=2.0 to strongly emphasize rare classes
sample_weights = compute_sample_weights(train_meta, config.TARGET_COLS, power=2.0)

# Create weighted sampler for training data
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),  # Maintain same dataset size
    replacement=True  # Allow sampling with replacement
)

# Create training DataLoader with weighted sampling
train_loader = DataLoader(
    train_ds,
    batch_size=config.BATCH_SIZE,
    sampler=sampler,  # Use our weighted sampler
    num_workers=4  # Parallel data loading
)

# Verify class distribution after sampling
sample_indices = list(sampler)
sampled_labels = [train_meta[i]["labels"] for i in sample_indices]

# Count occurrences of each label in the sampled data
counts = Counter()
for lbl in sampled_labels:
    for k, v in lbl.items():
        if v == 1:
            counts[k] += 1

print("Class distribution after weighted sampling:")
print(dict(counts))

# Analyze training set label distribution
converted_labels = [convert_labels_to_targets(entry['labels']) for entry in train_ds.metadata_list]

# Count occurrences for each organ/condition
bowel_dist = Counter([x['bowel'] for x in converted_labels])
extra_dist = Counter([x['extra'] for x in converted_labels])
kidney_dist = Counter([x['kidney'] for x in converted_labels])
liver_dist = Counter([x['liver'] for x in converted_labels])
spleen_dist = Counter([x['spleen'] for x in converted_labels])

print("\nTraining Set Label Distribution:")
print("Bowel:", dict(bowel_dist))
print("Extravasation:", dict(extra_dist))
print("Kidney:", dict(kidney_dist))
print("Liver:", dict(liver_dist))
print("Spleen:", dict(spleen_dist))

## Model Architecture Implementation

In [None]:
class DenseNet121model(nn.Module):
    """Custom 3D DenseNet121 model with multiple classification heads.
    
    Features:
    - MONAI's 3D DenseNet121 backbone with Global Average Pooling
    - Separate heads for each prediction task (binary and multi-class)
    - Integrated Grad-CAM functionality for visualization
    """
    
    def __init__(self, in_channels=1, pretrained=False):
        """Initialize the model architecture.
        
        Args:
            in_channels: Number of input channels (1 for grayscale CT scans)
            pretrained: Whether to use pretrained weights for the backbone
        """
        super().__init__()
        
        # Variables for Grad-CAM visualization
        self.activations = None  # Stores layer activations
        self.gradients = None   # Stores gradients for visualization
        
        # Initialize DenseNet121 backbone with Global Average Pooling
        self.backbone = DenseNet121(
            spatial_dims=3,       # 3D version for volumetric data
            in_channels=in_channels,
            out_channels=512,     # Feature dimension after GAP
            pretrained=pretrained
        )
        
        # Register hook to capture activations from last convolutional layer
        self.backbone.features[-1].register_forward_hook(self.save_activation)
        
        # Initialize task-specific classification heads
        self.bowel_head = self._create_binary_head()    # Binary classification
        self.extra_head = self._create_binary_head()    # Binary classification  
        self.liver_head = self._create_multiclass_head()  # 3-class classification
        self.kidney_head = self._create_multiclass_head() # 3-class classification
        self.spleen_head = self._create_multiclass_head() # 3-class classification
        
    def _create_binary_head(self):
        """Create a binary classification head (1 output neuron with sigmoid)."""
        return nn.Sequential(
            nn.Linear(512, 256),    # FC layer
            nn.BatchNorm1d(256),    # Batch normalization
            nn.SiLU(),             # Swish activation
            nn.Dropout(0.3),        # Regularization
            nn.Linear(256, 1)       # Final output
        )
    
    def _create_multiclass_head(self):
        """Create a multi-class classification head (3 output neurons with softmax)."""
        return nn.Sequential(
            nn.Linear(512, 256),    # FC layer
            nn.BatchNorm1d(256),    # Batch normalization  
            nn.SiLU(),              # Swish activation
            nn.Dropout(0.3),        # Regularization
            nn.Linear(256, 3)       # Final output (3 classes)
        )
    
    def save_activation(self, module, input, output):
        """Hook function to save activations for Grad-CAM visualization."""
        self.activations = output
        if output.requires_grad:
            output.register_hook(self.save_gradient)
    
    def save_gradient(self, grad):
        """Hook function to save gradients for Grad-CAM visualization."""
        self.gradients = grad
    
    def forward(self, x):
        """Forward pass through the network.
        
        Args:
            x: Input tensor (batch of 3D CT scans)
            
        Returns:
            Dictionary containing predictions for all targets
        """
        # Prepare for Grad-CAM if in training mode
        if x.requires_grad:
            x.register_hook(self.save_gradient)
        self.activations = x
        
        # Extract features using backbone (output shape: [batch_size, 512])
        features = self.backbone(x)
        
        # Return predictions for all target organs
        return {
            "bowel": self.bowel_head(features),
            "extra": self.extra_head(features),
            "liver": self.liver_head(features),
            "kidney": self.kidney_head(features),
            "spleen": self.spleen_head(features)
        }
    
    def get_activations_gradient(self):
        """Get gradients for visualization purposes."""
        return self.gradients
    
    def get_activations(self):
        """Get activations for visualization purposes."""
        return self.activations

In [None]:
# Initialize model and move to appropriate device (GPU if available)
model = DenseNet121model().to(device)
print(f"Model initialized and moved to device: {next(model.parameters()).device}")

## Custom Loss Functions Implementation

In [None]:
class BinaryFocalLoss(nn.Module):
    """Focal loss for binary classification tasks.
    
    Addresses class imbalance by down-weighting well-classified examples.
    Combines BCEWithLogitsLoss with focal loss adjustment.
    """
    def __init__(self, pos_weight=None, gamma=2.0, reduction='mean'):
        """
        Args:
            pos_weight: Weight for positive class (for class balancing)
            gamma: Focusing parameter (higher values down-weight easy examples more)
            reduction: 'mean', 'sum' or None for loss reduction
        """
        super().__init__()
        self.pos_weight = pos_weight  # Weight for positive class
        self.gamma = gamma           # Focusing parameter
        self.reduction = reduction   # Loss reduction method

    def forward(self, inputs, targets):
        """Compute focal loss for binary classification.
        
        Args:
            inputs: Raw model outputs (logits, not sigmoided)
            targets: Ground truth labels (0 or 1)
        """
        # Compute standard binary cross entropy loss
        bce_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, 
            pos_weight=self.pos_weight, 
            reduction='none'
        )
        
        # Compute probabilities from logits
        probs = torch.sigmoid(inputs)
        
        # Calculate modulating factor (1 - p_t)^gamma
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_factor = (1 - p_t) ** self.gamma
        
        # Apply focal factor to BCE loss
        loss = focal_factor * bce_loss

        # Apply reduction if specified
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

In [None]:
class MultiClassFocalLoss(nn.Module):
    """Focal loss for multi-class classification tasks.
    
    Addresses class imbalance by down-weighting well-classified examples.
    Combines CrossEntropyLoss with focal loss adjustment.
    """
    def __init__(self, weight=None, gamma=2.0, reduction='mean'):
        """
        Args:
            weight: Class weights tensor (for class balancing)
            gamma: Focusing parameter
            reduction: 'mean', 'sum' or None for loss reduction
        """
        super().__init__()
        self.weight = weight    # Class weights
        self.gamma = gamma      # Focusing parameter
        self.reduction = reduction

    def forward(self, inputs, targets):
        """Compute focal loss for multi-class classification.
        
        Args:
            inputs: Raw model outputs (logits, not softmaxed)
            targets: Ground truth class indices
        """
        # Compute standard cross entropy loss
        ce_loss = F.cross_entropy(
            inputs, targets, 
            weight=self.weight, 
            reduction='none'
        )
        
        # Compute probabilities from logits
        probs = F.softmax(inputs, dim=1)
        
        # Gather probabilities of true classes
        targets_unsq = targets.unsqueeze(1)
        p_t = probs.gather(1, targets_unsq).squeeze(1)
        
        # Calculate modulating factor (1 - p_t)^gamma
        focal_factor = (1 - p_t) ** self.gamma
        
        # Apply focal factor to CE loss
        loss = focal_factor * ce_loss

        # Apply reduction if specified
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

### Label Smoothing Focal Loss Implementation

In [None]:
class LabelSmoothingFocalLoss(nn.Module):
    """Combines label smoothing with focal loss for multi-class classification.
    
    Particularly effective for highly imbalanced datasets.
    """
    def __init__(self, gamma=2.0, smoothing=0.1, weight=None, reduction='mean'):
        """
        Args:
            gamma: Focusing parameter
            smoothing: Label smoothing factor (0-1)
            weight: Class weights tensor
            reduction: 'mean', 'sum' or None for loss reduction
        """
        super(LabelSmoothingFocalLoss, self).__init__()
        self.gamma = gamma          # Focusing parameter
        self.smoothing = smoothing  # Label smoothing factor
        self.weight = weight        # Class weights
        self.reduction = reduction  # Loss reduction method

    def forward(self, inputs, targets):
        """Compute label smoothed focal loss.
        
        Args:
            inputs: Raw model outputs (logits)
            targets: Ground truth class indices
        """
        num_classes = inputs.size(1)
        
        # Create smoothed label distribution
        with torch.no_grad():
            true_dist = torch.zeros_like(inputs)
            true_dist.fill_(self.smoothing / (num_classes - 1))
            true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)

        # Compute log probabilities
        log_probs = F.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)
        
        # Calculate focal factor
        focal_factor = (1 - probs).pow(self.gamma)

        # Apply class weights if provided
        if self.weight is not None:
            weight = self.weight.unsqueeze(0)  # shape (1, C)
            log_probs = log_probs * weight

        # Compute final loss
        loss = -true_dist * focal_factor * log_probs
        loss = loss.sum(dim=1)

        # Apply reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

## Class Weight Calculation

In [None]:
# Class counts from oversampled training data
oversampled_counts = {
    'extravasation_healthy': 1396,
    'extravasation_injury': 447,
    'bowel_healthy': 1526,
    'bowel_injury': 317,
    'kidney_healthy': 1173,
    'kidney_low': 334,
    'kidney_high': 336,
    'liver_healthy': 1167,
    'liver_low': 349,
    'liver_high': 327,
    'spleen_healthy': 1180,
    'spleen_low': 331,
    'spleen_high': 332
}

def compute_class_weights_from_counts(counts_dict, device='cpu', pos_weight_cap=10.0, balance_threshold=1.5):
    """Calculate class weights from count statistics.
    
    Args:
        counts_dict: Dictionary of class counts
        device: Target device for weight tensors
        pos_weight_cap: Maximum value for positive class weights
        balance_threshold: Ratio threshold to consider classes balanced
        
    Returns:
        bce_weights: Dictionary of weights for binary classification tasks
        ce_weights: Dictionary of weights for multi-class classification tasks
    """
    bce_weights = {}

    # Binary classes (bowel, extravasation)
    for key in ["bowel", "extravasation"]:
        pos = counts_dict[f"{key}_injury"]
        neg = counts_dict[f"{key}_healthy"]

        # Avoid division by zero
        pos = max(pos, 1e-6)
        neg = max(neg, 1e-6)
        ratio = neg / pos

        # Apply weighting only if significantly imbalanced
        if 1 / balance_threshold <= ratio <= balance_threshold:
            pos_weight = 1.0
        else:
            pos_weight = min(ratio, pos_weight_cap)

        bce_weights[key if key != "extravasation" else "extra"] = torch.tensor(
            pos_weight, dtype=torch.float32, device=device)

    # Multi-class weights (kidney, liver, spleen)
    ce_weights = {}
    for organ in ["kidney", "liver", "spleen"]:
        healthy = counts_dict.get(f"{organ}_healthy", 0)
        low = counts_dict.get(f"{organ}_low", 0)
        high = counts_dict.get(f"{organ}_high", 0)

        # Calculate inverse frequency weights
        counts = torch.tensor([healthy, low, high], dtype=torch.float32, device=device)
        counts = torch.clamp(counts, min=1e-6)
        inv_freq = 1.0 / counts
        weights = inv_freq / inv_freq.sum()
        ce_weights[organ] = weights

    return bce_weights, ce_weights

## Loss Function and Optimizer Setup

In [None]:
# Calculate class weights from oversampled counts
bce_weights, ce_weights = compute_class_weights_from_counts(
    oversampled_counts, 
    device=device
)

# Initialize loss functions for each prediction head
loss_fn_dict = {
    "bowel": nn.BCEWithLogitsLoss(pos_weight=bce_weights["bowel"]),
    "extra": nn.BCEWithLogitsLoss(pos_weight=bce_weights["extra"]),
    "kidney": LabelSmoothingFocalLoss(
        weight=ce_weights["kidney"], 
        gamma=1.5, 
        smoothing=0.07
    ),
    "liver": LabelSmoothingFocalLoss(
        weight=ce_weights["liver"], 
        gamma=1.5, 
        smoothing=0.07
    ),
    "spleen": LabelSmoothingFocalLoss(
        weight=ce_weights["spleen"], 
        gamma=1.5, 
        smoothing=0.07
    ),
}

# Initialize optimizer with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=1e-4, 
    weight_decay=5e-6  # L2 regularization
)

# Learning rate scheduler with warm restarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=15,       # Number of epochs before first restart
    T_mult=2,     # Period multiplier after each restart
    eta_min=1e-6  # Minimum learning rate
)

## Training Utilities and Configuration

In [None]:
def compute_head_weights_from_losses(avg_losses, head_priority=None, normalize=True, min_clip=1e-6):
    """
    Compute task weights based on per-head losses to balance multi-task learning.
    
    Args:
        avg_losses: Dictionary of average losses per task/head
        head_priority: Optional priority weights for specific heads
        normalize: Whether to normalize weights to mean 1.0
        min_clip: Minimum value to avoid division by zero
        
    Returns:
        Dictionary of weights for each task/head
    """
    head_priority = head_priority or {}
    # Compute inverse losses weighted by priority
    inv_losses = {
        k: (1.0 / max(v, min_clip)) * head_priority.get(k, 1.0)
        for k, v in avg_losses.items()
    }
    
    if normalize:
        mean_inv = sum(inv_losses.values()) / len(inv_losses)
        print('Using normalized task weights')
        return {k: v / mean_inv for k, v in inv_losses.items()}
    print('Using priority-weighted task weights')
    return inv_losses

## Training Loop Implementation

In [None]:
def train_one_epoch(model, loader, optimizer, loss_fn_dict, scheduler=None, 
                   grad_clip=None, debug=False, task_weights=None):
    """
    Train model for one epoch with optional gradient clipping and debugging.
    
    Args:
        model: Model to train
        loader: DataLoader for training data
        optimizer: Optimization algorithm
        loss_fn_dict: Dictionary of loss functions per task
        scheduler: Optional learning rate scheduler
        grad_clip: Maximum gradient norm for clipping
        debug: Whether to print debug information
        task_weights: Optional weights for each task
        
    Returns:
        Average loss and per-task losses for the epoch
    """
    model.train()
    running_loss = 0.0
    task_losses = defaultdict(float)
    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training", leave=False)

    for batch_idx, batch in pbar:
        # Prepare batch data
        inputs = batch["image"].to(device, dtype=torch.float32)
        labels = batch["label"].to(device, dtype=torch.float32)

        # Add channel dimension if needed
        if inputs.ndim == 4:
            inputs = inputs.unsqueeze(1)

        # Forward pass
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)

        # Prepare targets for each task
        targets = {
            "bowel": labels[:, 0].float(),   # Binary label
            "extra": labels[:, 1].float(),   # Binary label
            "kidney": labels[:, 2].long(),   # Class index: 0,1,2
            "liver": labels[:, 3].long(),
            "spleen": labels[:, 4].long(),
        }

        loss = torch.tensor(0.0, device=device, dtype=torch.float32)

        # Compute loss for each task
        for key in outputs:
            pred = outputs[key]
            target = targets[key]

            try:
                if key in ["bowel", "extra"]:
                    # Binary classification tasks
                    pred = pred.squeeze(-1) if pred.ndim > 1 else pred
                    task_loss = loss_fn_dict[key](pred, target)
                else:
                    # Multi-class classification tasks
                    task_loss = loss_fn_dict[key](pred, target.long())
            except Exception as e:
                print(f"Error in loss for {key} @ batch {batch_idx}: {e}")
                print(f"   Pred shape: {pred.shape}, Target shape: {target.shape}")
                continue

            if torch.isnan(task_loss).any():
                print(f"NaN in {key} loss @ batch {batch_idx}")
                continue

            task_losses[key] += task_loss.item()
            loss += task_loss

            # Debug output every 20 batches
            if debug and batch_idx % 20 == 0:
                print(f"\n[Debug Batch {batch_idx}]")
                for key in outputs:
                    pred_logits = outputs[key].detach().cpu().numpy()
                    target_vals = targets[key].cpu().numpy()
            
                    if key in ["bowel", "extra"]:
                        print(f"[{key}]")
                        print(f"   Pred logits: {np.round(pred_logits[:3].squeeze(), 4)}")
                        print(f"   Target:      {np.round(target_vals[:3], 4)}")
                    else:
                        print(f"[{key}]")
                        print(f"   Pred logits:\n{np.round(pred_logits[:3], 4)}")
                        print(f"   Target:      {target_vals[:3]}")

        # Skip batch if NaN loss
        if torch.isnan(loss).any():
            print(f"Skipping NaN total loss @ batch {batch_idx}")
            continue

        # Backward pass and optimization
        loss.backward()
        if grad_clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        # Update running loss
        running_loss += loss.item()
        avg_loss = running_loss / (batch_idx + 1)
        pbar.set_postfix({"loss": avg_loss})

    # Calculate average task losses
    task_losses = {k: v / len(loader) for k, v in task_losses.items()}
    
    # Print training statistics
    print("Organ-wise Training Losses:")
    for organ, organ_loss in task_losses.items():
        print(f"   {organ}: {organ_loss:.4f}")

    print(f'Train loss: {avg_loss}')
    
    # Print current learning rate
    for param_group in optimizer.param_groups:
        print(f"Current LR: {param_group['lr']:.6f}")
    
    return avg_loss, task_losses

## Validation and Evaluation

In [None]:
@torch.no_grad()
def validate(model, loader, loss_fn_dict, debug=False, thresholds=None, tta_fns=None):
    """
    Validate model performance on validation set.
    
    Args:
        model: Model to evaluate
        loader: DataLoader for validation data
        loss_fn_dict: Dictionary of loss functions per task
        debug: Whether to print debug information
        thresholds: Decision thresholds for each task
        tta_fns: List of test-time augmentation functions
        
    Returns:
        Validation loss, metrics, per-task losses, and detailed metrics
    """
    model.eval()
    val_loss = 0.0
    val_task_losses = defaultdict(float)
    pbar = tqdm(loader, desc="Validation", leave=False)

    # Initialize storage for predictions and targets
    all_preds = defaultdict(list)
    all_probs = defaultdict(list)
    all_targets = defaultdict(list)

    # Set default thresholds if not provided
    if thresholds is None:
        thresholds = {
            "bowel": {"thresholds": [0.01, 0.45], "best_f1": [0.9823, 0.7059]},
            "extra": {"thresholds": [0.02, 0.84], "best_f1": [0.9493, 0.5607]},
            "kidney": {"thresholds": [0.06, 0.79, 0.67], "best_f1": [0.9503, 0.6154, 0.6667]},
            "liver": {"thresholds": [0.13, 0.60, 0.75], "best_f1": [0.9054, 0.5077, 0.6400]},
            "spleen": {"thresholds": [0.14, 0.57, 0.56], "best_f1": [0.8905, 0.5660, 0.4938]}
        }
    elif isinstance(thresholds, dict) and not isinstance(list(thresholds.values())[0], dict):
        thresholds = {k: {"thresholds": [v]} for k, v in thresholds.items()}

    def tta_forward_batch(model, inputs, tta_fns):
        """Apply test-time augmentation and average predictions."""
        if not tta_fns:
            return model(inputs)
        logits_per_tta = []
        for fn in tta_fns:
            aug_inp = fn(inputs)
            logits_per_tta.append(model(aug_inp))
        # Average predictions across augmentations
        avg_logits = {}
        for k in logits_per_tta[0].keys():
            avg_logits[k] = torch.stack([d[k] for d in logits_per_tta], dim=0).mean(dim=0)
        return avg_logits

    for batch_idx, batch in enumerate(pbar):
        # Prepare batch data
        inputs = batch["image"].to(device, dtype=torch.float32)
        labels = batch["label"].to(device, dtype=torch.float32)

        if inputs.ndim == 4:
            inputs = inputs.unsqueeze(1)

        # Forward pass with optional TTA
        outputs = tta_forward_batch(model, inputs, tta_fns)

        # Prepare targets for each task
        targets = {
            "bowel": labels[:, 0].float(),
            "extra": labels[:, 1].float(),
            "kidney": labels[:, 2].long(),
            "liver": labels[:, 3].long(),
            "spleen": labels[:, 4].long(),
        }

        batch_loss = 0.0
        
        # Process each task separately
        for key in outputs:
            pred = outputs[key]
            target = targets[key]

            # Binary classification tasks
            if key in ["bowel", "extra"]:
                probs = torch.sigmoid(pred).squeeze(-1).cpu().numpy()
                probs = np.stack([1 - probs, probs], axis=1)  # shape (batch, 2)
                all_probs[key].extend(probs[:, 1])  # Store positive class probabilities
                
                # Apply class-specific thresholds
                class_preds = np.zeros_like(probs[:, 0], dtype=int)
                for class_idx in [0, 1]:
                    class_mask = (probs[:, class_idx] >= thresholds[key]["thresholds"][class_idx])
                    class_preds[class_mask] = class_idx
                
                target_np = target.cpu().numpy().astype(int)
                all_preds[key].extend(class_preds)
                all_targets[key].extend(target_np)
                
                loss = loss_fn_dict[key](pred.view(-1), target.float().view(-1))

            # Multi-class classification tasks
            else:
                probs = F.softmax(pred, dim=1).cpu().numpy()
                all_probs[key].extend(probs)
                
                # Apply per-class thresholds
                final_preds = np.zeros(probs.shape[0], dtype=int)
                for i in range(probs.shape[0]):
                    valid_classes = [c for c in range(probs.shape[1]) 
                                  if probs[i, c] >= thresholds[key]["thresholds"][c]]
                    
                    if valid_classes:
                        final_preds[i] = valid_classes[np.argmax(probs[i, valid_classes])]
                    else:
                        final_preds[i] = 0  # Default to healthy class
                
                target_np = target.cpu().numpy()
                all_preds[key].extend(final_preds)
                all_targets[key].extend(target_np)

                loss = loss_fn_dict[key](pred, target)

            val_task_losses[key] += loss.item()
            batch_loss += loss.item()

    val_loss += batch_loss
    pbar.set_postfix({"val_loss": val_loss / (pbar.n + 1)})

    # Calculate comprehensive metrics
    metrics = {}
    history_metrics = {}  # Simplified metrics for training history
    
    for organ in all_preds:
        y_true = np.array(all_targets[organ])
        y_pred = np.array(all_preds[organ])
        y_probs = np.array(all_probs[organ])
        
        # Binary classification metrics
        if organ in ["bowel", "extra"]:
            precision, recall, f1, _ = precision_recall_fscore_support(
                y_true, y_pred, average='binary', zero_division=0
            )
            accuracy = accuracy_score(y_true, y_pred)
            try:
                roc_auc = roc_auc_score(y_true, y_probs)
            except ValueError:
                roc_auc = 0.0
            
            # Confusion matrix components
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            
            metrics[organ] = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'roc_auc': roc_auc,
                'accuracy': accuracy,
                'confusion_matrix': {
                    'true_negative': tn,
                    'false_positive': fp,
                    'false_negative': fn,
                    'true_positive': tp
                }
            }
            
            history_metrics[organ] = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'roc_auc': roc_auc,
                'accuracy': accuracy
            }
        
        # Multi-class classification metrics
        else:
            # Per-class metrics
            precision, recall, f1, support = precision_recall_fscore_support(
                y_true, y_pred, average=None, zero_division=0
            )
            
            # Macro averages
            macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
                y_true, y_pred, average='macro', zero_division=0
            )
            
            accuracy = accuracy_score(y_true, y_pred)
            
            # Multi-class ROC AUC
            try:
                if y_probs.ndim == 2:
                    roc_auc = roc_auc_score(y_true, y_probs, multi_class='ovr')
                else:
                    roc_auc = 0.0
            except ValueError:
                roc_auc = 0.0
            
            # Full confusion matrix
            cm = confusion_matrix(y_true, y_pred)
            
            metrics[organ] = {
                'precision': precision.tolist(),
                'recall': recall.tolist(),
                'f1': f1.tolist(),
                'support': support.tolist(),
                'macro_precision': macro_precision,
                'macro_recall': macro_recall,
                'macro_f1': macro_f1,
                'accuracy': accuracy,
                'roc_auc': roc_auc,
                'confusion_matrix': cm.tolist()
            }
            
            history_metrics[organ] = {
                'precision': macro_precision,
                'recall': macro_recall,
                'f1': macro_f1,
                'roc_auc': roc_auc,
                'accuracy': accuracy
            }

    val_task_losses = {k: v / len(loader) for k, v in val_task_losses.items()}
    return val_loss / len(loader), history_metrics, val_task_losses, metrics

## Model Checkpoint Management

In [None]:
# Initialize Kaggle API
api = KaggleApi()
api.authenticate()

def upload_to_kaggle_model(dataset_owner, dataset_slug, model_path, checkpoint_path=None, version_note=""):
    """
    Upload model artifacts to Kaggle datasets.
    
    Args:
        dataset_owner: Owner of the target dataset
        dataset_slug: Name of the target dataset
        model_path: Path to model file
        checkpoint_path: Path to checkpoint file (optional)
        version_note: Description for dataset version
    """
    import json

    zip_path = "/kaggle/working/model_upload.zip"

    # Create zip archive of model files
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        zipf.write(model_path, arcname=os.path.basename(model_path))
        if checkpoint_path:
            zipf.write(checkpoint_path, arcname=os.path.basename(checkpoint_path))

    print(f"Zipped model(s) to: {zip_path}")

    # Create dataset metadata
    metadata = {
        "title": f"{dataset_slug} model",
        "id": f"{dataset_owner}/{dataset_slug}",
        "licenses": [{"name": "CC0-1.0"}]
    }
    metadata_path = "/kaggle/working/dataset-metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Created metadata file at {metadata_path}")

    # Upload new dataset version
    api.dataset_create_version(
        folder="/kaggle/working",
        version_notes=version_note,
        delete_old_versions=False,
        convert_to_csv=False
    )
    print(f"Uploaded to Kaggle Dataset: {dataset_owner}/{dataset_slug}")

def extract_checkpoint_from_dataset(dataset_input_path="/kaggle/input/bestest-dataset",
                                  extract_path="/kaggle/working"):
    """
    Extract model checkpoint files from Kaggle dataset.
    
    Args:
        dataset_input_path: Path to input dataset
        extract_path: Destination path for extracted files
        
    Returns:
        Path where files were extracted, or None if no files found
    """
    checkpoint_src = os.path.join(dataset_input_path, "checkpoint.pth")
    best_model_src = os.path.join(dataset_input_path, "model_best.pth")

    copied = False

    if os.path.exists(checkpoint_src):
        shutil.copy(checkpoint_src, os.path.join(extract_path, "checkpoint.pth"))
        print(f"Copied checkpoint.pth to {extract_path}")
        copied = True
    else:
        print("checkpoint.pth not found in dataset input.")

    if os.path.exists(best_model_src):
        shutil.copy(best_model_src, os.path.join(extract_path, "model_best.pth"))
        print(f"Copied model_best.pth to {extract_path}")
        copied = True
    else:
        print("model_best.pth not found in dataset input.")

    return extract_path if copied else None

## Main Training Function

In [None]:
def train(model, train_loader, val_loader, optimizer, scheduler, loss_fn_dict, num_epochs,
         save_dir="/kaggle/working", resume=False,
         upload_to_kaggle=False, dataset_owner=None, dataset_slug=None, 
         hyperparam_note="", grad_clip=None,
         custom_thresholds={"bowel": 0.76, "extra": 0.78}):
    """
    Main training loop with checkpointing and optional Kaggle uploads.
    
    Args:
        model: Model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        optimizer: Optimization algorithm
        scheduler: Learning rate scheduler
        loss_fn_dict: Dictionary of loss functions per task
        num_epochs: Total number of training epochs
        save_dir: Directory to save checkpoints
        resume: Whether to resume from checkpoint
        upload_to_kaggle: Whether to upload to Kaggle
        dataset_owner: Kaggle dataset owner
        dataset_slug: Kaggle dataset name
        hyperparam_note: Notes about hyperparameters
        grad_clip: Gradient clipping value
        custom_thresholds: Decision thresholds for evaluation
        
    Returns:
        Trained model and training history
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    os.makedirs(save_dir, exist_ok=True)

    # Define test-time augmentation functions
    def id_fn(x): return x          # Identity
    def flip0(x): return torch.flip(x, dims=[2])  # Flip depth
    def flip1(x): return torch.flip(x, dims=[3])  # Flip height
    def flip2(x): return torch.flip(x, dims=[4])  # Flip width
    def rot90_hw(x): return torch.rot90(x, k=1, dims=(3, 4))  # Rotate 90°
        
    best_val_loss = float('inf')
    start_epoch = 0
    prev_val_task_losses = None

    # Paths for model checkpoints
    checkpoint_path = os.path.join(save_dir, "checkpoint.pth")
    best_model_path = os.path.join(save_dir, "model_best.pth")

    # Initialize training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'metrics': {
            organ: {
                'precision': [], 'recall': [], 'f1': [], 'roc_auc': [], 'accuracy': [],
                'train_loss': [], 'val_loss': []
            }
            for organ in ['bowel', 'extra', 'kidney', 'liver', 'spleen']
        },
        'task_weights': {organ: [] for organ in ['bowel', 'extra', 'kidney', 'liver', 'spleen']}
    }

    # Resume from checkpoint if requested
    if resume:
        extracted_path = extract_checkpoint_from_dataset()
        if extracted_path:
            for file in ["checkpoint.pth", "model_best.pth"]:
                src = os.path.join(extracted_path, file)
                dst = os.path.join(save_dir, file)
                if os.path.exists(src):
                    os.replace(src, dst)
                    print(f"Copied {file} to {dst}")

    if resume and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', best_val_loss)
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        history = checkpoint.get('history', history)
        prev_val_task_losses = checkpoint.get('prev_val_task_losses', None)
        print(f"Resumed from epoch {start_epoch}")

        # Task priority weights
        head_priority = {
            "bowel": 0.9,    # More balanced
            "extra": 1.25,   # Slightly less balanced
            "kidney": 1.1,   # More imbalanced
            "liver": 1,      # Medium imbalance
            "spleen": 1.2    # Most imbalanced
        }
    prev_val_task_losses = None

    tta_transforms_list = [id_fn, flip0, flip1, flip2]

    # Main training loop
    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch [{epoch + 1}/{num_epochs}]")

        # Compute dynamic task weights based on previous losses
        task_weights = {k: 1.0 for k in loss_fn_dict} if prev_val_task_losses is None \
            else compute_head_weights_from_losses(prev_val_task_losses, head_priority)

        # Train for one epoch
        train_loss, train_task_losses = train_one_epoch(
            model, train_loader, optimizer, loss_fn_dict,
            grad_clip=grad_clip,
            task_weights=task_weights
        )

        # Validate after training
        val_loss, metrics, val_task_losses, detailed_metrics = validate(
            model, val_loader, loss_fn_dict, debug=False,
            thresholds=custom_thresholds, tta_fns=None
        )
        
        prev_val_task_losses = val_task_losses  # Save for next epoch

        print("Dynamic Task Weights:")
        for organ, weight in task_weights.items():
            print(f"   {organ}: {weight:.4f}")

        # Update learning rate
        scheduler.step()

        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)

        for organ in metrics:
            for metric in metrics[organ]:
                history['metrics'][organ][metric].append(metrics[organ][metric])
            history['metrics'][organ]['train_loss'].append(train_task_losses.get(organ, None))
            history['metrics'][organ]['val_loss'].append(val_task_losses.get(organ, None))

        for organ in history['task_weights']:
            history['task_weights'][organ].append(task_weights.get(organ, None))

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'history': history,
            'prev_val_task_losses': prev_val_task_losses,
        }, checkpoint_path)

        # Save best model if improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved! Val Loss = {val_loss:.4f}")

            if upload_to_kaggle and dataset_owner and dataset_slug:
                note = f"{hyperparam_note} | Epoch {epoch + 1}, Val Loss: {val_loss:.4f}"
                upload_to_kaggle_model(dataset_owner, dataset_slug, 
                                     best_model_path, checkpoint_path, 
                                     version_note=note)
        else:
            print(f"No improvement. Val Loss = {val_loss:.4f}")
            if upload_to_kaggle and dataset_owner and dataset_slug:
                note = f"{hyperparam_note} | Epoch {epoch + 1}, No improvement"
                upload_to_kaggle_model(dataset_owner, dataset_slug, 
                                     model_path=checkpoint_path, 
                                     version_note=note)

    return model, history

## Training Configuration and Execution

In [None]:
# Dataset information for model saving
dataset_owner = "anusapkota"
dataset_slug = "abdominal-trauma-detection-final-model-dataset"
dataset_id = f"{dataset_owner}/{dataset_slug}"

# Custom prediction thresholds
custom_thresholds = {
    "bowel": {"thresholds": [0.0, 0.6]},
    "extra": {"thresholds": [0.0, 0.6]},
    "kidney": {"thresholds": [0.0, 0.45, 0.66]},
    "liver": {"thresholds": [0.0, 0.45, 0.66]},
    "spleen": {"thresholds": [0.0, 0.50, 0.66]}
}

# Main training execution
NUM_EPOCHS = config.EPOCHS
save_dir = '/kaggle/working/'

# Hyperparameter notes for logging
hyperparam_note = (
    f"lr={config.LR}, bs={config.BATCH_SIZE}, "
    f"image_size={config.IMAGE_SIZE}, "
    "focal_loss: fine tuned subset = 30%"
)

# Check GPU availability
print("Is CUDA available?", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Current device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

# Start training
model, history = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss_fn_dict,
    NUM_EPOCHS,
    save_dir,
    resume=True,
    upload_to_kaggle=True,
    dataset_owner=dataset_owner,
    dataset_slug=dataset_slug,
    hyperparam_note=hyperparam_note, 
    custom_thresholds=custom_thresholds
)

## Model Loading

In [None]:
def load_model(model_path):
    """Load a DenseNet121 model from a checkpoint file.
    
    Args:
        model_path (str): Path to the model checkpoint file
        
    Returns:
        model: Loaded and evaluated model
    """
    model = DenseNet121model()
    state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
    
    # Handle different state dict formats
    if 'model_state_dict' in state_dict:
        state_dict = state_dict['model_state_dict']
    
    # Remove batch tracking keys if present
    state_dict = {k: v for k, v in state_dict.items() 
                 if not k.endswith('num_batches_tracked')}
    
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model

In [None]:
checkpoint_path = '/kaggle/working/checkpoint.pth'
model = load_model(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
history = checkpoint['history']

## Evaluation Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _tta_forward_batch(model, inputs, tta_fns):
    """Run a batch through the model with multiple deterministic TTA transforms
    and return the averaged logits dict.
    
    Args:
        model: The model to evaluate
        inputs: Input tensor
        tta_fns: List of test-time augmentation functions
        
    Returns:
        Dictionary of averaged logits
    """
    if not tta_fns:
        return model(inputs)

    logits_per_tta = []
    for fn in tta_fns:
        aug_inp = fn(inputs)
        logits_per_tta.append(model(aug_inp))

    # Average per-head
    avg_logits = {}
    for k in logits_per_tta[0].keys():
        avg_logits[k] = torch.stack([d[k] for d in logits_per_tta], dim=0).mean(dim=0)
    return avg_logits

In [None]:
def run_test_evaluation(
    model,
    test_ds,
    loss_fn_dict,
    thresholds=None,
    batch_size=32,
    num_workers=2,
    tta_fns=None
):
    """Run evaluation on test dataset with optional TTA.
    
    Args:
        model: Model to evaluate
        test_ds: Test dataset
        loss_fn_dict: Dictionary of loss functions per organ
        thresholds: Classification thresholds (defaults to pre-defined values)
        batch_size: Evaluation batch size
        num_workers: DataLoader workers
        tta_fns: List of TTA functions
        
    Returns:
        Tuple of (test_metrics, preds_targets)
    """
    model.to(device)
    model.eval()
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    print("Running test evaluation...")

    # Set default thresholds if not provided
    if thresholds is None:
        thresholds = {
            "bowel": {"thresholds": [0.01, 0.45], "best_f1": [0.9823, 0.7059]},
            "extra": {"thresholds": [0.02, 0.84], "best_f1": [0.9493, 0.5607]},
            "liver": {"thresholds": [0.13, 0.60, 0.75], "best_f1": [0.9054, 0.5077, 0.6400]},
            "kidney": {"thresholds": [0.06, 0.79, 0.67], "best_f1": [0.9503, 0.6154, 0.6667]},
            "spleen": {"thresholds": [0.14, 0.57, 0.56], "best_f1": [0.8905, 0.5660, 0.4938]}
        }

    all_preds = defaultdict(list)
    all_labels = defaultdict(list)
    total_loss = 0.0
    total_batches = 0

    organ_indices = {
        "bowel": 0,
        "extra": 1,
        "kidney": 2,
        "liver": 3,
        "spleen": 4,
    }

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating", leave=False):
            inputs = batch["image"].to(device)
            labels = batch["label"].to(device)

            # TTA-aware forward pass
            outputs = _tta_forward_batch(model, inputs, tta_fns)

            # Calculate batch loss
            batch_loss = 0.0
            for organ, loss_fn in loss_fn_dict.items():
                idx = organ_indices[organ]
                pred = outputs[organ]
                target = labels[:, idx]

                if isinstance(loss_fn, nn.BCEWithLogitsLoss):
                    target = target.unsqueeze(1).float()
                elif isinstance(loss_fn, nn.CrossEntropyLoss) or hasattr(loss_fn, "smoothing"):
                    target = target.long()
                else:
                    target = target.float()

                batch_loss += loss_fn(pred, target)

            total_loss += batch_loss.item()
            total_batches += 1

            # Process predictions
            for organ in outputs.keys():
                idx = organ_indices[organ]
                pred_tensor = outputs[organ]
                target_tensor = labels[:, idx]

                if pred_tensor.shape[-1] == 1:
                    # Binary classification case
                    probs = torch.sigmoid(pred_tensor).squeeze(-1).cpu()
                    probs = torch.stack([1 - probs, probs], dim=1)
                    organ_thresholds = thresholds[organ]["thresholds"]
                    
                    preds_arr = []
                    for sample_probs in probs:
                        predicted_class = 0  # default to class 0
                        for class_idx in [1, 0]:  # check class 1 first
                            if sample_probs[class_idx] >= organ_thresholds[class_idx]:
                                predicted_class = class_idx
                                break
                        preds_arr.append(predicted_class)
                    
                    labels_arr = target_tensor.cpu().tolist()
                else:
                    # Multi-class classification case
                    probs = torch.softmax(pred_tensor, dim=1).cpu()
                    organ_thresholds = thresholds[organ]["thresholds"]
                    
                    preds_arr = []
                    for sample_probs in probs:
                        valid_classes = []
                        for class_idx in range(probs.shape[1]):
                            if sample_probs[class_idx] >= organ_thresholds[class_idx]:
                                valid_classes.append(class_idx)
                        
                        if valid_classes:
                            valid_probs = sample_probs[torch.tensor(valid_classes)]
                            pred_class = valid_classes[torch.argmax(valid_probs).item()]
                        else:
                            pred_class = 0
                        preds_arr.append(pred_class)
                    
                    labels_arr = target_tensor.cpu().tolist()

                all_preds[organ].extend(preds_arr)
                all_labels[organ].extend(labels_arr)

    # Calculate final metrics
    avg_loss = total_loss / total_batches if total_batches > 0 else 0.0
    test_metrics = {"loss": avg_loss}
    preds_targets = {
        organ: {"preds": all_preds[organ], "labels": all_labels[organ]}
        for organ in all_preds
    }

    print(f"\nTest Loss: {avg_loss:.4f}")
    return test_metrics, preds_targets


## Utility Functions

In [None]:
def save_metrics_and_preds(test_metrics, preds_targets, save_dir="/kaggle/working"):
    """Save evaluation metrics and predictions to JSON files.
    
    Args:
        test_metrics: Evaluation metrics dictionary
        preds_targets: Predictions and targets dictionary
        save_dir: Directory to save files
        
    Returns:
        Tuple of (metrics_path, preds_path)
    """
    metrics_path = os.path.join(save_dir, "test_metrics.json")
    preds_path = os.path.join(save_dir, "preds_targets.json")

    with open(metrics_path, "w") as f:
        json.dump(test_metrics, f, indent=2)

    with open(preds_path, "w") as f:
        json.dump(preds_targets, f, indent=2)

    print(f"Saved metrics to {metrics_path}")
    print(f"Saved predictions and targets to {preds_path}")

    return metrics_path, preds_path

def upload_to_kaggle_model(
    dataset_owner, 
    dataset_slug, 
    model_path, 
    checkpoint_path=None, 
    metrics_path=None, 
    preds_path=None, 
    version_note=""
):
    """Upload model and results to Kaggle as a new dataset version.
    
    Args:
        dataset_owner: Owner of the target dataset
        dataset_slug: Dataset name
        model_path: Path to model file
        checkpoint_path: Path to checkpoint file (optional)
        metrics_path: Path to metrics file (optional)
        preds_path: Path to predictions file (optional)
        version_note: Notes for this version
    """
    api = KaggleApi()
    api.authenticate()

    # Create zip archive of all files
    zip_path = "/kaggle/working/model_upload.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        zipf.write(model_path, arcname=os.path.basename(model_path))
        if checkpoint_path:
            zipf.write(checkpoint_path, arcname=os.path.basename(checkpoint_path))
        if metrics_path:
            zipf.write(metrics_path, arcname=os.path.basename(metrics_path))
        if preds_path:
            zipf.write(preds_path, arcname=os.path.basename(preds_path))

    print(f"Zipped model(s) and metrics to: {zip_path}")

    # Create metadata file
    metadata = {
        "title": f"{dataset_slug} model",
        "id": f"{dataset_owner}/{dataset_slug}",
        "licenses": [{"name": "CC0-1.0"}]
    }

    metadata_path = "/kaggle/working/dataset-metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Created metadata file at {metadata_path}")

    # Upload to Kaggle
    api.dataset_create_version(
        folder="/kaggle/working",
        version_notes=version_note,
        delete_old_versions=False,
        convert_to_csv=False
    )
    print(f"Uploaded to Kaggle Dataset: {dataset_owner}/{dataset_slug}")

## TTA Transformations

In [None]:
def id_fn(x):       # identity
    return x

def flip0(x):       # flip depth
    return torch.flip(x, dims=[2])

def flip1(x):       # flip height
    return torch.flip(x, dims=[3])

def flip2(x):       # flip width
    return torch.flip(x, dims=[4])

def rot90_hw(x):    # rotate 90 over H-W (axes 3,4)
    return torch.rot90(x, k=1, dims=(3, 4))

# Default TTA functions
DEFAULT_TTA_FNS = [id_fn, flip0, flip1, flip2]


## Evaluation Pipeline

In [None]:
def run_test_and_upload(
    model,
    test_ds,
    loss_fn_dict,
    thresholds,
    dataset_owner,
    dataset_slug,
    model_path,
    checkpoint_path=None,
    version_note=""
):
    """Run full evaluation pipeline and upload results to Kaggle.
    
    Args:
        model: Model to evaluate
        test_ds: Test dataset
        loss_fn_dict: Dictionary of loss functions
        thresholds: Classification thresholds
        dataset_owner: Kaggle dataset owner
        dataset_slug: Kaggle dataset name
        model_path: Path to model file
        checkpoint_path: Path to checkpoint file (optional)
        version_note: Version notes for Kaggle
        
    Returns:
        Test metrics dictionary
    """
    test_metrics, preds_targets = run_test_evaluation(
        model=model,
        test_ds=test_ds,
        loss_fn_dict=loss_fn_dict,
        thresholds=thresholds
    )

    metrics_path, preds_path = save_metrics_and_preds(test_metrics, preds_targets)

    upload_to_kaggle_model(
        dataset_owner=dataset_owner,
        dataset_slug=dataset_slug,
        model_path=model_path,
        checkpoint_path=checkpoint_path,
        metrics_path=metrics_path,
        preds_path=preds_path,
        version_note=version_note
    )

    return test_metrics

In [None]:
test_metrics, preds_targets = run_test_evaluation(
    model=model,
    test_ds=test_ds,
    loss_fn_dict=loss_fn_dict,
    thresholds=custom_thresholds,
    batch_size=16,
    num_workers=4,
    tta_fns=DEFAULT_TTA_FNS
)

# Save and upload results
save_metrics_and_preds(test_metrics, preds_targets, save_dir="/kaggle/working")
upload_to_kaggle_model(
    dataset_owner, 
    dataset_slug,
    "/kaggle/working/model_best.pth",
    checkpoint_path="/kaggle/working/checkpoint.pth",
    metrics_path='/kaggle/working/test_metrics.json',
    preds_path='/kaggle/working/preds_targets.json',
    version_note=""
)

## Threshold Optimization

In [None]:
def find_best_thresholds_per_class(probs, true_labels, default_class=0, num_thresholds=50):
    """Find the best probability threshold for multi-class classification.
    
    Args:
        probs: numpy array (N_samples, num_classes) - softmax probabilities
        true_labels: numpy array (N_samples,) - true class labels
        default_class: Class to assign if no probability exceeds threshold
        num_thresholds: Number of thresholds to evaluate
        
    Returns:
        Tuple of (best_threshold, best_f1_score)
    """
    thresholds = np.linspace(0, 1, num_thresholds)
    best_f1 = -1.0
    best_threshold = 0.5  # Default fallback threshold

    for thresh in thresholds:
        preds = []
        for sample_probs in probs:
            # Identify classes above threshold
            above_thresh = sample_probs >= thresh
            if above_thresh.any():
                # Select class with max probability among those above threshold
                pred_class = np.argmax(sample_probs * above_thresh)
            else:
                # Fallback to default class
                pred_class = default_class
            preds.append(pred_class)

        f1 = f1_score(true_labels, preds, average='macro')
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = thresh

    return best_threshold, best_f1

## Validation Data Collection

In [None]:
# Organ index mapping
organ_indices = {
    "bowel": 0,
    "extra": 1,
    "kidney": 2,
    "liver": 3,
    "spleen": 4,
}

# Initialize collections for probabilities and labels
all_probs = defaultdict(list)  # per organ, list of softmax probs arrays
all_labels = defaultdict(list)  # per organ, list of true labels

model.eval()

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Validation Batches"):
        inputs = batch["image"].to(device)
        labels = batch["label"].to(device)

        outputs = model(inputs)

        for organ in outputs.keys():
            preds = outputs[organ]  # logits tensor

            if preds.shape[-1] == 1:
                # Binary case: sigmoid probability
                probs = torch.sigmoid(preds).squeeze(-1).cpu().numpy()
                # Convert to 2-class format [1-p, p]
                probs = np.stack([1 - probs, probs], axis=1)
            else:
                # Multi-class case: softmax probabilities
                probs = torch.softmax(preds, dim=1).cpu().numpy()

            all_probs[organ].append(probs)
            all_labels[organ].append(labels[:, organ_indices[organ]].cpu().numpy())

# Concatenate all batches per organ
for organ in all_probs:
    all_probs[organ] = np.concatenate(all_probs[organ], axis=0)
    all_labels[organ] = np.concatenate(all_labels[organ], axis=0)

In [None]:
best_thresholds = {}
best_f1s = {}

for organ in ["kidney", "liver", "spleen"]:
    probs = all_probs[organ]  # softmax outputs
    labels = all_labels[organ]  # true labels
    thresh, f1 = find_best_thresholds_per_class(probs, labels, default_class=0)
    best_thresholds[organ] = thresh
    best_f1s[organ] = f1
    print(f"Best threshold for {organ}: {thresh:.2f} with F1: {f1:.4f}")

## Alternative Threshold Optimization

In [None]:
best_thresholds = {}
metric = "f1"  # Can be "f1" or "accuracy"

for organ in all_probs:
    probs = all_probs[organ]  # shape (N, num_classes)
    true = all_labels[organ]  # shape (N,)
    num_classes = probs.shape[1]

    best_thresh = 0.0
    best_score = 0.0

    # Evaluate thresholds between 0.1 to 0.95
    thresholds = np.linspace(0.1, 0.95, 18)

    for thresh in thresholds:
        # Initialize predictions with "no prediction" class (-1)
        pred = np.full_like(true, fill_value=-1)

        # For samples where max probability exceeds threshold, take argmax
        confident = (probs.max(axis=1) >= thresh)
        pred[confident] = probs[confident].argmax(axis=1)

        # Skip evaluation if no predictions meet threshold
        valid_idx = pred != -1
        if np.sum(valid_idx) == 0:
            continue

        if metric == "f1":
            score = f1_score(true[valid_idx], pred[valid_idx], average='macro')
        elif metric == "accuracy":
            score = accuracy_score(true[valid_idx], pred[valid_idx])

        if score > best_score:
            best_score = score
            best_thresh = thresh

    best_thresholds[organ] = best_thresh
    print(f"[{organ.upper()}] Best {metric}: {best_score:.4f} at threshold: {best_thresh:.2f}")

## Prediction Visualization

In [None]:
def plot_random_predictions_all_organs(model, dataset, device, num_samples=5, slice_axis=0, threshold=0.5):
    """Plot random samples with predicted and true labels for all organs.
    
    Args:
        model: Trained model in eval mode
        dataset: Indexable dataset
        device: Torch device
        num_samples: Number of samples to plot
        slice_axis: Axis for 3D volume slicing (0,1,2)
        threshold: Classification threshold for binary predictions
    """
    organ_indices = {
        "bowel": 0,
        "extra": 1,
        "kidney": 2,
        "liver": 3,
        "spleen": 4,
    }
    
    # Label name mappings
    binary_label_names = {0: "healthy", 1: "injured"}
    multiclass_label_names = {
        "kidney": {0: "healthy", 1: "low injury", 2: "high injury"},
        "liver": {0: "healthy", 1: "low injury", 2: "high injury"},
        "spleen": {0: "healthy", 1: "low injury", 2: "high injury"},
        "bowel": binary_label_names,
        "extra": binary_label_names,
    }

    model.to(device)
    model.eval()

    # Select random samples
    indices = random.sample(range(len(dataset)), num_samples)

    # Setup plot grid
    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    if num_samples == 1:
        axes = np.expand_dims(axes, axis=1)

    with torch.no_grad():
        for i, idx in enumerate(indices):
            sample = dataset[idx]
            img = sample["image"]
            labels = sample["label"]

            # Prepare model input
            input_img = img.unsqueeze(0).to(device).float()
            outputs = model(input_img)

            # Extract middle slice for visualization
            slices = img.shape[-3:] if img.ndim >= 3 else img.shape
            mid_slice_idx = slices[slice_axis] // 2

            if img.ndim == 4:
                if slice_axis == 0:
                    slice_img = img[:, mid_slice_idx, :, :].cpu().numpy()
                elif slice_axis == 1:
                    slice_img = img[:, :, mid_slice_idx, :].cpu().numpy()
                else:
                    slice_img = img[:, :, :, mid_slice_idx].cpu().numpy()
                slice_img = np.moveaxis(slice_img, 0, -1)
            elif img.ndim == 3:
                if slice_axis == 0:
                    slice_img = img[mid_slice_idx, :, :].cpu().numpy()
                elif slice_axis == 1:
                    slice_img = img[:, mid_slice_idx, :].cpu().numpy()
                else:
                    slice_img = img[:, :, mid_slice_idx].cpu().numpy()
            else:
                slice_img = img.cpu().numpy()

            if slice_img.ndim == 3 and slice_img.shape[-1] > 1:
                slice_img = slice_img[..., 0]

            # Plot image
            axes[0, i].imshow(slice_img, cmap="gray")
            axes[0, i].axis("off")
            axes[0, i].set_title(f"Sample {idx}")

            # Generate prediction text
            text_lines = []
            for organ, organ_idx in organ_indices.items():
                pred_tensor = outputs[organ].cpu().squeeze(0)
                label_names = multiclass_label_names.get(organ, binary_label_names)
                true_label = labels[organ_idx].item()

                # Determine prediction based on output type
                if pred_tensor.ndim == 1:
                    # Multi-class logits
                    pred_label = torch.argmax(torch.softmax(pred_tensor, dim=0)).item()
                elif pred_tensor.shape[-1] == 1:
                    # Binary sigmoid
                    prob = torch.sigmoid(pred_tensor).item()
                    pred_label = int(prob >= threshold)
                elif pred_tensor.shape[-1] == 2:
                    # Binary softmax
                    probs = torch.softmax(pred_tensor, dim=-1)
                    prob = probs[1].item()
                    pred_label = int(prob >= threshold)
                else:
                    # Multi-class softmax
                    pred_label = torch.argmax(torch.softmax(pred_tensor, dim=-1)).item()

                true_label_str = label_names.get(true_label, str(true_label))
                pred_label_str = label_names.get(pred_label, str(pred_label))
                text_lines.append(f"{organ}:\n  True: {true_label_str}\n  Pred: {pred_label_str}")

            # Add prediction text
            axes[1, i].axis("off")
            axes[1, i].text(0, 0.5, "\n\n".join(text_lines), fontsize=10, va="center", ha="left")

    plt.tight_layout()
    plt.savefig('/kaggle/working/predictions')
    plt.show()


In [None]:
plot_random_predictions_all_organs(
    model=model,
    dataset=test_ds,
    device=device,
    num_samples=5,
    slice_axis=0,
    threshold=0.5
)