In [1]:
!pip install torchio
!pip install focal_loss_torch

Collecting torchio
  Downloading torchio-0.21.0-py3-none-any.whl.metadata (52 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m52.6/52.6 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.9->torchio)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.9->torchio)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.9->torchio)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9->torchio)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from to

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from PIL import Image
import seaborn as sns
import math
import os
import warnings
import logging
import time

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn, einsum
import torchio as tio
import torch.nn.functional as F
from scipy.ndimage import rotate, zoom
import random
from focal_loss.focal_loss import FocalLoss

from sklearn.metrics import (
    confusion_matrix, 
    classification_report, 
    cohen_kappa_score,
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve
)

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

warnings.filterwarnings("ignore", module="torchio")

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # n·∫øu mu·ªën reproducible tuy·ªát ƒë·ªëi: deterministic=True + benchmark=False
    # Ph√°t tri·ªÉn ki·∫øn tr√∫c: deterministic=False + benchmark=True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)

<torch._C.Generator at 0x7c80b5415990>

In [5]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, transforms=None):
        self.df = pd.read_csv(dataframe)
        self.transforms = transforms

    def __getitem__(self, index):
        path_object = self.df.loc[index]['mri_path']
        mri_file = '/workspace/data/SAG_3D_DESS_v2_full/MRI_Numpy/' + path_object
        mri_dict = np.load(mri_file)
        mri_object = mri_dict['data']

        mri_object = np.expand_dims(mri_object, 0) # (1 x 120 x 160 x 160)
        mri_object = self.transforms(mri_object)
        mri_tensor = torch.tensor(mri_object)

        label = self.df.loc[index]['kl_grade']

        return mri_tensor, label

    def __len__(self):
        return len(self.df)

In [6]:
spatial_augment = [
    tio.RandomAffine(degrees=15, p=0.5),
    tio.RandomFlip(axes=(0,), flip_probability=0.5),
]

intensity_augment = {
    tio.RandomNoise(): 0.25,
    tio.RandomBiasField(): 0.25,
    tio.RandomBlur(std=(0,1.5)): 0.25,
    tio.RandomMotion(): 0.25,
}


train_transforms = tio.Compose([
    tio.Compose(spatial_augment, p=1),
    tio.OneOf(intensity_augment, p=0.75),
    tio.RescaleIntensity(out_min_max=(0,1)),
])

val_transforms = tio.Compose([
    tio.RescaleIntensity(out_min_max=(0,1)),
])

test_transforms = tio.Compose([
    tio.RescaleIntensity(out_min_max=(0,1)),
])


df = pd.read_csv('/workspace/data/unified_xray_mri_label.csv')

train_df = df[df['subset'] == 'train'].reset_index(drop=True)
val_df = df[df['subset'] == 'val'].reset_index(drop=True)
test_df = df[df['subset'] == 'test'].reset_index(drop=True)

train_ds = CustomDataset(train_df, transforms=train_transforms)
val_ds = CustomDataset(val_df, transforms=val_transforms)
test_ds = CustomDataset(test_df, transforms=test_transforms)

In [7]:
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

val_loader = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

