In [4]:
# train_emonet_emotic_full.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import pandas as pd
import os
import argparse
from pathlib import Path
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    f1_score, accuracy_score, precision_score, recall_score,
    confusion_matrix, classification_report, mean_absolute_error,
    mean_squared_error, r2_score
)
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# ============================================
# MODEL DEFINITION (from previous response)
# ============================================
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3,
                     stride=strd, padding=padding, bias=bias)

class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(ConvBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, int(out_planes / 2))
        self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
        self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
        self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
        self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))

        if in_planes != out_planes:
            self.downsample = nn.Sequential(
                nn.BatchNorm2d(in_planes),
                nn.ReLU(True),
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),
            )
        else:
            self.downsample = None

    def forward(self, x):
        residual = x
        out1 = self.bn1(x)
        out1 = F.relu(out1, True)
        out1 = self.conv1(out1)

        out2 = self.bn2(out1)
        out2 = F.relu(out2, True)
        out2 = self.conv2(out2)

        out3 = self.bn3(out2)
        out3 = F.relu(out3, True)
        out3 = self.conv3(out3)

        out3 = torch.cat((out1, out2, out3), 1)

        if self.downsample is not None:
            residual = self.downsample(residual)

        out3 += residual
        return out3

class HourGlass(nn.Module):
    def __init__(self, num_modules, depth, num_features):
        super(HourGlass, self).__init__()
        self.num_modules = num_modules
        self.depth = depth
        self.features = num_features
        self._generate_network(self.depth)

    def _generate_network(self, level):
        self.add_module('b1_' + str(level), ConvBlock(256, 256))
        self.add_module('b2_' + str(level), ConvBlock(256, 256))

        if level > 1:
            self._generate_network(level - 1)
        else:
            self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))

        self.add_module('b3_' + str(level), ConvBlock(256, 256))

    def _forward(self, level, inp):
        up1 = inp
        up1 = self._modules['b1_' + str(level)](up1)

        low1 = F.max_pool2d(inp, 2, stride=2)
        low1 = self._modules['b2_' + str(level)](low1)

        if level > 1:
            low2 = self._forward(level - 1, low1)
        else:
            low2 = low1
            low2 = self._modules['b2_plus_' + str(level)](low2)

        low3 = low2
        low3 = self._modules['b3_' + str(level)](low3)
        up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
        return up1 + up2

    def forward(self, x):
        return self._forward(self.depth, x)

class EmoNetSingleLabel26(nn.Module):
    def __init__(self, num_modules=2, n_expression=26, n_reg=2, n_blocks=4, attention=True):
        super(EmoNetSingleLabel26, self).__init__()
        self.num_modules = num_modules
        self.n_expression = n_expression
        self.n_reg = n_reg
        self.attention = attention
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = ConvBlock(64, 128)
        self.conv3 = ConvBlock(128, 128)
        self.conv4 = ConvBlock(128, 256)

        for hg_module in range(self.num_modules):
            self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
            self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
            self.add_module('conv_last' + str(hg_module),
                            nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
            self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
            self.add_module('l' + str(hg_module), 
                            nn.Conv2d(256, 68, kernel_size=1, stride=1, padding=0))

            if hg_module < self.num_modules - 1:
                self.add_module('bl' + str(hg_module), 
                                nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
                self.add_module('al' + str(hg_module), 
                                nn.Conv2d(68, 256, kernel_size=1, stride=1, padding=0))

        if self.attention:
            n_in_features = 256 * (num_modules + 1)
        else:
            n_in_features = 256 * (num_modules + 1) + 68
        
        n_features = [(256, 256)] * n_blocks

        self.emo_convs = []
        self.conv1x1_input_emo_2 = nn.Conv2d(n_in_features, 256, kernel_size=1, stride=1, padding=0)
        for in_f, out_f in n_features:
            self.emo_convs.append(ConvBlock(in_f, out_f))
            self.emo_convs.append(nn.MaxPool2d(2, 2))
        self.emo_net_2 = nn.Sequential(*self.emo_convs)
        self.avg_pool_2 = nn.AvgPool2d(4)
        
        self.emo_fc_shared = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )
        
        self.emotion_head = nn.Linear(128, self.n_expression)
        self.valence_head = nn.Linear(128, 1)
        self.arousal_head = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), True)
        x = F.max_pool2d(self.conv2(x), 2, stride=2)
        x = self.conv3(x)
        x = self.conv4(x)

        previous = x
        hg_features = []

        for i in range(self.num_modules):
            hg = self._modules['m' + str(i)](previous)
            ll = hg
            ll = self._modules['top_m_' + str(i)](ll)
            ll = F.relu(self._modules['bn_end' + str(i)]
                        (self._modules['conv_last' + str(i)](ll)), True)
            tmp_out = self._modules['l' + str(i)](ll)

            if i < self.num_modules - 1:
                ll = self._modules['bl' + str(i)](ll)
                tmp_out_ = self._modules['al' + str(i)](tmp_out)
                previous = previous + ll + tmp_out_

            hg_features.append(ll)

        hg_features_cat = torch.cat(tuple(hg_features), dim=1)

        if self.attention:
            mask = torch.sum(tmp_out, dim=1, keepdim=True)
            mask = torch.sigmoid(mask)
            hg_features_cat = hg_features_cat * mask
            emo_feat = torch.cat((x, hg_features_cat), dim=1)
        else:
            emo_feat = torch.cat([x, hg_features_cat, tmp_out], dim=1)

        emo_feat_conv1D = self.conv1x1_input_emo_2(emo_feat)
        final_features = self.emo_net_2(emo_feat_conv1D)
        final_features = self.avg_pool_2(final_features)
        batch_size = final_features.shape[0]
        final_features = final_features.view(batch_size, -1)
        
        shared_features = self.emo_fc_shared(final_features)
        
        emotion_logits = self.emotion_head(shared_features)
        valence = torch.tanh(self.valence_head(shared_features)).squeeze(1)
        arousal = torch.tanh(self.arousal_head(shared_features)).squeeze(1)

        return {
            'heatmap': tmp_out,
            'expression': emotion_logits,
            'valence': valence,
            'arousal': arousal
        }

    def load_pretrained_emonet(self, pretrained_path, freeze_backbone=True):
        print(f"Loading pretrained EmoNet from {pretrained_path}")
        
        try:
            pretrained_state = torch.load(pretrained_path, map_location='cpu')
            
            if 'state_dict' in pretrained_state:
                pretrained_state = pretrained_state['state_dict']
            
            pretrained_state = {k.replace('module.', ''): v 
                               for k, v in pretrained_state.items()}
            
            model_state = self.state_dict()
            
            compatible_weights = {}
            incompatible_keys = []
            
            for k, v in pretrained_state.items():
                if 'emo_fc_2' in k:
                    incompatible_keys.append(k)
                    continue
                
                if k in model_state and model_state[k].shape == v.shape:
                    compatible_weights[k] = v
                else:
                    incompatible_keys.append(k)
            
            model_state.update(compatible_weights)
            self.load_state_dict(model_state)
            
            print(f"✓ Loaded {len(compatible_weights)} layers from pretrained model")
            print(f"✗ Skipped {len(incompatible_keys)} incompatible layers")
            
            if freeze_backbone:
                self.freeze_backbone()
            
            return True
            
        except Exception as e:
            print(f"✗ Failed to load pretrained weights: {e}")
            return False
    
    def freeze_backbone(self):
        frozen_params = 0
        trainable_params = 0
        
        for name, param in self.named_parameters():
            if any(head in name for head in ['emotion_head', 'valence_head', 
                                              'arousal_head', 'emo_fc_shared']):
                param.requires_grad = True
                trainable_params += param.numel()
            else:
                param.requires_grad = False
                frozen_params += param.numel()
        
        print(f"✓ Frozen backbone:")
        print(f"    Frozen: {frozen_params:,} | Trainable: {trainable_params:,}")


