# Google ViT 
## Initalization

In [1]:
# Cell 0 
import sys
from pathlib import Path

notebook_dir = Path.cwd()
project_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from transformers import ViTForImageClassification
import random
from torchvision import transforms
import albumentations as A
from src.transforms import base_transform
from src.fer2013 import FER2013Dataset
from src.config import (
    DEVICE, 
    NUM_LABELS, 
    EMOTION_LABELS,
    DEFAULT_BATCH_SIZE,
    DEFAULT_LEARNING_RATE
)
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from src.train import train_model

print(f"Using device: {DEVICE}")

MODEL_NAME = "google/vit-base-patch16-224-in21k"

Using device: cuda


### Weights and Biases 

In [2]:
# Cell 1 
from src.wandb_utils import login, check_wandb_mode, sync_offline_runs

# "online", "offline", or "disabled"
# If set to offlien dont forget to sink
WANDB_MODE = "online" 

print("Initializing Weights & Biases...")
current_mode = login(
    project="emotion-classifier-vit",
    mode=WANDB_MODE
)

print(f"W&B initialized successfully in {current_mode.upper()} mode!")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Ray\_netrc


Initializing Weights & Biases...
WandB mode set to: ONLINE


[34m[1mwandb[0m: Currently logged in as: [33mraycaringal[0m ([33mraycaringal-university-of-texas-austin[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


WandB initialized in ONLINE mode for project: emotion-classifier-vit
Current WandB mode: ONLINE
W&B initialized successfully in ONLINE mode!


In [3]:

from src.wandb_utils import *

# Weights and Biases Util Commands 

# Check current mode
# check_wandb_mode()

# Sync offline runs (when you have internet)
# sync_offline_runs(all_runs=True)

# List available offline runs
# list_offline_runs()

# Change mode 
# set_wandb_mode("offline")  

# Set Confirm to False for a Dry Run
# clear_offline_runs(confirm=True)


### Tranformations 

In [4]:

# Simpler transformation sets without deprecated parameters
transform_configs = {
    "none": base_transform(),  # Use the base transforms from transforms.py
    
    "light": A.Compose([
        A.HorizontalFlip(p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        A.Affine(translate_percent=0.05, scale=(0.95, 1.05), rotate=(-10, 10), p=0.3),
        *base_transform()  
    ]),
    
    "medium": A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.Affine(translate_percent=0.1, scale=(0.9, 1.1), rotate=(-15, 15), p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        *base_transform() 
    ]),
    
    "heavy": A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Affine(translate_percent=0.15, scale=(0.85, 1.15), rotate=(-20, 20), p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.4),
        A.GridDropout(ratio=0.1, p=0.3),
        *base_transform()  
    ])
}

print("Transformation Configs Loaded")

Transformation Configs Loaded


---
##  Fine Tuning Section
Using FER2013 dataset.

### Hyper Parameter Queue

In [None]:
# Define experiment configurations
EPOCHS = 7

experiment_configs = [
    # Baseline with different transforms
    {
        "name": "baseline_none_linear_probe",
        "transform_key": "none",
        "epochs": EPOCHS,
        "learning_rate": DEFAULT_LEARNING_RATE,
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
    # {
    #     "name": "baseline_light_long",
    #     "transform_key": "light", 
    #     "epochs": EPOCHS,
    #     "learning_rate": DEFAULT_LEARNING_RATE,
    #     "batch_size": DEFAULT_BATCH_SIZE,
    #     "weight_decay": 0.01
    # },
    {
        "name": "baseline_medium_linear_probe",
        "transform_key": "medium",
        "epochs": EPOCHS, 
        "learning_rate": DEFAULT_LEARNING_RATE,
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
    {
        "name": "baseline_heavy_linear_probe",
        "transform_key": "heavy",
        "epochs": EPOCHS,
        "learning_rate": DEFAULT_LEARNING_RATE, 
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
]

print(f"{len(experiment_configs)} Experiment Configs Loaded")

1 Experiment Configs Loaded


### Training Loop

In [None]:
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from src.wandb_utils import cleanup_wandb_run

all_results = {}
failed_experiments = []

print(f"Starting training for {len(experiment_configs)} experiments")
print("=" * 70)

for i, config in enumerate(tqdm(experiment_configs, desc="Training Experiments")):
    print(f"\n{'='*70}")
    print(f"üî¨ Experiment {i+1}/{len(experiment_configs)}: {config['name']}")
    print(f"   Transform: {config['transform_key']}")
    print(f"   LR: {config['learning_rate']}")
    print(f"   Epochs: {config['epochs']}")
    print(f"   Batch Size: {config['batch_size']}")
    print(f"{'='*70}")
    
    # Ensure any previous WandB run is cleaned up
    cleanup_wandb_run()
    
    try:
        # Create datasets
        transform = transform_configs[config['transform_key']]
        
        train = FER2013Dataset(
            split="train",
            transform=transform
        )
        valid = FER2013Dataset(
            split="valid", 
            transform=base_transform()
        )
        
        # Initialize model
        model = ViTForImageClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True
        ).to(DEVICE)
        
        # Initialize optimizer
        optimizer = AdamW(
            model.parameters(), 
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        print(f"‚úÖ Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
        
        # Train model
        model_exp, history_exp, run_folder_exp = train_model(
            model=model,
            optimizer=optimizer,
            train_dataset=train,
            val_dataset=valid,
            num_epochs=config['epochs'],
            batch_size=config['batch_size'],
            device=DEVICE,
            model_name=config['name'],  
            use_wandb=True,
            wandb_config={
                "learning_rate": config['learning_rate'],
                "batch_size": config['batch_size'],
                "epochs": config['epochs'],
                "weight_decay": config['weight_decay'],
                "model_name": "vit_base_patch16_224",
                "architecture": "ViT", 
                "dataset": "FER2013",
                "transform_set": config['transform_key'],
                "experiment_name": config['name']
            }
        )
        
        # Store results
        all_results[config['name']] = {
            'model': model_exp,
            'history': history_exp,
            'run_folder': run_folder_exp,
            'config': config,
            'best_val_accuracy': max(history_exp['val_acc']),      
            'best_val_loss': min(history_exp['val_loss']),
            'final_train_accuracy': history_exp['train_acc'][-1],  
            'final_train_loss': history_exp['train_loss'][-1]
        }
        
        print(f"\n COMPLETED: {config['name']}")
        print(f"   Best Val Accuracy: {all_results[config['name']]['best_val_accuracy']:.4f}")
        print(f"   Best Val Loss: {all_results[config['name']]['best_val_loss']:.4f}")
        print(f"   Run folder: {run_folder_exp}")
        
    except KeyboardInterrupt:
        print(f"\n  Training interrupted by user at experiment: {config['name']}")
        cleanup_wandb_run()
        break
        
    except Exception as e:
        print(f"\n ERROR in experiment {config['name']}: {str(e)}")
        print(f"   Exception type: {type(e).__name__}")
        
        # Store failed experiment info
        failed_experiments.append({
            'name': config['name'],
            'error': str(e),
            'error_type': type(e).__name__
        })
        
        # Clean up WandB
        cleanup_wandb_run()
        
        # Decide whether to continue or stop
        print(f"   Continuing to next experiment...")
        
    finally:
        # Clean up memory regardless of success/failure
        if 'model' in locals():
            del model
        if 'model_exp' in locals():
            del model_exp
        if 'optimizer' in locals():
            del optimizer
        if 'train' in locals():
            del train
        if 'valid' in locals():
            del valid
            
        # Force garbage collection
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        print(f"   Memory cleaned up")

# Final cleanup
cleanup_wandb_run()

# Print summary
print("\n" + "="*70)
print(" TRAINING COMPLETE - SUMMARY")
print("="*70)
print(f" Successful experiments: {len(all_results)}/{len(experiment_configs)}")
print(f" Failed experiments: {len(failed_experiments)}/{len(experiment_configs)}")

if all_results:
    print("\n Results:")
    for name, result in all_results.items():
        print(f"   {name}: Val Acc = {result['best_val_accuracy']:.4f}, Val Loss = {result['best_val_loss']:.4f}")

if failed_experiments:
    print("\n  Failed Experiments:")
    for failed in failed_experiments:
        print(f"   {failed['name']}: {failed['error_type']} - {failed['error']}")

print("\n" + "="*70)

### Linear Probe

In [6]:
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from src.wandb_utils import cleanup_wandb_run
from src.linear_probe import train_linear_probe
from transformers import ViTForImageClassification

all_results = {}
failed_experiments = []

print(f"Starting LINEAR PROBE training for {len(experiment_configs)} experiments")
print("=" * 70)

for i, config in enumerate(tqdm(experiment_configs, desc="Training Experiments")):
    print(f"\n{'='*70}")
    print(f"üî¨ Experiment {i+1}/{len(experiment_configs)}: {config['name']}")
    print(f"   Training Type: LINEAR PROBE (frozen encoder)")
    print(f"   Transform: {config['transform_key']}")
    print(f"   LR: {config['learning_rate']}")
    print(f"   Epochs: {config['epochs']}")
    print(f"   Batch Size: {config['batch_size']}")
    print(f"{'='*70}")
    
    # Ensure any previous WandB run is cleaned up
    cleanup_wandb_run()
    
    try:
        # Create datasets
        transform = transform_configs[config['transform_key']]
        
        train = FER2013Dataset(
            split="train",
            transform=transform
        )
        valid = FER2013Dataset(
            split="valid", 
            transform=base_transform()
        )
        
        # Initialize model
        model = ViTForImageClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True
        ).to(DEVICE)
        
        # IMPORTANT: For linear probe, only optimize classifier parameters
        # The encoder will be frozen inside train_linear_probe()
        optimizer = AdamW(
            model.classifier.parameters(),  # Only classifier, not model.parameters()
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        print(f"‚úÖ Model initialized")
        print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"   Classifier parameters: {sum(p.numel() for p in model.classifier.parameters()):,}")
        
        # Train model with LINEAR PROBE
        model_exp, history_exp, run_folder_exp = train_linear_probe(
            model=model,
            optimizer=optimizer,
            train_dataset=train,
            val_dataset=valid,
            num_epochs=config['epochs'],
            batch_size=config['batch_size'],
            device=DEVICE,
            model_name=config['name'],  
            use_wandb=True,
            wandb_config={
                "learning_rate": config['learning_rate'],
                "batch_size": config['batch_size'],
                "epochs": config['epochs'],
                "weight_decay": config['weight_decay'],
                "model_name": "vit_base_patch16_224",
                "architecture": "ViT", 
                "dataset": "FER2013",
                "transform_set": config['transform_key'],
                "experiment_name": config['name'],
                "training_type": "linear_probe"
            }
        )
        
        # Store results
        all_results[config['name']] = {
            'model': model_exp,
            'history': history_exp,
            'run_folder': run_folder_exp,
            'config': config,
            'best_val_accuracy': max(history_exp['val_acc']),      
            'best_val_loss': min(history_exp['val_loss']),
            'final_train_accuracy': history_exp['train_acc'][-1],  
            'final_train_loss': history_exp['train_loss'][-1]
        }
        
        print(f"\n‚úÖ COMPLETED: {config['name']}")
        print(f"   Best Val Accuracy: {all_results[config['name']]['best_val_accuracy']:.4f}")
        print(f"   Best Val Loss: {all_results[config['name']]['best_val_loss']:.4f}")
        print(f"   Run folder: {run_folder_exp}")
        
    except KeyboardInterrupt:
        print(f"\n‚ö†Ô∏è  Training interrupted by user at experiment: {config['name']}")
        cleanup_wandb_run()
        break
        
    except Exception as e:
        print(f"\n‚ùå ERROR in experiment {config['name']}: {str(e)}")
        print(f"   Exception type: {type(e).__name__}")
        
        # Store failed experiment info
        failed_experiments.append({
            'name': config['name'],
            'error': str(e),
            'error_type': type(e).__name__
        })
        
        # Clean up WandB
        cleanup_wandb_run()
        
        # Decide whether to continue or stop
        print(f"   Continuing to next experiment...")
        
    finally:
        # Clean up memory regardless of success/failure
        if 'model' in locals():
            del model
        if 'model_exp' in locals():
            del model_exp
        if 'optimizer' in locals():
            del optimizer
        if 'train' in locals():
            del train
        if 'valid' in locals():
            del valid
            
        # Force garbage collection
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        print(f"   Memory cleaned up")

# Final cleanup
cleanup_wandb_run()

# Print summary
print("\n" + "="*70)
print("üìä LINEAR PROBE TRAINING COMPLETE - SUMMARY")
print("="*70)
print(f"‚úÖ Successful experiments: {len(all_results)}/{len(experiment_configs)}")
print(f"‚ùå Failed experiments: {len(failed_experiments)}/{len(experiment_configs)}")

if all_results:
    print("\nüìà Results:")
    for name, result in all_results.items():
        print(f"   {name}: Val Acc = {result['best_val_accuracy']:.4f}, Val Loss = {result['best_val_loss']:.4f}")

if failed_experiments:
    print("\n‚ö†Ô∏è  Failed Experiments:")
    for failed in failed_experiments:
        print(f"   {failed['name']}: {failed['error_type']} - {failed['error']}")

print("\n" + "="*70)

Starting LINEAR PROBE training for 1 experiments


Training Experiments:   0%|          | 0/1 [00:00<?, ?it/s]


üî¨ Experiment 1/1: baseline_light_long
   Training Type: LINEAR PROBE (frozen encoder)
   Transform: light
   LR: 2e-05
   Epochs: 20
   Batch Size: 32
WandB run cleaned up


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Model initialized
   Total parameters: 85,804,039
   Classifier parameters: 5,383
Freezing encoder parameters...
Trainable parameters: 5,383 / 85,804,039
Percentage trainable: 0.01%
Created run folder: baseline_light_long1




WandB run started: baseline_light_long1
WandB Dashboard: https://wandb.ai/raycaringal-university-of-texas-austin/emotion-classification/runs/qatag4qj
Training parameters saved to: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\training_parameters.json
Training baseline_light_long for 20 epochs...
Training type: LINEAR PROBE (frozen encoder)
Run name (WandB): baseline_light_long1
Total training steps: 17960
Device: cuda
Batch size: 32
Train batches: 898
Val batches: 113
Run folder: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1
Best model will be saved to: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
Backup interval: every 5 epochs
W&B tracking: https://wandb.ai/raycaringal-university-of-texas-austin/emotion-classification/runs/qatag4qj

Epoch 1/20
----------------------------------------------------------------------


Training Epoch 0:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 0:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.8614 | Train Acc: 0.2797 | Train F1: 0.1981
Val Loss:   1.7752 | Val Acc:   0.3070 | Val F1:   0.2053
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.3070, Val F1: 0.2053)
Backup created: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\backups\backup_epoch_000_20251207_131815.pth

Epoch 2/20
----------------------------------------------------------------------


Training Epoch 1:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 1:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.7199 | Train Acc: 0.3392 | Train F1: 0.2563
Val Loss:   1.6695 | Val Acc:   0.3778 | Val F1:   0.3052
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.3778, Val F1: 0.3052)

Epoch 3/20
----------------------------------------------------------------------


Training Epoch 2:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 2:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.6378 | Train Acc: 0.3953 | Train F1: 0.3362
Val Loss:   1.6029 | Val Acc:   0.4157 | Val F1:   0.3566
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4157, Val F1: 0.3566)

Epoch 4/20
----------------------------------------------------------------------


Training Epoch 3:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 3:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.5839 | Train Acc: 0.4260 | Train F1: 0.3772
Val Loss:   1.5557 | Val Acc:   0.4416 | Val F1:   0.3918
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4416, Val F1: 0.3918)

Epoch 5/20
----------------------------------------------------------------------


Training Epoch 4:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 4:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.5434 | Train Acc: 0.4456 | Train F1: 0.4034
Val Loss:   1.5208 | Val Acc:   0.4503 | Val F1:   0.4058
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4503, Val F1: 0.4058)

Epoch 6/20
----------------------------------------------------------------------


Training Epoch 5:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 5:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.5123 | Train Acc: 0.4556 | Train F1: 0.4171
Val Loss:   1.4937 | Val Acc:   0.4606 | Val F1:   0.4198
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4606, Val F1: 0.4198)
Backup created: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\backups\backup_epoch_005_20251207_132645.pth

Epoch 7/20
----------------------------------------------------------------------


Training Epoch 6:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 6:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4886 | Train Acc: 0.4679 | Train F1: 0.4330
Val Loss:   1.4724 | Val Acc:   0.4664 | Val F1:   0.4279
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4664, Val F1: 0.4279)

Epoch 8/20
----------------------------------------------------------------------


Training Epoch 7:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 7:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4708 | Train Acc: 0.4713 | Train F1: 0.4385
Val Loss:   1.4555 | Val Acc:   0.4714 | Val F1:   0.4351
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4714, Val F1: 0.4351)

Epoch 9/20
----------------------------------------------------------------------


Training Epoch 8:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 8:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4563 | Train Acc: 0.4765 | Train F1: 0.4451
Val Loss:   1.4416 | Val Acc:   0.4748 | Val F1:   0.4394
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4748, Val F1: 0.4394)

