# Google ViT 
## Initalization

In [1]:
import sys
from pathlib import Path

notebook_dir = Path.cwd()
project_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))


import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification
import matplotlib.pyplot as plt
import random
from torchvision import transforms
import wandb 
from src.dataset import FER2013Dataset, get_datasets
from src.config import (
    DEVICE, 
    NUM_LABELS, 
    EMOTION_LABELS,
    DEFAULT_BATCH_SIZE,
    DEFAULT_NUM_EPOCHS,
    DEFAULT_LEARNING_RATE,
    CHECKPOINTS_DIR,
    RESULTS_DIR
)
from src.train import train_model
from src.evaluate import (
    evaluate_model,
    print_classification_report,
)

print(f"Using device: {DEVICE}")

MODEL_NAME = "google/vit-base-patch16-224-in21k"

Using device: cuda


In [2]:
import wandb
from src.config import WANDB_API_KEY

print("Initializing Weights & Biases...")
wandb.login(key=WANDB_API_KEY)

print("W&B initialized successfully!")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\rayrc\_netrc


Initializing Weights & Biases...


[34m[1mwandb[0m: Currently logged in as: [33mraycaringal[0m ([33mraycaringal-university-of-texas-austin[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B initialized successfully!


In [3]:
print("Loading model and processor...")
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    ignore_mismatched_sizes=True
)

print("Model loaded!")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

Loading model and processor...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded!
Number of parameters: 85,804,039


In [4]:
print("Loading datasets...")
train_ds, val_ds, test_ds = get_datasets()

print(f"Train size: {len(train_ds)}")
print(f"Val size: {len(val_ds)}")
print(f"Test size: {len(test_ds)}")

Loading datasets...
Train size: 28709
Val size: 3589
Test size: 3589


---
##  Fine Tuning Section
Using FER2013 dataset.

In [5]:
optimizer = AdamW(model.parameters(), lr=DEFAULT_LEARNING_RATE)

print(f"Optimizer: AdamW")
print(f"Learning rate: {DEFAULT_LEARNING_RATE}")
print(f"Batch size: {DEFAULT_BATCH_SIZE}")
print(f"Epochs: {DEFAULT_NUM_EPOCHS}")

Optimizer: AdamW
Learning rate: 2e-05
Batch size: 32
Epochs: 10


In [None]:

model, history = train_model(
    model=model,
    optimizer=optimizer,
    train_dataset=train_ds,
    val_dataset=val_ds,
    num_epochs=3,
    batch_size=DEFAULT_BATCH_SIZE,
    device=DEVICE,
    model_name="vit_base_patch16_224",
    use_wandb=True,
    wandb_config={
        "learning_rate": DEFAULT_LEARNING_RATE,
        "batch_size": DEFAULT_BATCH_SIZE,
        "epochs": DEFAULT_NUM_EPOCHS,
        "model_name": "vit_base_patch16_224",
        "architecture": "ViT",
        "dataset": "FER2013"
    }
)

Training vit_base_patch16_224 for 10 epochs...
Total training steps: 8980
Device: cuda
Batch size: 32
Train batches: 898
Val batches: 113
Best model will be saved to: c:\Users\rayrc\OneDrive\Documents\ML\Emotion Classifier ViT\checkpoints\best_vit_base_patch16_224.pth
W&B tracking: https://wandb.ai/raycaringal-university-of-texas-austin/emotion-classification/runs/y0ji4rit

Epoch 1/10
----------------------------------------------------------------------


Training Epoch 0:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 0:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 1.3355 | Train Acc: 0.5111 | Train F1: 0.4926
Val Loss:   1.0213 | Val Acc:   0.6330 | Val F1:   0.6212
✓ New best model saved! (Val Acc: 0.6330, Val F1: 0.6212)

Epoch 2/10
----------------------------------------------------------------------


Training Epoch 1:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 1:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.8943 | Train Acc: 0.6824 | Train F1: 0.6738
Val Loss:   0.9759 | Val Acc:   0.6425 | Val F1:   0.6247
✓ New best model saved! (Val Acc: 0.6425, Val F1: 0.6247)

Epoch 3/10
----------------------------------------------------------------------


Training Epoch 2:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 2:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.7137 | Train Acc: 0.7550 | Train F1: 0.7510
Val Loss:   0.9065 | Val Acc:   0.6768 | Val F1:   0.6761
✓ New best model saved! (Val Acc: 0.6768, Val F1: 0.6761)

Epoch 4/10
----------------------------------------------------------------------


Training Epoch 3:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 3:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.5471 | Train Acc: 0.8250 | Train F1: 0.8238
Val Loss:   0.9268 | Val Acc:   0.6902 | Val F1:   0.6826
✓ New best model saved! (Val Acc: 0.6902, Val F1: 0.6826)

Epoch 5/10
----------------------------------------------------------------------


Training Epoch 4:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 4:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.3853 | Train Acc: 0.8876 | Train F1: 0.8872
Val Loss:   0.9580 | Val Acc:   0.7030 | Val F1:   0.6999
✓ New best model saved! (Val Acc: 0.7030, Val F1: 0.6999)

Epoch 6/10
----------------------------------------------------------------------


Training Epoch 5:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 5:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.2485 | Train Acc: 0.9372 | Train F1: 0.9371
Val Loss:   1.0230 | Val Acc:   0.6971 | Val F1:   0.6964

Epoch 7/10
----------------------------------------------------------------------


Training Epoch 6:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 6:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.1588 | Train Acc: 0.9651 | Train F1: 0.9651
Val Loss:   1.0939 | Val Acc:   0.6969 | Val F1:   0.6962

Epoch 8/10
----------------------------------------------------------------------


Training Epoch 7:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 7:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.1075 | Train Acc: 0.9791 | Train F1: 0.9791
Val Loss:   1.1653 | Val Acc:   0.6999 | Val F1:   0.6998

Epoch 9/10
----------------------------------------------------------------------


Training Epoch 8:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 8:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.0791 | Train Acc: 0.9858 | Train F1: 0.9858
Val Loss:   1.1892 | Val Acc:   0.6980 | Val F1:   0.6982

Epoch 10/10
----------------------------------------------------------------------


Training Epoch 9:   0%|          | 0/898 [00:00<?, ?it/s]

Validating Epoch 9:   0%|          | 0/113 [00:00<?, ?it/s]


Train Loss: 0.0619 | Train Acc: 0.9894 | Train F1: 0.9894
Val Loss:   1.2120 | Val Acc:   0.6932 | Val F1:   0.6932

Training completed!
Best validation accuracy: 0.7030
Training history saved to: c:\Users\rayrc\OneDrive\Documents\ML\Emotion Classifier ViT\checkpoints\history_vit_base_patch16_224.json


0,1
batch,▁▂▃▅▁▄▅█▁▂▅▆▇▄▆▁▂▃▄▅█▁▁▂▃▇█▂▃▇▄▆▃▆▆█▁▂▅▆
batch_train_loss,█▇▆▆▅▄▄▄▄▃▃▄▃▃▄▂▃▂▂▂▃▂▂▂▂▁▁▁▂▁▂▁▁▁▁▁▂▁▁▂
epoch,▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇▇█
learning_rate,▁▂▇██▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▁
train_accuracy,▁▄▅▆▇▇████
train_f1,▁▄▅▆▇▇████
train_loss,█▆▅▄▃▂▂▁▁▁
train_precision,▁▃▅▆▇▇████
train_recall,▁▄▅▆▇▇████
val_accuracy,▁▂▅▇█▇▇█▇▇

0,1
batch,850
batch_train_loss,0.02003
epoch,9
learning_rate,0
train_accuracy,0.98945
train_f1,0.98944
train_loss,0.06193
train_precision,0.98945
train_recall,0.98945
val_accuracy,0.69323


---
# Evaluation
Evaluate the trained model on the test set.

In [None]:
# Evaluate model on test set with W&B logging
metrics = evaluate_model(
    model=model, 
    test_dataset=test_ds, 
    batch_size=DEFAULT_BATCH_SIZE, 
    device=DEVICE,
    log_to_wandb=True,
    run_name="vit_base_patch16_224_final"
)

In [None]:
# Print detailed classification report
print_classification_report(metrics)

---
# Test Predictions
Let's visualize some predictions from the trained model.

In [None]:
# Visualize random predictions from test set
def predict_and_visualize(dataset, index, model, processor):
    """Get an image from the dataset, run model prediction, and display results."""
    
    img, true_label = dataset[index]
    img_pil = transforms.ToPILImage()(img)
    
    # Run model
    model.eval()
    model.to(DEVICE)
    inputs = processor(images=img_pil, return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)

    # Post-process
    probs = torch.softmax(outputs.logits, dim=-1)[0]
    pred_label = torch.argmax(probs).item()
    confidence = probs[pred_label].item()
     
    # Visualize
    print(f"Predicted Label: {EMOTION_LABELS[pred_label]} (Confidence: {confidence:.2%})")
    print(f"True Label:      {EMOTION_LABELS[true_label]}")
    
    # Show top 3 predictions
    top3_probs, top3_idx = torch.topk(probs, 3)
    print("\nTop 3 Predictions:")
    for i, (prob, idx) in enumerate(zip(top3_probs, top3_idx)):
        print(f"  {i+1}. {EMOTION_LABELS[idx]}: {prob:.2%}")
    
    plt.figure(figsize=(6, 6))
    plt.imshow(img_pil, cmap='gray')
    plt.title(f"Predicted: {EMOTION_LABELS[pred_label]}\nTrue: {EMOTION_LABELS[true_label]}")
    plt.axis("off")
    plt.tight_layout()
    plt.show()
    
    return true_label, pred_label, confidence


print("Testing predictions AFTER training:\n")

num_samples = 5
for i in range(num_samples):
    print(f"\n{'='*70}")
    print(f"Sample {i+1}/{num_samples}")
    print('='*70)
    idx = random.randint(0, len(test_ds)-1)
    true, pred, conf = predict_and_visualize(test_ds, idx, model, processor)

In [None]:
# Save results summary
import json

results_summary = {
    'model_name': MODEL_NAME,
    'test_accuracy': float(metrics['accuracy']),
    'test_precision': float(metrics['precision']),
    'test_recall': float(metrics['recall']),
    'test_f1': float(metrics['f1']),
    'per_class_f1': {
        emotion: float(f1) 
        for emotion, f1 in zip(EMOTION_LABELS, metrics['f1_per_class'])
    }
}

summary_path = RESULTS_DIR / "vit_results_summary.json"
with open(summary_path, 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"Results summary saved to: {summary_path}")
print("\nFinal Results:")
print(json.dumps(results_summary, indent=2))