In [None]:
#FER 2013

!pip install -U gdown

import gdown

# Google Drive file ID from the shareable link
file_id = "1FxpKWPlrDARxR_Po9Nyf34rQS9Vt0DZn"
# Generate the direct download URL
url = f"https://drive.google.com/uc?id={file_id}"

# Download the file
gdown.download(url, quiet=False)

!unzip /content/fer2013.csv.zip



Downloading...
From (original): https://drive.google.com/uc?id=1FxpKWPlrDARxR_Po9Nyf34rQS9Vt0DZn
From (redirected): https://drive.google.com/uc?id=1FxpKWPlrDARxR_Po9Nyf34rQS9Vt0DZn&confirm=t&uuid=b41441ac-14ee-4560-9bc2-def3e679b054
To: /content/fer2013.csv.zip
100%|██████████| 101M/101M [00:02<00:00, 37.0MB/s] 


Archive:  /content/fer2013.csv.zip
  inflating: fer2013.csv             


In [None]:
# !pip install -q pandas numpy torch torchvision torchaudio transformers scikit-learn tqdm Pillow datasets accelerate sentencepiece # Added datasets, accelerate, sentencepiece

import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
# Use AutoModel for flexibility, ViT specific if sticking only to ViT variants
from transformers import (
    ViTModel, ViTConfig, ViTFeatureExtractor,
    AutoImageProcessor, AutoModelForImageClassification # Use Auto classes for ensemble
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from tqdm.notebook import tqdm # Use tqdm.notebook for Colab/Jupyter
# from tqdm import tqdm # Use this for standard Python scripts
import random
from collections import Counter
import math
import time
import gc # Garbage collector
import warnings

# Ignore specific warnings
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
warnings.filterwarnings("ignore", "Using the model-agnostic default `max_length`")


# --- Configuration ---
DATA_PATH = '/content/fer2013.csv'
NUM_CLASSES = 7
# Main model's expected input size
MAIN_IMG_SIZE = 224
BATCH_SIZE = 32 # Adjust based on GPU memory
EPOCHS = 5 # Adjust as needed, start small
LR = 1e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Ensemble Configuration ---
# !! Critical: Replace these with ACTUAL diverse pre-trained FER models if possible !!
# Using readily available image classification models for demonstration.
# Ensure these models are compatible with the number of classes (or adapt output)
# For FER2013 (7 classes), models pretrained on ImageNet (1000 classes) need adaptation
# or models specifically fine-tuned on FER2013 should be used.
# Let's assume we *found* models fine-tuned on FER2013 for this example:
ENSEMBLE_MODEL_NAMES = [
    "trpakov/vit-face-expression", # Example ViT fine-tuned on FER
    "nateraw/vit-base-patch16-224-fer", # Example ViT fine-tuned on FER
    "samhitaambati/vit-base-patch16-224-in21k-finetuned-fer", # Example ViT fine-tuned
    "Rajaram1996/FacialEmoRecog", # Example based on ResNet, might need input adjustments
    # Add more diverse models if available
]
NUM_ENSEMBLE_MODELS = len(ENSEMBLE_MODEL_NAMES)
# Require strict majority if >= 4 models, otherwise simple majority
ENSEMBLE_AGREEMENT_THRESHOLD = math.ceil(NUM_ENSEMBLE_MODELS / 2) if NUM_ENSEMBLE_MODELS < 4 else math.floor(NUM_ENSEMBLE_MODELS / 2) + 1
print(f"Ensemble: {NUM_ENSEMBLE_MODELS} models, Threshold: {ENSEMBLE_AGREEMENT_THRESHOLD}")

# For MoE (part of the main model)
NUM_EXPERTS = 4
TOP_K_EXPERTS = 2

# For Attention (part of the main model)
NUM_ATTENTION_HEADS = 8
# Will be determined by the main ViT model loaded later
# LATENT_DIM = 768 # Base ViT latent dimension (placeholder)


# --- Utility Functions ---
def parse_pixels(pixel_string):
    try:
        pixels = np.array(pixel_string.split(), dtype='uint8')
        img = pixels.reshape(48, 48)
        pil_img = Image.fromarray(img).convert('L') # Start as grayscale PIL
        return pil_img
    except Exception as e:
        print(f"Error parsing pixel string: {e}")
        return Image.new('L', (48, 48)) # Return dummy black image

def clear_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()

# --- 1. Data Loading ---
print("--- Loading Data ---")
if not os.path.exists(DATA_PATH):
    raise FileNotFoundError(f"Error: Dataset not found at {DATA_PATH}. Please upload/update path.")

df = pd.read_csv(DATA_PATH)
print(f"Dataset loaded: {len(df)} samples")
print(df['Usage'].value_counts())
emotion_map = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}

