# Video Emotion Recognition on MELD Dataset Using Masked SlowFast with Adaptive Pathway Scaling

---

# Understanding Loss Values in Mixed Precision Training
### Core Concept

#### In PyTorch's mixed precision training, `scaler.scale(loss)` `creates a new tensor without modifying the original loss tensor`, which is why using `loss.item()` for metrics remains correct. This design exists because gradient scaling is purely a numerical stability technique - it amplifies gradients to prevent underflow in FP16 precision, then automatically unscales them during optimization. The original loss value must remain unchanged because it represents the true objective function value needed for monitoring convergence, comparing models, and making training decisions. If we used scaled loss values for metrics, our training curves would be artificially inflated and incomparable across different scaling factors. The accumulation steps multiplier is theoretically incorrect because it distorts the mathematical expectation of the loss function - when accumulating gradients over multiple batches, we want the average loss per sample, not an inflated version. Therefore, loss.item() always gives the mathematically correct, unscaled loss value that should be used for all metric calculations, while the scaled version exists only temporarily during backpropagation to ensure numerical stability in gradient computation.

---

In [1]:
# !pip install torch torchvision pytorchvideo > /dev/null
# !pip install torch torchvision pytorchvideo > /dev/null 2>&1

# > sends standard output (stdout) to /dev/null (hides it).

# 2>&1 sends standard error (stderr) to the same place, so you see nothing at all, even if there are warnings.

In [2]:
# =====================================
# IMPORTS AND SETUP
# =====================================

!pip install torch torchvision pytorchvideo > /dev/null 2>&1

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
from tqdm import tqdm
import wandb
from pytorchvideo.models.hub import slowfast_r50
import torchvision.transforms as transforms
from kaggle_secrets import UserSecretsClient
import random
from huggingface_hub import HfApi, hf_hub_download, create_repo
import torch.nn as nn
from typing import Dict, Any, Optional
from torchvision import transforms
import pytorchvideo.models.hub as models
from sklearn.metrics import f1_score, confusion_matrix, classification_report
import torch.nn.functional as F
import os
import shutil
import torch.optim.lr_scheduler as lr_scheduler
import re
import glob
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2  # Fixed import
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_exponential

  check_for_updates()


In [3]:
# =====================================
# API AND SECRETS SETUP
# =====================================

def setup_api_tokens():
    """Setup API tokens for HuggingFace and Weights & Biases"""
    user_secrets = UserSecretsClient()
    hf_token = user_secrets.get_secret("HF_TOKEN")
    wandb_token = user_secrets.get_secret("WANDB_API_KEY")
    
    # Initialize W&B and HF API
    wandb.login(key=wandb_token)
    hf_api = HfApi(token=hf_token)
    
    return hf_token, wandb_token, hf_api

In [4]:
# =====================================
# REPRODUCIBILITY UTILITIES
# =====================================

def set_seed(seed=42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
# =====================================
# CONFIGURATION
# =====================================

class Config:
    """Configuration parameters for the training pipeline"""
    num_frames = 32
    crop_size = (224, 224)
    dataset_mean = [0.45, 0.45, 0.45]
    dataset_std = [0.225, 0.225, 0.225]
    seed = 42
    batch_size = 16
    max_epochs = 100
    min_epochs = 25    # Minimum 20 epochs for video models
    no_improvement_threshold = 0.005  # 0.5% weighted F1 improvement
    patience = 10     # Allow 10 epochs without significant improvement
    checkpoint_frequency = 1
    base_lr = 1e-5
    max_lr = 1e-3
    grad_clip = 1.0
    accumulation_steps = 4
    weight_decay = 1e-4
    resize_size = (256, 256)
    log_samples_freq = 50  # Log sample predictions every N batches

In [6]:
# =====================================
# VISUALIZATION UTILITIES
# =====================================

def plot_confusion_matrix(cm, class_names, split='val'):
    """Enhanced confusion matrix visualization"""
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {split}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    filename = f'confusion_matrix_{split}.png'
    plt.savefig(filename, dpi=300)
    plt.close()
    return filename

In [7]:
def debug_transformation_pipeline():
    """Run this function before creating datasets to verify transformations"""
    print("\n===== TRANSFORMATION DEBUG =====")
    
    # Create a dummy image
    dummy_img = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
    print(f"Dummy image min: {dummy_img.min()}, max: {dummy_img.max()}")
    print(f"Dummy image mean: {dummy_img.mean(axis=(0,1))}")
    
    # Create train and test transforms
    train_transform = A.Compose([
        A.Resize(Config.resize_size[0], Config.resize_size[1]),
        A.RandomCrop(height=Config.crop_size[0], width=Config.crop_size[1], p=1.0),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=15, p=0.4),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.Normalize(
            mean=Config.dataset_mean,
            std=Config.dataset_std,
            max_pixel_value=255.0  # CRITICAL
        ),
        ToTensorV2()
    ])
    
    test_transform = A.Compose([
        A.Resize(Config.resize_size[0], Config.resize_size[1]),
        A.CenterCrop(height=Config.crop_size[0], width=Config.crop_size[1], p=1.0),
        A.Normalize(
            mean=Config.dataset_mean,
            std=Config.dataset_std,
            max_pixel_value=255.0  # CRITICAL
        ),
        ToTensorV2()
    ])
    
    # Apply train transform
    train_transformed = train_transform(image=dummy_img)["image"]
    print("\nTrain transform results:")
    print(f"Min: {train_transformed.min().item():.4f}")
    print(f"Max: {train_transformed.max().item():.4f}")
    print(f"Mean: {train_transformed.mean(axis=(1,2)).tolist()}")
    
    # Apply test transform
    test_transformed = test_transform(image=dummy_img)["image"]
    print("\nTest transform results:")
    print(f"Min: {test_transformed.min().item():.4f}")
    print(f"Max: {test_transformed.max().item():.4f}")
    print(f"Mean: {test_transformed.mean(axis=(1,2)).tolist()}")
    
    print("============================\n")

In [8]:
# =====================================
# DATASET CLASS
# =====================================

