# Load Library

In [1]:
# imports for neural network
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset

# imports for vision tasks
import torchvision
import torchvision.models as models
import torchvision.transforms.v2 as transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import pydicom

# imports for preparing dataset
import os
import shutil
import zipfile
import pandas as pd
from skimage import io
from PIL import Image
import numpy as np
import timm

# imports for visualizations
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from tqdm import tqdm
from pathlib import Path
import json
from typing import Dict, Tuple, List, Optional, Union, Any
import platform
from enum import Enum
import timm
import time

from sklearn.metrics import (
    roc_auc_score, 
    precision_recall_curve, 
    auc,
    accuracy_score
)

# Load Data

## Define Data Structure Class

In [2]:
class SpinalConditionDataset(Dataset):
    """Custom Dataset class with optimized device handling"""
    def __init__(
        self, 
        images: torch.Tensor, 
        labels: torch.Tensor, 
        metadata: pd.DataFrame,
        device: Optional[torch.device] = None
    ):
        if device is None:
            self.device, self.device_type = get_best_available_device()
        else:
            self.device = device
            self.device_type = device.type
        
        # Store as CPU tensors initially
        self.images = images
        self.labels = labels
        self.metadata = metadata
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Move individual items to device when accessed
        image = self.images[idx].to(self.device)
        label = self.labels[idx].to(self.device)
        return image, label

## Define Premade Function

### Getting Best Devices

In [3]:
def get_best_available_device() -> Tuple[torch.device, str]:
    """
    Detect and return the best available device for tensor operations.
    
    Returns:
        Tuple[torch.device, str]: Device object and device type string
    """
    device_type = "cpu"
    
    # Check for CUDA
    if torch.cuda.is_available():
        device_type = "cuda"
        
    # Check for Apple M1/M2 MPS (Metal Performance Shaders)
    elif platform.processor().startswith('arm') and platform.system() == 'Darwin' and \
         hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device_type = "mps"
        
    # Check for TPU (if torch_xla is available)
    try:
        import torch_xla.core.xla_model as xm
        device_type = "tpu"
    except ImportError:
        pass
        
    # Create device object based on type
    if device_type == "tpu":
        try:
            device = xm.xla_device()
        except NameError:
            device = torch.device("cpu")
            device_type = "cpu"
    else:
        device = torch.device(device_type)
        
    return device, device_type

### Data Loading Function

In [4]:
def load_processed_dataset(base_load_path, verbose=False):
    """
    Load the processed dataset from disk.
    
    Args:
        base_load_path (str): Path where the dataset is saved
    
    Returns:
        dict: Loaded condition data
        dict: Configuration data
    """
    base_path = Path(base_load_path)
    
    if not base_path.exists():
        raise ValueError(f"Dataset path {base_path} does not exist")
    
    # Load configuration
    with open(base_path / 'config.json', 'r') as f:
        config = json.load(f)
    
    # Initialize condition data dictionary
    loaded_condition_data = {}
    
    # Load data for each condition
    for condition in config['core_conditions']:
        condition_path = base_path / condition
        
        if condition_path.exists():
            loaded_condition_data[condition] = {
                'images': torch.load(condition_path / 'images.pt', weights_only=False),
                'labels': torch.load(condition_path / 'labels.pt', weights_only=False),
                'metadata': pd.read_pickle(condition_path / 'metadata.pkl')
            }
    
    if verbose is True :
        print("\nDataset successfully loaded")
        print("\nDataset Summary:")
        for condition, data in loaded_condition_data.items():
            print(f"\n{condition}:")
            print(f"Total samples: {len(data['images'])}")
            label_dist = torch.bincount(data['labels'])
            for severity, idx in config['severity_mapping'].items():
                if idx < len(label_dist):
                    print(f"  {severity}: {label_dist[idx].item()}")
    
    return loaded_condition_data, config

### Data Splitting Function

In [5]:
def create_train_val_split(
    condition_data: Dict,
    device: Optional[torch.device] = None,
    val_ratio: float = 0.2,
    seed: int = 42
) -> Dict[str, Dict[str, Dict]]:
    """
    Split the dataset into training and validation sets with automatic device support.
    
    Args:
        condition_data (Dict): Dictionary containing data for each condition
        device (Optional[torch.device]): Device to store the tensors on
        val_ratio (float): Ratio of validation set size to total dataset size
        seed (int): Random seed for reproducibility
    
    Returns:
        Dict: Dictionary containing train and val splits for each condition
    """
    torch.manual_seed(seed)
    
    split_data = {
        'train': {},
        'val': {}
    }
    
    for condition, data in condition_data.items():
        dataset_size = len(data['images'])
        val_size = int(dataset_size * val_ratio)
        train_size = dataset_size - val_size
        
        # Create full dataset with device specification
        full_dataset = SpinalConditionDataset(
            data['images'],
            data['labels'],
            data['metadata'],
            device=device
        )
        
        # Split dataset
        train_dataset, val_dataset = random_split(
            full_dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(seed)
        )
        
        split_data['train'][condition] = train_dataset
        split_data['val'][condition] = val_dataset
        
    return split_data