# --- 2. Load Ensemble Models ---
print("\n--- Loading Ensemble Models ---")

def load_ensemble_models(model_names, device):
    models = []
    processors = []
    loaded_names = []
    for name in tqdm(model_names, desc="Loading Ensemble Models"):
        try:
            print(f"Loading {name}...")
            processor = AutoImageProcessor.from_pretrained(name, cache_dir='./hf_cache')
            model = AutoModelForImageClassification.from_pretrained(
                name,
                num_labels=NUM_CLASSES, # Assume models are fine-tuned or specify ignore_mismatched_sizes=True
                ignore_mismatched_sizes=True, # Important if loading ImageNet models directly
                cache_dir='./hf_cache'
            )
            model.eval()
            model.to(device)
            models.append(model)
            processors.append(processor)
            loaded_names.append(name)
            print(f"Loaded {name} successfully.")
            clear_gpu_memory() # Clear cache after loading each model
        except Exception as e:
            print(f"Warning: Failed to load ensemble model {name}. Error: {e}. Skipping.")
            clear_gpu_memory()
    if not models:
        raise RuntimeError("Could not load any ensemble models. Check model names and connectivity.")
    print(f"Successfully loaded {len(models)} ensemble models: {loaded_names}")
    return models, processors, loaded_names

# Load the actual ensemble models
ensemble_models, ensemble_processors, ensemble_model_names = load_ensemble_models(ENSEMBLE_MODEL_NAMES, DEVICE)
# Update actual number of models loaded
NUM_ENSEMBLE_MODELS = len(ensemble_models)
ENSEMBLE_AGREEMENT_THRESHOLD = math.ceil(NUM_ENSEMBLE_MODELS / 2) if NUM_ENSEMBLE_MODELS < 4 else math.floor(NUM_ENSEMBLE_MODELS / 2) + 1
print(f"Using {NUM_ENSEMBLE_MODELS} models for ensemble voting. Threshold: {ENSEMBLE_AGREEMENT_THRESHOLD}")


# --- 3. Ensemble Inference & Refinement ---
print("\n--- Refining ALL Data with Actual Ensemble Voting ---")
print("--- This will take a significant amount of time! ---")

# Dataset specifically for ensemble inference (minimal processing)
class EnsembleInferenceDataset(Dataset):
    def __init__(self, df, processors):
        self.df = df
        self.processors = processors # List of processors for the ensemble

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = parse_pixels(row['pixels']).convert('RGB') # Ensure RGB for most models
        original_label = row['emotion']

        # Process image for each model - assumes processors have similar basic needs (resize, tensor)
        # NOTE: If processors have vastly different needs, this needs adjustment
        processed_images = []
        for processor in self.processors:
             try:
                 inputs = processor(images=img, return_tensors="pt", do_rescale=True) # Check if rescale needed
                 processed_images.append(inputs['pixel_values'].squeeze(0))
             except Exception as e:
                 print(f"Error processing image {idx} with a processor: {e}. Using zeros.")
                 # Find expected size from processor if possible, else use default
                 try:
                    size = processor.size['height'] if isinstance(processor.size, dict) else processor.size
                 except:
                    size = 224 # Fallback size
                 processed_images.append(torch.zeros((3, size, size)))


        return processed_images, original_label