test_loader = DataLoader(
    test_ds,
    batch_size=4,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

In [8]:
# for mri, label in train_loader:
#     print(mri.shape)
#     print(label.shape)
#     print(torch.max(mri), torch.min(mri))
#     break

# Vision Transformer

In [20]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout,),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class VisionTransformer(nn.Module):
    def __init__(self,
                 *,
                 image_size,
                 image_patch_size,
                 frames,
                 frame_patch_size,
                 num_classes,
                 dim, depth,
                 heads,
                 mlp_dim,
                 pool = 'cls',
                 channels = 3,
                 dim_head = 64,
                 dropout = 0.,
                 emb_dropout = 0.,
                 pretrain_path=None):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(image_patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
        patch_dim = channels * patch_height * patch_width * frame_patch_size
        self.num_patches = num_patches
        self.image_size = image_size
        self.image_patch_size = image_patch_size
        self.frames = frames
        self.frame_patch_size=frame_patch_size

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.conv_proj = nn.Sequential(
            nn.Conv3d(channels, dim, kernel_size=(frame_patch_size, image_patch_size, image_patch_size), stride=(frame_patch_size, image_patch_size, image_patch_size),
        ))

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

        if pretrain_path is not None:
            self.load_pretrain(pretrain_path)
            print(f'Load pretrained {pretrain_path} sucessfully!')

    def load_pretrain(self, pretrain_path):
        jax_dict = torch.load(pretrain_path, map_location='cpu')
        new_dict = {}

        def interpolate_pos_embedding(pre_pos_embed):
            cls_token, pretrained_pos_embed = pre_pos_embed[:, :1, :], pre_pos_embed[:, 1:, :]  # [1, 1, 768], [1, 196, 768]
            new_num_patches = self.num_patches # 1000
            old_num_patches = int(pretrained_pos_embed.shape[1] ** 0.5) # 14
            pretrained_pos_embed = pretrained_pos_embed.reshape(1, old_num_patches, old_num_patches, -1).permute(0, 3, 1, 2)  # [1, 768, 14, 14]
            pretrained_pos_embed = pretrained_pos_embed.unsqueeze(2)  # [1, 768, 1, 14, 14]
            new_size = round(new_num_patches ** (1/3))
            pretrained_pos_embed = F.interpolate(pretrained_pos_embed, size=(new_size, new_size, new_size), mode='trilinear', align_corners=False)  # [1, 768, 10, 10, 10]
            pretrained_pos_embed = pretrained_pos_embed.permute(0, 2, 3, 4, 1).reshape(1, new_size*new_size*new_size, -1) # [1,1000, 768]
            new_pos_embed = torch.cat([cls_token, pretrained_pos_embed], dim=1)
            return new_pos_embed

        def mean_kernel(patch_emb_weight):
            patch_emb_weight = patch_emb_weight.mean(dim=1, keepdim=True)  # Shape: [768, 1, 16, 16]
            depth = self.conv_proj[0].weight.shape[2]
            patch_emb_weight = patch_emb_weight.unsqueeze(2).repeat(1, 1, depth, 1, 1)  # Shape: [768, 1, 12, 16, 16]
            return patch_emb_weight

        def add_item(key, value):
            key = key.replace('blocks', 'transformer.layers')
            new_dict[key] = value

        for key, value in jax_dict.items():
            if key == 'cls_token':
                new_dict[key] = value

            elif 'norm1' in key:
                new_key = key.replace('norm1', '0.norm')
                add_item(new_key, value)
            elif 'attn.qkv' in key:
                new_key = key.replace('attn.qkv', '0.to_qkv')
                add_item(new_key, value)
            elif 'attn.proj' in key:
                new_key = key.replace('attn.proj', '0.to_out.0')
                add_item(new_key, value)
            elif 'norm2' in key:
                new_key = key.replace('norm2', '1.net.0')
                add_item(new_key, value)
            elif 'mlp.fc1' in key:
                new_key = key.replace('mlp.fc1', '1.net.1')
                add_item(new_key, value)
            elif 'mlp.fc2' in key:
                new_key = key.replace('mlp.fc2', '1.net.4')
                add_item(new_key, value)
            elif 'patch_embed.proj.weight' in key:
                new_key = key.replace('patch_embed.proj.weight', 'conv_proj.0.weight')
                value = mean_kernel(value)
                add_item(new_key, value)
            elif 'patch_embed.proj.bias' in key:
                new_key = key.replace('patch_embed.proj.bias', 'conv_proj.0.bias')
                add_item(new_key, value)
            elif key == 'pos_embed':
                value = interpolate_pos_embedding(value)
                add_item('pos_embedding', value)
            elif key == 'norm.weight':
                add_item('transformer.norm.weight', value)
            elif key == 'norm.bias':
                add_item('transformer.norm.bias', value)

        self.load_state_dict(new_dict, strict=False)


    def forward(self, img):
        x = self.conv_proj(img)
        x = x.flatten(2).transpose(1,2) # [B, N, C]
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [21]:
import timm
import torch

pretrained = timm.create_model('vit_base_patch16_224.orig_in21k', pretrained=True)

torch.save(pretrained.state_dict(), "vit_base_patch16_224_in21k.pth")

In [22]:
model = VisionTransformer(
    image_size=160,
    image_patch_size=16,
    frames = 120,
    frame_patch_size = 12,
    depth=12,
    heads=12,
    dim=768,
    mlp_dim=3072,
    dropout=0.2,
    emb_dropout=0.1,
    channels = 1,
    num_classes = 5,
    freeze_vit = True,
    pool = 'cls',
    pretrain_path = './pretrained/vit_base_patch16_224_in21k.pth',
)
model.to(device)

Load pretrained /kaggle/working/vit_base_patch16_224_in21k.pth sucessfully!
Freezing pretrained ViT components...
Trainable: transformer.layers.0.1.adapter_layer_norm_before.weight | Shape: [768]               
Trainable: transformer.layers.0.1.adapter_layer_norm_before.bias | Shape: [768]               
Trainable: transformer.layers.0.1.expand.weight               | Shape: [3072, 768]         
Trainable: transformer.layers.0.1.expand.bias                 | Shape: [3072]              
Trainable: transformer.layers.0.1.dw_conv.weight              | Shape: [3072, 1, 3, 3, 3]  
Trainable: transformer.layers.0.1.bn.weight                   | Shape: [3072]              
Trainable: transformer.layers.0.1.bn.bias                     | Shape: [3072]              
Trainable: transformer.layers.0.1.project.weight              | Shape: [768, 3072]         
Trainable: transformer.layers.0.1.project.bias                | Shape: [768]               
Trainable: transformer.layers.1.1.adapter_layer_no

VisionTransformer(
  (conv_proj): Sequential(
    (0): Conv3d(1, 768, kernel_size=(12, 16, 16), stride=(12, 16, 16))
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-11): 12 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.2, inplace=False)
          (to_qkv): Linear(in_features=768, out_features=2304, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Dropout(p=0.2, inplace=False)
          )
        )
        (1): InvertedResidualAdapter(
          (adapter_layer_norm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (expand): Linear(in_features=768, out_features=3072, bias=True)
          (dw_conv): Conv3d(3072, 3072, kernel_si

In [23]:
count_freeze = 0
count_tuning = 0

for name, param in model.named_parameters():
    if param.requires_grad == True:
        count_tuning += 1
    else:
        count_freeze += 1

print(f'There are {count_tuning} trainable params.')
print(f'There are {count_freeze} freeze params')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

There are 110 trainable params.
There are 138 freeze params
Total trainable parameters: 57760517
GAViKO go go!


# Define Loss

In [None]:
from focal_loss.focal_loss import FocalLoss

# alpha = torch.FloatTensor([2.65, 5.39, 3.83, 7.03, 29.67]).to(device)
criterion = FocalLoss(gamma=1.2)

In [None]:
trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(trainable_params, lr=1e-4)

from torch.optim.lr_scheduler import OneCycleLR
steps_per_epoch = len(train_loader)
num_epochs = 30
total_steps = steps_per_epoch * num_epochs
scheduler = OneCycleLR(
    optimizer,
    max_lr=3e-4,  # learning rate cao nh·∫•t
    total_steps=total_steps,
    pct_start=0.3,  # % s·ªë b∆∞·ªõc d√†nh cho giai ƒëo·∫°n tƒÉng lr (warmup)
    div_factor=10.0,  # lr_start = max_lr / div_factor
    final_div_factor=1000.0,  # lr_final = lr_start / final_div_factor
    anneal_strategy='cos',  # s·ª≠ d·ª•ng cosine annealing
    three_phase=False  # kh√¥ng d√πng 3 giai ƒëo·∫°n (ch·ªâ d√πng 2: l√™n-xu·ªëng)
)

# Training Setup

In [None]:
# =============================================================================
# SETUP
# =============================================================================
save_dir = "./output"
os.makedirs(save_dir, exist_ok=True)

# Logging setup
log_file = os.path.join(save_dir, 'training_log.txt')
logging.basicConfig(
    filename=log_file,
    level=logging.INFO,
    format='%(asctime)s - %(message)s'
)

# Also log to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)

In [None]:
class TrainingState:
    """Track training state"""
    def __init__(self):
        self.val_acc_max = 0.0
        self.val_kappa_max = 0.0
        self.current_epoch = 0
        self.epoch_since_improvement = 0
        self.val_loss = 0
        
        # Metrics history
        self.train_loss_history = []
        self.train_acc_history = []
        self.val_loss_history = []
        self.val_acc_history = []
        self.val_kappa_history = []
        self.lr_history = []
        
        # Best metrics
        self.best_epoch = 0
        self.best_model_path = None

state = TrainingState()

# Training Utils

In [None]:
def atomic_save(obj, path):
    """Atomic save to prevent corruption"""
    tmp = path + ".tmp"
    torch.save(obj, tmp)
    os.replace(tmp, path)  # Atomic on Unix
    logging.info(f"Model saved to {path}")

def compute_metrics(predictions, labels, num_classes=5):
    """
    Compute comprehensive metrics
    
    Returns:
        dict with accuracy, kappa, per-class accuracy, confusion matrix
    """
    # Accuracy
    accuracy = (predictions == labels).sum().item() / len(labels)
    
    # Cohen's Kappa 
    kappa = cohen_kappa_score(labels, predictions, weights='quadratic')
    
    # Confusion matrix
    cm = confusion_matrix(labels, predictions, labels=range(num_classes))
    
    # Per-class accuracy
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    
    return {
        'accuracy': accuracy,
        'kappa': kappa,
        'confusion_matrix': cm,
        'per_class_accuracy': per_class_acc
    }

def log_metrics(phase, epoch, loss, metrics, lr=None):
    """Log metrics to console and file"""
    msg = f"\n{'='*80}\n"
    msg += f"Epoch {epoch + 1} - {phase.upper()}\n"
    msg += f"{'='*80}\n"
    msg += f"Loss: {loss:.4f}\n"
    msg += f"Accuracy: {metrics['accuracy']*100:.2f}%\n"
    msg += f"Kappa Score: {metrics['kappa']:.4f}\n"
    
    if lr is not None:
        msg += f"Learning Rate: {lr:.6f}\n"
    
    msg += "\nPer-Class Accuracy:\n"
    for i, acc in enumerate(metrics['per_class_accuracy']):
        msg += f"  KL Grade {i}: {acc*100:.2f}%\n"
    
    msg += f"{'='*80}\n"
    
    print(msg)
    logging.info(msg)

def plot_training_curves(state, save_path):
    """Plot training curves"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = range(1, len(state.train_loss_history) + 1)
    
    # Loss
    axes[0, 0].plot(epochs, state.train_loss_history, 'b-', label='Train Loss')
    axes[0, 0].plot(epochs, state.val_loss_history, 'r-', label='Val Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Accuracy
    axes[0, 1].plot(epochs, state.train_acc_history, 'b-', label='Train Acc')
    axes[0, 1].plot(epochs, state.val_acc_history, 'r-', label='Val Acc')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Accuracy Curves')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Kappa
    axes[1, 0].plot(epochs, state.val_kappa_history, 'g-', label='Val Kappa')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Kappa Score')
    axes[1, 0].set_title('Kappa Score (Validation)')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Learning Rate
    axes[1, 1].plot(epochs, state.lr_history, 'purple', label='Learning Rate')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    logging.info(f"Training curves saved to {save_path}")

def plot_confusion_matrix(cm, save_path, normalize=False):
    """Plot confusion matrix"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt='.2f' if normalize else 'd',
        cmap='Blues',
        xticklabels=[f'KL {i}' for i in range(5)],
        yticklabels=[f'KL {i}' for i in range(5)],
        cbar_kws={'label': 'Normalized Count' if normalize else 'Count'}
    )
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix' + (' (Normalized)' if normalize else ''))
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    logging.info(f"Confusion matrix saved to {save_path}")

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch):
    """Train for one epoch"""
    model.train()
    
    running_loss = 0.0
    all_predictions = []
    all_labels = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1} [TRAIN]")
    
    for batch_idx, (batch, labels) in enumerate(pbar):
        # Get data
        batch = batch.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(batch)
        
        m = torch.nn.Softmax(dim=-1)
        # Compute loss
        loss = criterion(m(logits), labels)

        # Backward pass
        loss.backward()

        # Gradient clipping (optional but recommended)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # Update learning rate (if using OneCycleLR or similar)
        if scheduler is not None and hasattr(scheduler, 'step') and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step()
        
        # Track metrics
        running_loss += loss.item() * batch.size(0)
        predictions = torch.argmax(logits, dim=1).cpu().numpy()
        all_predictions.extend(predictions)
        all_labels.extend(labels.cpu().numpy())
        
        # Update progress bar
        current_loss = running_loss / ((batch_idx + 1) * train_loader.batch_size)
        pbar.set_postfix({'loss': f'{current_loss:.4f}'})
    
    # Compute epoch metrics
    epoch_loss = running_loss / len(train_loader.dataset)
    metrics = compute_metrics(
        np.array(all_predictions),
        np.array(all_labels)
    )
    
    return epoch_loss, metrics