### Data Loader Creator

In [6]:
def create_dataloaders(
    split_data: Dict[str, Dict[str, Dataset]],
    device_type: str,
    batch_size: int = 32,
    num_workers: int = 4,
    shuffle_train: bool = True
) -> Dict[str, Dict[str, DataLoader]]:
    """
    Create DataLoader objects optimized for the detected device.
    
    Args:
        split_data (Dict): Dictionary containing train and val splits
        device_type (str): Type of device being used
        batch_size (int): Batch size for DataLoader
        num_workers (int): Number of worker processes
        shuffle_train (bool): Whether to shuffle training data
    
    Returns:
        Dict: Dictionary containing DataLoader objects for each split and condition
    """
    dataloaders = {
        'train': {},
        'val': {}
    }
    
    # Optimize DataLoader settings based on device
    if device_type == "cuda":
        pin_memory = False  # Data is already on GPU
    elif device_type == "tpu":
        pin_memory = False
        num_workers = 0  # TPU often works better with synchronous loading
    else:  # CPU or MPS
        pin_memory = True
    
    for split in ['train', 'val']:
        for condition, dataset in split_data[split].items():
            dataloaders[split][condition] = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=(shuffle_train and split == 'train'),
                num_workers=num_workers,
                pin_memory=pin_memory
            )
    
    return dataloaders

### Getting Batch Size

In [7]:
def get_device_specific_batch_size(device_type: str, base_batch_size: int) -> int:
    """
    Adjust batch size based on device type and available memory.
    
    Args:
        device_type (str): Type of device being used
        base_batch_size (int): Requested batch size
        
    Returns:
        int: Adjusted batch size
    """
    if device_type == "cuda":
        # Get available GPU memory and adjust batch size if needed
        gpu_memory = torch.cuda.get_device_properties(0).total_memory
        if gpu_memory < 8 * (1024**3):  # Less than 8GB
            return min(base_batch_size, 16)
    elif device_type == "tpu":
        # TPUs often work better with larger batch sizes
        return max(base_batch_size, 128)
    
    return base_batch_size

### Usage Function