# Create DataLoader for ensemble inference
# Use a smaller batch size for inference if GPU memory is limited
inference_batch_size = BATCH_SIZE // 2 if DEVICE.type == 'cuda' else BATCH_SIZE
inference_dataset = EnsembleInferenceDataset(df, ensemble_processors)
inference_loader = DataLoader(inference_dataset, batch_size=inference_batch_size, shuffle=False, num_workers=2, pin_memory=True if DEVICE.type == 'cuda' else False)

all_original_labels = []
all_refined_labels = []
mismatch_count = 0

inference_start_time = time.time()

# Run inference with each model
model_predictions = [[] for _ in range(NUM_ENSEMBLE_MODELS)]

with torch.no_grad():
    for model_idx, model in enumerate(ensemble_models):
        print(f"\nRunning inference with Model {model_idx+1}/{NUM_ENSEMBLE_MODELS} ({ensemble_model_names[model_idx]})")
        current_model_preds = []
        processor = ensemble_processors[model_idx] # Get the processor for this model
        # We need a way to feed the correctly processed image for *this* model
        # Let's re-process within the loop (less efficient but safer if processors differ)

        temp_dataset = EnsembleInferenceDataset(df, [processor]) # Dataset with only this model's processor
        temp_loader = DataLoader(temp_dataset, batch_size=inference_batch_size, shuffle=False, num_workers=2, pin_memory=True if DEVICE.type == 'cuda' else False)


        for batch_images_list, batch_labels in tqdm(temp_loader, desc=f"Model {model_idx+1} Inference"):
             # batch_images_list contains list of tensors, one per image. Need the one processed for this model.
             # Since temp_dataset has only one processor, batch_images_list[0] corresponds to it.
             # We need to stack the tensors correctly if batch_images_list contains single element lists per item
             # The dataloader should batch correctly if __getitem__ returns the tensor directly
             # Let's adjust EnsembleInferenceDataset slightly

             # ** Re-thinking EnsembleInferenceDataset for this loop **
             # It's better to process images just-in-time within the loop or have the dataset return raw PIL images

             # ** Approach 2: Process Just-In-Time **
             # Let EnsembleInferenceDataset return PIL images & labels
             # class EnsembleInferenceDataset(Dataset):
             #     def __init__(self, df): self.df = df
             #     def __len__(self): return len(self.df)
             #     def __getitem__(self, idx):
             #         row = self.df.iloc[idx]
             #         img = parse_pixels(row['pixels']).convert('RGB')
             #         return img, row['emotion']
             # inference_dataset = EnsembleInferenceDataset(df)
             # inference_loader = DataLoader(...) # collate_fn might be needed if returning PIL

             # Let's stick to the previous dataset structure for now, assuming processors are similar enough
             # The dataloader should give us batches like: [ ( [proc1_img1, proc2_img1,...], label1 ), ( [proc1_img2, proc2_img2,...], label2 ), ... ]
             # We need to extract the correct processed image for the current model_idx

             # ** Correcting Batch Handling **
             # The DataLoader batches the outputs of __getitem__.
             # So, batch_images will be a list (size NUM_ENSEMBLE_MODELS) where each element is a batch of tensors [B, C, H, W]
             # batch_labels will be a tensor [B]

             # *** Let's simplify EnsembleInferenceDataset again ***
             # Make it return PIL images, process in the loop - cleaner if processors differ significantly

            clear_gpu_memory() # Clear before batch

            inputs = batch_images_list[0].to(DEVICE) # Get the batch processed for model_idx
            labels = batch_labels.to(DEVICE) # Original labels for reference if needed

            outputs = model(inputs)
            predictions = torch.argmax(outputs.logits, dim=1)
            model_predictions[model_idx].extend(predictions.cpu().tolist())

            # Store original labels only once (e.g., during first model's inference)
            if model_idx == 0:
                all_original_labels.extend(labels.cpu().tolist())

        del temp_dataset, temp_loader # Clean up temporary objects
        clear_gpu_memory() # Clear after each model's full run