def validate(model, val_loader, criterion, device, epoch):
    """Validate model"""
    model.eval()
    
    running_loss = 0.0
    all_predictions = []
    all_labels = []
    
    pbar = tqdm(val_loader, desc=f"Epoch {epoch + 1} [VAL]")
    
    with torch.no_grad():
        for batch, labels in pbar:
            # Get data
            batch = batch.to(device)
            labels = labels.to(device)
            
            # Forward pass
            logits = model(batch)
            m = torch.nn.Softmax(dim=-1)
            # Compute loss
            loss = criterion(m(logits), labels)
            
            # Track metrics
            running_loss += loss.item() * batch.size(0)
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            all_predictions.extend(predictions)
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            current_loss = running_loss / len(all_labels)
            pbar.set_postfix({'loss': f'{current_loss:.4f}'})
    
    # Compute epoch metrics
    epoch_loss = running_loss / len(val_loader.dataset)
    metrics = compute_metrics(
        np.array(all_predictions),
        np.array(all_labels)
    )
    
    return epoch_loss, metrics

In [None]:
def train(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    device,
    num_epochs=100,
    patience=10,
    save_dir='./checkpoints',
    early_stop_metric='accuracy'  # 'accuracy' or 'kappa'
):
    """
    Main training loop
    
    Args:
        model: CVPT model
        train_loader: Training dataloader
        val_loader: Validation dataloader
        criterion: Loss function
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        device: Device to train on
        num_epochs: Maximum number of epochs
        patience: Early stopping patience
        save_dir: Directory to save checkpoints
        early_stop_metric: Metric to use for early stopping ('accuracy' or 'kappa')
    """
    
    os.makedirs(save_dir, exist_ok=True)
    
    logging.info(f"\n{'='*80}")
    logging.info("STARTING TRAINING")
    logging.info(f"{'='*80}")
    logging.info(f"Device: {device}")
    logging.info(f"Number of epochs: {num_epochs}")
    logging.info(f"Patience: {patience}")
    logging.info(f"Early stop metric: {early_stop_metric}")
    logging.info(f"Save directory: {save_dir}")
    logging.info(f"{'='*80}\n")
    
    global state
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # ========== TRAINING ==========
        train_loss, train_metrics = train_one_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, epoch
        )
        
        # ========== VALIDATION ==========
        val_loss, val_metrics = validate(
            model, val_loader, criterion, device, epoch
        )
        
        # ========== UPDATE SCHEDULER (if ReduceLROnPlateau) ==========
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_loss)
        
        # ========== RECORD METRICS ==========
        current_lr = optimizer.param_groups[0]['lr']
        
        state.train_loss_history.append(train_loss)
        state.train_acc_history.append(train_metrics['accuracy'])
        state.val_loss_history.append(val_loss)
        state.val_acc_history.append(val_metrics['accuracy'])
        state.val_kappa_history.append(val_metrics['kappa'])
        state.lr_history.append(current_lr)
        state.current_epoch = epoch
        
        # ========== LOG METRICS ==========
        log_metrics('TRAIN', epoch, train_loss, train_metrics, lr=current_lr)
        log_metrics('VAL', epoch, val_loss, val_metrics)
        
        # ========== SAVE BEST MODEL ==========
        if early_stop_metric == 'accuracy':
            current_metric = val_metrics['accuracy']
            best_metric = state.val_acc_max

        elif early_stop_metric == 'val_loss': 
            current_val_loss = val_loss
            best_val_loss = state.val_loss
            
        else:  # kappa
            current_metric = val_metrics['kappa']
            best_metric = state.val_kappa_max
        
        if current_metric > best_metric:
            improvement = current_metric - best_metric
            logging.info(f"\nüéâ NEW BEST MODEL! {early_stop_metric.upper()} improved by {improvement:.4f}")
            logging.info(f"   Previous best: {best_metric:.4f}")
            logging.info(f"   New best: {current_metric:.4f}\n")
            
            # Update best metrics
            if early_stop_metric == 'accuracy':
                state.val_acc_max = current_metric
            elif early_stop_metric == 'val_loss':
                state.val_loss = val_loss 
            else:
                state.val_kappa_max = current_metric
            
            state.best_epoch = epoch
            state.epoch_since_improvement = 0

            checkpoint_state_dict = OrderedDict()
            for name, param in model.named_parameters():
                if param.requires_grad: 
                    checkpoint_state_dict[name] = param.data.clone()
            
            # Save model checkpoint
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': checkpoint_state_dict,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_metrics': train_metrics,
                'val_metrics': val_metrics,
                'state': state.__dict__,
            }
            
            # Save with epoch info
            model_path = os.path.join(
                save_dir,
                f'model_epoch{epoch+1}_{early_stop_metric}{current_metric:.4f}.pt'
            )
            atomic_save(checkpoint, model_path)
            
            # Save as best.pt
            best_path = os.path.join(save_dir, "best.pt")
            atomic_save(checkpoint, best_path)
            state.best_model_path = best_path
            
            # Save confusion matrix
            cm_path = os.path.join(save_dir, f'confusion_matrix_epoch{epoch+1}.png')
            plot_confusion_matrix(val_metrics['confusion_matrix'], cm_path, normalize=True)
            
        else:
            state.epoch_since_improvement += 1
            logging.info(f"\n‚ö†Ô∏è  No improvement for {state.epoch_since_improvement} epoch(s)")
            
            # Early stopping check
            if state.epoch_since_improvement >= patience:
                logging.info(f"\nüõë EARLY STOPPING at epoch {epoch + 1}")
                logging.info(f"   Best epoch: {state.best_epoch + 1}")
                logging.info(f"   Best {early_stop_metric}: {best_metric:.4f}\n")
                break
        
        # ========== SAVE TRAINING CURVES ==========
        if (epoch + 1) % 5 == 0 or epoch == 0:
            curves_path = os.path.join(save_dir, 'training_curves.png')
            plot_training_curves(state, curves_path)
        
        # ========== LOG EPOCH TIME ==========
        epoch_time = time.time() - epoch_start_time
        logging.info(f"‚è±Ô∏è  Epoch time: {epoch_time:.2f}s\n")
        
        # ========== SAVE PERIODIC CHECKPOINT ==========
        if (epoch + 1) % 10 == 0:
            checkpoint_state_dict = OrderedDict()
            for name, param in model.named_parameters():
                if param.requires_grad: 
                    checkpoint_state_dict[name] = param.data.clone()
                    
            periodic_path = os.path.join(save_dir, f'checkpoint_epoch{epoch+1}.pt')
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': checkpoint_state_dict,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'state': state.__dict__,
            }
            atomic_save(checkpoint, periodic_path)
    
    # ========== TRAINING COMPLETE ==========
    logging.info(f"\n{'='*80}")
    logging.info("TRAINING COMPLETE")
    logging.info(f"{'='*80}")
    logging.info(f"Total epochs: {state.current_epoch + 1}")
    logging.info(f"Best epoch: {state.best_epoch + 1}")
    logging.info(f"Best validation accuracy: {state.val_acc_max:.4f}")
    logging.info(f"Best validation kappa: {state.val_kappa_max:.4f}")
    logging.info(f"Best model saved to: {state.best_model_path}")
    logging.info(f"{'='*80}\n")
    
    # Final plots
    curves_path = os.path.join(save_dir, 'final_training_curves.png')
    plot_training_curves(state, curves_path)
    
    return state