class MELDDataset(Dataset):
    """Enhanced Dataset with Temporal Augmentations"""
    
    def __init__(self, metadata, split, train=True):
        self.data = metadata[split]
        self.train = train
        self.transform = self._build_transforms()
        self.error_log = open("dataset_errors.log", "a")

    def _build_transforms(self):
        """Build data augmentation transforms"""
        normalize = A.Normalize(
            mean=Config.dataset_mean,
            std=Config.dataset_std,
            max_pixel_value=255.0
        )
        
        if self.train:
            return A.Compose([
                A.Resize(Config.resize_size[0], Config.resize_size[1]),
                A.RandomCrop(height=Config.crop_size[0], width=Config.crop_size[1], p=1.0),
                A.HorizontalFlip(p=0.5),
                A.Rotate(limit=15, p=0.4),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                normalize,
                ToTensorV2()
            ])
        else:
            return A.Compose([
                A.Resize(Config.resize_size[0], Config.resize_size[1]),
                A.CenterCrop(height=Config.crop_size[0], width=Config.crop_size[1], p=1.0),
                normalize,
                ToTensorV2()
            ])

    def __getitem__(self, idx):
        item = self.data[idx]
        frames_dir = item['frames_dir']
        mask_info = item['mask_info']
        label = item['y']
        
        try:
            if not os.path.exists(frames_dir):
                raise FileNotFoundError(f"Directory not found: {frames_dir}")
            
            frame_files = sorted(
                [f for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.png'))],
                key=lambda x: int(re.search(r'^(\d+)', x).group(1))
            )
            
            if len(frame_files) < Config.num_frames:
                raise ValueError(f"Only {len(frame_files)} frames found, need {Config.num_frames}")

            # FIXED MASK HANDLING
            if len(frame_files) > Config.num_frames:
                start_idx = random.randint(0, len(frame_files) - Config.num_frames)
                selected_files = frame_files[start_idx:start_idx+Config.num_frames]
                selected_mask_info = mask_info[start_idx:start_idx+Config.num_frames]
            else:
                selected_files = frame_files
                selected_mask_info = mask_info[:len(selected_files)]
                if len(selected_mask_info) < Config.num_frames:
                    pad_length = Config.num_frames - len(selected_mask_info)
                    selected_mask_info = selected_mask_info + [0] * pad_length

            frames = []
            for i, fname in enumerate(selected_files):
                frame_path = os.path.join(frames_dir, fname)
                frame = cv2.imread(frame_path)
                if frame is None:
                    raise IOError(f"Failed to read {frame_path}")
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                transformed = self.transform(image=frame)["image"]
                frames.append(transformed)

            video_tensor = torch.stack(frames)  # [T, C, H, W]
            slow_pathway = video_tensor[::4].permute(1, 0, 2, 3)  # [C, T/4, H, W]
            fast_pathway = video_tensor.permute(1, 0, 2, 3)       # [C, T, H, W]
            
            mask = torch.tensor(selected_mask_info[:Config.num_frames], dtype=torch.float32)
            slow_mask = mask[::4]
            fast_mask = mask
            
            return slow_pathway, fast_pathway, slow_mask, fast_mask, label
            
        except Exception as e:
            self.error_log.write(f"Error loading index {idx}: {str(e)}\n")
            slow = torch.zeros(3, Config.num_frames // 4, *Config.crop_size)
            fast = torch.zeros(3, Config.num_frames, *Config.crop_size)
            slow_mask = torch.ones(Config.num_frames // 4)
            fast_mask = torch.ones(Config.num_frames)
            return slow, fast, slow_mask, fast_mask, label

    def __len__(self):
        return len(self.data)

    def __del__(self):
        self.error_log.close()

In [9]:
# =====================================
# MODEL ARCHITECTURE
# =====================================

class MaskedSlowFast(nn.Module):
    """Simplified Model Architecture"""
    
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = slowfast_r50(pretrained=True, progress=True)
        self.backbone.blocks = self.backbone.blocks[:-1]  # Remove classification head

        # Add batch normalization before classifier
        self.feature_bn = nn.BatchNorm1d(2304)
        
        # Simplified classifier
        self.classifier = nn.Sequential(
            nn.Linear(2304, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
        
        # Freeze initial layers
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # Gradual unfreezing setup
        self.unfreeze_stages = {5: False, 4: False, 3: False}

    def unfreeze_layers(self, epoch):
        """Gradual layer unfreezing during training"""
        if epoch >= 3 and not self.unfreeze_stages[5]:
            self._unfreeze_stage(5)
        if epoch >= 6 and not self.unfreeze_stages[4]:
            self._unfreeze_stage(4)
        if epoch >= 9 and not self.unfreeze_stages[3]:
            self._unfreeze_stage(3)
            
    def _unfreeze_stage(self, stage):
        """Unfreeze a specific stage of the backbone"""
        for param in self.backbone.blocks[stage].parameters():
            param.requires_grad = True
        self.unfreeze_stages[stage] = True
        print(f"Unfroze stage {stage} layers")

    def forward(self, slow_input, fast_input, slow_mask, fast_mask):
        # Apply masks to zero out padded frames
        slow_input = slow_input * slow_mask[:, None, :, None, None]
        fast_input = fast_input * fast_mask[:, None, :, None, None]
        
        # Get features
        features = self.backbone([slow_input, fast_input])
        features = features.view(features.size(0), -1)

        # Apply feature batch normalization
        features = self.feature_bn(features)
        
        return self.classifier(features)

In [10]:
# =====================================
# MODEL ARCHITECTURE
# =====================================

class MaskedSlowFast(nn.Module):
    """Simplified Model Architecture"""
    
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = slowfast_r50(pretrained=True, progress=True)
        self.backbone.blocks = self.backbone.blocks[:-1]  # Remove classification head

        # Add batch normalization before classifier
        self.feature_bn = nn.BatchNorm1d(2304)
        
        # Simplified classifier
        self.classifier = nn.Sequential(
            nn.Linear(2304, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
        
        # Freeze initial layers
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # Gradual unfreezing setup
        self.unfreeze_stages = {5: False, 4: False, 3: False}

    def unfreeze_layers(self, epoch):
        """Gradual layer unfreezing during training"""
        if epoch >= 3 and not self.unfreeze_stages[5]:
            self._unfreeze_stage(5)
        if epoch >= 6 and not self.unfreeze_stages[4]:
            self._unfreeze_stage(4)
        if epoch >= 9 and not self.unfreeze_stages[3]:
            self._unfreeze_stage(3)
            
    def _unfreeze_stage(self, stage):
        """Unfreeze a specific stage of the backbone"""
        for param in self.backbone.blocks[stage].parameters():
            param.requires_grad = True
        self.unfreeze_stages[stage] = True
        print(f"Unfroze stage {stage} layers")

    def forward(self, slow_input, fast_input, slow_mask, fast_mask):
        # Apply masks to zero out padded frames
        slow_input = slow_input * slow_mask[:, None, :, None, None]
        fast_input = fast_input * fast_mask[:, None, :, None, None]
        
        # Get features
        features = self.backbone([slow_input, fast_input])
        features = features.view(features.size(0), -1)

        # Apply feature batch normalization
        features = self.feature_bn(features)
        
        return self.classifier(features)

In [11]:
# =====================================
# LOSS FUNCTIONS
# =====================================

class FocalLoss(nn.Module):
    """Focal Loss for Class Imbalance"""
    
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss
            
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

In [12]:
# =====================================
# MODEL UTILITIES
# =====================================

def print_parameter_status(model):
    """Print detailed parameter status of the model"""
    total_params = 0
    trainable_params = 0
    frozen_params = 0
    trainable_names = []
    frozen_names = []

    for name, param in model.named_parameters():
        num_params = param.numel()
        total_params += num_params
        if param.requires_grad:
            trainable_params += num_params
            trainable_names.append(name)
        else:
            frozen_params += num_params
            frozen_names.append(name)

    print(f"\n{'='*50}")
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")
    print(f"Frozen Parameters: {frozen_params:,} ({frozen_params/total_params:.2%})")
    
    print("\nTrainable Parameters:")
    for name in trainable_names:
        print(f"- {name}")
        
    print("\nFrozen Parameters:")
    for name in frozen_names:
        print(f"- {name}")
    print("="*50)

In [13]:
# =====================================
# EVALUATION FUNCTIONS
# =====================================

def evaluate_model(model, data_loader, criterion, device, class_names, split='val'):
    """Enhanced Evaluation with Confusion Matrix"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in data_loader:
            slow, fast, slow_mask, fast_mask, labels = (
                batch[0].to(device),
                batch[1].to(device),
                batch[2].to(device),
                batch[3].to(device),
                batch[4].to(device)
            )
            
            outputs = model(slow, fast, slow_mask, fast_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    avg_loss = total_loss / len(data_loader)
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted')
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plot_file = plot_confusion_matrix(cm, class_names, split)
    
    # Classification report
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    report_df = pd.DataFrame(report).transpose()
    report_df.to_csv(f'classification_report_{split}.csv')
    
    # Print classes with no positive samples for ROC
    for i, class_name in enumerate(class_names):
        class_labels = (np.array(all_labels) == i).astype(int)
        if np.sum(class_labels) == 0:
            print(f"Skipping ROC for {class_name} - no positive samples in {split} set")
    
    return avg_loss, accuracy, macro_f1, weighted_f1, cm, plot_file, np.array(all_labels), np.array(all_probs), report_df

In [14]:
# =====================================
# CHECKPOINT MANAGEMENT
# =====================================

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def save_to_huggingface(model, epoch, optimizer, scheduler, best_val_loss, best_val_weighted_f1, patience_counter, scaler, global_step: int, total_steps, hf_api, is_best=False):
    """Save full training state checkpoint to HuggingFace Hub."""
    save_dir = f"./model_checkpoint_epoch_{epoch+1}"
    os.makedirs(save_dir, exist_ok=True)

    # Capture all critical states
    checkpoint = {
        'epoch': epoch,  # Current epoch (0-based)
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_loss': best_val_loss,
        "best_val_weighted_f1": best_val_weighted_f1,
        'patience_counter': patience_counter,
        'torch_rng_state': torch.get_rng_state(),          # For PyTorch randomness
        'numpy_rng_state': np.random.get_state(),          # For NumPy randomness
        'python_rng_state': random.getstate(),             # For Python random module
        'grad_scaler_state_dict': scaler.state_dict(),     # Mixed precision training
        'global_step': global_step,
        'scheduler_total_steps': total_steps,  # Store current total steps
        'scheduler_last_epoch': scheduler.last_epoch,
        'wandb_run_id': wandb.run.id if wandb.run else None  # Store run ID
    }

    checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save(checkpoint, checkpoint_path)

    repo_name = "prakanda/hatsu-meld-emotion-recognition-new"
    
    # Ensure the repo exists (create if it doesn't)
    hf_api.create_repo(repo_id=repo_name, repo_type="model", exist_ok=True)
    
    try:
        hf_api.upload_file(
            path_or_fileobj=checkpoint_path,
            path_in_repo=f"checkpoint_epoch_{epoch+1}.pth",
            repo_id=repo_name,
            repo_type="model",
        )
        if is_best:
            hf_api.upload_file(
                path_or_fileobj=checkpoint_path,
                path_in_repo="best_model.pth",
                repo_id=repo_name,
                repo_type="model",
            )
    except Exception as e:
        print(f"Error uploading: {str(e)}")
    finally:
        shutil.rmtree(save_dir, ignore_errors=True)



def load_checkpoint_safely(checkpoint_path, device):
    """Safely load checkpoint with proper tensor handling"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Handle main RNG state
    if 'torch_rng_state' in checkpoint:
        rng_state = checkpoint['torch_rng_state']
        if not isinstance(rng_state, torch.Tensor):
            rng_state = torch.tensor(rng_state, dtype=torch.uint8)
        checkpoint['torch_rng_state'] = rng_state.to(dtype=torch.uint8).cpu().contiguous()
    
    # NEW: Handle data generator state
    if 'data_generator_state' in checkpoint:
        gen_state = checkpoint['data_generator_state']
        if not isinstance(gen_state, torch.Tensor):
            gen_state = torch.tensor(gen_state, dtype=torch.uint8)
        checkpoint['data_generator_state'] = gen_state.to(dtype=torch.uint8).cpu().contiguous()
    
    return checkpoint

In [15]:
# =====================================
# TRAINING LOOP UTILITIES
# =====================================

def debug_small_batch_overfitting(model, train_loader, criterion, optimizer, device):
    """Debug by overfitting on a small batch"""
    print("\n===== Debugging: Overfitting small batch =====")
    
    # Create new scaler for debugging
    debug_scaler = torch.amp.GradScaler()

    debug_batch = [tensor.to(device) for tensor in next(iter(train_loader))]
    debug_losses = []
    
    for i in range(100):  # 100 debug iterations
        slow, fast, slow_mask, fast_mask, labels = debug_batch
        
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(slow, fast, slow_mask, fast_mask)
            loss = criterion(outputs, labels)
        
        debug_scaler.scale(loss).backward()
        debug_scaler.step(optimizer)
        debug_scaler.update()

        optimizer.zero_grad()
        debug_losses.append(loss.item())
        
        if i % 10 == 0:
            print(f"Debug Iteration {i}: Loss = {loss.item():.4f}")
            
    plt.plot(debug_losses)
    plt.title("Debug Overfitting Curve")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.savefig("debug_overfit.png")
    plt.close()

    print("Debug overfitting completed. Loss should decrease steadily.")
    print("=============================================\n")
    
    return "debug_overfit.png"



def initialize_training_metrics():
    """Initialize metrics tracking dictionaries"""
    return {
        'train/loss': [],
        'val/loss': [],
        'train/macro_f1': [],
        'val/macro_f1': [],
        'train/weighted_f1': [],
        'val/weighted_f1': [],
        'train/accuracy': [],
        'val/accuracy': [],
        'learning_rate': [],
        'grad_norm/epoch_mean': [],
    }



def log_sample_predictions(model, slow, fast, slow_mask, fast_mask, labels, class_names, epoch, batch_idx):
    """Log sample predictions with visualizations"""
    sample_predictions = []
    
    with torch.no_grad():
        outputs_debug = model(slow, fast, slow_mask, fast_mask)
        _, preds_debug = torch.max(outputs_debug, 1)
        
        for i in range(min(3, len(slow))):  # First 3 samples
            # Create visualization of input frames
            fig, ax = plt.subplots(1, 4, figsize=(20, 5))
            
            # Show first frame of slow pathway
            slow_frame = slow[i, :, 0].permute(1, 2, 0).cpu().numpy()
            slow_frame = np.clip((slow_frame * Config.dataset_std + Config.dataset_mean) * 255, 0, 255).astype(np.uint8)
            ax[0].imshow(slow_frame)
            ax[0].set_title(f"Slow Frame 0\nMask: {slow_mask[i, 0].item():.1f}")
            
            # Show last frame of slow pathway
            slow_frame = slow[i, :, -1].permute(1, 2, 0).cpu().numpy()
            slow_frame = np.clip((slow_frame * Config.dataset_std + Config.dataset_mean) * 255, 0, 255).astype(np.uint8)
            ax[1].imshow(slow_frame)
            ax[1].set_title(f"Slow Frame -1\nMask: {slow_mask[i, -1].item():.1f}")
            
            # Show first frame of fast pathway
            fast_frame = fast[i, :, 0].permute(1, 2, 0).cpu().numpy()
            fast_frame = np.clip((fast_frame * Config.dataset_std + Config.dataset_mean) * 255, 0, 255).astype(np.uint8)
            ax[2].imshow(fast_frame)
            ax[2].set_title(f"Fast Frame 0\nMask: {fast_mask[i, 0].item():.1f}")
            
            # Show last frame of fast pathway
            fast_frame = fast[i, :, -1].permute(1, 2, 0).cpu().numpy()
            fast_frame = np.clip((fast_frame * Config.dataset_std + Config.dataset_mean) * 255, 0, 255).astype(np.uint8)
            ax[3].imshow(fast_frame)
            ax[3].set_title(f"Fast Frame -1\nMask: {fast_mask[i, -1].item():.1f}")
            
            plt.tight_layout()
            plt.savefig(f"sample_{epoch}_{batch_idx}_{i}.png")
            plt.close()
            
            sample_predictions.append([
                epoch,
                batch_idx,
                class_names[preds_debug[i].item()],
                class_names[labels[i].item()],
                wandb.Image(f"sample_{epoch}_{batch_idx}_{i}.png")
            ])
    
    return sample_predictions



def log_gradient_norms(model, global_step):
    """Log gradient norms for monitoring with NaN handling"""
    grad_norms = []
    nan_count = 0
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            # Check for NaN/infinite gradients
            if not torch.isfinite(param.grad).all():
                nan_count += 1
                continue
                
            grad_norm = param.grad.norm().item()
            grad_norms.append(grad_norm)
            wandb.log({f"grad_norm/{name}": grad_norm}, step=global_step)
    
    # Log NaN count
    wandb.log({"grad_norm/nan_count": nan_count}, step=global_step)
    
    if grad_norms:
        wandb.log({
            "grad_norm/mean": np.mean(grad_norms),
            "grad_norm/max": np.max(grad_norms),
            "grad_norm/min": np.min(grad_norms),
            "grad_norm/hist": wandb.Histogram(grad_norms) if grad_norms else None
        }, step=global_step)
    
    return grad_norms



def build_wandb_log_data(epoch, train_metrics, val_metrics, optimizer, val_plot_file, 
                        all_labels, val_all_labels, val_all_probs, val_report, 
                        class_names, epoch_grad_norms, sample_predictions, epoch_metrics):
    """Build comprehensive log data for W&B"""
    avg_train_loss, train_accuracy, train_f1, train_weighted_f1 = train_metrics
    val_loss, val_accuracy, val_f1, val_wf1 = val_metrics
    
    log_data = {
        "epoch": epoch,
        "train/loss": avg_train_loss,
        "train/accuracy": train_accuracy,
        "train/macro_f1": train_f1,
        "train/weighted_f1": train_weighted_f1,
        "val/loss": val_loss,
        "val/accuracy": val_accuracy,
        "val/macro_f1": val_f1,
        "val/weighted_f1": val_wf1,
        "learning_rate": optimizer.param_groups[0]['lr'],
        "val/confusion_matrix": wandb.Image(val_plot_file),
        "train/class_distribution": wandb.Histogram(np.array(all_labels)),
    }

    if epoch_grad_norms:
        grad_mean = np.mean(epoch_grad_norms)
        log_data.update({
            "grad_norm/epoch_mean": grad_mean,
            "grad_norm/epoch_max": np.max(epoch_grad_norms),
            "grad_norm/epoch_min": np.min(epoch_grad_norms),
            "grad_norm/epoch_hist": wandb.Histogram(epoch_grad_norms),
        })

    # Add ROC curve to log data
    try:
        log_data["val/roc_curve"] = wandb.plot.roc_curve(
            val_all_labels,
            val_all_probs,
            labels=class_names
        )
    except Exception as e:
        print(f"Failed to log validation ROC: {str(e)}")
    
    # Add classification report metrics
    for metric in ['precision', 'recall', 'f1-score']:
        for i, class_name in enumerate(class_names):
            log_data[f"val/{class_name}_{metric}"] = val_report.loc[class_name, metric]
    
    # Log sample predictions as table
    if sample_predictions:
        log_data["train/sample_predictions"] = wandb.Table(
            columns=["Epoch", "Batch", "Predicted", "True", "Image"],
            data=sample_predictions
        )
    
    # Log metric history curves
    for metric_name, values in epoch_metrics.items():
        if values:
            data = [[x, y] for x, y in enumerate(values)]
            table = wandb.Table(data=data, columns=["epoch", metric_name])
            log_data[f"curves/{metric_name}"] = wandb.plot.line(
                table, 
                "epoch", 
                metric_name,
                title=f"{metric_name} over Epochs"
            )
    
    return log_data



def update_early_stopping(val_wf1, best_val_weighted_f1, patience_counter, epoch):
    """Update early stopping criteria based on validation performance"""
    improvement = val_wf1 - best_val_weighted_f1
    new_best = best_val_weighted_f1
    new_patience = patience_counter
    is_improvement = False

    if improvement > Config.no_improvement_threshold:
        new_best = val_wf1
        new_patience = 0
        is_improvement = True
        print(f"🌟 Significant improvement: +{improvement:.4f}")
    elif improvement > 0:  # Small improvement
        new_best = val_wf1
        new_patience = min(patience_counter, Config.patience // 2)
        is_improvement = True
        print(f"💫 Minor improvement: +{improvement:.4f}, patience reduced to {new_patience}")
    else:
        new_patience += 1
        print(f"⏳ No improvement. Patience: {new_patience}/{Config.patience}")
    
    return new_best, new_patience, is_improvement

In [16]:
# =====================================
# MAIN TRAINING LOOP
# =====================================

def train_model(
    model: torch.nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: DataLoader,
    criterion: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    device: torch.device,
    class_names,
    scaler: torch.cuda.amp.GradScaler,
    hf_api,
    total_steps:int,
    wandb_run_id: Optional[str] = None,
    max_epochs: int = 50,
    patience: int = 10,
    min_epochs: int = 10,
    checkpoint_frequency: int = 5,
    start_epoch: int = 0,
    best_val_loss: float = float('inf'),
    best_val_weighted_f1: float = 0.0,
    best_model_state: Optional[Dict[str, torch.Tensor]] = None, 
    patience_counter: int = 0,
    initial_global_step: int = 0
) -> Dict[str, Any]:

    try:
        # Initialize step counter
        global_step = initial_global_step

        if start_epoch > 0:
            print("Resuming training - starting a new wandb run to avoid step conflict")
            wandb_run_id = None  # Force new run for resumed training
                
        # Initialize WandB with resume capability
        wandb_init_kwargs = {
            "project": "meld-emotion-recognition-new",
            "settings": wandb.Settings(init_timeout=360),
            "config":{
            "user": "bhandariprakanda",
            "architecture": "MaskedSlowFast_R50",
            "notebook": "MELD Emotion Recognition SlowFast MaskedPathways",
            "version": "3_Changing `class_weights` in Alpha Calculation",
            "previous_version": "1.1_Continue Training After Small Change (bhandariprakanda)",
            "dataset": "MELD",
            "num_classes": len(class_names),
            "class_names": class_names,
            "resumed": start_epoch > 0,
            "max_epochs": max_epochs,
            "batch_size": train_loader.batch_size,
            "optimizer": optimizer.__class__.__name__,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "num_frames": Config.num_frames,
            "crop_size": Config.crop_size,
            "base_lr": Config.base_lr,
            "max_lr": Config.max_lr,
            "grad_clip": Config.grad_clip,
            "accumulation_steps": Config.accumulation_steps,
            "weight_decay": Config.weight_decay,
            "resize_size": Config.resize_size,
            "resumed_global_step": initial_global_step,
        },
            "resume": "allow" if wandb_run_id else None  # Enable resuming
        }

        # Resume existing run if ID exists
        if wandb_run_id:
            wandb_init_kwargs["id"] = wandb_run_id
            print(f"\n📊 Resuming WandB run: {wandb_run_id}\n")
        
        wandb.init(**wandb_init_kwargs)
        
        # Ensure we save run ID in future checkpoints
        if not wandb_run_id and wandb.run:
            wandb_run_id = wandb.run.id
            print(f"\n📊 Started new WandB run: {wandb_run_id}\n")
    
        # Watch gradients
        wandb.watch(model, log="all", log_freq=50)

        # Initialize metrics tracking
        epoch_metrics = initialize_training_metrics()
        
        # Initialize training state
        best_val_weighted_f1 = best_val_weighted_f1
        patience_counter = patience_counter
        best_model_state = best_model_state or model.state_dict().copy()
    
        # Debugging: Overfit small batch
        if start_epoch == 0:  # Only run debug on fresh start
            debug_plot = debug_small_batch_overfitting(model, train_loader, criterion, optimizer, device)
            wandb.log({"debug/debug_overfitting_curve": wandb.Image(debug_plot)}, step=global_step)
        
        for epoch in range(start_epoch, max_epochs):
            model.train()
            model.unfreeze_layers(epoch)
            
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            all_preds = []
            all_labels = []
            epoch_grad_norms = []
            sample_predictions = []
            
            optimizer.zero_grad()
            
            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs}')):
                # Increment global step at start of batch
                global_step += 1
                
                slow, fast, slow_mask, fast_mask, labels = (
                    batch[0].to(device),
                    batch[1].to(device),
                    batch[2].to(device),
                    batch[3].to(device),
                    batch[4].to(device)
                )
    
                # Debugging: Print input statistics for first batch of first epoch
                if epoch == 0 and batch_idx == 0:
                    print("\n===== Input Debugging =====")
                    print(f"Slow input shape: {slow.shape}")
                    print(f"Normalized Slow input range: [{slow.min().item():.4f}, {slow.max().item():.4f}]")
                    print(f"Normalized Slow input mean: {slow.mean().item():.4f}")
                    print(f"Fast input shape: {fast.shape}")
                    print(f"Normalized Fast input range: [{fast.min().item():.4f}, {fast.max().item():.4f}]")
                    print(f"Normalized Fast input mean: {fast.mean().item():.4f}")
                    print(f"Slow mask mean: {slow_mask.mean().item():.4f}")
                    print(f"Fast mask mean: {fast_mask.mean().item():.4f}")
                    print(f"Labels: {labels}")
                    print(f"Class distribution: {torch.bincount(labels)}")
                    print("==========================\n")
    
                # Log sample predictions periodically
                if batch_idx % Config.log_samples_freq == 0:
                    sample_preds = log_sample_predictions(
                        model, slow, fast, slow_mask, fast_mask, labels, 
                        class_names, epoch, batch_idx
                    )
                    sample_predictions.extend(sample_preds)
    
                # Mixed precision training
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(slow, fast, slow_mask, fast_mask)
                    loss = criterion(outputs, labels)

                # Check for NaN loss before proceeding
                if not torch.isfinite(loss):
                    print(f"⚠️ NaN loss detected at step {global_step}. Skipping batch.")
                    optimizer.zero_grad()
                    continue
                    
                scaler.scale(loss / Config.accumulation_steps).backward()
                
                # Update metrics
                train_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                train_correct += (preds == labels).sum().item()
                train_total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                # Gradient accumulation
                if (batch_idx + 1) % Config.accumulation_steps == 0 or batch_idx == len(train_loader) - 1:
                    # Check for NaN gradients before unscaling and clipping 
                    nan_found = False
                    for param in model.parameters():
                        if param.grad is not None and not torch.isfinite(param.grad).all():
                            nan_found = True
                            break
                    
                    if nan_found:
                        print(f"⚠️ NaN gradients detected at step {global_step}. Skipping update.")
                        optimizer.zero_grad()
                        continue  # Skip update block entirely
                        
                    # Unscale before gradient clipping
                    scaler.unscale_(optimizer)

                    # Add value-based gradient clipping in addition to norm clipping
                    #torch.nn.utils.clip_grad_value_(model.parameters(), 0.5)
                        
                    torch.nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip)
                    
                    # Log gradient norms
                    batch_grad_norms = log_gradient_norms(model, global_step)
                    if batch_grad_norms:
                        epoch_grad_norms.extend(batch_grad_norms)
                    
                    scaler.step(optimizer)

                    scaler.update()
                    
                    # Step the scheduler AFTER optimizer step
                    scheduler.step()
                    
                    # Log current learning rate
                    current_lr = optimizer.param_groups[0]['lr']
                    wandb.log({"learning_rate/step": current_lr}, step=global_step)
    
                    # Reset gradients
                    optimizer.zero_grad()
                    
                
            # Calculate training metrics
            avg_train_loss = train_loss / len(train_loader)
            train_accuracy = train_correct / train_total
            train_f1 = f1_score(all_labels, all_preds, average='macro')
            train_weighted_f1 = f1_score(all_labels, all_preds, average='weighted')
            
            # Validation phase
            val_loss, val_accuracy, val_f1, val_wf1, val_cm, val_plot_file, val_all_labels, val_all_probs, val_report = evaluate_model(
                model, val_loader, criterion, device, class_names, 'val')
            
            # Update epoch metrics
            epoch_metrics['train/loss'].append(avg_train_loss)
            epoch_metrics['val/loss'].append(val_loss)
            epoch_metrics['train/macro_f1'].append(train_f1)
            epoch_metrics['val/macro_f1'].append(val_f1)
            epoch_metrics['train/weighted_f1'].append(train_weighted_f1)
            epoch_metrics['val/weighted_f1'].append(val_wf1)
            epoch_metrics['train/accuracy'].append(train_accuracy)
            epoch_metrics['val/accuracy'].append(val_accuracy)
            epoch_metrics['learning_rate'].append(optimizer.param_groups[0]['lr'])
            epoch_metrics['grad_norm/epoch_mean'].append(np.mean(epoch_grad_norms) if epoch_grad_norms else 0)
                
            # Build log data dictionary
            log_data = build_wandb_log_data(
                epoch, 
                (avg_train_loss, train_accuracy, train_f1, train_weighted_f1),
                (val_loss, val_accuracy, val_f1, val_wf1),
                optimizer,
                val_plot_file,
                all_labels,
                val_all_labels,
                val_all_probs,
                val_report,
                class_names,
                epoch_grad_norms,
                sample_predictions,
                epoch_metrics
            )
            
            # Log everything at the current global_step
            wandb.log(log_data, step=global_step)
            
            # Print metrics
            print(f"\nEpoch [{epoch + 1}/{max_epochs}]")
            print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Train Macro F1: {train_f1:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val Macro F1: {val_f1:.4f}, Val Weighted F1: {val_wf1:.4f}")
            print(f"Class distribution: {torch.bincount(torch.tensor(all_labels))}")
    
            # Save checkpoint periodically
            if (epoch + 1) % checkpoint_frequency == 0:
                save_to_huggingface(
                    model=model,
                    epoch=epoch,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    best_val_loss=best_val_loss,
                    best_val_weighted_f1=best_val_weighted_f1,
                    patience_counter=patience_counter,
                    scaler=scaler,
                    global_step=global_step,
                    total_steps = total_steps,
                    hf_api = hf_api
                )

            # Update early stopping criteria
            best_val_weighted_f1, patience_counter, is_improvement = update_early_stopping(
                val_wf1, best_val_weighted_f1, patience_counter, epoch
            )
            
            # Save best model if improvement
            if is_improvement:
                save_to_huggingface(
                    model=model,
                    epoch=epoch,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    best_val_loss=best_val_loss,
                    best_val_weighted_f1=best_val_weighted_f1,
                    patience_counter=patience_counter,
                    scaler=scaler,
                    global_step=global_step,
                    total_steps = total_steps,
                    hf_api = hf_api,
                    is_best=True
                )
            
            # Only stop if no improvement after minimum epochs
            if epoch >= Config.min_epochs and patience_counter >= Config.patience:
                print(f"⏹️ Early stopping at epoch {epoch+1}. Best was epoch {epoch+1 - patience_counter}")
                break

        # Restore best model state
        model.load_state_dict(best_model_state)
    
        # Final test evaluation
        test_loss, test_accuracy, test_f1, test_wf1, test_cm, test_plot_file, test_all_labels, test_all_probs, test_report = evaluate_model(
            model, test_loader, criterion, device, class_names, 'test')
    
        # Prepare test log data
        test_log_data = {
            "final_test/loss": test_loss,
            "final_test/accuracy": test_accuracy,
            "final_test/f1": test_f1,
            "final_test/weighted_f1": test_wf1,
            "final_test/confusion_matrix": wandb.Image(test_plot_file),
        }
        
        # Test ROC logging
        try:
            test_log_data["test/roc_curve"] = wandb.plot.roc_curve(
                test_all_labels,
                test_all_probs,
                labels=class_names
            )
        except Exception as e:
            print(f"Failed to log test ROC: {str(e)}")
        
        # Log test classification report
        test_report_table = wandb.Table(dataframe=test_report)
        test_log_data["final_test/classification_report"] = test_report_table
        
        # Log class-wise metrics
        class_metrics = []
        for i, class_name in enumerate(class_names):
            tp = test_cm[i, i]
            fp = test_cm[:, i].sum() - tp
            fn = test_cm[i, :].sum() - tp
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            class_metrics.append([
                class_name, 
                precision, 
                recall, 
                f1, 
                tp, 
                fp, 
                fn,
                test_report.loc[class_name, 'precision'],
                test_report.loc[class_name, 'recall'],
                test_report.loc[class_name, 'f1-score'],
                test_report.loc[class_name, 'support']
            ])
        
        class_metrics_table = wandb.Table(
            columns=["Class", "Precision", "Recall", "F1", "TP", "FP", "FN", 
                     "Report_Precision", "Report_Recall", "Report_F1", "Support"],
            data=class_metrics
        )
        test_log_data["final_test/class_metrics"] = class_metrics_table
        
        # Log all test results at next step
        wandb.log(test_log_data, step=global_step + 1)
        wandb.finish()
    
        return {
            "best_val_weighted_f1": best_val_weighted_f1,
            "final_test_f1": test_f1,
            "final_test_weighted_f1": test_wf1,
            "final_test_loss": test_loss,
            "final_test_accuracy": test_accuracy,
            "test_confusion_matrix": test_cm,
            "test_plot_file": test_plot_file,
            "test_classification_report": test_report
        }

    except Exception as e:
        import traceback
        print(f"Error in train_model: {str(e)}")
        print(traceback.format_exc())
        return {
            "error": str(e),
            "traceback": traceback.format_exc()
        }

In [17]:
# =====================================
# MAIN EXECUTION
# =====================================

def main():
    # Setup environment
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    set_seed(Config.seed)
    
    # Setup API tokens
    hf_token, wandb_token, hf_api = setup_api_tokens()

    # Initialize WandB run ID
    wandb_run_id = None
    
    # Load metadata
    with open("/kaggle/input/meld-extracted-video-frames-rgb/extraction_checkpoint.json") as f:
        extraction_checkpoint = json.load(f)

    metadata = extraction_checkpoint["metadata"]
    
    # Class names
    class_names = ["Neutral", "Surprise", "Fear", "Sadness", "Joy", "Disgust", "Anger"]

    # Run transformation debug
    debug_transformation_pipeline()

    # Create datasets
    train_data = MELDDataset(metadata, 'train')
    val_data = MELDDataset(metadata, 'dev', train=False)
    test_data = MELDDataset(metadata, 'test', train=False)
        
    # Create data loaders
    train_loader = DataLoader(
        train_data,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(val_data, batch_size=Config.batch_size)
    test_loader = DataLoader(test_data, batch_size=Config.batch_size)
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    

    # Initialize model
    num_classes = len(np.unique([item["y"] for item in train_data.data]))
    model = MaskedSlowFast(num_classes).to(device)

    # Print parameter status before training
    print_parameter_status(model)
    
    # Create loss function with class weighting    
    class_counts = np.array([4710, 1205, 268, 683, 1743, 271, 1109])
    
    # Inverse square root
    inv_sqrt_freq = 1.0 / np.sqrt(class_counts)
    
    # Normalize to sum to 1 (optional but often useful)
    alpha = inv_sqrt_freq / inv_sqrt_freq.sum()
    
    # Convert to torch tensor
    class_weights = torch.tensor(alpha, dtype=torch.float32).to(device)
                
    criterion = FocalLoss(alpha=class_weights, gamma=2.0)

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=Config.base_lr, 
        weight_decay=Config.weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8  # Add epsilon for numerical stability
    )

    # Use OneCycleLR scheduler
    # Calculate steps per epoch correctly
    steps_per_epoch = len(train_loader) // Config.accumulation_steps
    if len(train_loader) % Config.accumulation_steps != 0:
        steps_per_epoch += 1
    
    total_steps = Config.max_epochs * steps_per_epoch
    
    # Initialize scheduler with proper parameters
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=Config.max_lr,
        total_steps=total_steps,
        pct_start=0.3,
        div_factor=25,
        final_div_factor=100,
        anneal_strategy='cos',
        cycle_momentum=False  # Important for AdamW
    )

    # Initialize with proper AMP settings
    scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())

    # Resume checkpoint path (set to None for fresh start)
    resume_checkpoint = None # None or "prakanda/hatsu-meld-emotion-recognition"
    checkpoint_filename = None    # None or "checkpoint_epoch_2.pth"
    
    # Initialize training variables
    start_epoch = 0
    best_val_loss = float('inf')
    best_val_weighted_f1 = 0.0
    patience_counter = 0
    initial_global_step = 0
    
    # Verify repository exists and is accessible
    if resume_checkpoint and checkpoint_filename:
        try:
            print(f"Attempting to download {checkpoint_filename} from {resume_checkpoint}")
            checkpoint_path = hf_hub_download(
                repo_id=resume_checkpoint,
                filename=checkpoint_filename,
                token=hf_token
            )
            
            checkpoint = load_checkpoint_safely(checkpoint_path, device)
            print(f"\nCheckpoint keys: {checkpoint.keys()}")

            # Enhanced RNG state validation
            if 'torch_rng_state' in checkpoint:
                rng_state = checkpoint['torch_rng_state']
                
                # Verify critical tensor properties
                if not (rng_state.is_contiguous() and 
                        rng_state.dtype == torch.uint8 and 
                        rng_state.device.type == 'cpu'):
                    raise ValueError(
                        f"Bad RNG state: "
                        f"contiguous={rng_state.is_contiguous()}, "
                        f"dtype={rng_state.dtype}, "
                        f"device={rng_state.device}"
                    )
                
                # Direct memory comparison with current state
                current_state = torch.get_rng_state()
                if rng_state.shape != current_state.shape:
                    raise ValueError("RNG state shape mismatch")
                    
                torch.set_rng_state(rng_state)
        
            # Restore other RNG states
            if 'numpy_rng_state' in checkpoint:
                np.random.set_state(checkpoint['numpy_rng_state'])
                
            if 'python_rng_state' in checkpoint:
                random.setstate(checkpoint['python_rng_state'])

            if 'grad_scaler_state_dict' in checkpoint:
                scaler.load_state_dict(checkpoint['grad_scaler_state_dict'])
                print("\nLoaded scheduler state from checkpoint")
    
            # Load model and optimizer
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint.get('epoch', 0) + 1
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            best_val_weighted_f1 = checkpoint.get('best_val_weighted_f1', 0.0)
            patience_counter = checkpoint.get('patience_counter', 0)

            if 'scheduler_state_dict' in checkpoint:
                # Check if schedule duration changed
                old_total_steps = checkpoint.get('scheduler_total_steps', total_steps)
                old_last_epoch = checkpoint.get('scheduler_last_epoch', 0)
                
                if old_total_steps == total_steps:
                    # Compatible schedule - load full state
                    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                else:
                    print(f"Schedule changed: {old_total_steps} → {total_steps} steps")
                    print(f"Resuming at step {old_last_epoch} of new schedule")
                    
                    # Calculate progress percentage
                    progress = old_last_epoch / old_total_steps
                    
                    # Scale to new schedule
                    new_last_epoch = int(progress * total_steps)
                    
                    # Reinitialize scheduler with adjusted position
                    # NEW: Extend schedule if needed
                    if new_last_epoch >= total_steps - 1:
                        print("Extending scheduler cycle for resumed training")
                        additional_epochs = Config.max_epochs - start_epoch
                        additional_steps = additional_epochs * steps_per_epoch
                        
                        scheduler = torch.optim.lr_scheduler.OneCycleLR(
                            optimizer,
                            max_lr=Config.max_lr,
                            total_steps=total_steps + additional_steps,  # Extend total steps
                            pct_start=0.1,  # Shorter warmup for extension
                            div_factor=10,
                            final_div_factor=100,
                            anneal_strategy='cos',
                            last_epoch=new_last_epoch  # Start from current position
                        )

            # Get saved ID
            wandb_run_id = checkpoint.get('wandb_run_id', None)  

            # Load global step if available
            initial_global_step = checkpoint.get('global_step', start_epoch * len(train_loader))
    
            print(f"\nSuccessfully loaded checkpoint from epoch {checkpoint['epoch']}")
            print(f"\nResuming training from epoch {start_epoch}")
            print(f"\nResuming training from global step {initial_global_step}")
            print(f"\nResuming training with 'best_val_loss' {best_val_loss}")
            print(f"\nResuming training with 'best_val_weighted_f1' {best_val_weighted_f1}\n")
    
        except Exception as e:
            print(f"Error loading checkpoint: {str(e)}")
            print("Starting training from scratch")
            start_epoch = 0
            best_val_loss = float('inf')
            best_val_weighted_f1 = 0.0
            torch.manual_seed(42)  # Reset RNG to known state

            initial_global_step = 0
            wandb_run_id = None
            
    else:
        print("No checkpoint provided, starting from scratch")

    
    # Train model
    results = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        class_names=class_names,
        checkpoint_frequency=Config.checkpoint_frequency,
        max_epochs=Config.max_epochs,
        patience=Config.patience,
        min_epochs=Config.min_epochs,
        start_epoch=start_epoch,
        best_val_loss=best_val_loss,
        best_val_weighted_f1 = best_val_weighted_f1,
        scaler=scaler,
        patience_counter = patience_counter,
        initial_global_step=initial_global_step,
        total_steps = total_steps,
        wandb_run_id=wandb_run_id,
        hf_api = hf_api
    )

    # Check if results contain an error
    if "error" in results:
        print(f"\nTraining failed with error: {results['error']}")
        print(f"Traceback: {results['traceback']}")
    else:
        print("\nTraining completed!")
        print(f"Best validation macro F1 score: {results['best_val_weighted_f1']:.4f}")
        print(f"Final test accuracy: {results['final_test_accuracy']:.4f}")
        print(f"Final test macro F1 score: {results['final_test_f1']:.4f}")
        print(f"Final test weighted F1 score: {results['final_test_weighted_f1']:.4f}")
        print(f"Final test loss: {results['final_test_loss']:.4f}")

In [18]:
if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mprakandabhandari[0m ([33mprakandabhandari-tribhuvan-university-institute-of-engin[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



===== TRANSFORMATION DEBUG =====
Dummy image min: 0, max: 255
Dummy image mean: [127.32973932 126.97712054 127.08350606]

Train transform results:
Min: -1.8606
Max: 1.6776
Mean: [-0.13604295253753662, -0.13617321848869324, -0.13813546299934387]

Test transform results:
Min: -1.9826
Max: 2.3922
Mean: [0.2142985463142395, 0.21584327518939972, 0.21102853119373322]



Downloading: "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOWFAST_8x8_R50.pyth" to /root/.cache/torch/hub/checkpoints/SLOWFAST_8x8_R50.pyth
100%|██████████| 264M/264M [00:02<00:00, 92.8MB/s]
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.



Total Parameters: 36,540,879
Trainable Parameters: 2,896,391 (7.93%)
Frozen Parameters: 33,644,488 (92.07%)

Trainable Parameters:
- feature_bn.weight
- feature_bn.bias
- classifier.0.weight
- classifier.0.bias
- classifier.1.weight
- classifier.1.bias
- classifier.4.weight
- classifier.4.bias
- classifier.5.weight
- classifier.5.bias
- classifier.8.weight
- classifier.8.bias

Frozen Parameters:
- backbone.blocks.0.multipathway_blocks.0.conv.weight
- backbone.blocks.0.multipathway_blocks.0.norm.weight
- backbone.blocks.0.multipathway_blocks.0.norm.bias
- backbone.blocks.0.multipathway_blocks.1.conv.weight
- backbone.blocks.0.multipathway_blocks.1.norm.weight
- backbone.blocks.0.multipathway_blocks.1.norm.bias
- backbone.blocks.0.multipathway_fusion.conv_fast_to_slow.weight
- backbone.blocks.0.multipathway_fusion.norm.weight
- backbone.blocks.0.multipathway_fusion.norm.bias
- backbone.blocks.1.multipathway_blocks.0.res_blocks.0.branch1_conv.weight
- backbone.blocks.1.multipathway_block

[34m[1mwandb[0m: Tracking run with wandb version 0.19.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250613_084307-v0z2yof7[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mdistinctive-resonance-2[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/prakandabhandari-tribhuvan-university-institute-of-engin/meld-emotion-recognition-new[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/prakandabhandari-tribhuvan-university-institute-of-engin/meld-emotion-recognition-new/runs/v0z2yof7[0m



📊 Started new WandB run: v0z2yof7


===== Debugging: Overfitting small batch =====
Debug Iteration 0: Loss = 0.1434
Debug Iteration 10: Loss = 0.0163
Debug Iteration 20: Loss = 0.0036
Debug Iteration 30: Loss = 0.0013
Debug Iteration 40: Loss = 0.0012
Debug Iteration 50: Loss = 0.0005
Debug Iteration 60: Loss = 0.0005
Debug Iteration 70: Loss = 0.0005
Debug Iteration 80: Loss = 0.0005
Debug Iteration 90: Loss = 0.0004
Debug overfitting completed. Loss should decrease steadily.



Epoch 1/100:   0%|          | 0/625 [00:00<?, ?it/s]


===== Input Debugging =====
Slow input shape: torch.Size([16, 3, 8, 224, 224])
Normalized Slow input range: [-2.0000, 2.4444]
Normalized Slow input mean: -1.1010
Fast input shape: torch.Size([16, 3, 32, 224, 224])
Normalized Fast input range: [-2.0000, 2.4444]
Normalized Fast input mean: -1.1207
Slow mask mean: 0.9688
Fast mask mean: 0.9648
Labels: tensor([0, 6, 3, 4, 5, 6, 0, 0, 3, 0, 0, 4, 5, 2, 0, 0], device='cuda:0')
Class distribution: tensor([7, 0, 1, 2, 2, 2, 2], device='cuda:0')



Epoch 1/100: 100%|██████████| 625/625 [12:03<00:00,  1.16s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Epoch [1/100]
Train Loss: 0.1301, Train Acc: 0.2889, Train Macro F1: 0.1597
Val Loss: 0.1677, Val Acc: 0.3962, Val Macro F1: 0.1016, Val Weighted F1: 0.2630
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_1.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

🌟 Significant improvement: +0.2630


checkpoint_epoch_1.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

Epoch 2/100: 100%|██████████| 625/625 [09:26<00:00,  1.10it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Epoch [2/100]
Train Loss: 0.1194, Train Acc: 0.3755, Train Macro F1: 0.1751
Val Loss: 0.1554, Val Acc: 0.3917, Val Macro F1: 0.1179, Val Weighted F1: 0.2783
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_2.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

🌟 Significant improvement: +0.0153


checkpoint_epoch_2.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

Epoch 3/100: 100%|██████████| 625/625 [09:01<00:00,  1.15it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Epoch [3/100]
Train Loss: 0.1158, Train Acc: 0.3956, Train Macro F1: 0.1940
Val Loss: 0.1549, Val Acc: 0.3773, Val Macro F1: 0.1265, Val Weighted F1: 0.2825
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_3.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

💫 Minor improvement: +0.0043, patience reduced to 0


checkpoint_epoch_3.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

Unfroze stage 5 layers


Epoch 4/100: 100%|██████████| 625/625 [09:20<00:00,  1.11it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Epoch [4/100]
Train Loss: 0.1137, Train Acc: 0.4042, Train Macro F1: 0.2054
Val Loss: 0.1531, Val Acc: 0.3547, Val Macro F1: 0.1436, Val Weighted F1: 0.2882
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_4.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

🌟 Significant improvement: +0.0057


checkpoint_epoch_4.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

Epoch 5/100: 100%|██████████| 625/625 [09:44<00:00,  1.07it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Epoch [5/100]
Train Loss: 0.1115, Train Acc: 0.4103, Train Macro F1: 0.2271
Val Loss: 0.1493, Val Acc: 0.4007, Val Macro F1: 0.1380, Val Weighted F1: 0.2952
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_5.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

🌟 Significant improvement: +0.0070


checkpoint_epoch_5.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

Epoch 6/100: 100%|██████████| 625/625 [09:00<00:00,  1.16it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Epoch [6/100]
Train Loss: 0.1106, Train Acc: 0.4042, Train Macro F1: 0.2306
Val Loss: 0.1625, Val Acc: 0.3637, Val Macro F1: 0.1332, Val Weighted F1: 0.2861
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_6.pth:   0%|          | 0.00/170M [00:00<?, ?B/s]

⏳ No improvement. Patience: 1/10
Unfroze stage 4 layers


Epoch 7/100: 100%|██████████| 625/625 [09:45<00:00,  1.07it/s]



Epoch [7/100]
Train Loss: 0.1149, Train Acc: 0.3974, Train Macro F1: 0.2196
Val Loss: 0.1409, Val Acc: 0.2879, Val Macro F1: 0.1407, Val Weighted F1: 0.2660
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_7.pth:   0%|          | 0.00/342M [00:00<?, ?B/s]

⏳ No improvement. Patience: 2/10


Epoch 8/100: 100%|██████████| 625/625 [09:16<00:00,  1.12it/s]



Epoch [8/100]
Train Loss: 0.1081, Train Acc: 0.4007, Train Macro F1: 0.2575
Val Loss: 0.1487, Val Acc: 0.3818, Val Macro F1: 0.1299, Val Weighted F1: 0.2914
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_8.pth:   0%|          | 0.00/342M [00:00<?, ?B/s]

⏳ No improvement. Patience: 3/10


Epoch 9/100: 100%|██████████| 625/625 [09:02<00:00,  1.15it/s]



Epoch [9/100]
Train Loss: 0.1029, Train Acc: 0.4092, Train Macro F1: 0.2776
Val Loss: 0.1491, Val Acc: 0.2437, Val Macro F1: 0.1365, Val Weighted F1: 0.2432
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_9.pth:   0%|          | 0.00/342M [00:00<?, ?B/s]

⏳ No improvement. Patience: 4/10
Unfroze stage 3 layers


Epoch 10/100: 100%|██████████| 625/625 [12:01<00:00,  1.15s/it]



Epoch [10/100]
Train Loss: 0.1061, Train Acc: 0.3927, Train Macro F1: 0.2736
Val Loss: 0.1387, Val Acc: 0.3872, Val Macro F1: 0.1534, Val Weighted F1: 0.3000
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_10.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

💫 Minor improvement: +0.0048, patience reduced to 4


checkpoint_epoch_10.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

Error uploading: 500 Server Error: Internal Server Error for url: https://huggingface.co/api/models/prakanda/hatsu-meld-emotion-recognition-new/commit/main (Request ID: Root=1-684c044b-501e06892e4e6f1054ed602d;0bf3016d-2435-4287-8a98-c6097bff9346)

Internal Error - We're working hard to fix this as soon as possible!


Epoch 11/100: 100%|██████████| 625/625 [11:58<00:00,  1.15s/it]



Epoch [11/100]
Train Loss: 0.0980, Train Acc: 0.4124, Train Macro F1: 0.3104
Val Loss: 0.1425, Val Acc: 0.3764, Val Macro F1: 0.1376, Val Weighted F1: 0.2928
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_11.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 5/10


Epoch 12/100: 100%|██████████| 625/625 [11:53<00:00,  1.14s/it]



Epoch [12/100]
Train Loss: 0.0932, Train Acc: 0.4167, Train Macro F1: 0.3276
Val Loss: 0.1525, Val Acc: 0.4088, Val Macro F1: 0.1727, Val Weighted F1: 0.3326
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_12.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

🌟 Significant improvement: +0.0326


checkpoint_epoch_12.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

Epoch 13/100: 100%|██████████| 625/625 [12:00<00:00,  1.15s/it]



Epoch [13/100]
Train Loss: 0.0892, Train Acc: 0.4341, Train Macro F1: 0.3607
Val Loss: 0.1419, Val Acc: 0.3195, Val Macro F1: 0.1764, Val Weighted F1: 0.2899
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_13.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 1/10


Epoch 14/100: 100%|██████████| 625/625 [11:59<00:00,  1.15s/it]



Epoch [14/100]
Train Loss: 0.0847, Train Acc: 0.4498, Train Macro F1: 0.3858
Val Loss: 0.1533, Val Acc: 0.3141, Val Macro F1: 0.1601, Val Weighted F1: 0.2852
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_14.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 2/10


Epoch 15/100: 100%|██████████| 625/625 [11:59<00:00,  1.15s/it]



Epoch [15/100]
Train Loss: 0.0812, Train Acc: 0.4369, Train Macro F1: 0.3819
Val Loss: 0.1577, Val Acc: 0.3258, Val Macro F1: 0.1954, Val Weighted F1: 0.3172
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_15.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 3/10


Epoch 16/100: 100%|██████████| 625/625 [12:00<00:00,  1.15s/it]



Epoch [16/100]
Train Loss: 0.0778, Train Acc: 0.4602, Train Macro F1: 0.4054
Val Loss: 0.1514, Val Acc: 0.2329, Val Macro F1: 0.1611, Val Weighted F1: 0.2277
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_16.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 4/10


Epoch 17/100: 100%|██████████| 625/625 [11:57<00:00,  1.15s/it]



Epoch [17/100]
Train Loss: 0.0765, Train Acc: 0.4594, Train Macro F1: 0.4044
Val Loss: 0.1654, Val Acc: 0.3042, Val Macro F1: 0.1413, Val Weighted F1: 0.2660
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_17.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 5/10


Epoch 18/100: 100%|██████████| 625/625 [12:02<00:00,  1.16s/it]



Epoch [18/100]
Train Loss: 0.0729, Train Acc: 0.4747, Train Macro F1: 0.4318
Val Loss: 0.1658, Val Acc: 0.3231, Val Macro F1: 0.1982, Val Weighted F1: 0.3104
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_18.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 6/10


Epoch 19/100: 100%|██████████| 625/625 [11:57<00:00,  1.15s/it]



Epoch [19/100]
Train Loss: 0.0715, Train Acc: 0.4825, Train Macro F1: 0.4421
Val Loss: 0.1804, Val Acc: 0.3069, Val Macro F1: 0.1702, Val Weighted F1: 0.2818
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_19.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 7/10


Epoch 20/100: 100%|██████████| 625/625 [12:07<00:00,  1.16s/it]



Epoch [20/100]
Train Loss: 0.0665, Train Acc: 0.4946, Train Macro F1: 0.4628
Val Loss: 0.1715, Val Acc: 0.3394, Val Macro F1: 0.1898, Val Weighted F1: 0.3238
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_20.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 8/10


Epoch 21/100: 100%|██████████| 625/625 [12:05<00:00,  1.16s/it]



Epoch [21/100]
Train Loss: 0.0641, Train Acc: 0.5019, Train Macro F1: 0.4750
Val Loss: 0.1926, Val Acc: 0.3249, Val Macro F1: 0.1447, Val Weighted F1: 0.2902
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_21.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 9/10


Epoch 22/100: 100%|██████████| 625/625 [12:05<00:00,  1.16s/it]



Epoch [22/100]
Train Loss: 0.0614, Train Acc: 0.5188, Train Macro F1: 0.4944
Val Loss: 0.1877, Val Acc: 0.2572, Val Macro F1: 0.1600, Val Weighted F1: 0.2622
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_22.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 10/10


Epoch 23/100: 100%|██████████| 625/625 [11:59<00:00,  1.15s/it]



Epoch [23/100]
Train Loss: 0.0590, Train Acc: 0.5342, Train Macro F1: 0.5215
Val Loss: 0.1993, Val Acc: 0.3005, Val Macro F1: 0.1864, Val Weighted F1: 0.3039
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_23.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 11/10


Epoch 24/100: 100%|██████████| 625/625 [12:30<00:00,  1.20s/it]



Epoch [24/100]
Train Loss: 0.0567, Train Acc: 0.5377, Train Macro F1: 0.5154
Val Loss: 0.1931, Val Acc: 0.2121, Val Macro F1: 0.1494, Val Weighted F1: 0.2340
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_24.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 12/10


Epoch 25/100: 100%|██████████| 625/625 [12:29<00:00,  1.20s/it]



Epoch [25/100]
Train Loss: 0.0540, Train Acc: 0.5478, Train Macro F1: 0.5332
Val Loss: 0.1966, Val Acc: 0.2518, Val Macro F1: 0.1594, Val Weighted F1: 0.2472
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_25.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 13/10


Epoch 26/100: 100%|██████████| 625/625 [11:56<00:00,  1.15s/it]



Epoch [26/100]
Train Loss: 0.0525, Train Acc: 0.5555, Train Macro F1: 0.5415
Val Loss: 0.2080, Val Acc: 0.2987, Val Macro F1: 0.1888, Val Weighted F1: 0.3060
Class distribution: tensor([4709, 1205,  268,  683, 1743,  271, 1109])


checkpoint_epoch_26.pth:   0%|          | 0.00/426M [00:00<?, ?B/s]

⏳ No improvement. Patience: 14/10
⏹️ Early stopping at epoch 26. Best was epoch 12


[34m[1mwandb[0m: uploading artifact run-v0z2yof7-final_testclassification_report; uploading artifact run-v0z2yof7-final_testclass_metrics; uploading media/images/final_test/confusion_matrix_16250_3087bbef2264d04c69ce.png; uploading artifact run-v0z2yof7-testroc_curve_table; uploading media/table/test/roc_curve_table_16250_c9d69bbb0de0c06af93f.table.json (+ 3 more)
[34m[1mwandb[0m: uploading artifact run-v0z2yof7-final_testclassification_report; uploading artifact run-v0z2yof7-final_testclass_metrics; uploading artifact run-v0z2yof7-testroc_curve_table
[34m[1mwandb[0m: uploading artifact run-v0z2yof7-final_testclass_metrics; uploading artifact run-v0z2yof7-testroc_curve_table
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                                                                                epoch ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
[34m[1mwandb[0m:  


Training completed!
Best validation macro F1 score: 0.3326
Final test accuracy: 0.2533
Final test macro F1 score: 0.1493
Final test weighted F1 score: 0.2754
Final test loss: 0.1763