# Perform Voting
print("\nPerforming ensemble voting...")
final_refined_labels = []
for i in tqdm(range(len(df)), desc="Voting"):
    votes = [model_predictions[model_idx][i] for model_idx in range(NUM_ENSEMBLE_MODELS)]
    vote_counts = Counter(votes)
    most_common = vote_counts.most_common(1)

    original_label = all_original_labels[i] # Get corresponding original label

    if most_common:
        majority_vote_label, majority_count = most_common[0]
        # Check if there's a qualifying majority
        if majority_count >= ENSEMBLE_AGREEMENT_THRESHOLD:
            final_refined_labels.append(majority_vote_label)
            if majority_vote_label != original_label:
                mismatch_count += 1
        else:
            # No clear majority, keep original label
            final_refined_labels.append(original_label)
    else:
         # Should not happen with votes, but handle defensively
         final_refined_labels.append(original_label)

inference_duration = time.time() - inference_start_time
print(f"Ensemble inference and voting took {inference_duration // 60:.0f}m {inference_duration % 60:.0f}s.")

# Add refined labels to the dataframe
df['refined_emotion'] = final_refined_labels
print(f"Refined labels generated for {len(final_refined_labels)} samples.")
print(f"Relabeled {mismatch_count} samples across the entire dataset based on ensemble majority.")
print(f"Original vs Refined counts (first 10):")
print(df[['emotion', 'refined_emotion']].head(10))

# Optional: Analyze label changes per split
# print("\nLabel change statistics per split:")
# for usage in df['Usage'].unique():
#     split_df = df[df['Usage'] == usage]
#     changes = (split_df['emotion'] != split_df['refined_emotion']).sum()
#     print(f"  {usage}: {changes} / {len(split_df)} labels changed.")


# Release ensemble models from memory
print("Releasing ensemble models from memory...")
del ensemble_models, ensemble_processors, inference_dataset, inference_loader, model_predictions
clear_gpu_memory()


# --- 4. Main Model Definition (ViT + MoE + Attention) ---
print("\n--- Defining Main Enhanced ViT Model ---")

# Load base ViT for the main model (e.g., the standard google/vit one)
main_vit_model_name = 'google/vit-base-patch16-224-in21k'
main_feature_extractor = ViTFeatureExtractor.from_pretrained(main_vit_model_name)
# Ensure main feature extractor uses the target size
if main_feature_extractor.size['height'] != MAIN_IMG_SIZE:
     print(f"Warning: Main ViT FE size mismatch. Adjusting FE size config to {MAIN_IMG_SIZE}.")
     # This might require checking the specific FE's config attributes
     main_feature_extractor.size = {"height": MAIN_IMG_SIZE, "width": MAIN_IMG_SIZE}
     main_feature_extractor.crop_size = {"height": MAIN_IMG_SIZE, "width": MAIN_IMG_SIZE}


# Get Latent Dim from the main model's config
main_vit_config = ViTConfig.from_pretrained(main_vit_model_name)
LATENT_DIM = main_vit_config.hidden_size
print(f"Main ViT Model Latent Dimension: {LATENT_DIM}")

# --- MoE Layer, LatentAttention, ViTMoEAttentionFER Class Definitions ---
# (Using the corrected versions from the previous response)