In [5]:
# train_emonet_emotic_fixed.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
import numpy as np
import pandas as pd
import os
import argparse
from pathlib import Path
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    f1_score, accuracy_score, precision_score, recall_score,
    confusion_matrix, classification_report, mean_absolute_error,
    mean_squared_error, r2_score
)
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
EMOTION_CLASSES = {
    0: "Peace", 1: "Affection", 2: "Esteem", 3: "Anticipation", 4: "Engagement",
    5: "Confidence", 6: "Happiness", 7: "Pleasure", 8: "Excitement", 9: "Surprise",
    10: "Sympathy", 11: "Doubt/Confusion", 12: "Disconnection", 13: "Fatigue",
    14: "Embarrassment", 15: "Yearning", 16: "Disapproval", 17: "Aversion",
    18: "Annoyance", 19: "Anger", 20: "Sensitivity", 21: "Sadness",
    22: "Disquietment", 23: "Fear", 24: "Pain", 25: "Suffering"
}
EMOTION_COLUMNS = [
    'Peace', 'Affection', 'Esteem', 'Anticipation', 'Engagement',
    'Confidence', 'Happiness', 'Pleasure', 'Excitement', 'Surprise',
    'Sympathy', 'Doubt/Confusion', 'Disconnection', 'Fatigue',
    'Embarrassment', 'Yearning', 'Disapproval', 'Aversion',
    'Annoyance', 'Anger', 'Sensitivity', 'Sadness', 'Disquietment',
    'Fear', 'Pain', 'Suffering'
]
# [Keep all your existing model code: EmoNetSingleLabel26, ConvBlock, HourGlass, etc.]
# [Keep EMOTION_CLASSES and EMOTION_COLUMNS definitions]
# [Keep MetricsTracker class]
# [Keep plotting functions]
# ======================
class MetricsTracker:
    def __init__(self, num_classes=26, class_names=None):
        self.num_classes = num_classes
        self.class_names = class_names or [str(i) for i in range(num_classes)]
        
        self.all_preds = []
        self.all_labels = []
        self.all_valence_preds = []
        self.all_valence_labels = []
        self.all_arousal_preds = []
        self.all_arousal_labels = []
    
    def update(self, preds, labels, valence_preds=None, valence_labels=None, 
               arousal_preds=None, arousal_labels=None):
        self.all_preds.extend(preds.cpu().numpy())
        self.all_labels.extend(labels.cpu().numpy())
        
        if valence_preds is not None:
            self.all_valence_preds.extend(valence_preds.cpu().numpy())
            self.all_valence_labels.extend(valence_labels.cpu().numpy())
        
        if arousal_preds is not None:
            self.all_arousal_preds.extend(arousal_preds.cpu().numpy())
            self.all_arousal_labels.extend(arousal_labels.cpu().numpy())
    
    def compute(self):
        preds = np.array(self.all_preds)
        labels = np.array(self.all_labels)
        
        metrics = {
            'accuracy': accuracy_score(labels, preds),
            'f1_macro': f1_score(labels, preds, average='macro', zero_division=0),
            'f1_weighted': f1_score(labels, preds, average='weighted', zero_division=0),
            'precision_macro': precision_score(labels, preds, average='macro', zero_division=0),
            'recall_macro': recall_score(labels, preds, average='macro', zero_division=0)
        }
        
        if self.all_valence_preds:
            valence_preds = np.array(self.all_valence_preds)
            valence_labels = np.array(self.all_valence_labels)
            metrics['valence_mae'] = mean_absolute_error(valence_labels, valence_preds)
            metrics['valence_rmse'] = np.sqrt(mean_squared_error(valence_labels, valence_preds))
            metrics['valence_r2'] = r2_score(valence_labels, valence_preds)
        
        if self.all_arousal_preds:
            arousal_preds = np.array(self.all_arousal_preds)
            arousal_labels = np.array(self.all_arousal_labels)
            metrics['arousal_mae'] = mean_absolute_error(arousal_labels, arousal_preds)
            metrics['arousal_rmse'] = np.sqrt(mean_squared_error(arousal_labels, arousal_preds))
            metrics['arousal_r2'] = r2_score(arousal_labels, arousal_preds)
        
        return metrics
    
    def print_metrics(self, metrics_dict, phase=""):
        print(f"\n{'='*50}")
        print(f"{phase} Metrics")
        print(f"{'='*50}")
        print(f"Accuracy:       {metrics_dict['accuracy']*100:.2f}%")
        print(f"F1 (Macro):     {metrics_dict['f1_macro']:.4f}")
        print(f"F1 (Weighted):  {metrics_dict['f1_weighted']:.4f}")
        print(f"Precision:      {metrics_dict['precision_macro']:.4f}")
        print(f"Recall:         {metrics_dict['recall_macro']:.4f}")
        
        if 'valence_mae' in metrics_dict:
            print(f"\nValence:")
            print(f"  MAE:  {metrics_dict['valence_mae']:.4f}")
            print(f"  RMSE: {metrics_dict['valence_rmse']:.4f}")
            print(f"  R²:   {metrics_dict['valence_r2']:.4f}")
        
        if 'arousal_mae' in metrics_dict:
            print(f"\nArousal:")
            print(f"  MAE:  {metrics_dict['arousal_mae']:.4f}")
            print(f"  RMSE: {metrics_dict['arousal_rmse']:.4f}")
            print(f"  R²:   {metrics_dict['arousal_r2']:.4f}")
        
        print(f"{'='*50}\n")

# ============================================
# PLOTTING FUNCTIONS
# ============================================
def plot_training_progress(history, save_path):
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Training Progress', fontsize=16, fontweight='bold')
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Val', linewidth=2)
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history['train_acc'], label='Train', linewidth=2)
    axes[0, 1].plot(history['val_acc'], label='Val', linewidth=2)
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score
    axes[0, 2].plot(history['train_f1'], label='Train', linewidth=2)
    axes[0, 2].plot(history['val_f1'], label='Val', linewidth=2)
    axes[0, 2].set_title('F1 Score (Macro)')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('F1 Score')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Component Losses
    axes[1, 0].plot(history['train_emotion_loss'], label='Train Emotion', linewidth=2)
    axes[1, 0].plot(history['val_emotion_loss'], label='Val Emotion', linewidth=2)
    axes[1, 0].set_title('Emotion Loss')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Valence MAE
    axes[1, 1].plot(history['train_valence_mae'], label='Train', linewidth=2)
    axes[1, 1].plot(history['val_valence_mae'], label='Val', linewidth=2)
    axes[1, 1].set_title('Valence MAE')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('MAE')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Arousal MAE
    axes[1, 2].plot(history['train_arousal_mae'], label='Train', linewidth=2)
    axes[1, 2].plot(history['val_arousal_mae'], label='Val', linewidth=2)
    axes[1, 2].set_title('Arousal MAE')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('MAE')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# ============================================
# DATASET
# ========================
# ============================================
# IMPROVED DATASET WITH BETTER ERROR HANDLING
# ============================================
class EMOTICDataset(Dataset):
    def __init__(self, csv_file, img_arrays_dir, image_size=256, transform=None, max_samples=None):
        self.data = pd.read_csv(csv_file)
        self.img_arrays_dir = Path(img_arrays_dir)
        self.image_size = image_size
        self.transform = transform
        
        # Clean data
        self.data = self.data.dropna(subset=['Crop_name', 'Valence_norm', 'Arousal_norm']).reset_index(drop=True)
        
        # Extract dominant emotion
        if 'dominant_emotion' not in self.data.columns:
            print("Extracting dominant_emotion from one-hot encoded columns...")
            emotion_cols = [col for col in EMOTION_COLUMNS if col in self.data.columns]
            
            if len(emotion_cols) == 0:
                raise ValueError(f"❌ No emotion columns found!")
            
            emotion_values = self.data[emotion_cols].values
            self.data['dominant_emotion'] = emotion_values.argmax(axis=1)
        
        if max_samples:
            self.data = self.data.head(max_samples)
        
        print(f"\n✓ Dataset loaded: {len(self.data)} samples")
        print(f"\n📊 Dominant Emotion Distribution:")
        
        self.emotion_counts = np.zeros(26)
        for idx in range(len(self.data)):
            self.emotion_counts[self.data.iloc[idx]['dominant_emotion']] += 1
        
        self.present_classes = []
        for idx in range(26):
            count = int(self.emotion_counts[idx])
            if count > 0:
                self.present_classes.append(idx)
                print(f"  {idx:2d}. {EMOTION_CLASSES[idx]:20s}: {count:5d} samples ({100*count/len(self.data):5.1f}%)")
        
        # Check imbalance
        present_counts = self.emotion_counts[self.emotion_counts > 0]
        if len(present_counts) > 0:
            max_count = present_counts.max()
            min_count = present_counts.min()
            imbalance_ratio = max_count / min_count
            print(f"\n⚖️  Class Imbalance Ratio: {imbalance_ratio:.1f}:1")
            print(f"   Present classes: {len(self.present_classes)}/26")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        crop_name = row['Crop_name']
        img_path = self.img_arrays_dir / crop_name
        
        try:
            img_array = np.load(str(img_path))
            
            if len(img_array.shape) == 3 and img_array.shape[0] == 3:
                img_array = img_array.transpose(1, 2, 0)
            
            if img_array.max() <= 1.0:
                img_array = (img_array * 255).astype(np.uint8)
            else:
                img_array = img_array.astype(np.uint8)
            
            image = cv2.resize(img_array, (self.image_size, self.image_size))
            image_tensor = torch.FloatTensor(image).permute(2, 0, 1) / 255.0
            
            if self.transform:
                image_tensor = self.transform(image_tensor)
            
            emotion_class = int(row['dominant_emotion'])
            emotion_class = torch.LongTensor([emotion_class])[0]
            valence = torch.FloatTensor([float(row['Valence_norm'])])[0]
            arousal = torch.FloatTensor([float(row['Arousal_norm'])])[0]
            
            return {
                'image': image_tensor,
                'emotion': emotion_class,
                'valence': valence,
                'arousal': arousal,
                'filename': crop_name
            }
            
        except Exception as e:
            print(f"❌ Failed to load {crop_name}: {e}")
            dummy_image = torch.zeros(3, self.image_size, self.image_size)
            return {
                'image': dummy_image,
                'emotion': torch.LongTensor([0])[0],
                'valence': torch.FloatTensor([0.0])[0],
                'arousal': torch.FloatTensor([0.0])[0],
                'filename': 'dummy'
            }