Epoch 10/20
----------------------------------------------------------------------


Training Epoch 9:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 9:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4438 | Train Acc: 0.4764 | Train F1: 0.4457
Val Loss:   1.4302 | Val Acc:   0.4784 | Val F1:   0.4446
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4784, Val F1: 0.4446)

Epoch 11/20
----------------------------------------------------------------------


Training Epoch 10:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 10:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4322 | Train Acc: 0.4808 | Train F1: 0.4517
Val Loss:   1.4207 | Val Acc:   0.4801 | Val F1:   0.4470
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4801, Val F1: 0.4470)
Backup created: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\backups\backup_epoch_010_20251207_133546.pth

Epoch 12/20
----------------------------------------------------------------------


Training Epoch 11:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 11:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4236 | Train Acc: 0.4825 | Train F1: 0.4540
Val Loss:   1.4130 | Val Acc:   0.4801 | Val F1:   0.4476

Epoch 13/20
----------------------------------------------------------------------


Training Epoch 12:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 12:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4152 | Train Acc: 0.4854 | Train F1: 0.4568
Val Loss:   1.4067 | Val Acc:   0.4806 | Val F1:   0.4490
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4806, Val F1: 0.4490)

Epoch 14/20
----------------------------------------------------------------------


Training Epoch 13:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 13:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4107 | Train Acc: 0.4871 | Train F1: 0.4595
Val Loss:   1.4013 | Val Acc:   0.4837 | Val F1:   0.4525
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4837, Val F1: 0.4525)

Epoch 15/20
----------------------------------------------------------------------


Training Epoch 14:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 14:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4048 | Train Acc: 0.4891 | Train F1: 0.4614
Val Loss:   1.3971 | Val Acc:   0.4834 | Val F1:   0.4528

Epoch 16/20
----------------------------------------------------------------------


Training Epoch 15:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 15:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.4041 | Train Acc: 0.4886 | Train F1: 0.4616
Val Loss:   1.3937 | Val Acc:   0.4840 | Val F1:   0.4537
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4840, Val F1: 0.4537)
Backup created: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\backups\backup_epoch_015_20251207_134408.pth

Epoch 17/20
----------------------------------------------------------------------


Training Epoch 16:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 16:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.3992 | Train Acc: 0.4906 | Train F1: 0.4637
Val Loss:   1.3912 | Val Acc:   0.4840 | Val F1:   0.4539

Epoch 18/20
----------------------------------------------------------------------


Training Epoch 17:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 17:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.3970 | Train Acc: 0.4905 | Train F1: 0.4638
Val Loss:   1.3895 | Val Acc:   0.4845 | Val F1:   0.4549
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4845, Val F1: 0.4549)

Epoch 19/20
----------------------------------------------------------------------


Training Epoch 18:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 18:   0%|          | 0/113 [00:40<?, ?it/s]


Train Loss: 1.3955 | Train Acc: 0.4912 | Train F1: 0.4646
Val Loss:   1.3884 | Val Acc:   0.4845 | Val F1:   0.4550

Epoch 20/20
----------------------------------------------------------------------


Training Epoch 19:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 19:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.3946 | Train Acc: 0.4911 | Train F1: 0.4648
Val Loss:   1.3881 | Val Acc:   0.4848 | Val F1:   0.4553
Checkpoint saved: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\best_baseline_light_long.pth
‚úÖ New best model saved! (Val Acc: 0.4848, Val F1: 0.4553)

‚úÖ Training completed!
Best validation accuracy: 0.4848
Training parameters saved to: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\training_parameters.json
Training history saved to: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\history_baseline_light_long1.json




Final backup created: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\backups\backup_final_20251207_135136.pth
Training completed successfully - cleaning up all backups...
üßπ Deleted backup: backup_epoch_000_20251207_131815.pth
üßπ Deleted backup: backup_epoch_005_20251207_132645.pth
üßπ Deleted backup: backup_epoch_010_20251207_133546.pth
üßπ Deleted backup: backup_epoch_015_20251207_134408.pth
üßπ Deleted backup: backup_final_20251207_135136.pth
Deleted all 5 backup files after successful training
Deleted 5 backup files
  Removed empty backups directory: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1\backups
Backups folder successfully removed


0,1
batch,‚ñÅ‚ñÉ‚ñÜ‚ñÅ‚ñÉ‚ñà‚ñÖ‚ñà‚ñÅ‚ñÖ‚ñà‚ñÅ‚ñÖ‚ñÖ‚ñÜ‚ñÉ‚ñÖ‚ñÜ‚ñà‚ñÅ‚ñà‚ñÜ‚ñà‚ñÉ‚ñÉ‚ñÅ‚ñÖ‚ñà‚ñÉ‚ñÖ‚ñà‚ñÅ‚ñÉ‚ñÖ‚ñÅ‚ñÖ‚ñà‚ñÅ‚ñÖ‚ñÜ
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà
train/batch_loss,‚ñà‚ñá‚ñá‚ñá‚ñÖ‚ñÑ‚ñÉ‚ñÜ‚ñÑ‚ñÉ‚ñÖ‚ñÉ‚ñÖ‚ñÑ‚ñÑ‚ñÉ‚ñÑ‚ñÑ‚ñÜ‚ñÇ‚ñÉ‚ñÉ‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÇ‚ñÇ‚ñÉ‚ñÖ‚ñÉ‚ñÅ‚ñÖ‚ñÉ‚ñÉ‚ñÑ‚ñÇ‚ñÉ
train/learning_rate,‚ñÅ‚ñÑ‚ñá‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ
train/loss,‚ñà‚ñÜ‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val/accuracy,‚ñÅ‚ñÑ‚ñÖ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
val/f1,‚ñÅ‚ñÑ‚ñÖ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
val/loss,‚ñà‚ñÜ‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val/precision,‚ñÅ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
val/recall,‚ñÅ‚ñÑ‚ñÖ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà

0,1
batch,800.0
epoch,19.0
train/batch_loss,1.46164
train/learning_rate,0.0
train/loss,1.39462
val/accuracy,0.48481
val/f1,0.45527
val/loss,1.3881
val/precision,0.4504
val/recall,0.48481


WandB run cleaned up

‚úÖ COMPLETED: baseline_light_long
   Best Val Accuracy: 0.4848
   Best Val Loss: 1.3881
   Run folder: c:\Users\Ray\Documents\AI\Emotion-Classification\checkpoints\baseline_light_long1
   Memory cleaned up
WandB run cleaned up

üìä LINEAR PROBE TRAINING COMPLETE - SUMMARY
‚úÖ Successful experiments: 1/1
‚ùå Failed experiments: 0/1

üìà Results:
   baseline_light_long: Val Acc = 0.4848, Val Loss = 1.3881



---
### Metrics

In [None]:
# Cell 9: Independent evaluation (can run after kernel restart)
from src.evaluate import evaluate_all_saved_models
from src.fer2013 import FER2013Dataset
from src.transforms import base_transform
import matplotlib.pyplot as plt

print("üß™ Starting INDEPENDENT evaluation of all saved models...")

# Load test dataset
test_ds = FER2013Dataset(
    split="test", 
    transform=base_transform()
)

print(f"Test dataset size: {len(test_ds)}")

# Evaluate all saved models (no need for all_results in memory)
summary_data = evaluate_all_saved_models(test_ds)

print("\n‚úÖ All saved models evaluated and summarized!")
print(f"üìä Performance plot saved to: experiment_performance_comparison.png")

# Show best model details
if summary_data:
    best_exp = summary_data[0]
    print(f"\nüèÜ Best model: {best_exp['experiment']}")
    print(f"   Test Accuracy: {best_exp['test_accuracy']:.4f}")
    print(f"   Transform: {best_exp['transform']}")
    print(f"   Run Folder: {best_exp['run_folder']}")
else:
    print("‚ùå No models were successfully evaluated")

In [None]:
# Cell 9A: Evaluate specific experiments using your experiment_configs
from src.evaluate import evaluate_from_experiment_configs
from src.fer2013 import FER2013Dataset
from src.transforms import base_transform
import matplotlib.pyplot as plt

print("üß™ Evaluating specific experiments from config...")

# Load test dataset
test_ds = FER2013Dataset(
    split="test", 
    transform=base_transform()
)

print(f"Test dataset size: {len(test_ds)}")

# Evaluate using your experiment_configs (finds latest runs automatically)
summary_data = evaluate_from_experiment_configs(experiment_configs, test_ds)

print("\n‚úÖ Specific experiments evaluated!")
print(f"üìä Performance plot saved to: experiment_performance_comparison.png")

# Show best model details
if summary_data:
    best_exp = summary_data[0]
    print(f"\nüèÜ Best model: {best_exp['experiment']}")
    print(f"   Run: {best_exp['run_name']}")
    print(f"   Test Accuracy: {best_exp['test_accuracy']:.4f}")
    print(f"   Transform: {best_exp['transform']}")

---
###  Test Predictions

In [None]:
# Visualize predictions from multiple models
from src.metadata import find_latest_run_for_experiment, load_training_parameters
from src.checkpoint_utils import load_model_from_checkpoint
from transformers import ViTForImageClassification, ViTImageProcessor
import random
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

MODELS_TO_TEST = [
    "baseline_none",
    "baseline_light", 
]

NUM_SAMPLES = 3  # Number of random test samples per model

def predict_and_visualize_batch(dataset, indices, model, processor, model_name):
    """Run predictions on multiple samples for a single model."""
    results = []
    
    for i, idx in enumerate(indices):
        img, true_label = dataset[idx]
        img_pil = transforms.ToPILImage()(img)
        
        # Run model
        model.eval()
        model.to(DEVICE)
        inputs = processor(images=img_pil, return_tensors="pt")
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)

        # Post-process
        probs = torch.softmax(outputs.logits, dim=-1)[0]
        pred_label = torch.argmax(probs).item()
        confidence = probs[pred_label].item()
        
        # Get top 3 predictions
        top3_probs, top3_idx = torch.topk(probs, 3)
        top3_predictions = [
            (EMOTION_LABELS[idx.item()], prob.item()) 
            for prob, idx in zip(top3_probs, top3_idx)
        ]
        
        results.append({
            'sample_index': idx,
            'true_label': true_label,
            'pred_label': pred_label,
            'confidence': confidence,
            'correct': true_label == pred_label,
            'top3_predictions': top3_predictions,
            'image': img_pil
        })
    
    return results

def display_model_predictions(model_name, results, sample_indices):
    """Display predictions for a single model."""
    print(f"\n{'='*70}")
    print(f"Model: {model_name}")
    print(f"{'='*70}")
    
    correct_count = sum(1 for r in results if r['correct'])
    accuracy = correct_count / len(results)
    
    print(f"Batch Accuracy: {correct_count}/{len(results)} ({accuracy:.1%})")
    print(f"Samples tested: {sample_indices}")
    print()
    
    for i, result in enumerate(results):
        print(f"Sample {i+1} (Index {result['sample_index']}):")
        print(f"  True: {EMOTION_LABELS[result['true_label']]:<12}", end="")
        print(f"  Predicted: {EMOTION_LABELS[result['pred_label']]:<12}", end="")
        print(f"  Confidence: {result['confidence']:.1%}", end="")
        print(f"  {'‚úì' if result['correct'] else '‚úó'}")
        
        # Show top 3 predictions
        print(f"  Top 3: ", end="")
        for j, (emotion, prob) in enumerate(result['top3_predictions']):
            print(f"{emotion}: {prob:.1%}", end="")
            if j < 2:
                print(", ", end="")
        print()
    
    # Visualize all samples in a grid
    fig, axes = plt.subplots(1, len(results), figsize=(4*len(results), 4))
    if len(results) == 1:
        axes = [axes]
    
    for i, (result, ax) in enumerate(zip(results, axes)):
        ax.imshow(result['image'], cmap='gray')
        correct = result['correct']
        color = 'green' if correct else 'red'
        title = f"Sample {i+1}\n"
        title += f"True: {EMOTION_LABELS[result['true_label']]}\n"
        title += f"Pred: {EMOTION_LABELS[result['pred_label']]}\n"
        title += f"Conf: {result['confidence']:.1%}"
        ax.set_title(title, color=color, fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f"Model: {model_name} (Accuracy: {accuracy:.1%})", fontsize=12)
    plt.tight_layout()
    plt.show()
    
    return accuracy