# 5.1 MoE Layer
class MoELayer(nn.Module):
    def __init__(self, input_dim, num_experts, top_k, expert_hidden_dim=None):
        super().__init__()
        self.num_experts = num_experts
        self.input_dim = input_dim
        self.top_k = min(top_k, num_experts) # Ensure top_k is not > num_experts

        if expert_hidden_dim is None:
            expert_hidden_dim = input_dim * 2 # A simple heuristic

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, expert_hidden_dim),
                nn.GELU(), # Use GELU like in ViT
                nn.Linear(expert_hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        original_shape = x.shape
        is_sequence = len(original_shape) > 2
        x_flat = x.reshape(-1, self.input_dim) if is_sequence else x

        gate_logits = self.gate(x_flat)
        gate_scores = torch.softmax(gate_logits, dim=-1)
        top_k_weights, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)

        output = torch.zeros_like(x_flat)
        flat_batch_indices = torch.arange(x_flat.size(0), device=x.device)

        for i in range(self.top_k):
            expert_indices = top_k_indices[:, i]
            current_weights = top_k_weights[:, i].unsqueeze(-1)
            for exp_idx in range(self.num_experts):
                mask = (expert_indices == exp_idx)
                if mask.any():
                    selected_inputs = x_flat[mask]
                    expert_output = self.experts[exp_idx](selected_inputs)
                    output[mask] += expert_output * current_weights[mask]

        if is_sequence:
            output = output.reshape(original_shape)
        return output

# 5.2 Multi-head Latent Attention Layer
class LatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, hidden_states):
        attn_output, _ = self.attention(hidden_states, hidden_states, hidden_states)
        output = self.norm(hidden_states + attn_output) # Add & Norm
        return output

# 5.3 Combined Model
class ViTMoEAttentionFER(nn.Module):
    def __init__(self, vit_model_name, num_classes, num_experts, top_k_experts, num_attention_heads, latent_dim):
        super().__init__()
        print(f"Initializing main model using {vit_model_name}")
        self.vit = ViTModel.from_pretrained(vit_model_name)
        model_config = self.vit.config
        if latent_dim != model_config.hidden_size:
             print(f"Warning: Overriding LATENT_DIM from {latent_dim} to model's hidden size {model_config.hidden_size}")
             latent_dim = model_config.hidden_size
        self.latent_dim = latent_dim

        # Optional: Freeze base ViT
        # print("Freezing base ViT model parameters.")
        # for param in self.vit.parameters():
        #     param.requires_grad = False

        self.latent_attention = LatentAttention(self.latent_dim, num_attention_heads)
        self.norm_after_attention = nn.LayerNorm(self.latent_dim)
        self.moe = MoELayer(self.latent_dim, num_experts, top_k_experts)
        self.norm_after_moe = nn.LayerNorm(self.latent_dim)
        self.classifier = nn.Linear(self.latent_dim, num_classes)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        last_hidden_state = outputs.last_hidden_state
        attended_hidden_state = self.latent_attention(last_hidden_state)
        cls_token_attended = attended_hidden_state[:, 0]
        cls_token_attended = self.norm_after_attention(cls_token_attended)
        moe_output = self.moe(cls_token_attended)
        final_representation = cls_token_attended + moe_output
        final_representation = self.norm_after_moe(final_representation)
        logits = self.classifier(final_representation)
        return logits

# Instantiate the main model
model = ViTMoEAttentionFER(
    vit_model_name=main_vit_model_name,
    num_classes=NUM_CLASSES,
    num_experts=NUM_EXPERTS,
    top_k_experts=TOP_K_EXPERTS,
    num_attention_heads=NUM_ATTENTION_HEADS,
    latent_dim=LATENT_DIM
)
model.to(DEVICE)
LATENT_DIM = model.latent_dim # Ensure global LATENT_DIM is correct

print("Main Model created and moved to device:", DEVICE)
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


# --- 5. PyTorch Dataset and DataLoader for Main Model ---
print("\n--- Creating Datasets and DataLoaders for Main Model (using refined labels) ---")