# ============================================
# CLASS WEIGHTING WITH EXTREME IMBALANCE HANDLING
# ============================================
def compute_class_weights(train_dataset, method='effective_number', beta=0.9999):
    """
    Compute class weights with multiple strategies for extreme imbalance
    
    Args:
        train_dataset: Dataset instance
        method: 'inverse_freq', 'effective_number', or 'sqrt_inv_freq'
        beta: For effective number method (0.9999 works well for extreme imbalance)
    """
    emotion_counts = train_dataset.emotion_counts
    total = emotion_counts.sum()
    
    class_weights = np.zeros(26)
    
    if method == 'inverse_freq':
        # Standard inverse frequency
        for i in range(26):
            if emotion_counts[i] > 0:
                class_weights[i] = total / (26 * emotion_counts[i])
    
    elif method == 'effective_number':
        # Effective number of samples (better for extreme imbalance)
        # Reference: "Class-Balanced Loss Based on Effective Number of Samples"
        for i in range(26):
            if emotion_counts[i] > 0:
                effective_num = (1.0 - np.power(beta, emotion_counts[i])) / (1.0 - beta)
                class_weights[i] = 1.0 / effective_num
    
    elif method == 'sqrt_inv_freq':
        # Square root inverse frequency (softer than inverse)
        for i in range(26):
            if emotion_counts[i] > 0:
                class_weights[i] = np.sqrt(total / emotion_counts[i])
    
    # Normalize weights
    class_weights = class_weights / class_weights.sum() * 26
    
    # Cap extreme weights
    class_weights = np.clip(class_weights, 0, 20.0)
    
    class_weights = torch.FloatTensor(class_weights)
    
    print(f"\n📊 Class Weights (method={method}):")
    print(f"{'Emotion':<25} {'Count':<10} {'Weight':<10}")
    print("-" * 50)
    
    for i in range(26):
        if emotion_counts[i] > 0:
            print(f"{EMOTION_CLASSES[i]:<25} {int(emotion_counts[i]):<10} {class_weights[i]:.3f}")
    
    return class_weights

def create_balanced_sampler(train_dataset):
    """
    Create weighted sampler for balanced batch sampling
    """
    emotion_counts = train_dataset.emotion_counts
    
    # Compute sample weights (inverse frequency)
    sample_weights = np.zeros(len(train_dataset))
    
    for idx in range(len(train_dataset)):
        emotion_class = train_dataset.data.iloc[idx]['dominant_emotion']
        if emotion_counts[emotion_class] > 0:
            sample_weights[idx] = 1.0 / emotion_counts[emotion_class]
    
    # Normalize
    sample_weights = sample_weights / sample_weights.sum()
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(train_dataset),
        replacement=True
    )
    
    print(f"\n✓ Created WeightedRandomSampler for balanced training")
    
    return sampler
# ============================================
# ENHANCED METRICS WITH TOP-K ACCURACY
# ============================================
class EnhancedMetricsTracker:
    def __init__(self, num_classes=26, class_names=None, top_k_values=[1, 3, 5]):
        self.num_classes = num_classes
        self.class_names = class_names or [str(i) for i in range(num_classes)]
        self.top_k_values = top_k_values
        
        self.all_preds = []
        self.all_labels = []
        self.all_probs = []  # Store full probability distributions
        self.all_valence_preds = []
        self.all_valence_labels = []
        self.all_arousal_preds = []
        self.all_arousal_labels = []
    
    def update(self, probs, labels, valence_preds=None, valence_labels=None, 
               arousal_preds=None, arousal_labels=None):
        """
        Args:
            probs: Softmax probabilities (N, num_classes)
            labels: Ground truth labels (N,)
        """
        # Store probabilities
        self.all_probs.append(probs.cpu().numpy())
        
        # Get predictions (argmax)
        preds = torch.argmax(probs, dim=1)
        self.all_preds.extend(preds.cpu().numpy())
        self.all_labels.extend(labels.cpu().numpy())
        
        if valence_preds is not None:
            self.all_valence_preds.extend(valence_preds.cpu().numpy())
            self.all_valence_labels.extend(valence_labels.cpu().numpy())
        
        if arousal_preds is not None:
            self.all_arousal_preds.extend(arousal_preds.cpu().numpy())
            self.all_arousal_labels.extend(arousal_labels.cpu().numpy())
    
    def compute(self):
        preds = np.array(self.all_preds)
        labels = np.array(self.all_labels)
        probs = np.concatenate(self.all_probs, axis=0)  # (N, num_classes)
        
        # Verify probabilities sum to 1
        prob_sums = probs.sum(axis=1)
        assert np.allclose(prob_sums, 1.0, atol=1e-5), \
            f"❌ Probabilities don't sum to 1! Range: [{prob_sums.min():.6f}, {prob_sums.max():.6f}]"
        
        metrics = {
            'accuracy': accuracy_score(labels, preds),
            'f1_macro': f1_score(labels, preds, average='macro', zero_division=0),
            'f1_weighted': f1_score(labels, preds, average='weighted', zero_division=0),
            'precision_macro': precision_score(labels, preds, average='macro', zero_division=0),
            'recall_macro': recall_score(labels, preds, average='macro', zero_division=0)
        }
        
        # ✅ TOP-K ACCURACY
        for k in self.top_k_values:
            top_k_acc = self._compute_top_k_accuracy(probs, labels, k)
            metrics[f'top_{k}_accuracy'] = top_k_acc
        
        # ✅ PROBABILITY ANALYSIS
        prob_stats = self._analyze_probabilities(probs, labels)
        metrics.update(prob_stats)
        
        # Regression metrics
        if self.all_valence_preds:
            valence_preds = np.array(self.all_valence_preds)
            valence_labels = np.array(self.all_valence_labels)
            metrics['valence_mae'] = mean_absolute_error(valence_labels, valence_preds)
            metrics['valence_rmse'] = np.sqrt(mean_squared_error(valence_labels, valence_preds))
            metrics['valence_r2'] = r2_score(valence_labels, valence_preds)
        
        if self.all_arousal_preds:
            arousal_preds = np.array(self.all_arousal_preds)
            arousal_labels = np.array(self.all_arousal_labels)
            metrics['arousal_mae'] = mean_absolute_error(arousal_labels, arousal_preds)
            metrics['arousal_rmse'] = np.sqrt(mean_squared_error(arousal_labels, arousal_preds))
            metrics['arousal_r2'] = r2_score(arousal_labels, arousal_preds)
        
        return metrics
    
    def _compute_top_k_accuracy(self, probs, labels, k):
        """
        Check if true label is in top-k predictions
        """
        top_k_preds = np.argsort(probs, axis=1)[:, -k:]  # (N, k)
        correct = 0
        for i, label in enumerate(labels):
            if label in top_k_preds[i]:
                correct += 1
        return correct / len(labels)
    
    def _analyze_probabilities(self, probs, labels):
        """
        Analyze probability distributions
        """
        # Max probability statistics
        max_probs = probs.max(axis=1)
        
        # Probability of true label
        true_label_probs = probs[np.arange(len(labels)), labels]
        
        # Entropy (measure of uncertainty)
        epsilon = 1e-10
        entropy = -np.sum(probs * np.log(probs + epsilon), axis=1)
        
        return {
            'mean_max_prob': max_probs.mean(),
            'std_max_prob': max_probs.std(),
            'mean_true_prob': true_label_probs.mean(),
            'std_true_prob': true_label_probs.std(),
            'mean_entropy': entropy.mean(),
            'std_entropy': entropy.std()
        }
    
    def print_metrics(self, metrics_dict, phase=""):
        print(f"\n{'='*70}")
        print(f"{phase} Metrics")
        print(f"{'='*70}")
        
        # Classification metrics
        print(f"\n📊 Classification Performance:")
        print(f"  Top-1 Accuracy:   {metrics_dict['accuracy']*100:.2f}%")
        for k in self.top_k_values:
            if f'top_{k}_accuracy' in metrics_dict:
                print(f"  Top-{k} Accuracy:   {metrics_dict[f'top_{k}_accuracy']*100:.2f}%")
        
        print(f"\n  F1 (Macro):       {metrics_dict['f1_macro']:.4f}")
        print(f"  F1 (Weighted):    {metrics_dict['f1_weighted']:.4f}")
        print(f"  Precision:        {metrics_dict['precision_macro']:.4f}")
        print(f"  Recall:           {metrics_dict['recall_macro']:.4f}")
        
        # Probability analysis
        print(f"\n📈 Probability Analysis:")
        print(f"  Mean Max Prob:    {metrics_dict['mean_max_prob']:.4f} ± {metrics_dict['std_max_prob']:.4f}")
        print(f"  Mean True Prob:   {metrics_dict['mean_true_prob']:.4f} ± {metrics_dict['std_true_prob']:.4f}")
        print(f"  Mean Entropy:     {metrics_dict['mean_entropy']:.4f} ± {metrics_dict['std_entropy']:.4f}")
        
        # Regression metrics
        if 'valence_mae' in metrics_dict:
            print(f"\n📉 Valence Regression:")
            print(f"  MAE:  {metrics_dict['valence_mae']:.4f}")
            print(f"  RMSE: {metrics_dict['valence_rmse']:.4f}")
            print(f"  R²:   {metrics_dict['valence_r2']:.4f}")
        
        if 'arousal_mae' in metrics_dict:
            print(f"\n📉 Arousal Regression:")
            print(f"  MAE:  {metrics_dict['arousal_mae']:.4f}")
            print(f"  RMSE: {metrics_dict['arousal_rmse']:.4f}")
            print(f"  R²:   {metrics_dict['arousal_r2']:.4f}")
        
        print(f"{'='*70}\n")