# Load processor (same for all models)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

print(f"Testing {len(MODELS_TO_TEST)} models on {NUM_SAMPLES} random samples each")
print(f"Test dataset size: {len(test_ds)}")
print()

# Get random sample indices (same for all models for fair comparison)
sample_indices = random.sample(range(len(test_ds)), NUM_SAMPLES)
print(f"Random sample indices: {sample_indices}")

model_results = {}

for model_name in MODELS_TO_TEST:
    try:
        print(f"\n{'='*70}")
        print(f"Loading model: {model_name}")
        print(f"{'='*70}")
        
        # Find the latest run for this model
        run_folder = find_latest_run_for_experiment(model_name, CHECKPOINTS_DIR)
        
        # Load the best model checkpoint
        checkpoint_path = run_folder / f"best_{run_folder.name}.pth"
        
        if not checkpoint_path.exists():
            print(f"Checkpoint not found: {checkpoint_path}")
            continue
            
        # Load model
        model = load_model_from_checkpoint(checkpoint_path)
        
        # Get model info
        params = load_training_parameters(run_folder)
        print(f"Loaded: {run_folder.name}")
        print(f"Transform: {model_name.split('_')[-1]}")
        print(f"Epochs: {params.get('num_epochs', 'N/A')}")
        print(f"Learning rate: {params.get('learning_rate', 'N/A'):.2e}")
        
        # Run predictions
        results = predict_and_visualize_batch(
            dataset=test_ds,
            indices=sample_indices,
            model=model,
            processor=processor,
            model_name=model_name
        )
        
        # Display results
        accuracy = display_model_predictions(model_name, results, sample_indices)
        model_results[model_name] = {
            'accuracy': accuracy,
            'correct': sum(1 for r in results if r['correct']),
            'total': len(results),
            'run_folder': run_folder.name
        }
        
    except Exception as e:
        print(f"Failed to test {model_name}: {e}")
        import traceback
        traceback.print_exc()

# Print summary comparison
print(f"\n{'='*70}")
print("SUMMARY: Model Comparison")
print(f"{'='*70}")

if model_results:
    # Sort by accuracy
    sorted_results = sorted(
        model_results.items(), 
        key=lambda x: x[1]['accuracy'], 
        reverse=True
    )
    
    print("\nPerformance Ranking:")
    for i, (model_name, result) in enumerate(sorted_results):
        print(f"{i+1}. {model_name:<20} {result['correct']}/{result['total']} ({result['accuracy']:.1%})")
    
    # Best and worst performers
    best_model = sorted_results[0]
    worst_model = sorted_results[-1]
    
    print(f"\nBest: {best_model[0]} ({best_model[1]['accuracy']:.1%})")
    print(f"Worst: {worst_model[0]} ({worst_model[1]['accuracy']:.1%})")
    
    # Optional: Create comparison visualization
    fig, ax = plt.subplots(figsize=(10, 6))
    models = [m[0] for m in sorted_results]
    accuracies = [m[1]['accuracy'] for m in sorted_results]
    
    bars = ax.bar(models, accuracies, color=['green', 'lightgreen', 'orange', 'red'])
    ax.set_xlabel('Model')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Model Comparison on {NUM_SAMPLES} Samples')
    ax.set_ylim(0, 1)
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2, height + 0.02,
                f'{acc:.1%}', ha='center', va='bottom')
    
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
else:
    print("No models were successfully tested.")

## LoRa Section
### Hyper Parameter Queue

In [None]:
from src.lora import (
    create_lora_config,
    get_lora_model,
    train_lora_model,
    load_lora_model,
    merge_and_save_lora_model,
)

print("‚úÖ LoRA utilities imported successfully")


# ============================================================================
# CELL: Define LoRA Experiment Configurations
# ============================================================================
# Define LoRA experiment configurations
LORA_EPOCHS = 8

lora_experiment_configs = [
    {
        "name": f"lora_r4_dropout_{dropout}",
        "transform_key": "light",
        "epochs": 8,
        "learning_rate": 5e-4,
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01,
        "lora_r": 4,
        "lora_alpha": 8,
        "lora_dropout": dropout,
        "target_modules": ["query", "value"],
    }
    for dropout in [0.0, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5]  # Test wide range
]

# lora_experiment_configs = []
# for r in [8, 16, 32]:
#     for alpha_ratio in [0.5, 1, 2, 4]:  # alpha = r * ratio
#         alpha = r * alpha_ratio
#         lora_experiment_configs.append({
#             "name": f"lora_r{r}_alpha{alpha}",
#             "transform_key": "light",
#             "epochs": 8,
#             "learning_rate": 5e-4,
#             "batch_size": DEFAULT_BATCH_SIZE,
#             "weight_decay": 0.01,
#             "lora_r": r,
#             "lora_alpha": alpha,
#             "lora_dropout": 0.1,
#             "target_modules": ["query", "value"],
#         })


# lora_experiment_configs = [
#     # Fix r=4, vary alpha
#     {
#         "name": f"lora_r4_alpha_{alpha}",
#         "transform_key": "light",
#         "epochs": 8,
#         "learning_rate": 5e-4,  # Your best LR
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 4,
#         "lora_alpha": alpha,
#         "lora_dropout": 0.1,
#         "target_modules": ["query", "value"],
#     }
#     for alpha in [2, 4, 8, 16, 32, 64]  # Test extreme values
# ]

# lora_experiment_configs = [
#     # Base config with varying LRs
#     {
#         "name": f"lora_r4_lr_{lr}",
#         "transform_key": "light",
#         "epochs": 8,
#         "learning_rate": lr,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 4,
#         "lora_alpha": 8,
#         "lora_dropout": 0.1,
#         "target_modules": ["query", "value"],
#     }
#     for lr in [1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 2e-3, 3e-3]
# ]


# lora_experiment_configs = [
#     {
#         "name": "lora_r32_light_short",
#         "transform_key": "light",  # Stick with what works
#         "epochs": 8,  # Shorter training - you plateau early
#         "learning_rate": 1e-3,
#         "batch_size": DEFAULT_BATCH_SIZE,  # Keep your original batch size
#         "weight_decay": 0.01,
#         "lora_r": 32,
#         "lora_alpha": 64,
#         "lora_dropout": 0.1,  # Lower dropout for facial features
#         "target_modules": ["query", "value", "output.dense"],  # Start with these
#     },
#     {
#         "name": "lora_r48_light_focused",
#         "transform_key": "light",
#         "epochs": 8,
#         "learning_rate": 8e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 48,
#         "lora_alpha": 96,
#         "lora_dropout": 0.15,
#         "target_modules": ["query", "value"],  # Just Q, V might be better
#     },
# ]

# 2nd 
# lora_experiment_configs = [
#     # Experiment 1: Baseline (your best performer)
#     {
#         "name": "lora_r4_light_lr",
#         "transform_key": "light",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 5e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 4,
#         "lora_alpha": 8,
#         "lora_dropout": 0.05,  # Lower dropout for FER2013
#         "target_modules": ["query", "value"],
#     },

#     # Experiment 2: Query-only attention (simpler, fewer params)
#     {
#         "name": "lora_r4_light_query_only",
#         "transform_key": "light",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 3e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 4,
#         "lora_alpha": 8,
#         "lora_dropout": 0.05,
#         "target_modules": ["query"],  # Only query projections
#     },
#     # Experiment 3: Even lower dropout for subtle facial features
#     {
#         "name": "lora_r4_light_low_dropout",
#         "transform_key": "light",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 3e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 4,
#         "lora_alpha": 8,
#         "lora_dropout": 0.01,  # Minimal dropout
#         "target_modules": ["query", "value"],
#     },
#     # Experiment 4: Slightly higher rank for facial detail
#     {
#         "name": "lora_r6_light",
#         "transform_key": "light",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 3e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         "lora_r": 6,  # Slightly higher than r=4
#         "lora_alpha": 12,  # 2 * r
#         "lora_dropout": 0.05,
#         "target_modules": ["query", "value"],
#     },
# ]


# First 
# lora_experiment_configs = [
#     # # Experiment 1: Conservative LoRA (low rank, efficient)
#     {
#         "name": "lora_r4_light",
#         "transform_key": "light",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 3e-4,  # Higher LR for LoRA
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         # LoRA specific parameters
#         "lora_r": 4,
#         "lora_alpha": 8,  # 2 * r
#         "lora_dropout": 0.1,
#         "target_modules": ["query", "value"],
#     },
#     # Experiment 2: Balanced LoRA (medium rank)
#     {
#         "name": "lora_r8_medium",
#         "transform_key": "medium",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 2e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         # LoRA specific parameters
#         "lora_r": 8,
#         "lora_alpha": 16,
#         "lora_dropout": 0.1,
#         "target_modules": ["query", "value"],
#     },
#     # Experiment 3: Higher capacity LoRA
#     {
#         "name": "lora_r16_medium",
#         "transform_key": "medium",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 1e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         # LoRA specific parameters
#         "lora_r": 16,
#         "lora_alpha": 32,
#         "lora_dropout": 0.1,
#         "target_modules": ["query", "value"],
#     },
#     # Experiment 4: Full attention LoRA (Q, K, V)
#     {
#         "name": "lora_r8_qkv_heavy",
#         "transform_key": "heavy",
#         "epochs": LORA_EPOCHS,
#         "learning_rate": 2e-4,
#         "batch_size": DEFAULT_BATCH_SIZE,
#         "weight_decay": 0.01,
#         # LoRA specific parameters
#         "lora_r": 8,
#         "lora_alpha": 16,
#         "lora_dropout": 0.1,
#         "target_modules": ["query", "key", "value"],
#     },
# ]

print(f"{len(lora_experiment_configs)} LoRA Experiment Configs Loaded")

# Print summary
print("\nüìä LoRA Configuration Summary:")
for config in lora_experiment_configs:
    print(f"\n{config['name']}:")
    print(f"  LoRA rank: {config['lora_r']}, alpha: {config['lora_alpha']}")
    print(f"  Target modules: {config['target_modules']}")
    print(f"  Learning rate: {config['learning_rate']}")
    print(f"  Transform: {config['transform_key']}")


### Training Loop

In [None]:
# ============================================================================
# CELL: LoRA Training Loop (Fixed Version with Memory Management)
# ============================================================================
from tqdm.notebook import tqdm
import torch
import gc
from torch.optim import AdamW
from src.wandb_utils import cleanup_wandb_run
from peft import get_peft_model  # Direct import for debugging

all_lora_results = {}
failed_lora_experiments = []

print(f"Starting LoRA training for {len(lora_experiment_configs)} experiments")
print("=" * 70)