# Trainer

In [None]:
final_state = train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        num_epochs=1000,
        patience=1000,
        save_dir=save_dir,
        early_stop_metric='accuracy'  # Use kappa for early stopping
    )
    
print("\nTraining completed successfully!")
print(f"Best validation accuracy: {final_state.val_acc_max:.4f}")
print(f"Best validation kappa: {final_state.val_kappa_max:.4f}")
print(f"Best model: {final_state.best_model_path}")

# Evaluation Utils

In [None]:
# =============================================================================
# EVALUATION FUNCTIONS
# =============================================================================

def evaluate_model(model, test_loader, criterion, device, return_predictions=False):
    """
    Comprehensive evaluation on test set
    
    Args:
        model: Trained CVPT model
        test_loader: Test dataloader
        criterion: Loss function
        device: Device to evaluate on
        return_predictions: Whether to return all predictions
    
    Returns:
        results: Dictionary containing all evaluation metrics
        predictions_dict: (optional) Dict with predictions, labels, etc.
    """
    model.eval()
    
    running_loss = 0.0
    all_predictions = []
    all_labels = []
    all_logits = []
    all_probs = []
    
    print(f"\n{'='*80}")
    print("EVALUATING ON TEST SET")
    print(f"{'='*80}\n")
    
    with torch.no_grad():
        for batch, labels in tqdm(test_loader, desc="Evaluating"):
            # Get data
            batch = batch.to(device)
            labels = labels.to(device)
            
            # Forward pass
            logits = model(batch)
            m = torch.nn.Softmax(dim=-1)
            # Compute loss
            loss = criterion(m(logits), labels)
            
            # Get predictions and probabilities
            probs = torch.softmax(logits, dim=1)
            predictions = torch.argmax(logits, dim=1)
            
            # Store results
            running_loss += loss.item() * batch.size(0)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_logits.append(logits.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    
    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_logits = np.vstack(all_logits)
    all_probs = np.vstack(all_probs)
    
    # Compute loss
    test_loss = running_loss / len(test_loader.dataset)
    
    # Compute comprehensive metrics
    results = compute_comprehensive_metrics(
        predictions=all_predictions,
        labels=all_labels,
        probs=all_probs,
        test_loss=test_loss
    )
    
    if return_predictions:
        predictions_dict = {
            'predictions': all_predictions,
            'labels': all_labels,
            'logits': all_logits,
            'probs': all_probs,
        }
        return results, predictions_dict
    
    return results


def compute_comprehensive_metrics(predictions, labels, probs, test_loss):
    """
    Compute all evaluation metrics
    
    Args:
        predictions: Predicted classes [N]
        labels: True labels [N]
        probs: Predicted probabilities [N, num_classes]
        test_loss: Test loss value
    
    Returns:
        results: Dictionary with all metrics
    """
    num_classes = probs.shape[1]
    
    results = {
        'test_loss': test_loss,
    }
    
    # ========== Classification Metrics ==========
    
    # Overall accuracy
    results['accuracy'] = accuracy_score(labels, predictions)
    
    # Cohen's Kappa (weighted for ordinal data)
    results['kappa'] = cohen_kappa_score(labels, predictions, weights='quadratic')
    results['kappa_linear'] = cohen_kappa_score(labels, predictions, weights='linear')
    results['kappa_unweighted'] = cohen_kappa_score(labels, predictions)
    
    # Precision, Recall, F1-Score (macro and weighted)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, predictions, average='weighted', zero_division=0
    )
    
    results['precision_macro'] = precision_macro
    results['recall_macro'] = recall_macro
    results['f1_macro'] = f1_macro
    results['precision_weighted'] = precision_weighted
    results['recall_weighted'] = recall_weighted
    results['f1_weighted'] = f1_weighted
    
    # Per-class metrics
    precision_per_class, recall_per_class, f1_per_class, support_per_class = \
        precision_recall_fscore_support(labels, predictions, average=None, zero_division=0)
    
    results['per_class_metrics'] = {
        f'KL_{i}': {
            'precision': precision_per_class[i],
            'recall': recall_per_class[i],
            'f1': f1_per_class[i],
            'support': int(support_per_class[i])
        }
        for i in range(num_classes)
    }
    
    # Confusion matrix
    cm = confusion_matrix(labels, predictions, labels=range(num_classes))
    results['confusion_matrix'] = cm
    
    # Per-class accuracy
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    results['per_class_accuracy'] = {
        f'KL_{i}': per_class_acc[i] for i in range(num_classes)
    }
    
    # ========== Ordinal Metrics ==========
    
    # Mean Absolute Error (MAE) - important for ordinal classification
    results['mae'] = np.mean(np.abs(predictions - labels))
    
    # Mean Squared Error (MSE)
    results['mse'] = np.mean((predictions - labels) ** 2)
    
    # Root Mean Squared Error (RMSE)
    results['rmse'] = np.sqrt(results['mse'])
    
    # Off-by-one accuracy (correct or within 1 grade)
    off_by_one = np.abs(predictions - labels) <= 1
    results['off_by_one_accuracy'] = np.mean(off_by_one)
    
    # ========== Confidence Metrics ==========
    
    # Average confidence (max probability)
    confidences = np.max(probs, axis=1)
    results['avg_confidence'] = np.mean(confidences)
    results['std_confidence'] = np.std(confidences)
    
    # Confidence for correct/incorrect predictions
    correct_mask = predictions == labels
    if correct_mask.sum() > 0:
        results['avg_confidence_correct'] = np.mean(confidences[correct_mask])
    if (~correct_mask).sum() > 0:
        results['avg_confidence_incorrect'] = np.mean(confidences[~correct_mask])
    
    # ========== Calibration Metrics ==========
    
    # Expected Calibration Error (ECE)
    results['ece'] = compute_ece(probs, labels, n_bins=10)
    
    # ========== AUC Metrics (One-vs-Rest) ==========
    
    try:
        # Binarize labels for AUC computation
        from sklearn.preprocessing import label_binarize
        labels_binarized = label_binarize(labels, classes=range(num_classes))
        
        # Compute AUC for each class (One-vs-Rest)
        auc_per_class = {}
        for i in range(num_classes):
            if labels_binarized[:, i].sum() > 0:  # Check if class exists
                auc = roc_auc_score(labels_binarized[:, i], probs[:, i])
                auc_per_class[f'KL_{i}'] = auc
        
        results['auc_per_class'] = auc_per_class
        
        # Macro AUC
        if len(auc_per_class) > 0:
            results['auc_macro'] = np.mean(list(auc_per_class.values()))
    except:
        pass
    
    return results