# ============================================
# PROBABILITY VERIFICATION UTILITY
# ============================================
def verify_probabilities(probs, labels, sample_size=5):
    """
    Detailed probability verification for debugging
    
    Args:
        probs: (N, num_classes) probability tensor
        labels: (N,) ground truth labels
        sample_size: number of samples to print
    """
    probs_np = probs.detach().cpu().numpy()
    labels_np = labels.detach().cpu().numpy()
    
    print(f"\n{'='*70}")
    print("🔍 PROBABILITY VERIFICATION")
    print(f"{'='*70}")
    
    # Check normalization
    prob_sums = probs_np.sum(axis=1)
    print(f"\n✓ Probability Normalization:")
    print(f"  Min sum: {prob_sums.min():.8f}")
    print(f"  Max sum: {prob_sums.max():.8f}")
    print(f"  Mean sum: {prob_sums.mean():.8f}")
    print(f"  All close to 1.0: {np.allclose(prob_sums, 1.0, atol=1e-5)}")
    
    # Sample analysis
    print(f"\n📋 Sample Analysis (first {sample_size} samples):")
    print(f"{'='*70}")
    
    for i in range(min(sample_size, len(labels_np))):
        true_label = labels_np[i]
        sample_probs = probs_np[i]
        
        # Get top-5 predictions
        top_5_indices = np.argsort(sample_probs)[-5:][::-1]
        top_5_probs = sample_probs[top_5_indices]
        
        print(f"\nSample {i+1}:")
        print(f"  True Label: {true_label} ({EMOTION_CLASSES[true_label]})")
        print(f"  True Label Prob: {sample_probs[true_label]:.4f}")
        print(f"  Probability Sum: {sample_probs.sum():.8f}")
        
        print(f"\n  Top-5 Predictions:")
        for rank, (idx, prob) in enumerate(zip(top_5_indices, top_5_probs), 1):
            marker = "✓" if idx == true_label else " "
            print(f"    {rank}. {marker} [{idx:2d}] {EMOTION_CLASSES[idx]:20s}: {prob:.4f}")
    
    print(f"\n{'='*70}\n")