for i, config in enumerate(tqdm(lora_experiment_configs, desc="LoRA Experiments")):
    print(f"\n{'='*70}")
    print(f"LoRA Experiment {i+1}/{len(lora_experiment_configs)}: {config['name']}")
    print(f"   Transform: {config['transform_key']}")
    print(f"   LR: {config['learning_rate']}")
    print(f"   Epochs: {config['epochs']}")
    print(f"   Batch Size: {config['batch_size']}")
    print(f"   LoRA Rank: {config['lora_r']}, Alpha: {config['lora_alpha']}")
    print(f"   Target Modules: {config['target_modules']}")
    print(f"{'='*70}")
    
    # Ensure any previous WandB run is cleaned up
    cleanup_wandb_run()
    
    try:
        # Clear memory before starting new experiment
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            print(f"GPU Memory cleared: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")
        
        # Create datasets
        transform = transform_configs[config['transform_key']]
        
        train = FER2013Dataset(
            split="train",
            transform=transform
        )
        valid = FER2013Dataset(
            split="valid", 
            transform=base_transform()
        )
        
        # Initialize base model
        print("Loading base model...")
        base_model = ViTForImageClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True
        )
        
        # Create LoRA configuration
        lora_config = create_lora_config(
            r=config['lora_r'],
            lora_alpha=config['lora_alpha'],
            lora_dropout=config['lora_dropout'],
            target_modules=config['target_modules'],
        )
        
        print(f"DEBUG: LoRA config created: r={lora_config.r}, target_modules={lora_config.target_modules}")
        
        # Apply LoRA with classifier unfreezing
        model = get_lora_model(base_model, lora_config, unfreeze_classifier=True)
        
        # Check trainable parameters BEFORE moving to device
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"DEBUG: Trainable parameters: {trainable_params:,} / {total_params:,} ({(trainable_params/total_params)*100:.2f}%)")
        
        if trainable_params == 0:
            print("‚ùå CRITICAL ERROR: No trainable parameters found!")
            print("DEBUG: Checking parameter names...")
            for name, param in model.named_parameters():
                print(f"  {name}: requires_grad={param.requires_grad}, shape={param.shape}")
            raise RuntimeError("No trainable parameters in model. LoRA adapters not applied properly.")
        
        # Move model to device AFTER parameter counting
        model.to(DEVICE)
        model.print_trainable_parameters()
        
        # Print detailed parameter info
        print("\nDetailed parameter breakdown:")
        lora_params = 0
        classifier_params = 0
        other_params = 0
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                if 'lora' in name.lower():
                    lora_params += param.numel()
                    print(f"  üîµ LoRA: {name} - {param.numel():,} params")
                elif 'classifier' in name.lower() or 'head' in name.lower():
                    classifier_params += param.numel()
                    print(f"  üü¢ Classifier: {name} - {param.numel():,} params")
                else:
                    other_params += param.numel()
                    print(f"  ‚ö´ Other: {name} - {param.numel():,} params")
        
        print(f"\nParameter summary:")
        print(f"  LoRA parameters: {lora_params:,}")
        print(f"  Classifier parameters: {classifier_params:,}")
        print(f"  Other trainable parameters: {other_params:,}")
        print(f"  Total trainable: {trainable_params:,}")
        
        # Initialize optimizer (only trainable parameters)
        optimizer_params = [p for p in model.parameters() if p.requires_grad]
        print(f"DEBUG: Optimizer will optimize {len(optimizer_params)} parameter groups")
        
        optimizer = AdamW(
            optimizer_params,
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        print(f"\n‚úÖ LoRA Model initialized successfully")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,} ({(trainable_params/total_params)*100:.2f}%)")
        print(f"   Device: {DEVICE}")
        
        # Calculate gradient accumulation steps based on batch size
        # Lower batch sizes might benefit from gradient accumulation
        gradient_accumulation_steps = 1
        if config['batch_size'] >= 64:  # Large batch size, reduce gradient accumulation
            gradient_accumulation_steps = 1
        elif config['batch_size'] >= 32:
            gradient_accumulation_steps = 2
        else:
            gradient_accumulation_steps = 4
        
        print(f"   Gradient accumulation steps: {gradient_accumulation_steps}")
        
        # Clear memory before training
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Train model
        model_exp, history_exp, run_folder_exp = train_lora_model(
            model=model,
            optimizer=optimizer,
            train_dataset=train,
            val_dataset=valid,
            num_epochs=config['epochs'],
            batch_size=config['batch_size'],
            device=DEVICE,
            model_name=config['name'],
            use_wandb=True,
            wandb_config={
                "learning_rate": config['learning_rate'],
                "batch_size": config['batch_size'],
                "epochs": config['epochs'],
                "weight_decay": config['weight_decay'],
                "model_name": "vit_base_patch16_224",
                "architecture": "ViT + LoRA",
                "dataset": "FER2013",
                "transform_set": config['transform_key'],
                "experiment_name": config['name'],
                "lora_r": config['lora_r'],
                "lora_alpha": config['lora_alpha'],
                "lora_dropout": config['lora_dropout'],
                "lora_target_modules": config['target_modules'],
                "trainable_params": trainable_params,
                "trainable_percentage": (trainable_params/total_params)*100,
                "gradient_accumulation_steps": gradient_accumulation_steps,
            },
            gradient_accumulation_steps=gradient_accumulation_steps,
        )
        
        # Store results
        all_lora_results[config['name']] = {
            'model': model_exp,
            'history': history_exp,
            'run_folder': run_folder_exp,
            'config': config,
            'best_val_accuracy': max(history_exp['val_acc']),
            'best_val_loss': min(history_exp['val_loss']),
            'final_train_accuracy': history_exp['train_acc'][-1],
            'final_train_loss': history_exp['train_loss'][-1],
            'lora_adapter_path': run_folder_exp / "lora_adapter",
            'trainable_params': trainable_params,
            'total_params': total_params,
        }
        
        print(f"\n‚úÖ COMPLETED: {config['name']}")
        print(f"   Best Val Accuracy: {all_lora_results[config['name']]['best_val_accuracy']:.4f}")
        print(f"   Best Val Loss: {all_lora_results[config['name']]['best_val_loss']:.4f}")
        print(f"   Run folder: {run_folder_exp}")
        
        # Save summary to file
        summary_path = run_folder_exp / "training_summary.txt"
        with open(summary_path, 'w') as f:
            f.write(f"LoRA Experiment: {config['name']}\n")
            f.write(f"Transform: {config['transform_key']}\n")
            f.write(f"Learning Rate: {config['learning_rate']}\n")
            f.write(f"Epochs: {config['epochs']}\n")
            f.write(f"Batch Size: {config['batch_size']}\n")
            f.write(f"LoRA Rank: {config['lora_r']}\n")
            f.write(f"LoRA Alpha: {config['lora_alpha']}\n")
            f.write(f"LoRA Dropout: {config['lora_dropout']}\n")
            f.write(f"Target Modules: {config['target_modules']}\n")
            f.write(f"Best Val Accuracy: {all_lora_results[config['name']]['best_val_accuracy']:.4f}\n")
            f.write(f"Best Val Loss: {all_lora_results[config['name']]['best_val_loss']:.4f}\n")
            f.write(f"Final Train Accuracy: {history_exp['train_acc'][-1]:.4f}\n")
            f.write(f"Final Train Loss: {history_exp['train_loss'][-1]:.4f}\n")
            f.write(f"Trainable Parameters: {trainable_params:,}\n")
            f.write(f"Total Parameters: {total_params:,}\n")
            f.write(f"Trainable Percentage: {(trainable_params/total_params)*100:.2f}%\n")
        
        print(f"   Summary saved to: {summary_path}")
        
    except KeyboardInterrupt:
        print(f"\n‚ö†Ô∏è Training interrupted by user at experiment: {config['name']}")
        cleanup_wandb_run()
        break
        
    except Exception as e:
        print(f"\n‚ùå ERROR in experiment {config['name']}: {str(e)}")
        print(f"Exception type: {type(e).__name__}")
        
        # Store failed experiment info
        failed_lora_experiments.append({
            'name': config['name'],
            'error': str(e),
            'error_type': type(e).__name__,
            'traceback': traceback.format_exc() if 'traceback' in locals() else "N/A"
        })
        
        # Clean up WandB
        cleanup_wandb_run()
        
        print(f"Continuing to next experiment...")
        
    finally:
        # Comprehensive memory cleanup
        print("\nüßπ Performing comprehensive memory cleanup...")
        
        # Delete all local variables
        local_vars = list(locals().keys())
        for var_name in ['model', 'base_model', 'model_exp', 'optimizer', 'train', 'valid', 'lora_config']:
            if var_name in locals():
                del locals()[var_name]
                print(f"   Deleted: {var_name}")
        
        # Force garbage collection
        gc.collect()
        
        # Clear CUDA cache if available
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            
            # Print memory stats
            allocated = torch.cuda.memory_allocated() / 1e9
            reserved = torch.cuda.memory_reserved() / 1e9
            print(f"   GPU Memory after cleanup: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")
        
        print("‚úÖ Memory cleaned up")
        
        # Add a small delay between experiments
        import time
        time.sleep(2)  # 2 second delay to ensure proper cleanup

# Final cleanup
cleanup_wandb_run()

# Print summary
print("\n" + "="*70)
print("LORA TRAINING COMPLETE - SUMMARY")
print("="*70)
print(f"Successful experiments: {len(all_lora_results)}/{len(lora_experiment_configs)}")
print(f"Failed experiments: {len(failed_lora_experiments)}/{len(lora_experiment_configs)}")

if all_lora_results:
    print("\nüìä LoRA Results (sorted by Best Val Accuracy):")
    
    # Sort results by best validation accuracy
    sorted_results = sorted(
        all_lora_results.items(),
        key=lambda x: x[1]['best_val_accuracy'],
        reverse=True
    )
    
    for name, result in sorted_results:
        config = result['config']
        print(f"\n   {name}:")
        print(f"      Val Acc: {result['best_val_accuracy']:.4f}")
        print(f"      Val Loss: {result['best_val_loss']:.4f}")
        print(f"      Train Acc: {result['final_train_accuracy']:.4f}")
        print(f"      LR: {config['learning_rate']}")
        print(f"      LoRA Rank: {config['lora_r']}")
        print(f"      Target Modules: {config['target_modules']}")
        print(f"      Trainable Params: {result['trainable_params']:,}")
        print(f"      Trainable %: {(result['trainable_params']/result['total_params'])*100:.2f}%")
        print(f"      Run Folder: {result['run_folder'].name}")

# Print failed experiments
if failed_lora_experiments:
    print("\n‚ùå Failed Experiments:")
    for failed in failed_lora_experiments:
        print(f"\n   {failed['name']}:")
        print(f"      Error Type: {failed['error_type']}")
        print(f"      Error: {failed['error']}")

# Save overall results to file
if all_lora_results:
    import json
    from datetime import datetime
    
    results_summary = {
        'timestamp': datetime.now().isoformat(),
        'total_experiments': len(lora_experiment_configs),
        'successful': len(all_lora_results),
        'failed': len(failed_lora_experiments),
        'results': {},
        'failed_experiments': failed_lora_experiments,
    }
    
    for name, result in all_lora_results.items():
        results_summary['results'][name] = {
            'best_val_accuracy': float(result['best_val_accuracy']),
            'best_val_loss': float(result['best_val_loss']),
            'final_train_accuracy': float(result['final_train_accuracy']),
            'final_train_loss': float(result['final_train_loss']),
            'trainable_params': result['trainable_params'],
            'total_params': result['total_params'],
            'run_folder': str(result['run_folder']),
            'config': result['config'],
        }
    
    # Save to JSON
    import os
    results_dir = Path("checkpoints") / "lora_experiments_summary"
    results_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = results_dir / f"lora_experiments_summary_{timestamp}.json"
    
    with open(results_file, 'w') as f:
        json.dump(results_summary, f, indent=2, default=str)
    
    print(f"\nüìÅ Full results summary saved to: {results_file}")

print("\n" + "="*70)

### Inference
#### Load Model

In [None]:
from transformers import ViTImageProcessor
import torch.nn.functional as F

# Find best LoRA model
if all_lora_results:
    best_lora_name = max(all_lora_results, key=lambda x: all_lora_results[x]['best_val_accuracy'])
    best_lora_result = all_lora_results[best_lora_name]
    
    print(f"Loading best LoRA model: {best_lora_name}")
    print(f"Best validation accuracy: {best_lora_result['best_val_accuracy']:.4f}")
    print(f"LoRA adapter path: {best_lora_result['lora_adapter_path']}")
    
    # Load base model
    base_model_for_inference = ViTForImageClassification.from_pretrained(
        MODEL_NAME,
        num_labels=NUM_LABELS,
        ignore_mismatched_sizes=True
    )
    
    # Load LoRA adapters
    lora_model_for_inference = load_lora_model(
        base_model_for_inference,
        best_lora_result['lora_adapter_path'],
        device=DEVICE
    )
    
    print("\n‚úÖ LoRA model ready for inference!")
else:
    print("‚ùå No LoRA models trained yet. Run the training cell first.")

#### Test Predictions

In [None]:
from src.metadata import find_latest_run_for_experiment
from transformers import ViTImageProcessor
import random
import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.metrics import accuracy_score
from pathlib import Path
import pandas as pd

from src.evaluate import evaluate_model
from src.lora import load_lora_model
from src.config import DEVICE, EMOTION_LABELS, NUM_LABELS
from src.fer2013 import FER2013Dataset
from src.transforms import base_transform
from transformers import ViTForImageClassification

# Initialize test dataset
test_ds = FER2013Dataset(split="test", transform=base_transform())
MODEL_NAME = "google/vit-base-patch16-224-in21k"

# Define your checkpoints directory
CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

# Convert EMOTION_LABELS list to dict
EMOTION_DICT = {i: emotion for i, emotion in enumerate(EMOTION_LABELS)}

print("üîç LoRA Model Evaluation - UPDATED VERSION")
print("=" * 80)
print(f"Device: {DEVICE}")
print(f"Number of labels: {NUM_LABELS}")
print(f"Emotion labels: {EMOTION_LABELS}")
print("=" * 80)

def load_lora_model_correctly_for_evaluation(model_folder):
    """
    CORRECT way to load a LoRA model for evaluation.
    
    Must load BOTH:
    1. LoRA adapters (from lora_adapter/)
    2. Classifier weights (from .pth checkpoint)
    """
    # Find the latest run for this experiment
    run_folder = find_latest_run_for_experiment(model_folder, CHECKPOINTS_DIR)
    if not run_folder:
        print(f"‚ùå No run folder found for {model_folder}")
        return None
    
    print(f"\nüìÅ Loading from: {model_folder}")
    print(f"   Run folder: {run_folder.name}")
    
    # Step 1: Find the .pth checkpoint (has classifier weights)
    pth_files = list(run_folder.glob("*best*.pth"))
    if not pth_files:
        print(f"   ‚ùå No .pth checkpoint found in {run_folder}")
        return None
    
    pth_path = pth_files[0]
    print(f"   ‚úì Found checkpoint: {pth_path.name}")
    
    # Step 2: Check if lora_adapter folder exists
    lora_adapter_path = run_folder / "lora_adapter"
    if not lora_adapter_path.exists():
        print(f"   ‚ùå lora_adapter folder not found at: {lora_adapter_path}")
        return None
    
    print(f"   ‚úì Found lora_adapter folder")
    
    try:
        # Step 3: Load base model
        print(f"   Loading base ViT model...")
        base_model = ViTForImageClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True
        )
        
        # Step 4: Load LoRA adapters using your existing function
        print(f"   Loading LoRA adapters...")
        model = load_lora_model(
            base_model=base_model,
            lora_adapter_path=str(lora_adapter_path),
            device='cpu'  # Load to CPU first
        )
        
        # Step 5: Merge LoRA weights into base model
        print(f"   Merging LoRA weights...")
        model = model.merge_and_unload()
        
        # Step 6: Load classifier weights from .pth checkpoint
        print(f"   Loading classifier weights from checkpoint...")
        checkpoint = torch.load(pth_path, map_location='cpu', weights_only=False)
        state_dict = checkpoint['model_state_dict']
        
        # Extract classifier weights - handle both naming formats
        classifier_weight = None
        classifier_bias = None
        
        # Try different naming patterns
        possible_weight_keys = [
            'base_model.model.classifier.weight',
            'classifier.weight',
            'model.classifier.weight'
        ]
        
        possible_bias_keys = [
            'base_model.model.classifier.bias',
            'classifier.bias',
            'model.classifier.bias'
        ]
        
        for key in possible_weight_keys:
            if key in state_dict:
                classifier_weight = state_dict[key]
                print(f"   Found classifier weight: {key}")
                break
        
        for key in possible_bias_keys:
            if key in state_dict:
                classifier_bias = state_dict[key]
                print(f"   Found classifier bias: {key}")
                break
        
        if classifier_weight is not None and classifier_bias is not None:
            # Load classifier weights into the merged model
            model.classifier.weight.data = classifier_weight
            model.classifier.bias.data = classifier_bias
            print(f"   ‚úÖ Classifier weights loaded!")
            print(f"      Weight shape: {classifier_weight.shape}")
            print(f"      Bias shape: {classifier_bias.shape}")
        else:
            print(f"   ‚ö†Ô∏è  WARNING: Classifier weights not found in checkpoint!")
            print(f"      Looking for keys containing 'classifier':")
            classifier_keys = [k for k in state_dict.keys() if 'classifier' in k.lower()]
            for key in classifier_keys:
                print(f"      - {key}")
        
        # Step 7: Move to device and set to eval mode
        model.to(DEVICE)
        model.eval()
        
        # Verify model structure
        print(f"   ‚úÖ Model loaded successfully!")
        print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
        
        # Quick check of classifier weights
        if hasattr(model, 'classifier'):
            weight_mean = model.classifier.weight.data.mean().item()
            weight_std = model.classifier.weight.data.std().item()
            print(f"   Classifier weight stats: mean={weight_mean:.6f}, std={weight_std:.6f}")
        
        return model
        
    except Exception as e:
        print(f"   ‚ùå Error loading LoRA model: {e}")
        import traceback
        traceback.print_exc()
        return None

