# Load Library

In [1]:
# imports for neural network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset

# imports for vision tasks
import torchvision
import torchvision.models as models
import torchvision.transforms.v2 as transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import pydicom

# imports for preparing dataset
import os
import shutil
import zipfile
import pandas as pd
from skimage import io
from PIL import Image
import numpy as np
import timm

# imports for visualizations
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from tqdm import tqdm
from pathlib import Path
import json
from typing import Dict, Tuple, List, Optional, Union, Any
import platform
from enum import Enum
import timm
import time

from sklearn.metrics import (
    roc_auc_score, 
    precision_recall_curve, 
    auc,
    accuracy_score,
    confusion_matrix
)

# Load Data

## Define Data Structure Class

In [2]:
class SpinalConditionDataset(Dataset):
    """Custom Dataset class with optimized device handling"""
    def __init__(
        self, 
        images: torch.Tensor, 
        labels: torch.Tensor, 
        metadata: pd.DataFrame,
        device: Optional[torch.device] = None
    ):
        if device is None:
            self.device, self.device_type = get_best_available_device()
        else:
            self.device = device
            self.device_type = device.type
        
        # Store as CPU tensors initially
        self.images = images
        self.labels = labels
        self.metadata = metadata
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Move individual items to device when accessed
        image = self.images[idx].to(self.device)
        label = self.labels[idx].to(self.device)
        return image, label

## Define Premade Function

### Getting Best Devices

In [3]:
def get_best_available_device() -> Tuple[torch.device, str]:
    """
    Detect and return the best available device for tensor operations.
    
    Returns:
        Tuple[torch.device, str]: Device object and device type string
    """
    device_type = "cpu"
    
    # Check for CUDA
    if torch.cuda.is_available():
        device_type = "cuda"
        
    # Check for Apple M1/M2 MPS (Metal Performance Shaders)
    elif platform.processor().startswith('arm') and platform.system() == 'Darwin' and \
         hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device_type = "mps"
        
    # Check for TPU (if torch_xla is available)
    try:
        import torch_xla.core.xla_model as xm
        device_type = "tpu"
    except ImportError:
        pass
        
    # Create device object based on type
    if device_type == "tpu":
        try:
            device = xm.xla_device()
        except NameError:
            device = torch.device("cpu")
            device_type = "cpu"
    else:
        device = torch.device(device_type)
        
    return device, device_type

### Data Loading Function

In [4]:
def load_processed_dataset(base_load_path, verbose=False):
    """
    Load the processed dataset from disk.
    
    Args:
        base_load_path (str): Path where the dataset is saved
    
    Returns:
        dict: Loaded condition data
        dict: Configuration data
    """
    base_path = Path(base_load_path)
    
    if not base_path.exists():
        raise ValueError(f"Dataset path {base_path} does not exist")
    
    # Load configuration
    with open(base_path / 'config.json', 'r') as f:
        config = json.load(f)
    
    # Initialize condition data dictionary
    loaded_condition_data = {}
    
    # Load data for each condition
    for condition in config['core_conditions']:
        condition_path = base_path / condition
        
        if condition_path.exists():
            loaded_condition_data[condition] = {
                'images': torch.load(condition_path / 'images.pt', weights_only=False),
                'labels': torch.load(condition_path / 'labels.pt', weights_only=False),
                'metadata': pd.read_pickle(condition_path / 'metadata.pkl')
            }
    
    if verbose is True :
        print("\nDataset successfully loaded")
        print("\nDataset Summary:")
        for condition, data in loaded_condition_data.items():
            print(f"\n{condition}:")
            print(f"Total samples: {len(data['images'])}")
            label_dist = torch.bincount(data['labels'])
            for severity, idx in config['severity_mapping'].items():
                if idx < len(label_dist):
                    print(f"  {severity}: {label_dist[idx].item()}")
    
    return loaded_condition_data, config

### Data Splitting Function