# ============================================
# ENHANCED TRAINING WITH PROBABILITY ANALYSIS
# ============================================
def train_emonet_with_probability_analysis(config):
    """
    Training with comprehensive probability verification and top-k metrics
    """
    
    save_dir = Path(config['save_path'])
    save_dir.mkdir(parents=True, exist_ok=True)
    plots_dir = save_dir / 'plots'
    plots_dir.mkdir(exist_ok=True)
    
    # Save config
    with open(save_dir / 'config.json', 'w') as f:
        json.dump(config, f, indent=2)
    
    print(f"\n{'='*70}")
    print(f"Training Configuration")
    print(f"{'='*70}")
    for key, value in config.items():
        print(f"  {key:20s}: {value}")
    print(f"{'='*70}\n")
    
    # Data augmentation
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
        transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.1)),
    ])
    
    # Datasets
    print("Loading datasets...")
    train_dataset = EMOTICDataset(
        csv_file=config['train_csv'],
        img_arrays_dir=config['img_arrays_dir'],
        image_size=config['image_size'],
        transform=train_transform,
        max_samples=config.get('max_samples')
    )
    
    val_dataset = EMOTICDataset(
        csv_file=config['val_csv'],
        img_arrays_dir=config['img_arrays_dir'],
        image_size=config['image_size'],
        transform=None,
        max_samples=config.get('max_samples')
    )
    
    # Class weights and balanced sampling
    print(f"\n{'='*70}")
    print("HANDLING CLASS IMBALANCE")
    print(f"{'='*70}")
    
    class_weights = compute_class_weights(
        train_dataset, 
        method=config.get('weight_method', 'effective_number'),
        beta=config.get('weight_beta', 0.9999)
    )
    
    balanced_sampler = create_balanced_sampler(train_dataset)
    
    # Data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        sampler=balanced_sampler,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    # Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n{'='*70}")
    print(f"MODEL SETUP")
    print(f"{'='*70}")
    print(f"Using device: {device}")
    
    model = EmoNetSingleLabel26(n_expression=26, n_reg=2, attention=True).to(device)
    
    # Load pretrained weights
    if config.get('pretrained_path') and os.path.exists(config['pretrained_path']):
        model.load_pretrained_emonet(
            config['pretrained_path'],
            freeze_backbone=False
        )
    
    # Loss with class weights
    class_weights = class_weights.to(device)
    emotion_criterion = nn.CrossEntropyLoss(weight=class_weights)
    regression_criterion = nn.MSELoss()
    
    # Optimizer
    backbone_params = []
    head_params = []
    
    for name, param in model.named_parameters():
        if any(x in name for x in ['emotion_head', 'valence_head', 'arousal_head', 'emo_fc_shared']):
            head_params.append(param)
        else:
            backbone_params.append(param)
    
    if config.get('pretrained_path'):
        optimizer = optim.Adam([
            {'params': head_params, 'lr': config['learning_rate']},
            {'params': backbone_params, 'lr': config['learning_rate'] * 0.1}
        ], weight_decay=config['weight_decay'])
    else:
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                              weight_decay=config['weight_decay'])
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True, min_lr=1e-7
    )
    
    # ✅ Use enhanced metrics tracker
    top_k_values = config.get('top_k_values', [1, 3, 5])
    
    # Training history (now includes top-k accuracies)
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': [],
        'train_emotion_loss': [], 'val_emotion_loss': [],
        'train_valence_loss': [], 'val_valence_loss': [],
        'train_arousal_loss': [], 'val_arousal_loss': [],
        'train_valence_mae': [], 'val_valence_mae': [],
        'train_arousal_mae': [], 'val_arousal_mae': [],
        'learning_rate': [],
        'train_mean_max_prob': [], 'val_mean_max_prob': [],
        'train_mean_true_prob': [], 'val_mean_true_prob': [],
    }
    
    # Add top-k accuracy tracking
    for k in top_k_values:
        history[f'train_top_{k}_acc'] = []
        history[f'val_top_{k}_acc'] = []
    
    best_val_f1 = 0.0
    best_val_top_5 = 0.0  # Track best top-5 accuracy
    best_epoch = 0
    patience_counter = 0
    max_patience = 15
    
    print(f"\n{'='*70}")
    print("STARTING TRAINING")
    print(f"{'='*70}\n")
    
    # Training loop
    for epoch in range(config['num_epochs']):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print(f"{'='*70}")
        
        # ========== TRAINING ==========
        model.train()
        train_metrics = EnhancedMetricsTracker(
            num_classes=26, 
            class_names=EMOTION_COLUMNS,
            top_k_values=top_k_values
        )
        
        train_loss = 0.0
        train_emotion_loss = 0.0
        train_valence_loss = 0.0
        train_arousal_loss = 0.0
        
        pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
        for batch_idx, batch in enumerate(pbar):
            images = batch['image'].to(device)
            emotion_labels = batch['emotion'].to(device)
            valence_labels = batch['valence'].to(device)
            arousal_labels = batch['arousal'].to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Get probabilities (softmax)
            logits = outputs['expression']
            probs = F.softmax(logits, dim=1)
            
            # ✅ Verify probabilities in first batch of first epoch
            if epoch == 0 and batch_idx == 0:
                verify_probabilities(probs, emotion_labels, sample_size=5)
            
            # Losses
            e_loss = emotion_criterion(logits, emotion_labels)
            v_loss = regression_criterion(outputs['valence'], valence_labels)
            a_loss = regression_criterion(outputs['arousal'], arousal_labels)
            
            total_loss = (
                config.get('emotion_loss_weight', 5.0) * e_loss + 
                config.get('valence_loss_weight', 0.5) * v_loss + 
                config.get('arousal_loss_weight', 0.5) * a_loss
            )
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += total_loss.item()
            train_emotion_loss += e_loss.item()
            train_valence_loss += v_loss.item()
            train_arousal_loss += a_loss.item()
            
            # ✅ Update metrics with probabilities
            train_metrics.update(
                probs.detach(), emotion_labels.detach(),
                outputs['valence'].detach(), valence_labels.detach(),
                outputs['arousal'].detach(), arousal_labels.detach()
            )
            
            pbar.set_postfix({
                'Loss': f'{total_loss.item():.4f}',
                'E': f'{e_loss.item():.3f}',
                'V': f'{v_loss.item():.3f}',
                'A': f'{a_loss.item():.3f}'
            })
        
        train_metrics_dict = train_metrics.compute()
        
        # ========== VALIDATION ==========
        model.eval()
        val_metrics = EnhancedMetricsTracker(
            num_classes=26, 
            class_names=EMOTION_COLUMNS,
            top_k_values=top_k_values
        )
        
        val_loss = 0.0
        val_emotion_loss = 0.0
        val_valence_loss = 0.0
        val_arousal_loss = 0.0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
            for batch_idx, batch in enumerate(pbar):
                images = batch['image'].to(device)
                emotion_labels = batch['emotion'].to(device)
                valence_labels = batch['valence'].to(device)
                arousal_labels = batch['arousal'].to(device)
                
                outputs = model(images)
                
                logits = outputs['expression']
                probs = F.softmax(logits, dim=1)
                
                # ✅ Verify validation probabilities in first batch of first epoch
                if epoch == 0 and batch_idx == 0:
                    print("\n🔍 VALIDATION SET PROBABILITIES:")
                    verify_probabilities(probs, emotion_labels, sample_size=3)
                
                e_loss = emotion_criterion(logits, emotion_labels)
                v_loss = regression_criterion(outputs['valence'], valence_labels)
                a_loss = regression_criterion(outputs['arousal'], arousal_labels)
                
                total_loss = (
                    config.get('emotion_loss_weight', 5.0) * e_loss + 
                    config.get('valence_loss_weight', 0.5) * v_loss + 
                    config.get('arousal_loss_weight', 0.5) * a_loss
                )
                
                val_loss += total_loss.item()
                val_emotion_loss += e_loss.item()
                val_valence_loss += v_loss.item()
                val_arousal_loss += a_loss.item()
                
                # ✅ Update metrics with probabilities
                val_metrics.update(
                    probs.detach(), emotion_labels.detach(),
                    outputs['valence'].detach(), valence_labels.detach(),
                    outputs['arousal'].detach(), arousal_labels.detach()
                )
                
                pbar.set_postfix({'Loss': f'{total_loss.item():.4f}'})
        
        val_metrics_dict = val_metrics.compute()
        
        # Calculate averages
        train_loss /= len(train_loader)
        train_emotion_loss /= len(train_loader)
        train_valence_loss /= len(train_loader)
        train_arousal_loss /= len(train_loader)
        
        val_loss /= len(val_loader)
        val_emotion_loss /= len(val_loader)
        val_valence_loss /= len(val_loader)
        val_arousal_loss /= len(val_loader)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['train_emotion_loss'].append(train_emotion_loss)
        history['train_valence_loss'].append(train_valence_loss)
        history['train_arousal_loss'].append(train_arousal_loss)
        history['train_acc'].append(train_metrics_dict['accuracy'] * 100)
        history['train_f1'].append(train_metrics_dict['f1_macro'])
        history['train_valence_mae'].append(train_metrics_dict['valence_mae'])
        history['train_arousal_mae'].append(train_metrics_dict['arousal_mae'])
        history['train_mean_max_prob'].append(train_metrics_dict['mean_max_prob'])
        history['train_mean_true_prob'].append(train_metrics_dict['mean_true_prob'])
        
        history['val_loss'].append(val_loss)
        history['val_emotion_loss'].append(val_emotion_loss)
        history['val_valence_loss'].append(val_valence_loss)
        history['val_arousal_loss'].append(val_arousal_loss)
        history['val_acc'].append(val_metrics_dict['accuracy'] * 100)
        history['val_f1'].append(val_metrics_dict['f1_macro'])
        history['val_valence_mae'].append(val_metrics_dict['valence_mae'])
        history['val_arousal_mae'].append(val_metrics_dict['arousal_mae'])
        history['val_mean_max_prob'].append(val_metrics_dict['mean_max_prob'])
        history['val_mean_true_prob'].append(val_metrics_dict['mean_true_prob'])
        history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        
        # ✅ Track top-k accuracies
        for k in top_k_values:
            history[f'train_top_{k}_acc'].append(train_metrics_dict[f'top_{k}_accuracy'] * 100)
            history[f'val_top_{k}_acc'].append(val_metrics_dict[f'top_{k}_accuracy'] * 100)
        
        # Print summary
        print(f"\n📊 Epoch {epoch+1} Summary:")
        print(f"  Train - Loss: {train_loss:.4f} | Top-1: {train_metrics_dict['accuracy']*100:.2f}% | F1: {train_metrics_dict['f1_macro']:.4f}")
        print(f"          Top-5: {train_metrics_dict['top_5_accuracy']*100:.2f}% | Max Prob: {train_metrics_dict['mean_max_prob']:.3f}")
        print(f"  Val   - Loss: {val_loss:.4f} | Top-1: {val_metrics_dict['accuracy']*100:.2f}% | F1: {val_metrics_dict['f1_macro']:.4f}")
        print(f"          Top-5: {val_metrics_dict['top_5_accuracy']*100:.2f}% | Max Prob: {val_metrics_dict['mean_max_prob']:.3f}")
        
        # Detailed metrics every 5 epochs
        if (epoch + 1) % 5 == 0:
            train_metrics.print_metrics(train_metrics_dict, "Training")
            val_metrics.print_metrics(val_metrics_dict, "Validation")
        
        # ✅ Save best model based on BOTH F1 and Top-5 accuracy
        improvement = False
        
        # Primary metric: F1 score
        if val_metrics_dict['f1_macro'] > best_val_f1:
            best_val_f1 = val_metrics_dict['f1_macro']
            improvement = True
        
        # Secondary metric: Top-5 accuracy
        if val_metrics_dict['top_5_accuracy'] > best_val_top_5:
            best_val_top_5 = val_metrics_dict['top_5_accuracy']
            improvement = True
        
        if improvement:
            best_epoch = epoch + 1
            patience_counter = 0
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'config': config,
                'history': history,
                'best_val_f1': best_val_f1,
                'best_val_top_5': best_val_top_5,
                'val_metrics': val_metrics_dict,
                'class_weights': class_weights.cpu()
            }, save_dir / 'best_model.pth')
            
            print(f"\n  ✓ Saved best model (F1: {best_val_f1:.4f} | Top-5: {best_val_top_5*100:.2f}%)")
        else:
            patience_counter += 1
            print(f"\n  ⚠ No improvement ({patience_counter}/{max_patience})")
            print(f"     Best F1: {best_val_f1:.4f} | Best Top-5: {best_val_top_5*100:.2f}%")
        
 
        # Checkpoint
        if (epoch + 1) % 2 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history
            }, save_dir / f'checkpoint_epoch_{epoch+1}.pth')
        
        # Update scheduler
        scheduler.step(val_metrics_dict['f1_macro'])
        
        # Plot progress
        if (epoch + 1) % 5 == 0:
            plot_training_progress_enhanced(history, plots_dir / f'progress_epoch_{epoch+1}.png', top_k_values)
    
    print(f"\n{'='*70}")
    print(f"Training Complete!")
    print(f"Best Epoch: {best_epoch}")
    print(f"Best Val F1: {best_val_f1:.4f}")
    print(f"Best Val Top-5: {best_val_top_5*100:.2f}%")
    print(f"{'='*70}\n")
    
    return model, history