# Configuration
LORA_MODELS_TO_TEST = [
    "lora_r4_light1",
    "lora_r8_medium",
    "lora_r8_qkv_heavy",
    "lora_r16_medium",
]

NUM_SAMPLES = 3  # Number of random test samples per model
sample_indices = random.sample(range(len(test_ds)), NUM_SAMPLES)
print(f"\nRandom sample indices for visual inspection: {sample_indices}")

# Load processor for image preprocessing
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

def display_model_predictions(model_name, results, sample_indices):
    """Display predictions for a single model."""
    print(f"\n{'='*70}")
    print(f"Model: {model_name}")
    print(f"{'='*70}")
    
    correct_count = sum(1 for r in results if r['correct'])
    accuracy = correct_count / len(results)
    
    print(f"Batch Accuracy: {correct_count}/{len(results)} ({accuracy:.1%})")
    print(f"Samples tested: {sample_indices}")
    print()
    
    for i, result in enumerate(results):
        print(f"Sample {i+1} (Index {result['sample_index']}):")
        print(f"  True: {EMOTION_DICT[result['true_label']]:<12}", end="")
        print(f"  Predicted: {EMOTION_DICT[result['pred_label']]:<12}", end="")
        print(f"  Confidence: {result['confidence']:.1%}", end="")
        print(f"  {'‚úì' if result['correct'] else '‚úó'}")
        
        # Show top 3 predictions
        print(f"  Top 3: ", end="")
        for j, (emotion_idx, prob) in enumerate(result['top3_predictions']):
            emotion_name = EMOTION_DICT[emotion_idx]
            print(f"{emotion_name}: {prob:.1%}", end="")
            if j < 2:
                print(", ", end="")
        print()
    
    # Visualize all samples in a grid
    fig, axes = plt.subplots(1, len(results), figsize=(4*len(results), 4))
    if len(results) == 1:
        axes = [axes]
    
    for i, (result, ax) in enumerate(zip(results, axes)):
        ax.imshow(result['image'], cmap='gray')
        title = f"Sample {i+1}"
        ax.set_title(title, fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f"Model: {model_name}", fontsize=12)
    plt.tight_layout()
    plt.show()
    
    return accuracy

def predict_lora_batch(dataset, indices, model, model_name):
    """Run predictions on multiple samples for a LoRA model."""
    results = []
    
    for idx in indices:
        img, true_label = dataset[idx]
        
        # Convert tensor to numpy for display
        img_display = img.cpu().numpy()
        if len(img_display.shape) == 3:
            img_display = img_display.transpose(1, 2, 0)
            # If RGB, convert to grayscale for display
            if img_display.shape[2] == 3:
                img_display = img_display.mean(axis=2)
        
        # Run model
        model.eval()
        model.to(DEVICE)
        
        # Add batch dimension
        img_batch = img.unsqueeze(0).to(DEVICE)
        
        with torch.no_grad():
            outputs = model(pixel_values=img_batch)
            probs = torch.softmax(outputs.logits, dim=-1)[0]
            pred_label = torch.argmax(probs).item()
            confidence = probs[pred_label].item()
        
        # Get top 3 predictions
        top3_probs, top3_idx = torch.topk(probs, 3)
        top3_predictions = [
            (idx.item(), prob.item()) 
            for prob, idx in zip(top3_probs, top3_idx)
        ]
        
        results.append({
            'sample_index': idx,
            'true_label': true_label,
            'pred_label': pred_label,
            'confidence': confidence,
            'correct': true_label == pred_label,
            'top3_predictions': top3_predictions,
            'image': img_display
        })
    
    return results

print(f"\nTesting {len(LORA_MODELS_TO_TEST)} LoRA models")
print(f"Test dataset size: {len(test_ds)}")

lora_model_results = {}
full_evaluation_results = {}

