# Google ViT 
## Initalization

In [None]:
# Cell 0 
import sys
from pathlib import Path

notebook_dir = Path.cwd()
project_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from transformers import ViTForImageClassification
import random
from torchvision import transforms
import albumentations as A
from src.transforms import base_transform
from src.dataset import FER2013Dataset
from src.config import (
    DEVICE, 
    NUM_LABELS, 
    EMOTION_LABELS,
    DEFAULT_BATCH_SIZE,
    DEFAULT_LEARNING_RATE
)
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from src.train import train_model

print(f"Using device: {DEVICE}")

MODEL_NAME = "google/vit-base-patch16-224-in21k"

### Weights and Biases 

In [None]:
# Cell 1 
from src.wandb_utils import login, check_wandb_mode, sync_offline_runs

# "online", "offline", or "disabled"
# If set to offlien dont forget to sink
WANDB_MODE = "online" 

print("Initializing Weights & Biases...")
current_mode = login(
    project="emotion-classifier-vit",
    mode=WANDB_MODE
)

print(f"W&B initialized successfully in {current_mode.upper()} mode!")

In [None]:

from src.wandb_utils import *

# Weights and Biases Util Commands 

# Check current mode
# check_wandb_mode()

# Sync offline runs (when you have internet)
sync_offline_runs(all_runs=True)

# List available offline runs
# list_offline_runs()

# Change mode 
# set_wandb_mode("online")  

# Set Confirm to False for a Dry Run
# clear_offline_runs(confirm=True)


---
##  Fine Tuning Section
Using FER2013 dataset.

### Tranformations 

In [None]:

# Simpler transformation sets without deprecated parameters
transform_configs = {
    "none": base_transform(),  # Use the base transforms from transforms.py
    
    "light": A.Compose([
        A.HorizontalFlip(p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        A.Affine(translate_percent=0.05, scale=(0.95, 1.05), rotate=(-10, 10), p=0.3),
        *base_transform()  # Include base transforms at the end
    ]),
    
    "medium": A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.Affine(translate_percent=0.1, scale=(0.9, 1.1), rotate=(-15, 15), p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        *base_transform()  # Include base transforms at the end
    ]),
    
    "heavy": A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Affine(translate_percent=0.15, scale=(0.85, 1.15), rotate=(-20, 20), p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.4),
        A.GridDropout(ratio=0.1, p=0.3),
        *base_transform()  # Include base transforms at the end
    ])
}

print("Transformation Configs Loaded")

### Hyper Parameter Queue

In [None]:
# Define experiment configurations
EPOCHS = 6