def compute_ece(probs, labels, n_bins=10):
    """
    Compute Expected Calibration Error
    
    Args:
        probs: Predicted probabilities [N, C]
        labels: True labels [N]
        n_bins: Number of bins
    
    Returns:
        ece: Expected Calibration Error
    """
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = (predictions == labels).astype(float)
    
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    
    for i in range(n_bins):
        bin_lower = bin_boundaries[i]
        bin_upper = bin_boundaries[i + 1]
        
        # Find samples in this bin
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = np.mean(in_bin)
        
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    
    return ece


# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================

def plot_confusion_matrix_detailed(cm, save_path, class_names=None):
    """Plot detailed confusion matrix with annotations"""
    if class_names is None:
        class_names = [f'KL {i}' for i in range(len(cm))]
    
    # Normalize
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    # Raw counts
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        ax=axes[0],
        cbar_kws={'label': 'Count'}
    )
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    axes[0].set_title('Confusion Matrix (Raw Counts)', fontsize=14)
    
    # Normalized
    sns.heatmap(
        cm_normalized,
        annot=True,
        fmt='.2%',
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        ax=axes[1],
        cbar_kws={'label': 'Percentage'}
    )
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"‚úì Confusion matrix saved to {save_path}")


def plot_per_class_metrics(results, save_path):
    """Plot per-class performance metrics"""
    per_class = results['per_class_metrics']
    classes = sorted(per_class.keys())
    
    metrics = ['precision', 'recall', 'f1']
    data = {metric: [per_class[cls][metric] for cls in classes] for metric in metrics}
    
    x = np.arange(len(classes))
    width = 0.25
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    for i, metric in enumerate(metrics):
        offset = (i - 1) * width
        ax.bar(x + offset, data[metric], width, label=metric.capitalize())
    
    ax.set_xlabel('KL Grade', fontsize=12)
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Per-Class Performance Metrics', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(classes)
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1.05])
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"‚úì Per-class metrics plot saved to {save_path}")