for lora_name in LORA_MODELS_TO_TEST:
    try:
        print(f"\n{'='*80}")
        print(f"Processing: {lora_name}")
        print(f"{'='*80}")
        
        # Load the model using CORRECT method
        model = load_lora_model_correctly_for_evaluation(lora_name)
        
        if model is None:
            print(f"‚ùå Failed to load model")
            continue
        
        # 1. Quick sanity check on random samples
        print(f"\nüß™ Running quick visual inspection on {NUM_SAMPLES} samples...")
        visual_results = predict_lora_batch(
            dataset=test_ds,
            indices=sample_indices,
            model=model,
            model_name=lora_name
        )
        
        visual_accuracy = display_model_predictions(lora_name, visual_results, sample_indices)
        
        # 2. Run full evaluation
        print(f"\nüìä Running full evaluation on entire test set...")
        metrics = evaluate_model(
            model=model,
            test_dataset=test_ds,
            log_to_wandb=False,
            run_name=lora_name.replace("_", " ").title()
        )
        
        # Store both results
        lora_model_results[lora_name] = {
            'visual_accuracy': visual_accuracy,
            'correct': sum(1 for r in visual_results if r['correct']),
            'total': len(visual_results),
            'full_accuracy': metrics.get('accuracy', 0)
        }
        
        full_evaluation_results[lora_name] = {
            'metrics': metrics,
            'visual_accuracy': visual_accuracy
        }
        
        # Print results
        display_name = lora_name.replace("_", " ").title()
        print(f"\n‚úÖ Results for {display_name}:")
        print(f"   ‚Ä¢ Visual check ({NUM_SAMPLES} samples): {visual_accuracy:.4f}")
        print(f"   ‚Ä¢ Full test accuracy: {metrics.get('accuracy', 0):.4f}")
        print(f"   ‚Ä¢ Precision: {metrics.get('precision', 0):.4f}")
        print(f"   ‚Ä¢ Recall: {metrics.get('recall', 0):.4f}")
        print(f"   ‚Ä¢ F1 Score: {metrics.get('f1', 0):.4f}")
        
        # Clean up
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    except Exception as e:
        print(f"‚ùå Error evaluating {lora_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

# Print summary
print(f"\n{'='*80}")
print("SUMMARY: LoRA Model Comparison")
print(f"{'='*80}")

if lora_model_results:
    # Create comparison table
    comparison_data = []
    for model_name, results in lora_model_results.items():
        comparison_data.append({
            'Model': model_name.replace("_", " ").title(),
            'Visual Accuracy': results['visual_accuracy'],
            'Full Test Accuracy': results['full_accuracy'],
            'Visual Correct': f"{results['correct']}/{results['total']}"
        })
    
    df_comparison = pd.DataFrame(comparison_data)
    
    # Sort by full test accuracy
    df_comparison = df_comparison.sort_values('Full Test Accuracy', ascending=False)
    
    # Format for display
    df_display = df_comparison.copy()
    for col in ['Visual Accuracy', 'Full Test Accuracy']:
        df_display[col] = df_display[col].apply(lambda x: f"{x:.4f}")
    
    print("\n" + df_display.to_string(index=False))
    
    # Print ranking
    print(f"\nüèÜ Performance Ranking:")
    for i, (_, row) in enumerate(df_comparison.iterrows()):
        print(f"{i+1}. {row['Model']:<25} "
              f"Full Acc: {row['Full Test Accuracy']:.4f} | "
              f"Visual: {row['Visual Correct']} ({row['Visual Accuracy']:.1%})")
    
    best_model = df_comparison.iloc[0]
    print(f"\nüèÜ Best LoRA Model: {best_model['Model']} "
          f"(Full Accuracy: {best_model['Full Test Accuracy']:.4f})")
    
    # Save results to CSV
    csv_path = CHECKPOINTS_DIR / "lora_comparison_results.csv"
    df_comparison.to_csv(csv_path, index=False)
    print(f"\nüíæ Results saved to: {csv_path}")
    
    # Check if any models have suspiciously low performance
    low_perf_models = df_comparison[df_comparison['Full Test Accuracy'] < 0.3]
    if not low_perf_models.empty:
        print(f"\n‚ö†Ô∏è  WARNING: Models with low performance (< 30%):")
        for _, row in low_perf_models.iterrows():
            print(f"   ‚Ä¢ {row['Model']}: {row['Full Test Accuracy']:.4f}")
    
else:
    print("‚ùå No models were successfully evaluated.")

print(f"\n{'='*80}")
print("EVALUATION COMPLETE")
print(f"Models attempted: {len(LORA_MODELS_TO_TEST)}")
print(f"Models successfully evaluated: {len(lora_model_results)}")
print(f"{'='*80}")

#### Full Test Set Evaluation 

In [None]:
import torch
from pathlib import Path
import pandas as pd
import numpy as np

from src.evaluate import evaluate_model
from src.lora import load_lora_model_for_inference
from src.config import DEVICE, EMOTION_LABELS, NUM_LABELS
from src.fer2013 import FER2013Dataset
from src.transforms import base_transform
from src.metadata import find_latest_run_for_experiment

# Initialize test dataset
test_ds = FER2013Dataset(split="test", transform=base_transform())

# Define your checkpoints directory
CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

print("FULL TEST SET EVALUATION - LoRA Models")
print("=" * 80)
print(f"Device: {DEVICE}")
print(f"Test dataset size: {len(test_ds)}")
print(f"Number of labels: {NUM_LABELS}")
print("=" * 80)

# ===== CONFIGURATION =====
LORA_MODELS_TO_EVALUATE = [
    # "lora_r4_light_baseline_long",
    # "lora_r6_light_long"

    "lora_r4_light_lr",
    "lora_r32_light_short",
    "lora_r48_light_focused"
    # "lora_r4_light_low_dropout",
    # "lora_r6_light"
]
# =========================

print(f"\nEvaluating {len(LORA_MODELS_TO_EVALUATE)} LoRA models on full test set")
print("-" * 80)

all_results = {}
failed_models = []

for model_name in LORA_MODELS_TO_EVALUATE:
    print(f"\n{'='*60}")
    print(f"Processing: {model_name}")
    print(f"{'='*60}")
    
    try:
        # Find the latest run
        run_folder = find_latest_run_for_experiment(model_name, CHECKPOINTS_DIR)
        
        # Load model using the new function
        print("Loading model...")
        model = load_lora_model_for_inference(run_folder, device=DEVICE)
        
        # Run full evaluation
        print("Running evaluation on full test set...")
        metrics = evaluate_model(
            model=model,
            test_dataset=test_ds,
            log_to_wandb=False,
            run_name=model_name.replace("_", " ").title()
        )
        
        # Store results
        all_results[model_name] = {
            'accuracy': metrics.get('accuracy', 0),
            'precision': metrics.get('precision', 0),
            'recall': metrics.get('recall', 0),
            'f1': metrics.get('f1', 0),
            'run_folder': str(run_folder),
        }
        
        # Print results
        display_name = model_name.replace("_", " ").title()
        print(f"\nResults for {display_name}:")
        print(f"  Accuracy:  {metrics.get('accuracy', 0):.4f}")
        print(f"  Precision: {metrics.get('precision', 0):.4f}")
        print(f"  Recall:    {metrics.get('recall', 0):.4f}")
        print(f"  F1 Score:  {metrics.get('f1', 0):.4f}")
        
        # Clean up
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    except Exception as e:
        print(f"Error evaluating {model_name}: {e}")
        failed_models.append(model_name)
        continue

# Generate final report
if all_results:
    print(f"\n{'='*80}")
    print("FINAL EVALUATION REPORT - FULL TEST SET")
    print('='*80)
    
    # Create comparison table
    comparison_data = []
    for model_name, results in all_results.items():
        comparison_data.append({
            'Model': model_name.replace("_", " ").title(),
            'Accuracy': results['accuracy'],
            'Precision': results['precision'],
            'Recall': results['recall'],
            'F1 Score': results['f1'],
            'Run Folder': Path(results['run_folder']).name
        })
    
    df_comparison = pd.DataFrame(comparison_data)
    
    # Sort by accuracy
    df_comparison = df_comparison.sort_values('Accuracy', ascending=False)
    
    # Format for display
    df_display = df_comparison.copy()
    for col in ['Accuracy', 'Precision', 'Recall', 'F1 Score']:
        df_display[col] = df_display[col].apply(lambda x: f"{x:.4f}")
    
    print("\n" + df_display.to_string(index=False))
    
    # Find best model
    best_model = df_comparison.iloc[0]
    print(f"\nBEST PERFORMING MODEL:")
    print(f"  Name: {best_model['Model']}")
    print(f"  Accuracy: {best_model['Accuracy']:.4f}")
    print(f"  F1 Score: {best_model['F1 Score']:.4f}")
    
    # Calculate statistics
    accuracies = df_comparison['Accuracy'].tolist()
    print(f"\nSTATISTICS:")
    print(f"  Average Accuracy: {np.mean(accuracies):.4f}")
    print(f"  Best Accuracy:    {np.max(accuracies):.4f}")
    print(f"  Worst Accuracy:   {np.min(accuracies):.4f}")
    print(f"  Std Deviation:    {np.std(accuracies):.4f}")
    
    # Save to CSV
    csv_path = CHECKPOINTS_DIR / "lora_full_evaluation_results.csv"
    df_comparison.to_csv(csv_path, index=False)
    print(f"\nFull results saved to: {csv_path}")
    
    # Identify underperforming models
    low_perf_threshold = 0.3
    low_perf_models = df_comparison[df_comparison['Accuracy'] < low_perf_threshold]
    
    if not low_perf_models.empty:
        print(f"\nMODELS WITH LOW PERFORMANCE (<{low_perf_threshold:.0%}):")
        for _, row in low_perf_models.iterrows():
            print(f"  {row['Model']}: {row['Accuracy']:.4f}")
    
else:
    print("No models were successfully evaluated.")

# Report failed models
if failed_models:
    print(f"\nFAILED TO EVALUATE ({len(failed_models)}):")
    for model in failed_models:
        print(f"  {model}")

print(f"\n{'='*80}")
print("SUMMARY:")
print(f"  Total models: {len(LORA_MODELS_TO_EVALUATE)}")
print(f"  Successfully evaluated: {len(all_results)}")
print(f"  Failed: {len(failed_models)}")
print('='*80)

#### Visual Insepction

In [None]:
import torch
from pathlib import Path
import random
import matplotlib.pyplot as plt
import numpy as np

from src.lora import load_lora_model_for_inference
from src.config import DEVICE, EMOTION_LABELS, NUM_LABELS
from src.fer2013 import FER2013Dataset
from src.transforms import base_transform
from src.metadata import find_latest_run_for_experiment

# Initialize test dataset
test_ds = FER2013Dataset(split="test", transform=base_transform())

# Define your checkpoints directory
CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

# Convert EMOTION_LABELS list to dict
EMOTION_DICT = {i: emotion for i, emotion in enumerate(EMOTION_LABELS)}

print("VISUAL INSPECTION - LoRA Models on N Samples")
print("=" * 80)
print(f"Device: {DEVICE}")
print(f"Test dataset size: {len(test_ds)}")
print(f"Number of labels: {NUM_LABELS}")
print("=" * 80)

def display_predictions(model_name, results, sample_indices):
    """Display predictions in a clean format."""
    print(f"\n{'='*70}")
    print(f"Model: {model_name}")
    print(f"{'='*70}")
    
    correct_count = sum(1 for r in results if r['correct'])
    accuracy = correct_count / len(results)
    
    print(f"Accuracy on {len(results)} samples: {correct_count}/{len(results)} ({accuracy:.1%})")
    print(f"Sample indices: {sample_indices}")
    print()
    
    for i, result in enumerate(results):
        status = "‚úì" if result['correct'] else "‚úó"
        print(f"Sample {i+1} (Index {result['sample_index']}): {status}")
        print(f"  True:      {EMOTION_DICT[result['true_label']]:<12}")
        print(f"  Predicted: {EMOTION_DICT[result['pred_label']]:<12} ({result['confidence']:.1%})")
        
        print(f"  Top predictions:", end="")
        for j, (emotion_idx, prob) in enumerate(result['top3_predictions'][:3]):
            emotion_name = EMOTION_DICT[emotion_idx]
            print(f" {emotion_name}: {prob:.1%}", end="")
            if j < 2:
                print(",", end="")
        print()
    
    return accuracy

def visualize_samples(results, model_name):
    """Create visualization of predictions."""
    n_samples = len(results)
    fig, axes = plt.subplots(1, n_samples, figsize=(4*n_samples, 4))
    
    if n_samples == 1:
        axes = [axes]
    
    for i, (result, ax) in enumerate(zip(results, axes)):
        ax.imshow(result['image'], cmap='gray')
        
        border_color = 'green' if result['correct'] else 'red'
        for spine in ax.spines.values():
            spine.set_color(border_color)
            spine.set_linewidth(3)
        
        true_label = EMOTION_DICT[result['true_label']]
        pred_label = EMOTION_DICT[result['pred_label']]
        confidence = result['confidence']
        
        title = f"Sample {i+1}\nTrue: {true_label}\nPred: {pred_label} ({confidence:.0%})"
        ax.set_title(title, fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f"Model: {model_name}", fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()

def predict_samples(model, dataset, indices):
    """Run predictions on specific samples."""
    results = []
    
    for idx in indices:
        img, true_label = dataset[idx]
        
        # Prepare image for display
        img_display = img.cpu().numpy()
        if len(img_display.shape) == 3:
            img_display = img_display.transpose(1, 2, 0)
            if img_display.shape[2] == 3:
                img_display = img_display.mean(axis=2)
        
        # Run inference
        model.eval()
        img_batch = img.unsqueeze(0).to(DEVICE)
        
        with torch.no_grad():
            outputs = model(pixel_values=img_batch)
            probs = torch.softmax(outputs.logits, dim=-1)[0]
            pred_label = torch.argmax(probs).item()
            confidence = probs[pred_label].item()
        
        # Get top predictions
        top3_probs, top3_idx = torch.topk(probs, 3)
        top3_predictions = [(idx.item(), prob.item()) for prob, idx in zip(top3_probs, top3_idx)]
        
        results.append({
            'sample_index': idx,
            'true_label': true_label,
            'pred_label': pred_label,
            'confidence': confidence,
            'correct': true_label == pred_label,
            'top3_predictions': top3_predictions,
            'image': img_display
        })
    
    return results

# ===== CONFIGURATION =====
LORA_MODELS_TO_INSPECT = [
    "lora_r4_light1",
    "lora_r8_medium",
    "lora_r8_qkv_heavy",
    "lora_r16_medium",
]

NUM_SAMPLES = 3  # Number of random samples to inspect
# =========================

print(f"\nVisual inspection of {len(LORA_MODELS_TO_INSPECT)} LoRA models")
print(f"Testing {NUM_SAMPLES} random samples per model")
print("-" * 80)

# Get random samples (same for all models for fair comparison)
sample_indices = random.sample(range(len(test_ds)), NUM_SAMPLES)
print(f"Random sample indices: {sample_indices}")

inspection_results = {}

for model_name in LORA_MODELS_TO_INSPECT:
    print(f"\n{'='*60}")
    print(f"Model: {model_name}")
    print(f"{'='*60}")
    
    try:
        # Find the latest run
        run_folder = find_latest_run_for_experiment(model_name, CHECKPOINTS_DIR)
        
        # Load model using the new function
        model = load_lora_model_for_inference(run_folder, device=DEVICE)
        
        # Run predictions
        print(f"Running predictions on {NUM_SAMPLES} samples...")
        results = predict_samples(model, test_ds, sample_indices)
        
        # Display results
        accuracy = display_predictions(model_name, results, sample_indices)
        
        # Visualize
        visualize_samples(results, model_name)
        
        # Store results
        inspection_results[model_name] = {
            'accuracy': accuracy,
            'correct': sum(1 for r in results if r['correct']),
            'total': len(results),
            'samples': sample_indices,
            'run_folder': str(run_folder)
        }
        
        # Clean up
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    except Exception as e:
        print(f"Error during inspection: {e}")
        continue

# Print summary
if inspection_results:
    print(f"\n{'='*80}")
    print("VISUAL INSPECTION SUMMARY")
    print('='*80)
    
    # Sort by accuracy
    sorted_models = sorted(
        inspection_results.items(),
        key=lambda x: x[1]['accuracy'],
        reverse=True
    )
    
    print("\nPerformance Ranking:")
    print("-" * 40)
    for i, (model_name, results) in enumerate(sorted_models):
        print(f"{i+1}. {model_name:<25} "
              f"{results['correct']}/{results['total']} ({results['accuracy']:.1%})")
    
    best_model = sorted_models[0]
    print(f"\nBest in inspection: {best_model[0]} ({best_model[1]['accuracy']:.1%})")
    
else:
    print("\nNo models were successfully inspected.")

print(f"\n{'='*80}")
print("INSPECTION COMPLETE")
print(f"Models attempted: {len(LORA_MODELS_TO_INSPECT)}")
print(f"Models inspected: {len(inspection_results)}")
print(f"Samples per model: {NUM_SAMPLES}")
print('='*80)

## Debug 
### Reload Failed Models from Backup Checkpoint

In [None]:
# Direct FER2013 dataset loading and display
from datasets import load_dataset
import random
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# Load the dataset splits
print("Loading FER2013 dataset directly...")
train_ds = load_dataset("AutumnQiu/fer2013", split="train")
val_ds = load_dataset("AutumnQiu/fer2013", split="valid")
test_ds = load_dataset("AutumnQiu/fer2013", split="test")

print(f"Train samples: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")
print(f"Test samples: {len(test_ds)}")

# EMOTION_LABELS (in case not imported)
EMOTION_LABELS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]

# Display random samples from each split
num_samples = 3
splits = [("Train", train_ds), ("Validation", val_ds), ("Test", test_ds)]

for split_name, dataset in splits:
    print(f"\n{'='*70}")
    print(f"{split_name} Split - Random Samples")
    print(f"{'='*70}")
    
    # Get random indices
    sample_indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    
    # Create subplot
    fig, axes = plt.subplots(1, num_samples, figsize=(4*num_samples, 4))
    if num_samples == 1:
        axes = [axes]
    
    for i, idx in enumerate(sample_indices):
        item = dataset[idx]
        img = item["image"]
        label = item["label"]
        
        print(f"\nSample {i+1} (Index {idx}):")
        print(f"  Label: {label} ({EMOTION_LABELS[label]})")
        print(f"  Image type: {type(img)}")
        
        # Convert to PIL if needed
        if isinstance(img, Image.Image):
            img_pil = img
        else:
            # Handle numpy array or tensor
            img_pil = Image.fromarray(np.array(img))
        
        print(f"  PIL mode: {img_pil.mode}")
        print(f"  PIL size: {img_pil.size}")
        
        # Display
        axes[i].imshow(img_pil, cmap='gray')
        axes[i].set_title(f"{split_name}\n{EMOTION_LABELS[label]}\nMode: {img_pil.mode}")
        axes[i].axis('off')
    
    plt.suptitle(f"{split_name} Split - {len(dataset)} samples", fontsize=14)
    plt.tight_layout()
    plt.show()

# Investigate image properties
print(f"\n{'='*70}")
print("Dataset Properties Investigation")
print(f"{'='*70}")

# Check first few samples in test set
print("\nFirst 5 test samples:")
for i in range(min(5, len(test_ds))):
    item = test_ds[i]
    img = item["image"]
    label = item["label"]
    
    if isinstance(img, Image.Image):
        mode = img.mode
        size = img.size
        img_array = np.array(img)
    else:
        img_array = np.array(img)
        mode = f"Array shape: {img_array.shape}"
        size = f"Array dtype: {img_array.dtype}"
    
    print(f"  Sample {i}: Label {label} ({EMOTION_LABELS[label]})")
    print(f"    Image: {mode}, {size}")
    if hasattr(img_array, 'shape'):
        print(f"    Min/Max: {img_array.min()}/{img_array.max()}")
    print()

# Show a specific problematic sample if you know the index
print(f"\n{'='*70}")
print("Specific Sample Inspection (if you have problem indices)")
print(f"{'='*70}")

# If you had problem indices like [182, 215, 178], check them:
problem_indices = [182, 215, 178]
print(f"\nChecking indices: {problem_indices}")

for idx in problem_indices:
    if idx < len(test_ds):
        item = test_ds[idx]
        img = item["image"]
        label = item["label"]
        
        # Convert to consistent format for display
        if isinstance(img, Image.Image):
            img_pil = img
        else:
            img_pil = Image.fromarray(np.array(img))
        
        # Convert to grayscale for display
        if img_pil.mode == 'RGB':
            img_display = img_pil.convert('L')
            print(f"Index {idx}: Converted RGB -> Grayscale")
        else:
            img_display = img_pil
        
        plt.figure(figsize=(4, 4))
        plt.imshow(img_display, cmap='gray')
        plt.title(f"Index {idx}: {EMOTION_LABELS[label]}\nOriginal: {img_pil.mode}, Display: {img_display.mode}")
        plt.axis('off')
        plt.tight_layout()
        plt.show()
    else:
        print(f"Index {idx} out of range (test set size: {len(test_ds)})")

### Dataset Debug

In [None]:
# Quick debug: Show de-normalized images
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms

# ImageNet stats for ViT normalization
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

def denormalize(img_tensor):
    """
    Undo ImageNet normalization.
    Input is a tensor of shape [3, H, W].
    Output is a tensor in [0,1] range for visualization.
    """
    img = img_tensor.clone()
    img = img * IMAGENET_STD + IMAGENET_MEAN
    img = torch.clamp(img, 0, 1)
    return img


test_ds = FER2013Dataset(
    split="train",
    transform=base_transform()
)

print("Quick Debug: De-normalized Dataset Images")
print("=" * 70)

problem_indices = [1253, 417, 1863]  # Your indices

fig, axes = plt.subplots(1, len(problem_indices), figsize=(4*len(problem_indices), 4))
if len(problem_indices) == 1:
    axes = [axes]

for i, idx in enumerate(problem_indices):
    img_tensor, true_label = test_ds[idx]

    # De-normalize BEFORE converting to PIL
    img_tensor_denorm = denormalize(img_tensor)

    img_pil = transforms.ToPILImage()(img_tensor_denorm)

    print(f"Index {idx}:")
    print(f"  Tensor shape: {img_tensor.shape}")
    print(f"  PIL mode: {img_pil.mode}")

    axes[i].imshow(img_pil)  # don't force grayscale
    axes[i].set_title(f"Index {idx}\nLabel: {EMOTION_LABELS[true_label]}")
    axes[i].axis('off')

plt.suptitle("De-Normalized Dataset Images", fontsize=12)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
from PIL import Image  # <-- ADD THIS

# Quick debug without transforms - handle PIL Images
test_ds_no_transform = FER2013Dataset(split="train", transform=None)

print("Quick Debug: Raw Images WITHOUT transforms")
print("=" * 70)

problem_indices = [1253, 417, 1863]
fig, axes = plt.subplots(1, len(problem_indices), figsize=(4*len(problem_indices), 4))
if len(problem_indices) == 1:
    axes = [axes]

for i, idx in enumerate(problem_indices):
    img, true_label = test_ds_no_transform[idx]
    
    print(f"Index {idx}:")
    
    # Handle PIL Image or numpy array
    if isinstance(img, Image.Image):
        print(f"  PIL Image mode: {img.mode}")
        print(f"  PIL Image size: {img.size}")
        img_array = np.array(img)
    else:
        img_array = img
        print(f"  Array shape: {img_array.shape}")
        print(f"  Array type: {type(img_array)}")
    
    # Display
    if len(img_array.shape) == 3 and img_array.shape[2] == 3:
        # RGB image - take first channel for grayscale display
        img_to_show = img_array[:, :, 0]
    elif len(img_array.shape) == 2:
        # Already grayscale
        img_to_show = img_array
    else:
        # Unknown format, try to display as-is
        img_to_show = img_array
    
    axes[i].imshow(img_to_show, cmap='gray')
    axes[i].set_title(f"Index {idx}\nLabel: {EMOTION_LABELS[true_label]}")
    axes[i].axis('off')

plt.suptitle("Images WITHOUT transforms", fontsize=12)
plt.tight_layout()
plt.show()

### Lora Loader

In [None]:
import torch
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score

from src.evaluate import evaluate_model
from src.lora import load_lora_model
from src.config import DEVICE, EMOTION_LABELS, NUM_LABELS
from src.fer2013 import FER2013Dataset
from src.transforms import base_transform
from transformers import ViTForImageClassification

# Initialize test dataset
test_ds = FER2013Dataset(split="test", transform=base_transform())

# Define your checkpoints directory
CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")
MODEL_NAME = "google/vit-base-patch16-224-in21k"

# Convert EMOTION_LABELS list to dict
EMOTION_DICT = {i: emotion for i, emotion in enumerate(EMOTION_LABELS)}

print("üîç LoRA Model Evaluation - FIXED VERSION")
print("=" * 80)
print(f"Device: {DEVICE}")
print(f"Number of labels: {NUM_LABELS}")
print(f"Emotion labels: {EMOTION_LABELS}")
print("=" * 80)

# ===== USER: Specify which LoRA folders to evaluate =====
LORA_FOLDERS_TO_EVALUATE = [
    "lora_r4_light1",
    # Add more LoRA folder names here
]
# =========================================================

def load_lora_model_correctly(model_folder):
    """
    CORRECT way to load a LoRA model for evaluation.
    
    Must load BOTH:
    1. LoRA adapters (from lora_adapter/)
    2. Classifier weights (from .pth checkpoint)
    """
    folder_path = CHECKPOINTS_DIR / model_folder
    
    if not folder_path.exists():
        print(f"‚ùå Folder not found: {folder_path}")
        return None
    
    print(f"\nüìÅ Loading from: {model_folder}")
    
    # Step 1: Find the .pth checkpoint (has classifier weights)
    pth_files = list(folder_path.glob("*best*.pth"))
    if not pth_files:
        print(f"   ‚ùå No .pth checkpoint found")
        return None
    
    pth_path = pth_files[0]
    print(f"   ‚úì Found checkpoint: {pth_path.name}")
    
    # Step 2: Check if lora_adapter folder exists
    lora_adapter_path = folder_path / "lora_adapter"
    if not lora_adapter_path.exists():
        print(f"   ‚ùå lora_adapter folder not found at: {lora_adapter_path}")
        return None
    
    print(f"   ‚úì Found lora_adapter folder")
    
    try:
        # Step 3: Load base model
        print(f"   Loading base ViT model...")
        base_model = ViTForImageClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True
        )
        
        # Step 4: Load LoRA adapters
        print(f"   Loading LoRA adapters...")
        model = load_lora_model(
            base_model=base_model,
            lora_adapter_path=lora_adapter_path,
            device='cpu'  # Load to CPU first
        )
        
        # Step 5: Merge LoRA weights into base model
        print(f"   Merging LoRA weights...")
        model = model.merge_and_unload()
        
        # Step 6: Load classifier weights from .pth checkpoint
        print(f"   Loading classifier weights from checkpoint...")
        checkpoint = torch.load(pth_path, map_location='cpu', weights_only=False)
        state_dict = checkpoint['model_state_dict']
        
        # Extract classifier weights
        classifier_weight = state_dict.get('base_model.model.classifier.weight')
        classifier_bias = state_dict.get('base_model.model.classifier.bias')
        
        if classifier_weight is not None and classifier_bias is not None:
            # Load classifier weights into the merged model
            model.classifier.weight.data = classifier_weight
            model.classifier.bias.data = classifier_bias
            print(f"   ‚úÖ Classifier weights loaded!")
            print(f"      Weight mean: {classifier_weight.mean().item():.6f}")
            print(f"      Weight std: {classifier_weight.std().item():.6f}")
        else:
            print(f"   ‚ö†Ô∏è  WARNING: Classifier weights not found in checkpoint!")
        
        # Step 7: Move to device and set to eval mode
        model.to(DEVICE)
        model.eval()
        
        print(f"   ‚úÖ Model loaded successfully!")
        return model
        
    except Exception as e:
        print(f"   ‚ùå Error loading LoRA model: {e}")
        import traceback
        traceback.print_exc()
        return None