# ============================================
# ENHANCED PLOTTING
# ============================================
def plot_training_progress_enhanced(history, save_path, top_k_values):
    """Plot with top-k accuracies and probability analysis"""
    
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
    
    fig.suptitle('Training Progress - Enhanced Metrics', fontsize=16, fontweight='bold')
    
    # Loss
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(history['train_loss'], label='Train', linewidth=2)
    ax1.plot(history['val_loss'], label='Val', linewidth=2)
    ax1.set_title('Total Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Top-1 Accuracy
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(history['train_acc'], label='Train', linewidth=2)
    ax2.plot(history['val_acc'], label='Val', linewidth=2)
    ax2.set_title('Top-1 Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # ✅ Top-K Accuracies
    ax3 = fig.add_subplot(gs[0, 2])
    colors = ['blue', 'green', 'orange']
    for i, k in enumerate(top_k_values):
        if f'val_top_{k}_acc' in history:
            ax3.plot(history[f'val_top_{k}_acc'], 
                    label=f'Top-{k}', linewidth=2, color=colors[i % len(colors)])
    ax3.set_title('Top-K Validation Accuracy')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Accuracy (%)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # F1 Score
    ax4 = fig.add_subplot(gs[0, 3])
    ax4.plot(history['train_f1'], label='Train', linewidth=2)
    ax4.plot(history['val_f1'], label='Val', linewidth=2)
    ax4.set_title('F1 Score (Macro)')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('F1 Score')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # ✅ Probability Analysis
    ax5 = fig.add_subplot(gs[1, 0])
    ax5.plot(history['train_mean_max_prob'], label='Train', linewidth=2)
    ax5.plot(history['val_mean_max_prob'], label='Val', linewidth=2)
    ax5.set_title('Mean Max Probability')
    ax5.set_xlabel('Epoch')
    ax5.set_ylabel('Probability')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    ax6 = fig.add_subplot(gs[1, 1])
    ax6.plot(history['train_mean_true_prob'], label='Train', linewidth=2)
    ax6.plot(history['val_mean_true_prob'], label='Val', linewidth=2)
    ax6.set_title('Mean True Label Probability')
    ax6.set_xlabel('Epoch')
    ax6.set_ylabel('Probability')
    ax6.legend()
    ax6.grid(True, alpha=0.3)
    
    # Component Losses
    ax7 = fig.add_subplot(gs[1, 2])
    ax7.plot(history['train_emotion_loss'], label='Train', linewidth=2)
    ax7.plot(history['val_emotion_loss'], label='Val', linewidth=2)
    ax7.set_title('Emotion Loss')
    ax7.set_xlabel('Epoch')
    ax7.set_ylabel('Loss')
    ax7.legend()
    ax7.grid(True, alpha=0.3)
    
    # Learning Rate
    ax8 = fig.add_subplot(gs[1, 3])
    ax8.plot(history['learning_rate'], linewidth=2, color='red')
    ax8.set_title('Learning Rate')
    ax8.set_xlabel('Epoch')
    ax8.set_ylabel('LR')
    ax8.set_yscale('log')
    ax8.grid(True, alpha=0.3)
    
    # Valence MAE
    ax9 = fig.add_subplot(gs[2, 0])
    ax9.plot(history['train_valence_mae'], label='Train', linewidth=2)
    ax9.plot(history['val_valence_mae'], label='Val', linewidth=2)
    ax9.set_title('Valence MAE')
    ax9.set_xlabel('Epoch')
    ax9.set_ylabel('MAE')
    ax9.legend()
    ax9.grid(True, alpha=0.3)
    
    # Arousal MAE
    ax10 = fig.add_subplot(gs[2, 1])
    ax10.plot(history['train_arousal_mae'], label='Train', linewidth=2)
    ax10.plot(history['val_arousal_mae'], label='Val', linewidth=2)
    ax10.set_title('Arousal MAE')
    ax10.set_xlabel('Epoch')
    ax10.set_ylabel('MAE')
    ax10.legend()
    ax10.grid(True, alpha=0.3)
    
    # Valence Loss
    ax11 = fig.add_subplot(gs[2, 2])
    ax11.plot(history['train_valence_loss'], label='Train', linewidth=2)
    ax11.plot(history['val_valence_loss'], label='Val', linewidth=2)
    ax11.set_title('Valence Loss')
    ax11.set_xlabel('Epoch')
    ax11.set_ylabel('Loss')
    ax11.legend()
    ax11.grid(True, alpha=0.3)
    
    # Arousal Loss
    ax12 = fig.add_subplot(gs[2, 3])
    ax12.plot(history['train_arousal_loss'], label='Train', linewidth=2)
    ax12.plot(history['val_arousal_loss'], label='Val', linewidth=2)
    ax12.set_title('Arousal Loss')
    ax12.set_xlabel('Epoch')
    ax12.set_ylabel('Loss')
    ax12.legend()
    ax12.grid(True, alpha=0.3)
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
   config = {
        'train_csv': '/kaggle/input/emotic/normalized_balanced_limited_dominant_emotion.csv',
    'val_csv': '/kaggle/input/emotic/normalized_dataset_val_dominant_emotion.csv',
    'img_arrays_dir': '/kaggle/input/emotic/archive_emot 2/archive_emot/img_arrs',
    'save_path': '/kaggle/working/',
       'pretrained_path': '/kaggle/input/emotic/emonet_8.pth',
    'freeze_backbone': False,
    'image_size': 256,
    
    
        
        'num_epochs': 100,
        'batch_size': 64,
        'learning_rate': 1e-4,
        'weight_decay': 1e-4,
        'image_size': 256,
        'num_workers': 4,
        
        # Loss weights
        'emotion_loss_weight': 2.0,
        'valence_loss_weight': 0.5,
        'arousal_loss_weight': 0.5,
        
        # Class weighting
        'weight_method': 'effective_number',  # or 'inverse_freq', 'sqrt_inv_freq'
        'weight_beta': 0.9999,
        
        # ✅ Top-K metrics
        'top_k_values': [1, 3, 5],
        
        'max_samples': None  # for debugging
    }

model, history = train_emonet_with_probability_analysis(config)



Training Configuration
  train_csv           : /kaggle/input/emotic/normalized_balanced_limited_dominant_emotion.csv
  val_csv             : /kaggle/input/emotic/normalized_dataset_val_dominant_emotion.csv
  img_arrays_dir      : /kaggle/input/emotic/archive_emot 2/archive_emot/img_arrs
  save_path           : /kaggle/working/
  pretrained_path     : /kaggle/input/emotic/emonet_8.pth
  freeze_backbone     : False
  image_size          : 256
  num_epochs          : 100
  batch_size          : 64
  learning_rate       : 0.0001
  weight_decay        : 0.0001
  num_workers         : 4
  emotion_loss_weight : 2.0
  valence_loss_weight : 0.5
  arousal_loss_weight : 0.5
  weight_method       : effective_number
  weight_beta         : 0.9999
  top_k_values        : [1, 3, 5]
  max_samples         : None

Loading datasets...
Extracting dominant_emotion from one-hot encoded columns...

✓ Dataset loaded: 3862 samples

📊 Dominant Emotion Distribution:
   0. Peace               :   441 samples ( 1

Training Epoch 1:   0%|          | 0/60 [00:00<?, ?it/s]


🔍 PROBABILITY VERIFICATION

✓ Probability Normalization:
  Min sum: 0.99999982
  Max sum: 1.00000012
  Mean sum: 1.00000000
  All close to 1.0: True

📋 Sample Analysis (first 5 samples):

Sample 1:
  True Label: 21 (Sadness)
  True Label Prob: 0.1000
  Probability Sum: 0.99999994

  Top-5 Predictions:
    1.   [ 7] Pleasure            : 0.1502
    2.   [ 4] Engagement          : 0.1051
    3. ✓ [21] Sadness             : 0.1000
    4.   [ 0] Peace               : 0.0923
    5.   [19] Anger               : 0.0870

Sample 2:
  True Label: 8 (Excitement)
  True Label Prob: 0.0190
  Probability Sum: 1.00000012

  Top-5 Predictions:
    1.   [ 3] Anticipation        : 0.1430
    2.   [19] Anger               : 0.0851
    3.   [20] Sensitivity         : 0.0807
    4.   [24] Pain                : 0.0557
    5.   [ 1] Affection           : 0.0545

Sample 3:
  True Label: 18 (Annoyance)
  True Label Prob: 0.0393
  Probability Sum: 1.00000000

  Top-5 Predictions:
    1.   [24] Pain            

Training Epoch 1: 100%|██████████| 60/60 [00:58<00:00,  1.03it/s, Loss=6.6414, E=3.199, V=0.201, A=0.286]
Validation Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]


🔍 VALIDATION SET PROBABILITIES:


Validation Epoch 1:   7%|▋         | 1/15 [00:01<00:16,  1.20s/it, Loss=6.6510]


🔍 PROBABILITY VERIFICATION

✓ Probability Normalization:
  Min sum: 0.99999988
  Max sum: 1.00000012
  Mean sum: 1.00000000
  All close to 1.0: True

📋 Sample Analysis (first 3 samples):

Sample 1:
  True Label: 4 (Engagement)
  True Label Prob: 0.0157
  Probability Sum: 1.00000000

  Top-5 Predictions:
    1.   [24] Pain                : 0.0889
    2.   [11] Doubt/Confusion     : 0.0685
    3.   [25] Suffering           : 0.0612
    4.   [20] Sensitivity         : 0.0555
    5.   [ 3] Anticipation        : 0.0555

Sample 2:
  True Label: 3 (Anticipation)
  True Label Prob: 0.0291
  Probability Sum: 1.00000012

  Top-5 Predictions:
    1.   [ 2] Esteem              : 0.0621
    2.   [17] Aversion            : 0.0558
    3.   [24] Pain                : 0.0532
    4.   [22] Disquietment        : 0.0509
    5.   [11] Doubt/Confusion     : 0.0493

Sample 3:
  True Label: 1 (Affection)
  True Label Prob: 0.0249
  Probability Sum: 0.99999994

  Top-5 Predictions:
    1.   [25] Suffering    

Validation Epoch 1: 100%|██████████| 15/15 [00:05<00:00,  2.79it/s, Loss=6.5957]



📊 Epoch 1 Summary:
  Train - Loss: 6.8717 | Top-1: 4.32% | F1: 0.0305
          Top-5: 20.76% | Max Prob: 0.089
  Val   - Loss: 6.7733 | Top-1: 5.22% | F1: 0.0124
          Top-5: 26.51% | Max Prob: 0.081

  ✓ Saved best model (F1: 0.0124 | Top-5: 26.51%)

Epoch 2/100


Training Epoch 2: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=5.2369, E=2.483, V=0.228, A=0.313]
Validation Epoch 2: 100%|██████████| 15/15 [00:05<00:00,  3.00it/s, Loss=6.8329]



📊 Epoch 2 Summary:
  Train - Loss: 6.1820 | Top-1: 7.53% | F1: 0.0470
          Top-5: 24.56% | Max Prob: 0.099
  Val   - Loss: 7.0970 | Top-1: 0.42% | F1: 0.0009
          Top-5: 15.66% | Max Prob: 0.108

  ⚠ No improvement (1/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 3/100


Training Epoch 3: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=6.0674, E=2.949, V=0.194, A=0.146]
Validation Epoch 3: 100%|██████████| 15/15 [00:05<00:00,  3.00it/s, Loss=7.0414]



📊 Epoch 3 Summary:
  Train - Loss: 5.7740 | Top-1: 9.84% | F1: 0.0485
          Top-5: 28.49% | Max Prob: 0.125
  Val   - Loss: 7.3911 | Top-1: 0.00% | F1: 0.0000
          Top-5: 7.31% | Max Prob: 0.139

  ⚠ No improvement (2/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 4/100


Training Epoch 4: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=5.4383, E=2.595, V=0.205, A=0.290]
Validation Epoch 4: 100%|██████████| 15/15 [00:04<00:00,  3.02it/s, Loss=7.3470]



📊 Epoch 4 Summary:
  Train - Loss: 5.4543 | Top-1: 10.89% | F1: 0.0493
          Top-5: 28.52% | Max Prob: 0.152
  Val   - Loss: 7.7263 | Top-1: 0.00% | F1: 0.0000
          Top-5: 2.51% | Max Prob: 0.173

  ⚠ No improvement (3/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 5/100


Training Epoch 5: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=4.8536, E=2.304, V=0.239, A=0.252]
Validation Epoch 5: 100%|██████████| 15/15 [00:05<00:00,  3.00it/s, Loss=7.5909]



📊 Epoch 5 Summary:
  Train - Loss: 5.1214 | Top-1: 12.73% | F1: 0.0540
          Top-5: 29.51% | Max Prob: 0.178
  Val   - Loss: 7.9928 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.94% | Max Prob: 0.193

Training Metrics

📊 Classification Performance:
  Top-1 Accuracy:   12.73%
  Top-1 Accuracy:   12.73%
  Top-3 Accuracy:   22.06%
  Top-5 Accuracy:   29.51%

  F1 (Macro):       0.0540
  F1 (Weighted):    0.0558
  Precision:        0.0830
  Recall:           0.1262

📈 Probability Analysis:
  Mean Max Prob:    0.1779 ± 0.1063
  Mean True Prob:   0.0634 ± 0.0922
  Mean Entropy:     2.9286 ± 0.3070

📉 Valence Regression:
  MAE:  0.3549
  RMSE: 0.4500
  R²:   -0.1668

📉 Arousal Regression:
  MAE:  0.3929
  RMSE: 0.4890
  R²:   -0.2677


Validation Metrics

📊 Classification Performance:
  Top-1 Accuracy:   0.00%
  Top-1 Accuracy:   0.00%
  Top-3 Accuracy:   0.21%
  Top-5 Accuracy:   0.94%

  F1 (Macro):       0.0000
  F1 (Weighted):    0.0000
  Precision:        0.0000
  Recall:          

Training Epoch 6: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=5.8683, E=2.835, V=0.169, A=0.228]
Validation Epoch 6: 100%|██████████| 15/15 [00:04<00:00,  3.02it/s, Loss=7.9861]