In [5]:
def create_train_val_split(
    condition_data: Dict,
    device: Optional[torch.device] = None,
    val_ratio: float = 0.2,
    seed: int = 42
) -> Dict[str, Dict[str, Dict]]:
    """
    Split the dataset into training and validation sets with automatic device support.
    
    Args:
        condition_data (Dict): Dictionary containing data for each condition
        device (Optional[torch.device]): Device to store the tensors on
        val_ratio (float): Ratio of validation set size to total dataset size
        seed (int): Random seed for reproducibility
    
    Returns:
        Dict: Dictionary containing train and val splits for each condition
    """
    torch.manual_seed(seed)
    
    split_data = {
        'train': {},
        'val': {}
    }
    
    for condition, data in condition_data.items():
        dataset_size = len(data['images'])
        val_size = int(dataset_size * val_ratio)
        train_size = dataset_size - val_size
        
        # Create full dataset with device specification
        full_dataset = SpinalConditionDataset(
            data['images'],
            data['labels'],
            data['metadata'],
            device=device
        )
        
        # Split dataset
        train_dataset, val_dataset = random_split(
            full_dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(seed)
        )
        
        split_data['train'][condition] = train_dataset
        split_data['val'][condition] = val_dataset
        
    return split_data

### Data Loader Creator

In [6]:
def create_dataloaders(
    split_data: Dict[str, Dict[str, Dataset]],
    device_type: str,
    batch_size: int = 32,
    num_workers: int = 4,
    shuffle_train: bool = True
) -> Dict[str, Dict[str, DataLoader]]:
    """
    Create DataLoader objects optimized for the detected device.
    
    Args:
        split_data (Dict): Dictionary containing train and val splits
        device_type (str): Type of device being used
        batch_size (int): Batch size for DataLoader
        num_workers (int): Number of worker processes
        shuffle_train (bool): Whether to shuffle training data
    
    Returns:
        Dict: Dictionary containing DataLoader objects for each split and condition
    """
    dataloaders = {
        'train': {},
        'val': {}
    }
    
    # Optimize DataLoader settings based on device
    if device_type == "cuda":
        pin_memory = False  # Data is already on GPU
    elif device_type == "tpu":
        pin_memory = False
        num_workers = 0  # TPU often works better with synchronous loading
    else:  # CPU or MPS
        pin_memory = True
    
    for split in ['train', 'val']:
        for condition, dataset in split_data[split].items():
            dataloaders[split][condition] = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=(shuffle_train and split == 'train'),
                num_workers=num_workers,
                pin_memory=pin_memory
            )
    
    return dataloaders

### Getting Batch Size

In [7]:
def get_device_specific_batch_size(device_type: str, base_batch_size: int) -> int:
    """
    Adjust batch size based on device type and available memory.
    
    Args:
        device_type (str): Type of device being used
        base_batch_size (int): Requested batch size
        
    Returns:
        int: Adjusted batch size
    """
    if device_type == "cuda":
        # Get available GPU memory and adjust batch size if needed
        gpu_memory = torch.cuda.get_device_properties(0).total_memory
        if gpu_memory < 8 * (1024**3):  # Less than 8GB
            return min(base_batch_size, 16)
    elif device_type == "tpu":
        # TPUs often work better with larger batch sizes
        return max(base_batch_size, 128)
    
    return base_batch_size

### Usage Function