# Main evaluation loop
if LORA_FOLDERS_TO_EVALUATE:
    all_results = {}
    
    for lora_folder in LORA_FOLDERS_TO_EVALUATE:
        print(f"\n{'='*80}")
        print(f"Processing: {lora_folder}")
        print(f"{'='*80}")
        
        try:
            # Load the model using CORRECT method
            model = load_lora_model_correctly(lora_folder)
            
            if model is None:
                print(f"‚ùå Failed to load model")
                continue
            
            # Quick sanity check (32 samples)
            print("\nüß™ Running sanity check...")
            test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)
            
            batch_preds = []
            batch_labels = []
            batch_confidences = []
            
            with torch.no_grad():
                for images, labels in test_loader:
                    images = images.to(DEVICE)
                    outputs = model(pixel_values=images)
                    probs = torch.softmax(outputs.logits, dim=-1)
                    preds = torch.argmax(outputs.logits, dim=-1)
                    
                    batch_preds.extend(preds.cpu().numpy())
                    batch_labels.extend(labels.numpy())
                    batch_confidences.extend(probs.max(dim=-1).values.cpu().numpy())
                    
                    # Only first batch
                    if len(batch_preds) >= 32:
                        break
            
            sanity_acc = accuracy_score(batch_labels, batch_preds)
            avg_confidence = np.mean(batch_confidences)
            
            print(f"   Sanity check accuracy: {sanity_acc:.4f}")
            print(f"   Average confidence: {avg_confidence:.4f}")
            
            # Check prediction distribution
            unique_preds = np.unique(batch_preds)
            print(f"   Predicting {len(unique_preds)} different classes")
            
            if len(unique_preds) == 1:
                print(f"   ‚ö†Ô∏è  WARNING: Model predicts only class {unique_preds[0]} "
                      f"({EMOTION_DICT.get(unique_preds[0], 'Unknown')})")
            
            # If sanity check looks good, proceed with full evaluation
            if sanity_acc > 0.3:
                print(f"   ‚úÖ Sanity check passed! Proceeding with full evaluation...")
            else:
                print(f"   ‚ö†Ô∏è  Sanity check shows low accuracy, but continuing...")
            
            # Run full evaluation
            print("\nüìä Running full evaluation...")
            metrics = evaluate_model(
                model=model,
                test_dataset=test_ds,
                log_to_wandb=False,
                run_name=lora_folder.replace("_", " ").title()
            )
            
            # Store results
            all_results[lora_folder] = {
                'metrics': metrics,
                'sanity_acc': sanity_acc,
                'avg_confidence': avg_confidence
            }
            
            # Print results
            display_name = lora_folder.replace("_", " ").title()
            print(f"\n‚úÖ Results for {display_name}:")
            print(f"   ‚Ä¢ Test Accuracy: {metrics.get('accuracy', 0):.4f}")
            print(f"   ‚Ä¢ Precision:     {metrics.get('precision', 0):.4f}")
            print(f"   ‚Ä¢ Recall:        {metrics.get('recall', 0):.4f}")
            print(f"   ‚Ä¢ F1 Score:      {metrics.get('f1', 0):.4f}")
            
            # Check if results match training performance
            expected_acc = 0.64  # From your training_parameters.json
            actual_acc = metrics.get('accuracy', 0)
            diff = abs(expected_acc - actual_acc)
            
            if diff < 0.05:
                print(f"   ‚úÖ Performance matches training! (difference: {diff:.4f})")
            else:
                print(f"   ‚ö†Ô∏è  Performance differs from training by {diff:.4f}")
            
            # Clean up
            del model
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
        except Exception as e:
            print(f"‚ùå Error evaluating {lora_folder}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Generate comparison report
    if all_results:
        print(f"\n{'='*80}")
        print("üìä FINAL COMPARISON REPORT")
        print('='*80)
        
        # Create comparison table
        comparison_data = []
        for folder_name, results in all_results.items():
            metrics = results['metrics']
            comparison_data.append({
                'Model': folder_name.replace("_", " ").title(),
                'Test Accuracy': metrics.get('accuracy', 0),
                'Precision': metrics.get('precision', 0),
                'Recall': metrics.get('recall', 0),
                'F1 Score': metrics.get('f1', 0),
                'Sanity Check': results.get('sanity_acc', 0),
                'Avg Confidence': results.get('avg_confidence', 0)
            })
        
        df_comparison = pd.DataFrame(comparison_data)
        
        # Format for display
        df_display = df_comparison.copy()
        for col in ['Test Accuracy', 'Precision', 'Recall', 'F1 Score', 'Sanity Check', 'Avg Confidence']:
            if col in df_display.columns:
                df_display[col] = df_display[col].apply(lambda x: f"{x:.4f}")
        
        print("\n" + df_display.to_string(index=False))
        
        # Find best model
        if len(comparison_data) > 1:
            best_idx = df_comparison['Test Accuracy'].idxmax()
            worst_idx = df_comparison['Test Accuracy'].idxmin()
            
            print(f"\nüèÜ Best model: {df_comparison.loc[best_idx, 'Model']} "
                  f"(Accuracy: {df_comparison.loc[best_idx, 'Test Accuracy']:.4f})")
            print(f"üìâ Worst model: {df_comparison.loc[worst_idx, 'Model']} "
                  f"(Accuracy: {df_comparison.loc[worst_idx, 'Test Accuracy']:.4f})")
        
        # Save results
        csv_path = CHECKPOINTS_DIR / "lora_evaluation_results_fixed.csv"
        df_comparison.to_csv(csv_path, index=False)
        print(f"\nüíæ Results saved to: {csv_path}")
        
        # Check for low performance
        low_performance = df_comparison[df_comparison['Test Accuracy'] < 0.4]
        if not low_performance.empty:
            print(f"\n‚ö†Ô∏è  WARNING: Some models have low test accuracy:")
            for _, row in low_performance.iterrows():
                print(f"   ‚Ä¢ {row['Model']}: {row['Test Accuracy']:.4f}")
        else:
            print(f"\n‚úÖ All models performing as expected!")
    
    print(f"\n‚úÖ Evaluation complete. {len(all_results)} models evaluated.")
    
else:
    print("‚ÑπÔ∏è  No LoRA folders specified for evaluation.")

print("\n" + "=" * 80)
print("SUMMARY:")
print(f"Models processed: {len(LORA_FOLDERS_TO_EVALUATE)}")
print(f"Successfully evaluated: {len(all_results) if 'all_results' in locals() else 0}")
print("=" * 80)

### LoRa Debugger

In [None]:
import torch
from pathlib import Path
import numpy as np
from transformers import ViTForImageClassification
from peft import PeftModel

# Define paths
CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")
LORA_FOLDER = "lora_r4_light1"
LORA_ADAPTER_PATH = CHECKPOINTS_DIR / LORA_FOLDER / "lora_adapter"

print("üîç DEBUGGING LoRA MODEL LOADING")
print("=" * 70)
print(f"LoRA folder: {LORA_FOLDER}")
print(f"Adapter path: {LORA_ADAPTER_PATH}")
print(f"Path exists: {LORA_ADAPTER_PATH.exists()}")
print("=" * 70)

# 1. Check what files exist
print("\nüìÑ Checking files in lora_adapter directory:")
if LORA_ADAPTER_PATH.exists():
    for file in LORA_ADAPTER_PATH.iterdir():
        print(f"  ‚Ä¢ {file.name} (Size: {file.stat().st_size:,} bytes)")
else:
    print(f"‚ùå lora_adapter directory does not exist!")

# 2. Load base model and check its performance
print("\nüß™ Step 1: Testing BASE MODEL performance (should be random)")
base_model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    ignore_mismatched_sizes=True
).to(DEVICE)
base_model.eval()

# Quick test on a few samples
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)