In [8]:
def load_and_prepare_data(
    base_load_path: str,
    batch_size: int = 32,
    val_ratio: float = 0.2,
    num_workers: int = 4,
    seed: int = 42,
    verbose: bool = False
) -> Tuple[Dict, Dict, Dict, torch.device]:
    """
    Load and prepare data loaders with automatic device detection and optimization.
    
    Args:
        base_load_path (str): Path to dataset
        batch_size (int): Base batch size for DataLoader
        val_ratio (float): Validation set ratio
        num_workers (int): Number of worker processes
        seed (int): Random seed
        verbose (bool): Whether to print dataset information
    
    Returns:
        Tuple containing:
        - Dictionary of dataloaders
        - Dictionary of split datasets
        - Configuration dictionary
        - Device being used
    """
    # Detect best available device
    device, device_type = get_best_available_device()
    
    if verbose:
        print(f"\nUsing device: {device} ({device_type})")
    
    # Adjust batch size for device
    adjusted_batch_size = get_device_specific_batch_size(device_type, batch_size)
    if verbose and adjusted_batch_size != batch_size:
        print(f"Adjusted batch size from {batch_size} to {adjusted_batch_size} for {device_type}")
    
    # Load the dataset
    condition_data, config = load_processed_dataset(base_load_path, verbose=verbose)
    
    # Create train/val split
    split_data = create_train_val_split(
        condition_data,
        device=device,
        val_ratio=val_ratio,
        seed=seed
    )
    
    # Create dataloaders
    dataloaders = create_dataloaders(
        split_data,
        device_type=device_type,
        batch_size=adjusted_batch_size,
        num_workers=num_workers
    )
    
    if verbose:
        print("\nDataset Summary:")
        for split in ['train', 'val']:
            print(f"\n{split.capitalize()} set sizes:")
            for condition in config['core_conditions']:
                if condition in split_data[split]:
                    print(f"{condition}: {len(split_data[split][condition])}")
        
        # Print memory usage information
        if device_type == "cuda":
            print("\nGPU Memory Usage:")
            print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
            print(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        elif device_type == "tpu":
            try:
                import torch_xla.debug.metrics as met
                print("\nTPU Memory Usage:")
                print(met.metrics_report())
            except ImportError:
                pass
    
    return dataloaders, split_data, config, device

## Load & Create Dataloader

In [9]:
dataloaders, split_data, config, device = load_and_prepare_data(
    base_load_path="/kaggle/input/degenerative-spine-image-classificaton/processed_spine_dataset",
    batch_size=32,
    val_ratio=0.2,
    verbose=True,
    num_workers=0
)


Using device: cuda (cuda)

Dataset successfully loaded

Dataset Summary:

spinal_canal_stenosis:
Total samples: 9753
  Normal/Mild: 8552
  Moderate: 732
  Severe: 469

left_neural_foraminal_narrowing:
Total samples: 9860
  Normal/Mild: 7671
  Moderate: 1792
  Severe: 397

right_neural_foraminal_narrowing:
Total samples: 9829
  Normal/Mild: 7684
  Moderate: 1767
  Severe: 378

left_subarticular_stenosis:
Total samples: 9603
  Normal/Mild: 6857
  Moderate: 1834
  Severe: 912

right_subarticular_stenosis:
Total samples: 9612
  Normal/Mild: 6862
  Moderate: 1825
  Severe: 925

Dataset Summary:

Train set sizes:
spinal_canal_stenosis: 7803
left_neural_foraminal_narrowing: 7888
right_neural_foraminal_narrowing: 7864
left_subarticular_stenosis: 7683
right_subarticular_stenosis: 7690

Val set sizes:
spinal_canal_stenosis: 1950
left_neural_foraminal_narrowing: 1972
right_neural_foraminal_narrowing: 1965
left_subarticular_stenosis: 1920
right_subarticular_stenosis: 1922

GPU Memory Usage:
Alloc

# Spinal Model

In [10]:
import torch
import torch.nn as nn
import timm
from enum import Enum
from typing import Any, Dict, Tuple
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau


class ModelArchitecture(str, Enum):
    """Supported model architectures"""
    INCEPTION_V4 = "inception_v4"
    EFFICIENTNET = "efficientnet_b0"
    EFFICIENTNET_V2 = "efficientnetv2_s"
    VGG16 = "vgg16"


class SpinalModel(nn.Module):
    """Enhanced neural network model for spinal condition classification"""
    def __init__(
        self,
        architecture: ModelArchitecture,
        num_classes: int,
        pretrained: bool = True,
        dropout_rate: float = 0.5,
        weight_decay: float = 5e-3,
        unfreeze_layers: int = 0,
    ):
        super().__init__()
        
        self.architecture = architecture
        self.num_classes = num_classes
        self.weight_decay = weight_decay
        
        # Get the best available device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create the backbone model
        self.backbone = self._create_backbone(pretrained)
        
        # Apply partial freezing of layers
        self.partial_freeze_backbone(unfreeze_layers)
        
        # Get the number of features from the backbone
        num_features = self._get_num_features()
        
        # Enhanced classifier with residual connections
        self.classifier = self._create_classifier(num_features, dropout_rate)
        
        # Initialize weights
        self._initialize_weights()
        
        # Move model to device
        self.to(self.device)

    def _create_backbone(self, pretrained: bool) -> nn.Module:
        """Create the backbone model"""
        if self.architecture == ModelArchitecture.INCEPTION_V4:
            model = timm.create_model('inception_v4.tf_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.EFFICIENTNET:
            model = timm.create_model('efficientnet_b0.ra4_e3600_r224_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.EFFICIENTNET_V2:
            model = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=pretrained, num_classes=0)
        elif self.architecture == ModelArchitecture.VGG16:
            model = timm.create_model('vgg16.tv_in1k', pretrained=pretrained, num_classes=0)
        else:
            raise ValueError(f"Unsupported architecture: {self.architecture}")
        return model

    def _get_num_features(self) -> int:
        """Get number of features from backbone"""
        if self.architecture == ModelArchitecture.INCEPTION_V4:
            return 1536
        elif self.architecture in [ModelArchitecture.EFFICIENTNET, ModelArchitecture.EFFICIENTNET_V2]:
            return 1280
        elif self.architecture == ModelArchitecture.VGG16:
            return 4096
        raise ValueError(f"Unsupported architecture: {self.architecture}")

    def _create_classifier(self, num_features: int, dropout_rate: float) -> nn.Sequential:
        """Create classifier with residual connections"""
        return nn.Sequential(
            nn.AdaptiveAvgPool2d(1) if self.architecture == ModelArchitecture.VGG16 else nn.Identity(),
            nn.Flatten(),
            nn.BatchNorm1d(num_features),
            nn.Dropout(p=dropout_rate),
            
            # First dense block
            self._create_dense_block(num_features, 512, dropout_rate),
            
            # Second dense block
            self._create_dense_block(512, 256, dropout_rate),
            
            # Final classification
            nn.BatchNorm1d(256),
            nn.Dropout(p=dropout_rate),
            nn.Linear(256, self.num_classes)
        )

    def _create_dense_block(self, in_features: int, out_features: int, dropout_rate: float) -> nn.Sequential:
        """Create a dense block with batch norm and dropout"""
        return nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.ReLU(),
            nn.BatchNorm1d(out_features),
            nn.Dropout(p=dropout_rate)
        )

    def _initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def partial_freeze_backbone(self, unfreeze_layers: int):
        """Partially freeze backbone layers"""
        if unfreeze_layers == 0:
            for param in self.backbone.parameters():
                param.requires_grad = False
            return
            
        parameters = list(self.backbone.named_parameters())
        total_layers = len(parameters)
        
        # Freeze all layers first
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # Unfreeze the specified number of layers from the bottom
        for i in range(max(0, total_layers - unfreeze_layers), total_layers):
            parameters[i][1].requires_grad = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass"""
        x = x.to(self.device)
        features = self.backbone(x)
        return self.classifier(features)

def create_model(
    architecture: str,
    num_classes: int,
    pretrained: bool = True,
    dropout_rate: float = 0.5,
    weight_decay: float = 5e-3,
    unfreeze_layers: int = 0,
    verbose: bool = False
) -> Tuple[SpinalModel, Dict[str, Any]]:
    """Create a model instance"""
    try:
        arch = ModelArchitecture(architecture.lower())
    except ValueError:
        raise ValueError(
            f"Unsupported architecture: {architecture}. "
            f"Supported architectures: {[a.value for a in ModelArchitecture]}"
        )
    
    # Create model
    model = SpinalModel(
        architecture=arch,
        num_classes=num_classes,
        pretrained=pretrained,
        dropout_rate=dropout_rate,
        weight_decay=weight_decay,
        unfreeze_layers=unfreeze_layers
    )
    
    # Get preprocessing parameters
    preprocess_params = {
        'image_size': 299 if arch == ModelArchitecture.INCEPTION_V4 else 224,
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225],
        'device': model.device
    }

    if verbose:
        print(f"\nModel Configuration:")
        print(f"Architecture: {arch.value}")
        # print(f"Device: {preprocess_params['device']} ({preprocess_params['device_type']})")
        print(f"Input size: {preprocess_params['image_size']}x{preprocess_params['image_size']}")
        print(f"Number of classes: {num_classes}")
        print(f"Pretrained: {pretrained}")
        
        # Print model summary if torchinfo is available
        try:
            from torchsummary import summary
            # Assuming 'model' is your SpinalModel
            input_size = (3, preprocess_params['image_size'], preprocess_params['image_size'])
            summary(model, input_size=input_size)
        except ImportError:
            if verbose:
                print("\nInstall torchsummary for detailed model summary")
    
    return model, preprocess_params

In [11]:
model, preprocess_params = create_model(
    architecture="efficientnet_b0",  # or "efficientnet" or "efficientnet_v2"
    num_classes=3,
    pretrained=True,
    verbose=True,
    unfreeze_layers=23
)

model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]


Model Configuration:
Architecture: efficientnet_b0
Input size: 224x224
Number of classes: 3
Pretrained: True
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
          Identity-2         [-1, 32, 112, 112]               0
              SiLU-3         [-1, 32, 112, 112]               0
    BatchNormAct2d-4         [-1, 32, 112, 112]              64
            Conv2d-5         [-1, 32, 112, 112]             288
          Identity-6         [-1, 32, 112, 112]               0
              SiLU-7         [-1, 32, 112, 112]               0
    BatchNormAct2d-8         [-1, 32, 112, 112]              64
          Identity-9         [-1, 32, 112, 112]               0
           Conv2d-10              [-1, 8, 1, 1]             264
             SiLU-11              [-1, 8, 1, 1]               0
           Conv2d-12             [-1, 32, 1, 1]          

# Training

In [12]:
class TrainingLogger:
    """Handles logging of training metrics and saving checkpoints"""
    def __init__(self, save_dir: str):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize metrics storage
        self.metrics = {
            'train': {}, 'val': {},
            'best_epoch': 0,
            'best_val_loss': float('inf')
        }
        
    def log_epoch(self, epoch: int, metrics: Dict[str, Any]):
        """Log metrics for an epoch"""
        for split in ['train', 'val']:
            if split not in self.metrics:
                self.metrics[split] = {}
            
            for condition in metrics[split]:
                if condition not in self.metrics[split]:
                    self.metrics[split][condition] = {
                        'loss': [], 
                        'avg_roc_auc': [], 
                        'avg_pr_auc': [],
                        'accuracy' : []
                    }
                
                for metric in ['loss', 'avg_roc_auc', 'avg_pr_auc', 'accuracy']:
                    self.metrics[split][condition][metric].append(
                        metrics[split][condition][metric]
                    )
        
        # Update best epoch if needed
        avg_val_loss = np.mean([
            metrics['val'][condition]['loss'] 
            for condition in metrics['val']
        ])
        if avg_val_loss < self.metrics['best_val_loss']:
            self.metrics['best_val_loss'] = avg_val_loss
            self.metrics['best_epoch'] = epoch
            return True
        return False
    
    def save_metrics(self):
        """Save metrics to JSON"""
        with open(self.save_dir / 'metrics.json', 'w') as f:
            json.dump(self.metrics, f, indent=4)
    
    def plot_metrics(self):
        """Plot training and validation metrics"""
        metrics_to_plot = ['loss', 'avg_roc_auc', 'avg_pr_auc', 'accuracy']
        splits = ['train', 'val']
        
        for metric in metrics_to_plot:
            fig, axes = plt.subplots(1, len(splits), figsize=(15, 5))
            
            for idx, split in enumerate(splits):
                ax = axes[idx]
                
                for condition in self.metrics[split]:
                    values = self.metrics[split][condition][metric]
                    ax.plot(values, label=condition)
                
                ax.set_title(f'{split.capitalize()} {metric}')
                ax.set_xlabel('Epoch')
                ax.set_ylabel(metric.upper())
                ax.legend()
                ax.grid(True)
            
            plt.tight_layout()
            plt.savefig(self.save_dir / f'{metric}_plot.png')
            plt.close()

In [13]:
def train_model(
    model: nn.Module,
    dataloaders: Dict[str, Dict[str, DataLoader]],
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    num_epochs: int,
    save_dir: str,
    num_classes: int = 3,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    early_stopping_patience: int = 10,
    verbose: bool = True,
    class_weights: Optional[torch.Tensor] = None
) -> Tuple[nn.Module, Dict]:
    """
    Train model with multiple condition-specific dataloaders for multi-class classification.
    Supports optional class weights to handle class imbalance.
    """
    from tqdm.auto import tqdm  # Use auto for better notebook/terminal compatibility
    import colorama  # For color support
    colorama.init()  # Initialize colorama for cross-platform color support

    device = model.device
    logger = TrainingLogger(save_dir)
    start_time = time.time()
    
    # Calculate class weights if not provided
    if class_weights is None:
        # Aggregate labels from all training dataloaders
        all_labels = []
        for condition_loader in dataloaders['train'].values():
            for _, labels in condition_loader:
                all_labels.extend(labels.cpu().numpy())  # Add .cpu() before .numpy()
        
        # Count label occurrences
        label_counts = {}
        for label in all_labels:
            label_counts[label] = label_counts.get(label, 0) + 1
        
        # Calculate weights (inverse frequency)
        total_samples = len(all_labels)
        weights = [total_samples / (len(label_counts) * count) for count in label_counts.values()]
        
        # Normalize weights
        weights = np.array(weights)
        weights = weights / weights.sum()
        
        # Convert to tensor and move to device
        class_weights = torch.tensor(weights, dtype=torch.float).to(device)
        
        if verbose:
            print("Calculated Class Weights:", class_weights)
    else:
        # Ensure provided weights are on the correct device
        class_weights = class_weights.to(device)
    
    # Modify criterion to use class weights
    if class_weights is not None:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Customize tqdm style
    tqdm_config = {
        'colour': 'green',  # Overall progress color
        'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
        'dynamic_ncols': True,  # Adjust to terminal width
        'unit_scale': True,  # Automatically scale units
    }
    
    # Early stopping setup
    patience_counter = 0
    
    # Wrap epochs with tqdm with enhanced formatting
    for epoch in tqdm(range(num_epochs), 
                      desc=f"🏋️  Training Progress", 
                      total=num_epochs,
                      **tqdm_config):
        epoch_start_time = time.time()
        epoch_metrics = {'train': {}, 'val': {}}
        
        # Training and validation for each split
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                phase_color = 'cyan'
            else:
                model.eval()
                phase_color = 'magenta'
            
            # Process each condition with tqdm
            for condition, dataloader in tqdm(
                dataloaders[phase].items(), 
                desc=f"\033[1;{phase_color}m📊 {phase.capitalize()} Conditions\033[0m", 
                leave=False,
                colour='blue',
                bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}'
            ):
                running_loss = 0.0
                all_preds = []
                all_labels = []
                
                # Wrap dataloader with tqdm
                for inputs, labels in tqdm(
                    dataloader, 
                    desc=f"\033[1;{phase_color}m🔬 {phase.capitalize()} {condition}\033[0m", 
                    leave=False,
                    colour='white',
                    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}]'
                ):
                    inputs = inputs.to(device)
                    labels = labels.long().to(device)
                    
                    optimizer.zero_grad()
                    
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    
                    running_loss += loss.item() * inputs.size(0)
                    
                    # Store predictions and labels for metrics
                    # Use softmax for multi-class probability distribution
                    preds = torch.softmax(outputs, dim=1).detach().cpu().numpy()
                    labels_np = labels.cpu().numpy()
                    all_preds.extend(preds)
                    all_labels.extend(labels_np)
                
                # Calculate metrics for this condition
                all_preds = np.array(all_preds)
                all_labels = np.array(all_labels)
                
                # Calculate average loss
                epoch_loss = running_loss / len(dataloader.dataset)
                
                # Compute multi-class metrics
                # One-vs-Rest approach for ROC AUC and PR AUC
                roc_aucs = []
                pr_aucs = []
                accuracy = accuracy_score(all_labels, np.argmax(all_preds, axis=1))
                
                for i in range(num_classes):
                    # One-hot encode the labels for binary classification metrics
                    class_labels = (all_labels == i).astype(int)
                    class_preds = all_preds[:, i]
                    
                    # ROC AUC
                    try:
                        roc_aucs.append(roc_auc_score(class_labels, class_preds))
                    except ValueError:
                        roc_aucs.append(0.0)  # Handle cases with insufficient variation
                    
                    # Precision-Recall AUC
                    precision, recall, _ = precision_recall_curve(class_labels, class_preds)
                    pr_aucs.append(auc(recall, precision))
                
                # Store metrics
                epoch_metrics[phase][condition] = {
                    'loss': epoch_loss,
                    'avg_roc_auc': np.mean(roc_aucs),
                    'avg_pr_auc': np.mean(pr_aucs),
                    'accuracy' : accuracy
                }
            
            # Step scheduler if it exists and we're in training phase
            if phase == 'train' and scheduler is not None:
                scheduler.step()
        
        # Log epoch metrics and check for best model
        is_best = logger.log_epoch(epoch, epoch_metrics)
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': epoch_metrics,
        }
        torch.save(checkpoint, f"{save_dir}/checkpoint_latest.pth")
        
        if is_best:
            torch.save(checkpoint, f"{save_dir}/checkpoint_best.pth")
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print epoch summary
        if verbose:
            epoch_time = time.time() - epoch_start_time
            print(f"\nEpoch {epoch+1}/{num_epochs} - Time: {epoch_time:.2f}s")
            
            for phase in ['train', 'val']:
                avg_loss = np.mean([m['loss'] for m in epoch_metrics[phase].values()])
                avg_roc_auc = np.mean([m['avg_roc_auc'] for m in epoch_metrics[phase].values()])
                avg_pr_auc = np.mean([m['avg_pr_auc'] for m in epoch_metrics[phase].values()])
                avg_accuracy = np.mean([m['accuracy'] for m in epoch_metrics[phase].values()])
                
                print(f"{phase.capitalize():5s} - "
                      f"Loss: {avg_loss:.4f}, "
                      f"Avg ROC AUC: {avg_roc_auc:.4f}, "
                      f"Avg PR AUC: {avg_pr_auc:.4f}, "
                      f"Avg Accuracy : {avg_accuracy:.4f}, "
                )
        
        # Early stopping check
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered!")
            break
        
        # Save metrics and plots
        logger.save_metrics()
        logger.plot_metrics()
    
    # Load best model weights
    best_checkpoint = torch.load(f"{save_dir}/checkpoint_best.pth")
    model.load_state_dict(best_checkpoint['model_state_dict'])
    
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time/60:.2f} minutes")
    print(f"Best epoch: {logger.metrics['best_epoch']}")
    
    return model, logger.metrics

In [14]:
# When calling train_model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())
num_classes = 3

trained_model, metrics = train_model(
    model=model,
    dataloaders=dataloaders,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=50,
    save_dir='./checkpoints',
    num_classes=num_classes
)

Calculated Class Weights: tensor([0.0560, 0.6794, 0.2647], device='cuda:0')


🏋️  Training Progress:   0%|          | 0.00/50.0 [00:00<?, ?it/s]

[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 1/50 - Time: 34.86s
Train - Loss: 8.1059, Avg ROC AUC: 0.5526, Avg PR AUC: 0.3856, Avg Accuracy : 0.3846, 
Val   - Loss: 3.4082, Avg ROC AUC: 0.6137, Avg PR AUC: 0.3880, Avg Accuracy : 0.2841, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 2/50 - Time: 34.04s
Train - Loss: 4.1578, Avg ROC AUC: 0.5789, Avg PR AUC: 0.3730, Avg Accuracy : 0.4052, 
Val   - Loss: 1.6234, Avg ROC AUC: 0.6962, Avg PR AUC: 0.4336, Avg Accuracy : 0.2682, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 3/50 - Time: 33.43s
Train - Loss: 2.1820, Avg ROC AUC: 0.6288, Avg PR AUC: 0.3878, Avg Accuracy : 0.4404, 
Val   - Loss: 1.2111, Avg ROC AUC: 0.7360, Avg PR AUC: 0.4523, Avg Accuracy : 0.2440, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 4/50 - Time: 33.51s
Train - Loss: 1.3097, Avg ROC AUC: 0.6648, Avg PR AUC: 0.4012, Avg Accuracy : 0.4509, 
Val   - Loss: 0.8435, Avg ROC AUC: 0.7615, Avg PR AUC: 0.5032, Avg Accuracy : 0.2780, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 5/50 - Time: 33.53s
Train - Loss: 0.9656, Avg ROC AUC: 0.7124, Avg PR AUC: 0.4265, Avg Accuracy : 0.4622, 
Val   - Loss: 0.7875, Avg ROC AUC: 0.7749, Avg PR AUC: 0.5158, Avg Accuracy : 0.3643, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 6/50 - Time: 33.25s
Train - Loss: 0.8350, Avg ROC AUC: 0.7480, Avg PR AUC: 0.4486, Avg Accuracy : 0.4816, 
Val   - Loss: 0.8189, Avg ROC AUC: 0.7442, Avg PR AUC: 0.4742, Avg Accuracy : 0.3271, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 7/50 - Time: 33.51s
Train - Loss: 0.7991, Avg ROC AUC: 0.7466, Avg PR AUC: 0.4511, Avg Accuracy : 0.4670, 
Val   - Loss: 0.8004, Avg ROC AUC: 0.7497, Avg PR AUC: 0.4839, Avg Accuracy : 0.3351, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 8/50 - Time: 33.26s
Train - Loss: 0.7288, Avg ROC AUC: 0.7905, Avg PR AUC: 0.4968, Avg Accuracy : 0.5161, 
Val   - Loss: 0.8196, Avg ROC AUC: 0.7749, Avg PR AUC: 0.5173, Avg Accuracy : 0.4021, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 9/50 - Time: 33.04s
Train - Loss: 0.7023, Avg ROC AUC: 0.8112, Avg PR AUC: 0.5116, Avg Accuracy : 0.5484, 
Val   - Loss: 0.8048, Avg ROC AUC: 0.7624, Avg PR AUC: 0.5002, Avg Accuracy : 0.3149, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 10/50 - Time: 33.52s
Train - Loss: 0.6792, Avg ROC AUC: 0.8261, Avg PR AUC: 0.5309, Avg Accuracy : 0.5662, 
Val   - Loss: 0.7862, Avg ROC AUC: 0.7841, Avg PR AUC: 0.5275, Avg Accuracy : 0.3985, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 11/50 - Time: 33.31s
Train - Loss: 0.6407, Avg ROC AUC: 0.8478, Avg PR AUC: 0.5572, Avg Accuracy : 0.5974, 
Val   - Loss: 0.8093, Avg ROC AUC: 0.7815, Avg PR AUC: 0.5166, Avg Accuracy : 0.5556, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 12/50 - Time: 33.00s
Train - Loss: 0.6193, Avg ROC AUC: 0.8595, Avg PR AUC: 0.5814, Avg Accuracy : 0.6213, 
Val   - Loss: 0.7974, Avg ROC AUC: 0.7662, Avg PR AUC: 0.4982, Avg Accuracy : 0.4432, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 13/50 - Time: 33.84s
Train - Loss: 0.5764, Avg ROC AUC: 0.8794, Avg PR AUC: 0.6196, Avg Accuracy : 0.6583, 
Val   - Loss: 0.7992, Avg ROC AUC: 0.7789, Avg PR AUC: 0.5204, Avg Accuracy : 0.5734, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 14/50 - Time: 33.71s
Train - Loss: 0.5363, Avg ROC AUC: 0.8957, Avg PR AUC: 0.6534, Avg Accuracy : 0.6883, 
Val   - Loss: 0.8999, Avg ROC AUC: 0.7775, Avg PR AUC: 0.5132, Avg Accuracy : 0.6399, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 15/50 - Time: 34.53s
Train - Loss: 0.4918, Avg ROC AUC: 0.9109, Avg PR AUC: 0.6876, Avg Accuracy : 0.7243, 
Val   - Loss: 0.8613, Avg ROC AUC: 0.7709, Avg PR AUC: 0.5148, Avg Accuracy : 0.6144, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 16/50 - Time: 34.48s
Train - Loss: 0.4480, Avg ROC AUC: 0.9246, Avg PR AUC: 0.7268, Avg Accuracy : 0.7520, 
Val   - Loss: 0.9025, Avg ROC AUC: 0.7392, Avg PR AUC: 0.4772, Avg Accuracy : 0.4706, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 17/50 - Time: 34.44s
Train - Loss: 0.4076, Avg ROC AUC: 0.9372, Avg PR AUC: 0.7743, Avg Accuracy : 0.7735, 
Val   - Loss: 0.9774, Avg ROC AUC: 0.7511, Avg PR AUC: 0.5047, Avg Accuracy : 0.5727, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 18/50 - Time: 33.93s
Train - Loss: 0.3694, Avg ROC AUC: 0.9460, Avg PR AUC: 0.8057, Avg Accuracy : 0.7970, 
Val   - Loss: 1.1242, Avg ROC AUC: 0.7382, Avg PR AUC: 0.4841, Avg Accuracy : 0.5565, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 19/50 - Time: 34.12s
Train - Loss: 0.3376, Avg ROC AUC: 0.9540, Avg PR AUC: 0.8312, Avg Accuracy : 0.8126, 
Val   - Loss: 1.0786, Avg ROC AUC: 0.7464, Avg PR AUC: 0.4947, Avg Accuracy : 0.5837, 


[1;cyanm📊 Train Conditions[0m:   0%|          | 0/5

[1;cyanm🔬 Train spinal_canal_stenosis[0m:   0%|          | 0/244 [?it/s]

[1;cyanm🔬 Train left_neural_foraminal_narrowing[0m:   0%|          | 0/247 [?it/s]

[1;cyanm🔬 Train right_neural_foraminal_narrowing[0m:   0%|          | 0/246 [?it/s]

[1;cyanm🔬 Train left_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;cyanm🔬 Train right_subarticular_stenosis[0m:   0%|          | 0/241 [?it/s]

[1;magentam📊 Val Conditions[0m:   0%|          | 0/5

[1;magentam🔬 Val spinal_canal_stenosis[0m:   0%|          | 0/61 [?it/s]

[1;magentam🔬 Val left_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val right_neural_foraminal_narrowing[0m:   0%|          | 0/62 [?it/s]

[1;magentam🔬 Val left_subarticular_stenosis[0m:   0%|          | 0/60 [?it/s]

[1;magentam🔬 Val right_subarticular_stenosis[0m:   0%|          | 0/61 [?it/s]


Epoch 20/50 - Time: 33.87s
Train - Loss: 0.3134, Avg ROC AUC: 0.9603, Avg PR AUC: 0.8499, Avg Accuracy : 0.8285, 
Val   - Loss: 1.0642, Avg ROC AUC: 0.7308, Avg PR AUC: 0.4885, Avg Accuracy : 0.5287, 
Early stopping triggered!


  best_checkpoint = torch.load(f"{save_dir}/checkpoint_best.pth")



Training completed in 12.12 minutes
Best epoch: 9