📊 Epoch 6 Summary:
  Train - Loss: 4.8910 | Top-1: 12.45% | F1: 0.0512
          Top-5: 29.35% | Max Prob: 0.206
  Val   - Loss: 8.3615 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.31% | Max Prob: 0.238

  ⚠ No improvement (5/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 7/100


Training Epoch 7: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=3.9352, E=1.855, V=0.198, A=0.253]
Validation Epoch 7: 100%|██████████| 15/15 [00:05<00:00,  2.98it/s, Loss=8.1228]



📊 Epoch 7 Summary:
  Train - Loss: 4.7378 | Top-1: 13.57% | F1: 0.0573
          Top-5: 30.13% | Max Prob: 0.230
  Val   - Loss: 8.5012 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.21% | Max Prob: 0.237

  ⚠ No improvement (6/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 8/100


Training Epoch 8: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=4.9817, E=2.383, V=0.227, A=0.205]
Validation Epoch 8: 100%|██████████| 15/15 [00:05<00:00,  2.99it/s, Loss=8.2963]



📊 Epoch 8 Summary:
  Train - Loss: 4.6865 | Top-1: 13.28% | F1: 0.0550
          Top-5: 29.14% | Max Prob: 0.244
  Val   - Loss: 8.7166 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.21% | Max Prob: 0.258

  ⚠ No improvement (7/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 9/100


Training Epoch 9: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=4.1729, E=1.967, V=0.197, A=0.280]
Validation Epoch 9: 100%|██████████| 15/15 [00:04<00:00,  3.05it/s, Loss=8.3119]



📊 Epoch 9 Summary:
  Train - Loss: 4.4477 | Top-1: 14.45% | F1: 0.0577
          Top-5: 31.41% | Max Prob: 0.254
  Val   - Loss: 8.7408 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.21% | Max Prob: 0.258

  ⚠ No improvement (8/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 10/100


Training Epoch 10: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=4.8531, E=2.330, V=0.200, A=0.188]
Validation Epoch 10: 100%|██████████| 15/15 [00:04<00:00,  3.04it/s, Loss=8.4026]



📊 Epoch 10 Summary:
  Train - Loss: 4.5468 | Top-1: 13.96% | F1: 0.0593
          Top-5: 29.90% | Max Prob: 0.262
  Val   - Loss: 8.8624 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.21% | Max Prob: 0.261

Training Metrics

📊 Classification Performance:
  Top-1 Accuracy:   13.96%
  Top-1 Accuracy:   13.96%
  Top-3 Accuracy:   22.19%
  Top-5 Accuracy:   29.90%

  F1 (Macro):       0.0593
  F1 (Weighted):    0.0591
  Precision:        0.0943
  Recall:           0.1431

📈 Probability Analysis:
  Mean Max Prob:    0.2619 ± 0.1595
  Mean True Prob:   0.0837 ± 0.1499
  Mean Entropy:     2.6677 ± 0.4636