with torch.no_grad():
    all_preds = []
    all_labels = []
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        outputs = base_model(pixel_values=images)
        preds = torch.argmax(outputs.logits, dim=-1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        # Just test first batch
        break

from sklearn.metrics import accuracy_score
base_acc = accuracy_score(all_labels, all_preds)
print(f"Base model accuracy on first batch: {base_acc:.4f}")
print(f"Predictions distribution: {np.bincount(all_preds, minlength=NUM_LABELS)}")
print(f"True labels distribution: {np.bincount(all_labels, minlength=NUM_LABELS)}")

del base_model
torch.cuda.empty_cache()

# 3. Try loading LoRA with different methods
print("\nüß™ Step 2: Testing LoRA loading methods")

try:
    print("\nMethod 1: Using load_lora_model from your src.lora")
    from src.lora import load_lora_model
    
    base_model = ViTForImageClassification.from_pretrained(
        MODEL_NAME,
        num_labels=NUM_LABELS,
        ignore_mismatched_sizes=True
    )
    
    lora_model = load_lora_model(
        base_model=base_model,
        lora_adapter_path=str(LORA_ADAPTER_PATH),
        device=DEVICE
    )
    
    # Check if weights are actually trainable
    trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in lora_model.parameters())
    
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable percentage: {(trainable_params/total_params)*100:.2f}%")
    
    # Check a few specific LoRA parameters
    print("\nChecking specific LoRA parameter names and sizes:")
    for name, param in lora_model.named_parameters():
        if 'lora' in name.lower() or any(x in name.lower() for x in ['lora_a', 'lora_b']):
            print(f"  {name}: {param.shape}, requires_grad: {param.requires_grad}")
            if param.requires_grad:
                print(f"    Mean: {param.data.mean().item():.6f}, Std: {param.data.std().item():.6f}")
        if len([name for name, _ in lora_model.named_parameters() if 'lora' in name.lower()]) > 10:
            print("  ... (showing first 10 LoRA parameters)")
            break
    
    # Quick test
    with torch.no_grad():
        all_preds = []
        all_labels = []
        for images, labels in test_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = lora_model(pixel_values=images)
            preds = torch.argmax(outputs.logits, dim=-1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            break
    
    lora_acc = accuracy_score(all_labels, all_preds)
    print(f"\nLoRA model accuracy on first batch: {lora_acc:.4f}")
    print(f"Predictions distribution: {np.bincount(all_preds, minlength=NUM_LABELS)}")
    
    del lora_model, base_model
    
except Exception as e:
    print(f"‚ùå Error with Method 1: {e}")
    import traceback
    traceback.print_exc()

# 4. Check the adapter config
print("\nüìã Step 3: Checking adapter_config.json")
config_path = LORA_ADAPTER_PATH / "adapter_config.json"
if config_path.exists():
    import json
    with open(config_path, 'r') as f:
        config = json.load(f)
    print("LoRA Configuration:")
    for key, value in config.items():
        if key not in ['peft_type', 'task_type', 'inference_mode']:
            print(f"  {key}: {value}")

# 5. Try loading directly with PeftModel
print("\nüß™ Step 4: Loading directly with PeftModel")
try:
    base_model = ViTForImageClassification.from_pretrained(
        MODEL_NAME,
        num_labels=NUM_LABELS,
        ignore_mismatched_sizes=True
    )
    
    print(f"Loading from: {LORA_ADAPTER_PATH}")
    lora_model = PeftModel.from_pretrained(
        base_model,
        LORA_ADAPTER_PATH,
        device=DEVICE
    )
    
    print("‚úÖ Loaded with PeftModel.from_pretrained")
    
    # Check if it's actually a Peft model
    print(f"Is PeftModel instance: {isinstance(lora_model, PeftModel)}")
    
    # Count trainable params
    trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Test inference
    lora_model.eval()
    with torch.no_grad():
        # Get a single batch
        test_sample = next(iter(test_loader))
        images, labels = test_sample
        images = images.to(DEVICE)
        
        # Test without merging
        outputs = lora_model(pixel_values=images)
        preds = torch.argmax(outputs.logits, dim=-1)
        print(f"\nSingle batch predictions: {preds.cpu().numpy()[:10]}")
        
        # Test with merging
        print("\nTesting with merged weights...")
        merged_model = lora_model.merge_and_unload()
        outputs_merged = merged_model(pixel_values=images)
        preds_merged = torch.argmax(outputs_merged.logits, dim=-1)
        print(f"Merged model predictions: {preds_merged.cpu().numpy()[:10]}")
        
        # Check if predictions are different
        if torch.equal(preds, preds_merged):
            print("‚ö†Ô∏è WARNING: Predictions are identical (LoRA might not be applied)")
        else:
            print("‚úì Predictions differ (LoRA is being applied)")
    
    del lora_model, merged_model, base_model
    
except Exception as e:
    print(f"‚ùå Error with direct PeftModel loading: {e}")
    import traceback
    traceback.print_exc()

# 6. Check if weights were actually saved during training
print("\nüìä Step 5: Checking training history")
history_path = CHECKPOINTS_DIR / LORA_FOLDER / f"history_{LORA_FOLDER}.json"
if history_path.exists():
    import json
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    if 'val_acc' in history and history['val_acc']:
        best_val_acc = max(history['val_acc'])
        print(f"Best validation accuracy during training: {best_val_acc:.4f}")
        print(f"Final training accuracy: {history['train_acc'][-1] if 'train_acc' in history else 'N/A'}")
        
        if best_val_acc < 0.2:
            print("‚ö†Ô∏è WARNING: Model had poor accuracy during training too!")
        else:
            print("‚úÖ Model trained well, but loading is the issue")

# 7. Check the .pth file
print("\nüì¶ Step 6: Checking .pth checkpoint")
pth_files = list((CHECKPOINTS_DIR / LORA_FOLDER).glob("*.pth"))
if pth_files:
    pth_file = pth_files[0]
    print(f"Found .pth file: {pth_file.name} ({pth_file.stat().st_size:,} bytes)")
    
    # Try loading it
    try:
        checkpoint = torch.load(pth_file, map_location='cpu', weights_only=True)
        print(f"Checkpoint keys: {list(checkpoint.keys())}")
        
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
            print(f"State dict keys (first 10): {list(state_dict.keys())[:10]}")
            
            # Check for LoRA weights
            lora_keys = [k for k in state_dict.keys() if 'lora' in k.lower()]
            print(f"Number of LoRA keys in .pth: {len(lora_keys)}")
            if lora_keys:
                print(f"Sample LoRA keys: {lora_keys[:5]}")
    except Exception as e:
        print(f"‚ùå Error loading .pth file: {e}")

print("\n" + "=" * 70)
print("üí° RECOMMENDED NEXT STEPS:")
print("1. Check if LoRA weights were actually saved during training")
print("2. Verify the adapter_config.json has correct settings")
print("3. Try loading with merge_and_unload() for inference")
print("4. Check if base model architecture matches LoRA training")
print("=" * 70)

#### LoRa Diagnostic Cell

In [None]:
# ============================================================================
# CHECKPOINT INSPECTOR: See what's actually saved in your LoRA checkpoint
# ============================================================================

import torch
from pathlib import Path
import json

CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")
LORA_FOLDER = "lora_r4_light1"

folder_path = CHECKPOINTS_DIR / LORA_FOLDER

print("üîç Inspecting LoRA Checkpoint")
print("=" * 80)
print(f"Folder: {LORA_FOLDER}")
print("=" * 80)

# 1. Check what files exist
print("\n1Ô∏è‚É£ Files in checkpoint folder:")
for item in folder_path.iterdir():
    if item.is_file():
        size_mb = item.stat().st_size / (1024**2)
        print(f"   üìÑ {item.name} ({size_mb:.2f} MB)")
    elif item.is_dir():
        print(f"   üìÅ {item.name}/")

# 2. Inspect the .pth checkpoint
pth_file = folder_path / "best_lora_r4_light.pth"
if pth_file.exists():
    print(f"\n2Ô∏è‚É£ Inspecting {pth_file.name}:")
    checkpoint = torch.load(pth_file, map_location='cpu', weights_only=False)
    
    print(f"\n   Checkpoint keys: {list(checkpoint.keys())}")
    
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
        
        print(f"\n   Total keys in state_dict: {len(state_dict)}")
        
        # Categorize keys
        lora_keys = [k for k in state_dict.keys() if 'lora' in k.lower()]
        classifier_keys = [k for k in state_dict.keys() if 'classifier' in k.lower()]
        other_keys = [k for k in state_dict.keys() if 'lora' not in k.lower() and 'classifier' not in k.lower()]
        
        print(f"\n   üìä Key categories:")
        print(f"   - LoRA keys: {len(lora_keys)}")
        print(f"   - Classifier keys: {len(classifier_keys)}")
        print(f"   - Other keys: {len(other_keys)}")
        
        if classifier_keys:
            print(f"\n   ‚úÖ Classifier keys found:")
            for key in classifier_keys:
                tensor = state_dict[key]
                print(f"   - {key}: shape={tensor.shape}, mean={tensor.mean().item():.6f}, std={tensor.std().item():.6f}")
        else:
            print(f"\n   ‚ùå NO classifier keys found in checkpoint!")
        
        if lora_keys:
            print(f"\n   ‚úÖ Sample LoRA keys (first 5):")
            for key in lora_keys[:5]:
                tensor = state_dict[key]
                print(f"   - {key}: shape={tensor.shape}")
        
        # Check for signs of training
        print(f"\n   üîç Checking if weights were actually trained...")
        
        # Check classifier weights
        if 'base_model.model.classifier.weight' in state_dict:
            classifier_weight = state_dict['base_model.model.classifier.weight']
            print(f"   Classifier weight stats:")
            print(f"   - Mean: {classifier_weight.mean().item():.6f}")
            print(f"   - Std: {classifier_weight.std().item():.6f}")
            print(f"   - Min: {classifier_weight.min().item():.6f}")
            print(f"   - Max: {classifier_weight.max().item():.6f}")
            
            # Check if it's close to random initialization
            if abs(classifier_weight.mean().item()) < 0.01 and abs(classifier_weight.std().item() - 0.02) < 0.01:
                print(f"   ‚ö†Ô∏è  WARNING: Classifier weights look like random initialization!")
        
        # Check a LoRA weight
        if lora_keys:
            sample_lora = state_dict[lora_keys[0]]
            print(f"\n   Sample LoRA weight stats ({lora_keys[0]}):")
            print(f"   - Mean: {sample_lora.mean().item():.6f}")
            print(f"   - Std: {sample_lora.std().item():.6f}")
    
    if 'val_acc' in checkpoint:
        print(f"\n   üìà Training metrics from checkpoint:")
        print(f"   - Validation accuracy: {checkpoint['val_acc']:.4f}")
        print(f"   - Validation loss: {checkpoint.get('val_loss', 'N/A')}")
        print(f"   - Epoch: {checkpoint.get('epoch', 'N/A')}")

# 3. Check lora_adapter folder
lora_adapter_path = folder_path / "lora_adapter"
if lora_adapter_path.exists():
    print(f"\n3Ô∏è‚É£ Inspecting lora_adapter folder:")
    
    # Check adapter_config.json
    config_file = lora_adapter_path / "adapter_config.json"
    if config_file.exists():
        with open(config_file, 'r') as f:
            adapter_config = json.load(f)
        
        print(f"\n   adapter_config.json:")
        print(f"   - r: {adapter_config.get('r')}")
        print(f"   - lora_alpha: {adapter_config.get('lora_alpha')}")
        print(f"   - target_modules: {adapter_config.get('target_modules')}")
        print(f"   - lora_dropout: {adapter_config.get('lora_dropout')}")
    
    # Check adapter weights
    adapter_files = list(lora_adapter_path.glob("*.safetensors")) + list(lora_adapter_path.glob("*.bin"))
    if adapter_files:
        print(f"\n   Adapter weight files:")
        for f in adapter_files:
            size_mb = f.stat().st_size / (1024**2)
            print(f"   - {f.name} ({size_mb:.2f} MB)")
        
        # Try to load adapter weights
        try:
            if adapter_files[0].suffix == '.safetensors':
                from safetensors.torch import load_file
                adapter_weights = load_file(str(adapter_files[0]))
            else:
                adapter_weights = torch.load(adapter_files[0], map_location='cpu')
            
            print(f"\n   Adapter weights:")
            print(f"   - Total keys: {len(adapter_weights)}")
            print(f"   - Sample keys (first 3):")
            for key in list(adapter_weights.keys())[:3]:
                print(f"     ‚Ä¢ {key}: {adapter_weights[key].shape}")
            
            # Check if there are classifier weights in adapter
            classifier_in_adapter = [k for k in adapter_weights.keys() if 'classifier' in k.lower()]
            if classifier_in_adapter:
                print(f"\n   ‚úÖ Classifier weights in adapter: {len(classifier_in_adapter)}")
            else:
                print(f"\n   ‚ùå NO classifier weights in adapter!")
                print(f"   This means the classifier was never saved with LoRA adapters!")
        
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Could not load adapter weights: {e}")

# 4. Check training_parameters.json
params_file = folder_path / "training_parameters.json"
if params_file.exists():
    print(f"\n4Ô∏è‚É£ Training parameters:")
    with open(params_file, 'r') as f:
        params = json.load(f)
    
    print(f"   - Trainable params: {params.get('trainable_params', 'N/A'):,}")
    print(f"   - Total params: {params.get('total_params', 'N/A'):,}")
    print(f"   - Trainable %: {params.get('trainable_percentage', 'N/A'):.4f}%")
    print(f"   - Best val accuracy: {params.get('final_metrics', {}).get('best_val_accuracy', 'N/A')}")

print("\n" + "=" * 80)
print("üìä DIAGNOSIS")
print("=" * 80)

print("\nüí° Key findings:")
print("   1. Check if classifier keys exist in .pth checkpoint")
print("   2. Check if classifier keys exist in lora_adapter/")
print("   3. Compare trainable params from training vs what was saved")
print("\n   If classifier is missing from lora_adapter/, that's your problem!")
print("   The LoRA adapters only save LoRA weights, not the classifier.")
print("   The classifier must be saved separately in the .pth checkpoint.")

### Resume Training from Last Backup

In [None]:
# Resume Training from Last Backup
from src.backup import resume_training
import json
from pathlib import Path


CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

MODELS_TO_RESUME = [
    "baseline_heavy",
]

for model_folder in MODELS_TO_RESUME:
    print(f"\n{'='*70}")
    print(f"Resuming: {model_folder}")
    print(f"{'='*70}")
    
    try:
        run_folder = CHECKPOINTS_DIR / model_folder
        
        # Load training parameters to get original settings
        params_path = run_folder / "training_parameters.json"
        with open(params_path, 'r') as f:
            training_params = json.load(f)
        
        # Create fresh model and datasets
        model = ViTForImageClassification.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            num_labels=7,
            ignore_mismatched_sizes=True
        ).to("cuda")
        
        # Determine transform
        transform_key = "none"
        if 'heavy' in model_folder.lower():
            transform_key = "heavy"
        elif 'medium' in model_folder.lower():
            transform_key = "medium"
        elif 'light' in model_folder.lower():
            transform_key = "light"
        
        transform = transform_configs[transform_key]
        
        train_ds = FER2013Dataset(split="train", transform=transform)
        val_ds = FER2013Dataset(split="valid", transform=base_transform())
        
        optimizer = AdamW(
            model.parameters(), 
            lr=training_params['learning_rate'],
            weight_decay=training_params['optimizer_params']['weight_decay']
        )
        
        # Resume training
        model_resumed, history, new_run_folder = resume_training(
            run_folder=run_folder,
            model=model,
            optimizer=optimizer,
            train_dataset=train_ds,
            val_dataset=val_ds,
            num_epochs=training_params['num_epochs'], 
            batch_size=training_params['batch_size'],
            device="cuda",
            model_name=f"resumed_{model_folder}",
            use_wandb=False
        )
        
        print(f"Successfully resumed: {model_folder}")
        print(f"New run folder: {new_run_folder}")
        
    except Exception as e:
        print(f"Failed to resume {model_folder}: {e}")