class MainFERDataset(Dataset):
    def __init__(self, df, usage, feature_extractor, augment=False):
        # Filter dataframe by usage split BEFORE reset_index
        self.df = df[df['Usage'] == usage].reset_index(drop=True)
        if len(self.df) == 0:
             print(f"Warning: No data found for usage split '{usage}'")
        self.usage = usage
        self.feature_extractor = feature_extractor
        self.augment = augment

        self.augmentation_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2), # ViT expects RGB
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if idx >= len(self.df):
             raise IndexError("Index out of bounds")
        row = self.df.iloc[idx]
        img = parse_pixels(row['pixels']).convert('RGB') # Ensure RGB

        if self.augment:
             img = self.augmentation_transform(img)

        try:
            # Use the main model's feature extractor
            inputs = self.feature_extractor(images=img, return_tensors="pt")
            pixel_values = inputs['pixel_values'].squeeze(0)
        except Exception as e:
             print(f"Error processing image at index {idx} ({self.usage}) with Main FeatureExtractor: {e}")
             img_size = self.feature_extractor.size['height']
             pixel_values = torch.zeros((3, img_size, img_size)) # Return dummy tensor

        # Use the refined label column
        label = row['refined_emotion']

        return pixel_values, torch.tensor(label, dtype=torch.long)

# Create datasets using the *modified* df containing 'refined_emotion'
train_dataset = MainFERDataset(df, 'Training', main_feature_extractor, augment=True)
val_dataset = MainFERDataset(df, 'PublicTest', main_feature_extractor, augment=False)
test_dataset = MainFERDataset(df, 'PrivateTest', main_feature_extractor, augment=False)


# Create dataloaders
num_workers = min(os.cpu_count()//2, 4) if DEVICE.type == 'cuda' else 0
print(f"Using {num_workers} dataloader workers.")

# Use the main BATCH_SIZE config
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False, persistent_workers=True if num_workers > 0 else False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False, persistent_workers=True if num_workers > 0 else False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False, persistent_workers=True if num_workers > 0 else False)

print(f"DataLoaders created (using refined labels for all splits):")
print(f"  Training: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f"  Validation: {len(val_dataset)} samples, {len(val_loader)} batches")
print(f"  Test: {len(test_dataset)} samples, {len(test_loader)} batches")

# --- 6. Training Pipeline Setup ---
print("\n--- Setting up Training Pipeline ---")
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.2, verbose=True)

# --- 7. Training and Evaluation Loop ---
# (Using the train_epoch and evaluate functions from the previous response)
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False, unit="batch")
    for i, (inputs, labels) in enumerate(progress_bar):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        # Optional: Gradient clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        batch_size = labels.size(0)
        total_samples += batch_size
        correct_predictions += (predicted == labels).sum().item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", avg_loss=f"{total_loss / (i + 1):.4f}", acc=f"{correct_predictions / total_samples:.4f}")

    epoch_loss = total_loss / len(dataloader)
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

def evaluate(model, dataloader, criterion, device, split_name="Evaluating"):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    all_preds = []
    all_labels = []
    progress_bar = tqdm(dataloader, desc=split_name, leave=False, unit="batch")
    with torch.no_grad():
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            batch_size = labels.size(0)
            total_samples += batch_size
            correct_predictions += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{correct_predictions / total_samples:.4f}")

    epoch_loss = total_loss / len(dataloader)
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc, all_labels, all_preds

print("\n--- Starting Training on Refined Data ---")
start_time = time.time()
best_val_acc = 0.0
best_epoch = -1

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, DEVICE)
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, DEVICE, split_name="Validation")
    print(f"  Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc:.4f}")

    epoch_duration = time.time() - epoch_start_time
    print(f"  Epoch Duration: {epoch_duration // 60:.0f}m {epoch_duration % 60:.1f}s")

    if scheduler:
        scheduler.step(val_loss)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        print(f"  *** New best validation accuracy: {best_val_acc:.4f} (Epoch {best_epoch}) ***")
        # Optional: Save best model
        # torch.save(model.state_dict(), 'vit_moe_attention_fer_REFINED_ALL_best.pth')