def plot_error_distribution(predictions, labels, save_path):
    """Plot distribution of prediction errors"""
    errors = predictions - labels
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Error histogram
    axes[0].hist(errors, bins=np.arange(-4.5, 5.5, 1), edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Prediction Error (Predicted - True)', fontsize=12)
    axes[0].set_ylabel('Frequency', fontsize=12)
    axes[0].set_title('Distribution of Prediction Errors', fontsize=14)
    axes[0].grid(axis='y', alpha=0.3)
    axes[0].axvline(x=0, color='red', linestyle='--', linewidth=2, label='Perfect Prediction')
    axes[0].legend()
    
    # Error by true class
    unique_labels = sorted(np.unique(labels))
    error_by_class = [errors[labels == label] for label in unique_labels]
    
    axes[1].boxplot(error_by_class, labels=[f'KL {i}' for i in unique_labels])
    axes[1].set_xlabel('True KL Grade', fontsize=12)
    axes[1].set_ylabel('Prediction Error', fontsize=12)
    axes[1].set_title('Prediction Error by True Class', fontsize=14)
    axes[1].grid(axis='y', alpha=0.3)
    axes[1].axhline(y=0, color='red', linestyle='--', linewidth=2)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"‚úì Error distribution plot saved to {save_path}")


def plot_confidence_analysis(probs, predictions, labels, save_path):
    """Analyze prediction confidence"""
    confidences = np.max(probs, axis=1)
    correct = (predictions == labels)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Confidence distribution
    axes[0, 0].hist(confidences, bins=50, edgecolor='black', alpha=0.7)
    axes[0, 0].set_xlabel('Confidence (Max Probability)', fontsize=12)
    axes[0, 0].set_ylabel('Frequency', fontsize=12)
    axes[0, 0].set_title('Distribution of Prediction Confidence', fontsize=14)
    axes[0, 0].grid(axis='y', alpha=0.3)
    
    # 2. Confidence by correctness
    correct_conf = confidences[correct]
    incorrect_conf = confidences[~correct]
    
    axes[0, 1].hist([correct_conf, incorrect_conf], bins=30, label=['Correct', 'Incorrect'], 
                    edgecolor='black', alpha=0.7)
    axes[0, 1].set_xlabel('Confidence', fontsize=12)
    axes[0, 1].set_ylabel('Frequency', fontsize=12)
    axes[0, 1].set_title('Confidence: Correct vs Incorrect Predictions', fontsize=14)
    axes[0, 1].legend()
    axes[0, 1].grid(axis='y', alpha=0.3)
    
    # 3. Accuracy vs confidence (reliability diagram)
    n_bins = 10
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2
    bin_accuracies = []
    bin_confidences = []
    bin_counts = []
    
    for i in range(n_bins):
        in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if in_bin.sum() > 0:
            bin_accuracies.append(correct[in_bin].mean())
            bin_confidences.append(confidences[in_bin].mean())
            bin_counts.append(in_bin.sum())
        else:
            bin_accuracies.append(0)
            bin_confidences.append(bin_centers[i])
            bin_counts.append(0)
    
    axes[1, 0].plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')
    axes[1, 0].plot(bin_confidences, bin_accuracies, 'bo-', label='Model')
    axes[1, 0].set_xlabel('Confidence', fontsize=12)
    axes[1, 0].set_ylabel('Accuracy', fontsize=12)
    axes[1, 0].set_title('Reliability Diagram', fontsize=14)
    axes[1, 0].legend()
    axes[1, 0].grid(alpha=0.3)
    
    # 4. Confidence by class
    unique_labels = sorted(np.unique(labels))
    conf_by_class = [confidences[labels == label] for label in unique_labels]
    
    axes[1, 1].boxplot(conf_by_class, labels=[f'KL {i}' for i in unique_labels])
    axes[1, 1].set_xlabel('True KL Grade', fontsize=12)
    axes[1, 1].set_ylabel('Confidence', fontsize=12)
    axes[1, 1].set_title('Confidence by True Class', fontsize=14)
    axes[1, 1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"‚úì Confidence analysis plot saved to {save_path}")