📉 Valence Regression:
  MAE:  0.3398
  RMSE: 0.4308
  R²:   -0.0940

📉 Arousal Regression:
  MAE:  0.3868
  RMSE: 0.4798
  R²:   -0.2094


Validation Metrics

📊 Classification Performance:
  Top-1 Accuracy:   0.00%
  Top-1 Accuracy:   0.00%
  Top-3 Accuracy:   0.00%
  Top-5 Accuracy:   0.21%

  F1 (Macro):       0.0000
  F1 (Weighted):    0.0000
  Precision:        0.0000
  Recall:         

Training Epoch 11: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=4.2142, E=2.008, V=0.220, A=0.174]
Validation Epoch 11: 100%|██████████| 15/15 [00:04<00:00,  3.06it/s, Loss=8.4906]



📊 Epoch 11 Summary:
  Train - Loss: 4.3649 | Top-1: 15.23% | F1: 0.0595
          Top-5: 31.28% | Max Prob: 0.267
  Val   - Loss: 8.9798 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.21% | Max Prob: 0.267

  ⚠ No improvement (10/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 12/100


Training Epoch 12: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=3.7334, E=1.772, V=0.139, A=0.239]
Validation Epoch 12: 100%|██████████| 15/15 [00:04<00:00,  3.01it/s, Loss=8.4551]



📊 Epoch 12 Summary:
  Train - Loss: 4.3660 | Top-1: 14.90% | F1: 0.0586
          Top-5: 32.19% | Max Prob: 0.272
  Val   - Loss: 8.9107 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.10% | Max Prob: 0.254

  ⚠ No improvement (11/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 13/100


Training Epoch 13: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s, Loss=3.7092, E=1.758, V=0.198, A=0.186]
Validation Epoch 13: 100%|██████████| 15/15 [00:04<00:00,  3.03it/s, Loss=8.5706]



📊 Epoch 13 Summary:
  Train - Loss: 4.2699 | Top-1: 15.21% | F1: 0.0652
          Top-5: 31.95% | Max Prob: 0.278
  Val   - Loss: 9.0322 | Top-1: 0.00% | F1: 0.0000
          Top-5: 0.10% | Max Prob: 0.262

  ⚠ No improvement (12/15)
     Best F1: 0.0124 | Best Top-5: 26.51%

Epoch 14/100


Training Epoch 14:  75%|███████▌  | 45/60 [00:42<00:13,  1.13it/s, Loss=3.6797, E=1.754, V=0.184, A=0.162]

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, mean_squared_error
from tqdm import tqdm

class SimpleInference:
    """Simple inference for emotion recognition"""
    
    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.model.eval()
        self.device = device
    
    @torch.no_grad()
    def evaluate(self, dataloader):
        """
        Run inference and compute metrics
        
        Returns:
            dict with accuracy, f1, mae_valence, rmse_valence, mae_arousal, rmse_arousal
        """
        all_emotion_preds = []
        all_emotion_labels = []
        all_valence_preds = []
        all_valence_labels = []
        all_arousal_preds = []
        all_arousal_labels = []
        
        print("Running inference...")
        for batch in tqdm(dataloader):
            images = batch['image'].to(self.device)
            
            # Get one-hot emotion labels (shape: B x 26)
            emotion_labels = batch['emotion']  # B x 26
            
            # Get continuous labels
            valence_labels = batch['valence']  # B
            arousal_labels = batch['arousal']   # B
            
            # Forward pass
            outputs = self.model(images)
            
            # Emotion predictions
            emotion_logits = outputs['expression']  # B x 26
            emotion_probs = F.softmax(emotion_logits, dim=1)  # B x 26
            
            # Get argmax (dominant emotion)
            emotion_pred_class = torch.argmax(emotion_probs, dim=1)  # B
            
            # Get ground truth class from one-hot encoding
            emotion_true_class = torch.argmax(emotion_labels, dim=1)  # B
            
            # Store predictions
            all_emotion_preds.extend(emotion_pred_class.cpu().numpy())
            all_emotion_labels.extend(emotion_true_class.cpu().numpy())
            
            # Valence and Arousal predictions
            all_valence_preds.extend(outputs['valence'].cpu().numpy())
            all_valence_labels.extend(valence_labels.numpy())
            
            all_arousal_preds.extend(outputs['arousal'].cpu().numpy())
            all_arousal_labels.extend(arousal_labels.numpy())
        
        # Convert to numpy arrays
        emotion_preds = np.array(all_emotion_preds)
        emotion_labels = np.array(all_emotion_labels)
        
        valence_preds = np.array(all_valence_preds)
        valence_labels = np.array(all_valence_labels)
        
        arousal_preds = np.array(all_arousal_preds)
        arousal_labels = np.array(all_arousal_labels)
        
        # Compute metrics
        print("\nComputing metrics...")
        
        # Emotion metrics
        accuracy = accuracy_score(emotion_labels, emotion_preds)
        f1_macro = f1_score(emotion_labels, emotion_preds, average='macro', zero_division=0)
        f1_weighted = f1_score(emotion_labels, emotion_preds, average='weighted', zero_division=0)
        
        # Valence metrics
        valence_mae = mean_absolute_error(valence_labels, valence_preds)
        valence_rmse = np.sqrt(mean_squared_error(valence_labels, valence_preds))
        
        # Arousal metrics
        arousal_mae = mean_absolute_error(arousal_labels, arousal_preds)
        arousal_rmse = np.sqrt(mean_squared_error(arousal_labels, arousal_preds))
        
        metrics = {
            'emotion_accuracy': accuracy,
            'emotion_f1_macro': f1_macro,
            'emotion_f1_weighted': f1_weighted,
            'valence_mae': valence_mae,
            'valence_rmse': valence_rmse,
            'arousal_mae': arousal_mae,
            'arousal_rmse': arousal_rmse
        }
        
        # Print results
        print(f"\n{'='*60}")
        print(f"{'EVALUATION RESULTS':^60}")
        print(f"{'='*60}")
        print(f"\nEmotion Classification (26 classes):")
        print(f"  Accuracy:      {accuracy*100:6.2f}%")
        print(f"  F1 (Macro):    {f1_macro:6.4f}")
        print(f"  F1 (Weighted): {f1_weighted:6.4f}")
        print(f"\nValence Regression:")
        print(f"  MAE:           {valence_mae:6.4f}")
        print(f"  RMSE:          {valence_rmse:6.4f}")
        print(f"\nArousal Regression:")
        print(f"  MAE:           {arousal_mae:6.4f}")
        print(f"  RMSE:          {arousal_rmse:6.4f}")
        print(f"{'='*60}\n")
        
        return metrics

# ============================================
# USAGE
# ============================================

def run_inference(model_path, test_csv, img_arrays_dir, 
                  batch_size=64, num_workers=4, image_size=256):
    """
    Run inference on test set
    
    Args:
        model_path: Path to model checkpoint
        test_csv: Path to test CSV
        img_arrays_dir: Directory with image arrays
        batch_size: Batch size
        num_workers: Number of workers
        image_size: Image size
    
    Returns:
        dict with all metrics
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}\n")
    
    # Load model
    print(f"Loading model from {model_path}...")
    checkpoint = torch.load(model_path, map_location=device)
    
    model = EmoNetSingleLabel26(n_expression=26, n_reg=2, attention=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print("✓ Model loaded\n")
    
    # Create dataset
    print(f"Loading test dataset from {test_csv}...")
    test_dataset = EMOTICDataset(
        csv_file=test_csv,
        img_arrays_dir=img_arrays_dir,
        image_size=image_size,
        transform=None
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    print(f"✓ Test dataset loaded: {len(test_dataset)} samples\n")
    
    # Run inference
    inference = SimpleInference(model, device=device)
    metrics = inference.evaluate(test_loader)
    
    return metrics


In [None]:
 metrics = run_inference(
        model_path='/kaggle/working/best_model.pth',
        test_csv='/path/to/test.csv',
        img_arrays_dir='/path/to/img_arrays',
        batch_size=64,
        num_workers=4,
        image_size=256
    )
    
    print("\n✓ Inference complete!")
    print(f"  Emotion Accuracy: {metrics['emotion_accuracy']*100:.2f}%")
    print(f"  Emotion F1:       {metrics['emotion_f1_macro']:.4f}")
    print(f"  Valence MAE:      {metrics['valence_mae']:.4f}")
    print(f"  Valence RMSE:     {metrics['valence_rmse']:.4f}")
    print(f"  Arousal MAE:      {metrics['arousal_mae']:.4f}")
    print(f"  Arousal RMSE:     {metrics['arousal_rmse']:.4f}")