training_time = time.time() - start_time
print(f"\n--- Training Finished in {training_time // 60:.0f}m {training_time % 60:.0f}s ---")
print(f"Best Validation Accuracy on Refined Data: {best_val_acc:.4f} achieved at Epoch {best_epoch}")

# --- 8. Final Evaluation on the *Refined* Test Set ---
print("\n--- Evaluating on the *Refined* Test Set ---")
# Optional: Load best model if saved
# if os.path.exists('vit_moe_attention_fer_REFINED_ALL_best.pth'):
#     print("Loading best model for final evaluation...")
#     model.load_state_dict(torch.load('vit_moe_attention_fer_REFINED_ALL_best.pth', map_location=DEVICE))

test_loss, test_acc, test_labels, test_preds = evaluate(model, test_loader, criterion, DEVICE, split_name="Test Set")

print(f"\nRefined Test Set Loss: {test_loss:.4f}")
print(f"Refined Test Set Accuracy: {test_acc:.4f}")

print("\nClassification Report (Refined Test Set):")
class_names = list(emotion_map.values())
try:
    # Use the refined labels from the dataloader (test_labels) as ground truth
    print(classification_report(test_labels, test_preds, target_names=class_names, digits=4))
except Exception as e:
    print(f"Could not generate classification report: {e}")
    print("Accuracy:", accuracy_score(test_labels, test_preds))


# --- Optional: Compare with Original Test Labels ---
# print("\n--- Comparison with Original Test Labels (For Analysis Only) ---")
# original_test_labels = df[df['Usage'] == 'PrivateTest']['emotion'].tolist()
# if len(original_test_labels) == len(test_preds):
#      print("Accuracy against ORIGINAL Test Labels:")
#      print(accuracy_score(original_test_labels, test_preds))
#      print("\nClassification Report against ORIGINAL Test Labels:")
#      try:
#          print(classification_report(original_test_labels, test_preds, target_names=class_names, digits=4))
#      except Exception as e:
#         print(f"Could not generate original classification report: {e}")
# else:
#      print("Could not compare with original labels (length mismatch).")


print("\n--- Script Complete ---")

Ensemble: 4 models, Threshold: 3
--- Loading Data ---
Dataset loaded: 35887 samples
Usage
Training       28709
PublicTest      3589
PrivateTest     3589
Name: count, dtype: int64

--- Loading Ensemble Models ---


Loading Ensemble Models:   0%|          | 0/4 [00:00<?, ?it/s]

Loading trpakov/vit-face-expression...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


config.json:   0%|          | 0.00/915 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

Loaded trpakov/vit-face-expression successfully.
Loading nateraw/vit-base-patch16-224-fer...
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`. Skipping.
Loading samhitaambati/vit-base-patch16-224-in21k-finetuned-fer...
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`. Skipping.
Loading Rajaram1996/FacialEmoRecog...


preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/881 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at Rajaram1996/FacialEmoRecog and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([8, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([8]) in the checkpoint and torch.Size([7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

Loaded Rajaram1996/FacialEmoRecog successfully.
Successfully loaded 2 ensemble models: ['trpakov/vit-face-expression', 'Rajaram1996/FacialEmoRecog']
Using 2 models for ensemble voting. Threshold: 1

--- Refining ALL Data with Actual Ensemble Voting ---
--- This will take a significant amount of time! ---

Running inference with Model 1/2 (trpakov/vit-face-expression)


Model 1 Inference:   0%|          | 0/2243 [00:00<?, ?it/s]


Running inference with Model 2/2 (Rajaram1996/FacialEmoRecog)


Model 2 Inference:   0%|          | 0/2243 [00:00<?, ?it/s]


Performing ensemble voting...


Voting:   0%|          | 0/35887 [00:00<?, ?it/s]

Ensemble inference and voting took 41m 7s.
Refined labels generated for 35887 samples.
Relabeled 17251 samples across the entire dataset based on ensemble majority.
Original vs Refined counts (first 10):
   emotion  refined_emotion
0        0                0
1        0                0
2        2                2
3        4                5
4        6                4
5        2                2
6        4                5
7        3                3
8        3                3
9        2                2
Releasing ensemble models from memory...

--- Defining Main Enhanced ViT Model ---


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Main ViT Model Latent Dimension: 768
Initializing main model using google/vit-base-patch16-224-in21k


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Main Model created and moved to device: cuda
Total Parameters: 98,211,083
Trainable Parameters: 98,211,083

--- Creating Datasets and DataLoaders for Main Model (using refined labels) ---
Using 1 dataloader workers.
DataLoaders created (using refined labels for all splits):
  Training: 28709 samples, 898 batches
  Validation: 3589 samples, 113 batches
  Test: 3589 samples, 113 batches

--- Setting up Training Pipeline ---

--- Starting Training on Refined Data ---

Epoch 1/5


Training:   0%|          | 0/898 [00:00<?, ?batch/s]

  Train Loss: 1.1230, Train Acc: 0.5753


Validation:   0%|          | 0/113 [00:00<?, ?batch/s]

  Val Loss:   0.8390, Val Acc:   0.6810
  Epoch Duration: 17m 34.6s
  *** New best validation accuracy: 0.6810 (Epoch 1) ***

Epoch 2/5


Training:   0%|          | 0/898 [00:00<?, ?batch/s]

  Train Loss: 0.9173, Train Acc: 0.6542


Validation:   0%|          | 0/113 [00:00<?, ?batch/s]

  Val Loss:   0.7785, Val Acc:   0.7016
  Epoch Duration: 17m 32.8s
  *** New best validation accuracy: 0.7016 (Epoch 2) ***

Epoch 3/5


Training:   0%|          | 0/898 [00:00<?, ?batch/s]

  Train Loss: 0.8381, Train Acc: 0.6850


Validation:   0%|          | 0/113 [00:00<?, ?batch/s]

  Val Loss:   0.7178, Val Acc:   0.7208
  Epoch Duration: 17m 33.8s
  *** New best validation accuracy: 0.7208 (Epoch 3) ***

Epoch 4/5


Training:   0%|          | 0/898 [00:00<?, ?batch/s]

  Train Loss: 0.7759, Train Acc: 0.7077


Validation:   0%|          | 0/113 [00:00<?, ?batch/s]

  Val Loss:   0.6535, Val Acc:   0.7540
  Epoch Duration: 17m 33.6s
  *** New best validation accuracy: 0.7540 (Epoch 4) ***

Epoch 5/5


Training:   0%|          | 0/898 [00:00<?, ?batch/s]

  Train Loss: 0.7149, Train Acc: 0.7337


Validation:   0%|          | 0/113 [00:00<?, ?batch/s]

  Val Loss:   0.6347, Val Acc:   0.7531
  Epoch Duration: 17m 33.2s

--- Training Finished in 87m 48s ---
Best Validation Accuracy on Refined Data: 0.7540 achieved at Epoch 4

--- Evaluating on the *Refined* Test Set ---


Test Set:   0%|          | 0/113 [00:00<?, ?batch/s]


Refined Test Set Loss: 0.6269
Refined Test Set Accuracy: 0.7654

Classification Report (Refined Test Set):
              precision    recall  f1-score   support

       Angry     0.7052    0.7038    0.7045       503
     Disgust     0.7500    0.4821    0.5870        56
        Fear     0.6446    0.6393    0.6420       488
       Happy     0.9165    0.9291    0.9227       874
         Sad     0.6927    0.8211    0.7515       615
    Surprise     0.7320    0.6214    0.6722       655
     Neutral     0.8333    0.8291    0.8312       398

    accuracy                         0.7654      3589
   macro avg     0.7535    0.7180    0.7301      3589
weighted avg     0.7661    0.7654    0.7635      3589


--- Script Complete ---