In [8]:
def load_and_prepare_data(
    base_load_path: str,
    batch_size: int = 32,
    val_ratio: float = 0.2,
    num_workers: int = 4,
    seed: int = 42,
    verbose: bool = False
) -> Tuple[Dict, Dict, Dict, torch.device]:
    """
    Load and prepare data loaders with automatic device detection and optimization.
    
    Args:
        base_load_path (str): Path to dataset
        batch_size (int): Base batch size for DataLoader
        val_ratio (float): Validation set ratio
        num_workers (int): Number of worker processes
        seed (int): Random seed
        verbose (bool): Whether to print dataset information
    
    Returns:
        Tuple containing:
        - Dictionary of dataloaders
        - Dictionary of split datasets
        - Configuration dictionary
        - Device being used
    """
    # Detect best available device
    device, device_type = get_best_available_device()
    
    if verbose:
        print(f"\nUsing device: {device} ({device_type})")
    
    # Adjust batch size for device
    adjusted_batch_size = get_device_specific_batch_size(device_type, batch_size)
    if verbose and adjusted_batch_size != batch_size:
        print(f"Adjusted batch size from {batch_size} to {adjusted_batch_size} for {device_type}")
    
    # Load the dataset
    condition_data, config = load_processed_dataset(base_load_path, verbose=verbose)
    
    # Create train/val split
    split_data = create_train_val_split(
        condition_data,
        device=device,
        val_ratio=val_ratio,
        seed=seed
    )
    
    # Create dataloaders
    dataloaders = create_dataloaders(
        split_data,
        device_type=device_type,
        batch_size=adjusted_batch_size,
        num_workers=num_workers
    )
    
    if verbose:
        print("\nDataset Summary:")
        for split in ['train', 'val']:
            print(f"\n{split.capitalize()} set sizes:")
            for condition in config['core_conditions']:
                if condition in split_data[split]:
                    print(f"{condition}: {len(split_data[split][condition])}")
        
        # Print memory usage information
        if device_type == "cuda":
            print("\nGPU Memory Usage:")
            print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
            print(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        elif device_type == "tpu":
            try:
                import torch_xla.debug.metrics as met
                print("\nTPU Memory Usage:")
                print(met.metrics_report())
            except ImportError:
                pass
    
    return dataloaders, split_data, config, device

## Load & Create Dataloader

In [9]:
dataloaders, split_data, config, device = load_and_prepare_data(
    base_load_path="/kaggle/input/degenerative-spine-image-classificaton/processed_spine_dataset",
    batch_size=32,
    val_ratio=0.2,
    verbose=True,
    num_workers=0
)


Using device: cpu (cpu)

Dataset successfully loaded

Dataset Summary:

spinal_canal_stenosis:
Total samples: 9753
  Normal/Mild: 8552
  Moderate: 732
  Severe: 469

left_neural_foraminal_narrowing:
Total samples: 9860
  Normal/Mild: 7671
  Moderate: 1792
  Severe: 397

right_neural_foraminal_narrowing:
Total samples: 9829
  Normal/Mild: 7684
  Moderate: 1767
  Severe: 378

left_subarticular_stenosis:
Total samples: 9603
  Normal/Mild: 6857
  Moderate: 1834
  Severe: 912

right_subarticular_stenosis:
Total samples: 9612
  Normal/Mild: 6862
  Moderate: 1825
  Severe: 925

Dataset Summary:

Train set sizes:
spinal_canal_stenosis: 7803
left_neural_foraminal_narrowing: 7888
right_neural_foraminal_narrowing: 7864
left_subarticular_stenosis: 7683
right_subarticular_stenosis: 7690

Val set sizes:
spinal_canal_stenosis: 1950
left_neural_foraminal_narrowing: 1972
right_neural_foraminal_narrowing: 1965
left_subarticular_stenosis: 1920
right_subarticular_stenosis: 1922


In [10]:
from torch.utils.data import Subset

def get_validation_subset(val_data: Dataset, subset_size: int, seed: int = 42) -> Subset:
    """
    Extract a subset from the validation dataset.

    Args:
        val_data (Dataset): The validation dataset.
        subset_size (int): The number of samples to include in the subset.
        seed (int): Random seed for reproducibility.

    Returns:
        Subset: A subset of the validation dataset.
    """
    torch.manual_seed(seed)
    indices = torch.randperm(len(val_data))[:subset_size]  # Randomly sample indices
    return Subset(val_data, indices)


In [11]:
# Define conditions
conditions = [
    "spinal_canal_stenosis",
    "left_neural_foraminal_narrowing",
    "right_neural_foraminal_narrowing",
    "left_subarticular_stenosis",
    "right_subarticular_stenosis"
]

subset_size = 100
val_subset_loaders = {}

for condition in conditions:
    val_dataset = split_data['val'][condition]
    val_subset = get_validation_subset(val_dataset, subset_size)
    val_subset_loaders[condition] = DataLoader(
        val_subset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
        pin_memory=False
    )

# Spinal Model

In [12]:
class ModelArchitecture(str, Enum):
    """Supported model architectures"""
    INCEPTION_V4 = "inception_v4"
    EFFICIENTNET = "efficientnet_b0"
    EFFICIENTNET_V2 = "efficientnetv2_s"
    VGG16 = "vgg16"
    VIT = "vit"  # Added Vision Transformer (ViT)


class SpinalModel(nn.Module):
    """
    Neural network model for spinal condition classification with different backbone options.
    """
    def __init__(
        self,
        architecture: ModelArchitecture,
        num_classes: int,
        pretrained: bool = True,
        dropout_rate: float = 0.2,
    ):
        """
        Initialize the model with specified architecture.
        
        Args:
            architecture (ModelArchitecture): Choice of model architecture
            num_classes (int): Number of output classes
            pretrained (bool): Whether to use pretrained weights
            dropout_rate (float): Dropout rate for the final layer
        """
        super().__init__()
        
        self.architecture = architecture
        self.num_classes = num_classes
        
        # Get the best available device
        self.device, self.device_type = get_best_available_device()
        
        # Create the backbone model
        self.backbone = self._create_backbone(pretrained)
        
        # Get the number of features from the backbone
        if architecture == ModelArchitecture.INCEPTION_V4:
            num_features = 1536
        elif architecture == ModelArchitecture.EFFICIENTNET:
            num_features = 1280
        elif architecture == ModelArchitecture.EFFICIENTNET_V2:
            num_features = 1280
        elif architecture == ModelArchitecture.VGG16:
            num_features = 4096  # VGG16 outputs 4096 features
        elif architecture == ModelArchitecture.VIT:
            num_features = 768  # ViT base model outputs 768 features
        
        # Create classifier head
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(512, num_classes)
        )
        
        # Move model to the best available device
        self.to(self.device)

    def _create_backbone(self, pretrained: bool) -> nn.Module:
        """Create the backbone model based on the selected architecture."""
        pretrained_str = 'imagenet' if pretrained else None
        
        if self.architecture == ModelArchitecture.INCEPTION_V4:
            model = timm.create_model('inception_v4.tf_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.EFFICIENTNET:
            model = timm.create_model('efficientnet_b0.ra4_e3600_r224_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.EFFICIENTNET_V2:
            model = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.VGG16:
            model = timm.create_model('vgg16.tv_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.VIT:
            model = timm.create_model('vit_base_patch16_224.mae', pretrained=pretrained, num_classes=0)
        else:
            raise ValueError(f"Unsupported architecture: {self.architecture}")
            
        return model
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the model."""
        if x.device != self.device:
            x = x.to(self.device)

        if self.architecture == ModelArchitecture.VIT:
            x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
    
        # Get features from the backbone
        features = self.backbone(x)
        
        # Print the shape to debug
        # print("Shape after backbone:", features.shape)
        
        # Pass features through the classifier
        output = self.classifier(features)
        
        return output

        
    def get_preprocessing_parameters(self) -> Dict[str, Any]:
        """Get the preprocessing parameters for the selected architecture."""
        imagenet_stats = {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }
        
        image_sizes = {
            ModelArchitecture.INCEPTION_V4: 299,
            ModelArchitecture.EFFICIENTNET: 224,
            ModelArchitecture.EFFICIENTNET_V2: 256,
            ModelArchitecture.VGG16: 224,  # VGG16 uses 224x224 input size
            ModelArchitecture.VIT: 224     # ViT uses 224x224 input size
        }
        
        return {
            'image_size': image_sizes[self.architecture],
            'mean': imagenet_stats['mean'],
            'std': imagenet_stats['std'],
            'device': self.device,
            'device_type': self.device_type
        }


def create_model(
    architecture: str,
    num_classes: int,
    pretrained: bool = True,
    dropout_rate: float = 0.2,
    verbose: bool = False
) -> Tuple[SpinalModel, Dict[str, Any]]:
    """
    Create a model instance with the specified architecture.
    
    Args:
        architecture (str): Name of the architecture to use
        num_classes (int): Number of output classes
        pretrained (bool): Whether to use pretrained weights
        dropout_rate (float): Dropout rate for the final layer
        verbose (bool): Whether to print device and model information
        
    Returns:
        Tuple[SpinalModel, Dict[str, Any]]: Model instance and preprocessing parameters
    """
    # Validate and convert architecture string to enum
    try:
        arch = ModelArchitecture(architecture.lower())
    except ValueError:
        raise ValueError(
            f"Unsupported architecture: {architecture}. "
            f"Supported architectures: {[a.value for a in ModelArchitecture]}"
        )
    
    # Create model
    model = SpinalModel(
        architecture=arch,
        num_classes=num_classes,
        pretrained=pretrained,
        dropout_rate=dropout_rate
    )
    
    # Get preprocessing parameters
    preprocess_params = model.get_preprocessing_parameters()
    
    if verbose:
        print(f"\nModel Configuration:")
        print(f"Architecture: {arch.value}")
        print(f"Device: {preprocess_params['device']} ({preprocess_params['device_type']})")
        print(f"Input size: {preprocess_params['image_size']}x{preprocess_params['image_size']}")
        print(f"Number of classes: {num_classes}")
        print(f"Pretrained: {pretrained}")
        
        # Print model summary if torchinfo is available
        try:
            from torchsummary import summary
            # Assuming 'model' is your SpinalModel
            input_size = (3, preprocess_params['image_size'], preprocess_params['image_size'])
            summary(model, input_size=input_size)
        except ImportError:
            if verbose:
                print("\nInstall torchinfo for detailed model summary")
    
    return model, preprocess_params
    

In [13]:
model, preprocess_params = create_model(
    architecture="vit",  # or "efficientnet" or "efficientnet_v2"
    num_classes=3,
    pretrained=True,
    verbose=True,
    # unfreeze_layers=15
)

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]


Model Configuration:
Architecture: vit
Device: cpu (cpu)
Input size: 224x224
Number of classes: 3
Pretrained: True
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
          Identity-2             [-1, 196, 768]               0
        PatchEmbed-3             [-1, 196, 768]               0
           Dropout-4             [-1, 197, 768]               0
          Identity-5             [-1, 197, 768]               0
          Identity-6             [-1, 197, 768]               0
         LayerNorm-7             [-1, 197, 768]           1,536
            Linear-8            [-1, 197, 2304]       1,771,776
          Identity-9          [-1, 12, 197, 64]               0
         Identity-10          [-1, 12, 197, 64]               0
           Linear-11             [-1, 197, 768]         590,592
          Dropout-12             [-1, 197, 768]    

In [14]:
checkpoint = torch.load('/kaggle/input/rsna-lumbar-dataset-vit-class-weight/checkpoints/checkpoint_best.pth', map_location=model.device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # If using for inference

  checkpoint = torch.load('/kaggle/input/rsna-lumbar-dataset-vit-class-weight/checkpoints/checkpoint_best.pth', map_location=model.device)


SpinalModel(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none

# Test

In [15]:
def test_model(
   model: nn.Module,
   test_loaders: Dict[str, DataLoader],
   num_classes: int,
   save_dir: str,
   threshold: float = 0.5,
   device: Optional[torch.device] = None
) -> Dict:
   model.eval()
   device = device or model.device
   metrics = {}
   
   Path(save_dir).mkdir(parents=True, exist_ok=True)
   
   for condition, loader in test_loaders.items():
       all_preds, all_labels, all_probs, all_images = [], [], [], []
       
       with torch.no_grad():
           for inputs, labels in loader:
               inputs = inputs.to(device, non_blocking=True)
               labels = labels.to(device, non_blocking=True)
               
               outputs = model(inputs)
               probs = torch.softmax(outputs, dim=1)
               preds = torch.argmax(probs, dim=1)
               
               all_preds.extend(preds.cpu().numpy())
               all_probs.extend(probs.cpu().numpy())
               all_labels.extend(labels.cpu().numpy())
               all_images.extend(inputs.cpu().numpy())
       
       all_preds = np.array(all_preds)
       all_probs = np.array(all_probs)
       all_labels = np.array(all_labels)
       all_images = np.array(all_images)
       
       # Visualization
       num_samples = min(10, len(all_images))
       fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
       
       for i in range(num_samples):
           img = all_images[i].transpose(1, 2, 0)
           img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
           img = np.clip(img, 0, 1)
           
           axes[i, 0].imshow(img)
           axes[i, 0].set_title(f'Image {i}')
           
           axes[i, 1].bar(range(num_classes), np.eye(num_classes)[all_labels[i]])
           axes[i, 1].set_title(f'True: {all_labels[i]}')
           
           axes[i, 2].bar(range(num_classes), all_probs[i])
           axes[i, 2].set_title(f'Pred: {all_preds[i]} ({all_probs[i][all_preds[i]]:.2f})')
           
       plt.tight_layout()
       plt.savefig(os.path.join(save_dir, f'{condition}_results.png'))
       plt.close()
       
       # Metrics
       metrics[condition] = {
           'accuracy': accuracy_score(all_labels, all_preds),
           'confusion_matrix': confusion_matrix(all_labels, all_preds),
           'per_class_accuracy': [],
           'roc_auc_scores': [],
           'pr_auc_scores': []
       }
       
       for i in range(num_classes):
           class_labels = (all_labels == i).astype(int)
           class_probs = all_probs[:, i]
           class_preds = (all_preds == i).astype(int)
           
           metrics[condition]['roc_auc_scores'].append(roc_auc_score(class_labels, class_probs))
           precision, recall, _ = precision_recall_curve(class_labels, class_probs)
           metrics[condition]['pr_auc_scores'].append(auc(recall, precision))
           metrics[condition]['per_class_accuracy'].append(accuracy_score(class_labels, class_preds))
       
       metrics[condition].update({
           'avg_roc_auc': np.mean(metrics[condition]['roc_auc_scores']),
           'avg_pr_auc': np.mean(metrics[condition]['pr_auc_scores']),
           'avg_accuracy': np.mean(metrics[condition]['per_class_accuracy'])
       })
       
       print(f"\nResults for {condition}:")
       print(f"Average ROC AUC: {metrics[condition]['avg_roc_auc']:.4f}")
       print(f"Average PR AUC: {metrics[condition]['avg_pr_auc']:.4f}")
       print(f"Average Accuracy: {metrics[condition]['avg_accuracy']:.4f}")
       print("\nConfusion Matrix:")
       print(metrics[condition]['confusion_matrix'])
   
   return metrics

In [16]:
# Usage
test_metrics = test_model(
   model=model,
   test_loaders=val_subset_loaders,
   num_classes=3,
   save_dir='./checkpoints'
)


Results for spinal_canal_stenosis:
Average ROC AUC: 0.4707
Average PR AUC: 0.3381
Average Accuracy: 0.9133

Confusion Matrix:
[[87  0  0]
 [ 9  0  0]
 [ 4  0  0]]

Results for left_neural_foraminal_narrowing:
Average ROC AUC: 0.5855
Average PR AUC: 0.4026
Average Accuracy: 0.8867

Confusion Matrix:
[[83  0  0]
 [13  0  0]
 [ 4  0  0]]

Results for right_neural_foraminal_narrowing:
Average ROC AUC: 0.5828
Average PR AUC: 0.3744
Average Accuracy: 0.8467

Confusion Matrix:
[[77  0  0]
 [18  0  0]
 [ 5  0  0]]

Results for left_subarticular_stenosis:
Average ROC AUC: 0.5918
Average PR AUC: 0.4141
Average Accuracy: 0.8000

Confusion Matrix:
[[70  0  0]
 [23  0  0]
 [ 7  0  0]]

Results for right_subarticular_stenosis:
Average ROC AUC: 0.7379
Average PR AUC: 0.5116
Average Accuracy: 0.7867

Confusion Matrix:
[[68  0  0]
 [22  0  0]
 [10  0  0]]