def plot_roc_curves(labels, probs, save_path):
    """Plot ROC curves for each class (One-vs-Rest)"""
    from sklearn.preprocessing import label_binarize
    
    num_classes = probs.shape[1]
    labels_binarized = label_binarize(labels, classes=range(num_classes))
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
    
    for i, color in enumerate(colors):
        if labels_binarized[:, i].sum() > 0:
            fpr, tpr, _ = roc_curve(labels_binarized[:, i], probs[:, i])
            auc = roc_auc_score(labels_binarized[:, i], probs[:, i])
            
            ax.plot(fpr, tpr, color=color, linewidth=2,
                   label=f'KL {i} (AUC = {auc:.3f})')
    
    ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Random')
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    ax.set_title('ROC Curves (One-vs-Rest)', fontsize=14)
    ax.legend(loc='lower right')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"‚úì ROC curves saved to {save_path}")


# =============================================================================
# REPORT GENERATION
# =============================================================================

def print_evaluation_report(results):
    """Print comprehensive evaluation report"""
    print(f"\n{'='*80}")
    print("TEST SET EVALUATION RESULTS")
    print(f"{'='*80}\n")
    
    # Overall metrics
    print("OVERALL METRICS")
    print("-" * 80)
    print(f"Test Loss:                    {results['test_loss']:.4f}")
    print(f"Accuracy:                     {results['accuracy']*100:.2f}%")
    print(f"Quadratic Kappa:              {results['kappa']:.4f}")
    print(f"Linear Kappa:                 {results['kappa_linear']:.4f}")
    print(f"Unweighted Kappa:             {results['kappa_unweighted']:.4f}")
    print(f"Off-by-One Accuracy:          {results['off_by_one_accuracy']*100:.2f}%")
    print()
    
    # Macro metrics
    print(f"Precision (Macro):            {results['precision_macro']:.4f}")
    print(f"Recall (Macro):               {results['recall_macro']:.4f}")
    print(f"F1-Score (Macro):             {results['f1_macro']:.4f}")
    print()
    
    # Weighted metrics
    print(f"Precision (Weighted):         {results['precision_weighted']:.4f}")
    print(f"Recall (Weighted):            {results['recall_weighted']:.4f}")
    print(f"F1-Score (Weighted):          {results['f1_weighted']:.4f}")
    print()
    
    # Ordinal metrics
    print("ORDINAL METRICS")
    print("-" * 80)
    print(f"Mean Absolute Error (MAE):    {results['mae']:.4f}")
    print(f"Mean Squared Error (MSE):     {results['mse']:.4f}")
    print(f"Root MSE (RMSE):              {results['rmse']:.4f}")
    print()
    
    # Confidence metrics
    print("CONFIDENCE METRICS")
    print("-" * 80)
    print(f"Average Confidence:           {results['avg_confidence']:.4f} ¬± {results['std_confidence']:.4f}")
    if 'avg_confidence_correct' in results:
        print(f"Avg Confidence (Correct):     {results['avg_confidence_correct']:.4f}")
    if 'avg_confidence_incorrect' in results:
        print(f"Avg Confidence (Incorrect):   {results['avg_confidence_incorrect']:.4f}")
    print(f"Expected Calibration Error:   {results['ece']:.4f}")
    print()
    
    # Per-class metrics
    print("PER-CLASS METRICS")
    print("-" * 80)
    print(f"{'Class':<10} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
    print("-" * 80)
    
    for i in range(5):
        cls = f'KL_{i}'
        acc = results['per_class_accuracy'][cls]
        metrics = results['per_class_metrics'][cls]
        print(f"KL Grade {i:<3} {acc*100:>6.2f}%      {metrics['precision']:>6.4f}       "
              f"{metrics['recall']:>6.4f}       {metrics['f1']:>6.4f}       {metrics['support']:>6d}")
    print()
    
    # AUC metrics
    if 'auc_per_class' in results and len(results['auc_per_class']) > 0:
        print("AUC METRICS (One-vs-Rest)")
        print("-" * 80)
        for cls, auc in results['auc_per_class'].items():
            grade = cls.split('_')[1]
            print(f"KL Grade {grade}:  {auc:.4f}")
        if 'auc_macro' in results:
            print(f"Macro AUC:     {results['auc_macro']:.4f}")
        print()
    
    print(f"{'='*80}\n")

def save_results_to_csv(results, save_path):
    """Save per-class metrics to CSV for easy comparison"""
    data = []
    
    for i in range(5):
        cls = f'KL_{i}'
        row = {
            'Class': f'KL Grade {i}',
            'Accuracy': results['per_class_accuracy'][cls],
            'Precision': results['per_class_metrics'][cls]['precision'],
            'Recall': results['per_class_metrics'][cls]['recall'],
            'F1-Score': results['per_class_metrics'][cls]['f1'],
            'Support': results['per_class_metrics'][cls]['support'],
        }
        
        if 'auc_per_class' in results and cls in results['auc_per_class']:
            row['AUC'] = results['auc_per_class'][cls]
        
        data.append(row)
    
    df = pd.DataFrame(data)
    df.to_csv(save_path, index=False)
    print(f"‚úì Per-class metrics saved to {save_path}")


# =============================================================================
# MAIN EVALUATION FUNCTION
# =============================================================================

def run_full_evaluation(
    model,
    test_loader,
    criterion,
    device,
    save_dir='./evaluation_results',
    model_name='CVPT'
):
    """
    Run complete evaluation pipeline
    
    Args:
        model: Trained model
        test_loader: Test dataloader
        criterion: Loss function
        device: Device
        save_dir: Directory to save results
        model_name: Name for result files
    """
    
    os.makedirs(save_dir, exist_ok=True)
    
    print(f"\n{'='*80}")
    print(f"STARTING FULL EVALUATION - {model_name}")
    print(f"{'='*80}\n")
    
    # ========== 1. Evaluate Model ==========
    results, predictions_dict = evaluate_model(
        model, test_loader, criterion, device, return_predictions=True
    )
    
    # ========== 2. Print Report ==========
    print_evaluation_report(results)
    
    # ========== 3. Save Results ==========
    
    csv_path = os.path.join(save_dir, f'{model_name}_per_class_metrics.csv')
    save_results_to_csv(results, csv_path)
    
    # ========== 4. Generate Visualizations ==========
    
    # Confusion matrix
    cm_path = os.path.join(save_dir, f'{model_name}_confusion_matrix.png')
    plot_confusion_matrix_detailed(results['confusion_matrix'], cm_path)
    
    # Per-class metrics
    metrics_path = os.path.join(save_dir, f'{model_name}_per_class_metrics.png')
    plot_per_class_metrics(results, metrics_path)
    
    # Error distribution
    error_path = os.path.join(save_dir, f'{model_name}_error_distribution.png')
    plot_error_distribution(
        predictions_dict['predictions'],
        predictions_dict['labels'],
        error_path
    )
    
    # Confidence analysis
    conf_path = os.path.join(save_dir, f'{model_name}_confidence_analysis.png')
    plot_confidence_analysis(
        predictions_dict['probs'],
        predictions_dict['predictions'],
        predictions_dict['labels'],
        conf_path
    )
    
    # ROC curves
    roc_path = os.path.join(save_dir, f'{model_name}_roc_curves.png')
    plot_roc_curves(
        predictions_dict['labels'],
        predictions_dict['probs'],
        roc_path
    )
    
    print(f"\n{'='*80}")
    print("EVALUATION COMPLETE")
    print(f"{'='*80}")
    print(f"Results saved to: {save_dir}")
    print(f"{'='*80}\n")
    
    return results, predictions_dict

In [None]:
def load_state_dict(model, checkpoint_state_dict): 
    missing_keys = []
    shape_mismatch_keys = []
    loaded_keys = []
    
    for name, param in model.named_parameters():
        if param.requires_grad: 
            key = name
            if key in checkpoint_state_dict:
                # Check shape compatibility
                if model_state_dict[key].shape == checkpoint_state_dict[key].shape:
                    model_state_dict[key] = checkpoint_state_dict[key]
                    loaded_keys.append(key)
                else:
                    shape_mismatch_keys.append(key)
            else:
                missing_keys.append(key)

    if len(missing_keys) > 0: 
        print(f"Missing in checkpoint: {missing_keys}")

    if len(shape_mismatch_keys) > 0:
        print(f"Shape mismatch: {shape_mismatch_keys}")

    if len(missing_keys) == 0 and len(shape_mismatch_keys) == 0: 
        model.load_state_dict(model_state_dict)
        return model

In [None]:
# Load best model
print("Loading best model...")
checkpoint = torch.load(os.path.join('/kaggle/input/gaviko-model/pytorch/model_v2/1/eviko.pt'), weights_only=False)
load_state_dict(model, checkpoint['model_state_dict'])
model.to(device)

print(f"‚úì Loaded model from epoch {checkpoint['epoch'] + 1}")
print(f"  Val Accuracy: {checkpoint['val_metrics']['accuracy']*100:.2f}%")
print(f"  Val Kappa: {checkpoint['val_metrics']['kappa']:.4f}")

# Run evaluation
eval_results, predictions = run_full_evaluation(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    device=device,
    save_dir=os.path.join(save_dir, 'test_evaluation'),
    model_name='CVPT_baseline'
)

print("\nEvaluation completed successfully!")
print(f"\nFinal Test Results:")
print(f"   Accuracy: {eval_results['accuracy']*100:.2f}%")
print(f"   Kappa: {eval_results['kappa']:.4f}")
print(f"   MAE: {eval_results['mae']:.4f}")