experiment_configs = [
    # Baseline with different transforms
    {
        "name": "baseline_none",
        "transform_key": "none",
        "epochs": EPOCHS,
        "learning_rate": DEFAULT_LEARNING_RATE,
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
    {
        "name": "baseline_light",
        "transform_key": "light", 
        "epochs": EPOCHS,
        "learning_rate": DEFAULT_LEARNING_RATE,
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
    {
        "name": "baseline_medium",
        "transform_key": "medium",
        "epochs": EPOCHS, 
        "learning_rate": DEFAULT_LEARNING_RATE,
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
    {
        "name": "baseline_heavy",
        "transform_key": "heavy",
        "epochs": EPOCHS,
        "learning_rate": DEFAULT_LEARNING_RATE, 
        "batch_size": DEFAULT_BATCH_SIZE,
        "weight_decay": 0.01
    },
]

print(f"{len(experiment_configs)} Experiment Configs Loaded")

### Training Loop

In [None]:
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from src.wandb_utils import cleanup_wandb_run

all_results = {}
failed_experiments = []

print(f"Starting training for {len(experiment_configs)} experiments")
print("=" * 70)

for i, config in enumerate(tqdm(experiment_configs, desc="Training Experiments")):
    print(f"\n{'='*70}")
    print(f"üî¨ Experiment {i+1}/{len(experiment_configs)}: {config['name']}")
    print(f"   Transform: {config['transform_key']}")
    print(f"   LR: {config['learning_rate']}")
    print(f"   Epochs: {config['epochs']}")
    print(f"   Batch Size: {config['batch_size']}")
    print(f"{'='*70}")
    
    # Ensure any previous WandB run is cleaned up
    cleanup_wandb_run()
    
    try:
        # Create datasets
        transform = transform_configs[config['transform_key']]
        
        train = FER2013Dataset(
            split="train",
            transform=transform
        )
        valid = FER2013Dataset(
            split="valid", 
            transform=base_transform()  
        )
        
        # Initialize model
        model = ViTForImageClassification.from_pretrained(
            MODEL_NAME,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True
        ).to(DEVICE)
        
        # Initialize optimizer
        optimizer = AdamW(
            model.parameters(), 
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        print(f"‚úÖ Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
        
        # Train model
        model_exp, history_exp, run_folder_exp = train_model(
            model=model,
            optimizer=optimizer,
            train_dataset=train,
            val_dataset=valid,
            num_epochs=config['epochs'],
            batch_size=config['batch_size'],
            device=DEVICE,
            model_name=config['name'],  
            use_wandb=True,
            wandb_config={
                "learning_rate": config['learning_rate'],
                "batch_size": config['batch_size'],
                "epochs": config['epochs'],
                "weight_decay": config['weight_decay'],
                "model_name": "vit_base_patch16_224",
                "architecture": "ViT", 
                "dataset": "FER2013",
                "transform_set": config['transform_key'],
                "experiment_name": config['name']
            }
        )
        
        # Store results
        all_results[config['name']] = {
            'model': model_exp,
            'history': history_exp,
            'run_folder': run_folder_exp,
            'config': config,
            'best_val_accuracy': max(history_exp['val_acc']),      
            'best_val_loss': min(history_exp['val_loss']),
            'final_train_accuracy': history_exp['train_acc'][-1],  
            'final_train_loss': history_exp['train_loss'][-1]
        }
        
        print(f"\n COMPLETED: {config['name']}")
        print(f"   Best Val Accuracy: {all_results[config['name']]['best_val_accuracy']:.4f}")
        print(f"   Best Val Loss: {all_results[config['name']]['best_val_loss']:.4f}")
        print(f"   Run folder: {run_folder_exp}")
        
    except KeyboardInterrupt:
        print(f"\n  Training interrupted by user at experiment: {config['name']}")
        cleanup_wandb_run()
        break
        
    except Exception as e:
        print(f"\n ERROR in experiment {config['name']}: {str(e)}")
        print(f"   Exception type: {type(e).__name__}")
        
        # Store failed experiment info
        failed_experiments.append({
            'name': config['name'],
            'error': str(e),
            'error_type': type(e).__name__
        })
        
        # Clean up WandB
        cleanup_wandb_run()
        
        # Decide whether to continue or stop
        print(f"   Continuing to next experiment...")
        
    finally:
        # Clean up memory regardless of success/failure
        if 'model' in locals():
            del model
        if 'model_exp' in locals():
            del model_exp
        if 'optimizer' in locals():
            del optimizer
        if 'train' in locals():
            del train
        if 'valid' in locals():
            del valid
            
        # Force garbage collection
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        print(f"   Memory cleaned up")

# Final cleanup
cleanup_wandb_run()

# Print summary
print("\n" + "="*70)
print(" TRAINING COMPLETE - SUMMARY")
print("="*70)
print(f" Successful experiments: {len(all_results)}/{len(experiment_configs)}")
print(f" Failed experiments: {len(failed_experiments)}/{len(experiment_configs)}")

if all_results:
    print("\n Results:")
    for name, result in all_results.items():
        print(f"   {name}: Val Acc = {result['best_val_accuracy']:.4f}, Val Loss = {result['best_val_loss']:.4f}")

if failed_experiments:
    print("\n  Failed Experiments:")
    for failed in failed_experiments:
        print(f"   {failed['name']}: {failed['error_type']} - {failed['error']}")

print("\n" + "="*70)

---
## Evaluation
### Metrics

In [None]:
# Cell 9: Independent evaluation (can run after kernel restart)
from src.evaluate import evaluate_all_saved_models
from src.dataset import FER2013Dataset
from src.transforms import base_transform
import matplotlib.pyplot as plt

print("üß™ Starting INDEPENDENT evaluation of all saved models...")

# Load test dataset
test_ds = FER2013Dataset(
    split="test", 
    transform=base_transform()
)

print(f"Test dataset size: {len(test_ds)}")

# Evaluate all saved models (no need for all_results in memory)
summary_data = evaluate_all_saved_models(test_ds)

print("\n‚úÖ All saved models evaluated and summarized!")
print(f"üìä Performance plot saved to: experiment_performance_comparison.png")

# Show best model details
if summary_data:
    best_exp = summary_data[0]
    print(f"\nüèÜ Best model: {best_exp['experiment']}")
    print(f"   Test Accuracy: {best_exp['test_accuracy']:.4f}")
    print(f"   Transform: {best_exp['transform']}")
    print(f"   Run Folder: {best_exp['run_folder']}")
else:
    print("‚ùå No models were successfully evaluated")

In [None]:
# Cell 9A: Evaluate specific experiments using your experiment_configs
from src.evaluate import evaluate_from_experiment_configs
from src.dataset import FER2013Dataset
from src.transforms import base_transform
import matplotlib.pyplot as plt

print("üß™ Evaluating specific experiments from config...")

# Load test dataset
test_ds = FER2013Dataset(
    split="test", 
    transform=base_transform()
)

print(f"Test dataset size: {len(test_ds)}")

# Evaluate using your experiment_configs (finds latest runs automatically)
summary_data = evaluate_from_experiment_configs(experiment_configs, test_ds)

print("\n‚úÖ Specific experiments evaluated!")
print(f"üìä Performance plot saved to: experiment_performance_comparison.png")

# Show best model details
if summary_data:
    best_exp = summary_data[0]
    print(f"\nüèÜ Best model: {best_exp['experiment']}")
    print(f"   Run: {best_exp['run_name']}")
    print(f"   Test Accuracy: {best_exp['test_accuracy']:.4f}")
    print(f"   Transform: {best_exp['transform']}")

---
##  Test Predictions
Let's visualize some predictions from the trained model.

In [None]:
# Visualize predictions from multiple models
from src.metadata import find_latest_run_for_experiment, load_training_parameters
from src.checkpoint_utils import load_model_from_checkpoint
from transformers import ViTForImageClassification, ViTImageProcessor
import random
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

MODELS_TO_TEST = [
    "baseline_none",
    "baseline_light", 
]

NUM_SAMPLES = 3  # Number of random test samples per model

def predict_and_visualize_batch(dataset, indices, model, processor, model_name):
    """Run predictions on multiple samples for a single model."""
    results = []
    
    for i, idx in enumerate(indices):
        img, true_label = dataset[idx]
        img_pil = transforms.ToPILImage()(img)
        
        # Run model
        model.eval()
        model.to(DEVICE)
        inputs = processor(images=img_pil, return_tensors="pt")
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)

        # Post-process
        probs = torch.softmax(outputs.logits, dim=-1)[0]
        pred_label = torch.argmax(probs).item()
        confidence = probs[pred_label].item()
        
        # Get top 3 predictions
        top3_probs, top3_idx = torch.topk(probs, 3)
        top3_predictions = [
            (EMOTION_LABELS[idx.item()], prob.item()) 
            for prob, idx in zip(top3_probs, top3_idx)
        ]
        
        results.append({
            'sample_index': idx,
            'true_label': true_label,
            'pred_label': pred_label,
            'confidence': confidence,
            'correct': true_label == pred_label,
            'top3_predictions': top3_predictions,
            'image': img_pil
        })
    
    return results

def display_model_predictions(model_name, results, sample_indices):
    """Display predictions for a single model."""
    print(f"\n{'='*70}")
    print(f"Model: {model_name}")
    print(f"{'='*70}")
    
    correct_count = sum(1 for r in results if r['correct'])
    accuracy = correct_count / len(results)
    
    print(f"Batch Accuracy: {correct_count}/{len(results)} ({accuracy:.1%})")
    print(f"Samples tested: {sample_indices}")
    print()
    
    for i, result in enumerate(results):
        print(f"Sample {i+1} (Index {result['sample_index']}):")
        print(f"  True: {EMOTION_LABELS[result['true_label']]:<12}", end="")
        print(f"  Predicted: {EMOTION_LABELS[result['pred_label']]:<12}", end="")
        print(f"  Confidence: {result['confidence']:.1%}", end="")
        print(f"  {'‚úì' if result['correct'] else '‚úó'}")
        
        # Show top 3 predictions
        print(f"  Top 3: ", end="")
        for j, (emotion, prob) in enumerate(result['top3_predictions']):
            print(f"{emotion}: {prob:.1%}", end="")
            if j < 2:
                print(", ", end="")
        print()
    
    # Visualize all samples in a grid
    fig, axes = plt.subplots(1, len(results), figsize=(4*len(results), 4))
    if len(results) == 1:
        axes = [axes]
    
    for i, (result, ax) in enumerate(zip(results, axes)):
        ax.imshow(result['image'], cmap='gray')
        correct = result['correct']
        color = 'green' if correct else 'red'
        title = f"Sample {i+1}\n"
        title += f"True: {EMOTION_LABELS[result['true_label']]}\n"
        title += f"Pred: {EMOTION_LABELS[result['pred_label']]}\n"
        title += f"Conf: {result['confidence']:.1%}"
        ax.set_title(title, color=color, fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f"Model: {model_name} (Accuracy: {accuracy:.1%})", fontsize=12)
    plt.tight_layout()
    plt.show()
    
    return accuracy

# Load processor (same for all models)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

print(f"Testing {len(MODELS_TO_TEST)} models on {NUM_SAMPLES} random samples each")
print(f"Test dataset size: {len(test_ds)}")
print()

# Get random sample indices (same for all models for fair comparison)
sample_indices = random.sample(range(len(test_ds)), NUM_SAMPLES)
print(f"Random sample indices: {sample_indices}")

model_results = {}

for model_name in MODELS_TO_TEST:
    try:
        print(f"\n{'='*70}")
        print(f"Loading model: {model_name}")
        print(f"{'='*70}")
        
        # Find the latest run for this model
        run_folder = find_latest_run_for_experiment(model_name, CHECKPOINTS_DIR)
        
        # Load the best model checkpoint
        checkpoint_path = run_folder / f"best_{run_folder.name}.pth"
        
        if not checkpoint_path.exists():
            print(f"Checkpoint not found: {checkpoint_path}")
            continue
            
        # Load model
        model = load_model_from_checkpoint(checkpoint_path)
        
        # Get model info
        params = load_training_parameters(run_folder)
        print(f"Loaded: {run_folder.name}")
        print(f"Transform: {model_name.split('_')[-1]}")
        print(f"Epochs: {params.get('num_epochs', 'N/A')}")
        print(f"Learning rate: {params.get('learning_rate', 'N/A'):.2e}")
        
        # Run predictions
        results = predict_and_visualize_batch(
            dataset=test_ds,
            indices=sample_indices,
            model=model,
            processor=processor,
            model_name=model_name
        )
        
        # Display results
        accuracy = display_model_predictions(model_name, results, sample_indices)
        model_results[model_name] = {
            'accuracy': accuracy,
            'correct': sum(1 for r in results if r['correct']),
            'total': len(results),
            'run_folder': run_folder.name
        }
        
    except Exception as e:
        print(f"Failed to test {model_name}: {e}")
        import traceback
        traceback.print_exc()

# Print summary comparison
print(f"\n{'='*70}")
print("SUMMARY: Model Comparison")
print(f"{'='*70}")

if model_results:
    # Sort by accuracy
    sorted_results = sorted(
        model_results.items(), 
        key=lambda x: x[1]['accuracy'], 
        reverse=True
    )
    
    print("\nPerformance Ranking:")
    for i, (model_name, result) in enumerate(sorted_results):
        print(f"{i+1}. {model_name:<20} {result['correct']}/{result['total']} ({result['accuracy']:.1%})")
    
    # Best and worst performers
    best_model = sorted_results[0]
    worst_model = sorted_results[-1]
    
    print(f"\nBest: {best_model[0]} ({best_model[1]['accuracy']:.1%})")
    print(f"Worst: {worst_model[0]} ({worst_model[1]['accuracy']:.1%})")
    
    # Optional: Create comparison visualization
    fig, ax = plt.subplots(figsize=(10, 6))
    models = [m[0] for m in sorted_results]
    accuracies = [m[1]['accuracy'] for m in sorted_results]
    
    bars = ax.bar(models, accuracies, color=['green', 'lightgreen', 'orange', 'red'])
    ax.set_xlabel('Model')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Model Comparison on {NUM_SAMPLES} Samples')
    ax.set_ylim(0, 1)
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2, height + 0.02,
                f'{acc:.1%}', ha='center', va='bottom')
    
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
else:
    print("No models were successfully tested.")

## Debug 
### Reload Failed Models from Backup Checkpoint

In [None]:
# Resume Training from Last Backup
from src.backup import resume_training
import json
from pathlib import Path


CHECKPOINTS_DIR = Path("C:/Users/rayrc/OneDrive/Documents/ML/Emotion Classifier ViT/checkpoints")

MODELS_TO_RESUME = [
    "baseline_heavy",
]

for model_folder in MODELS_TO_RESUME:
    print(f"\n{'='*70}")
    print(f"Resuming: {model_folder}")
    print(f"{'='*70}")
    
    try:
        run_folder = CHECKPOINTS_DIR / model_folder
        
        # Load training parameters to get original settings
        params_path = run_folder / "training_parameters.json"
        with open(params_path, 'r') as f:
            training_params = json.load(f)
        
        # Create fresh model and datasets
        model = ViTForImageClassification.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            num_labels=7,
            ignore_mismatched_sizes=True
        ).to("cuda")
        
        # Determine transform
        transform_key = "none"
        if 'heavy' in model_folder.lower():
            transform_key = "heavy"
        elif 'medium' in model_folder.lower():
            transform_key = "medium"
        elif 'light' in model_folder.lower():
            transform_key = "light"
        
        transform = transform_configs[transform_key]
        
        train_ds = FER2013Dataset(split="train", transform=transform)
        val_ds = FER2013Dataset(split="valid", transform=base_transform())
        
        optimizer = AdamW(
            model.parameters(), 
            lr=training_params['learning_rate'],
            weight_decay=training_params['optimizer_params']['weight_decay']
        )
        
        # Resume training
        model_resumed, history, new_run_folder = resume_training(
            run_folder=run_folder,
            model=model,
            optimizer=optimizer,
            train_dataset=train_ds,
            val_dataset=val_ds,
            num_epochs=training_params['num_epochs'], 
            batch_size=training_params['batch_size'],
            device="cuda",
            model_name=f"resumed_{model_folder}",
            use_wandb=False
        )
        
        print(f"Successfully resumed: {model_folder}")
        print(f"New run folder: {new_run_folder}")
        
    except Exception as e:
        print(f"Failed to resume {model_folder}: {e}")