<a href="https://colab.research.google.com/github/Ash-0154/A-Machine-Learning-Approach-to-Predicting-Mental-Health-State-from-Smartwatches/blob/main/Untitled22.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch torchvision torchaudio transformers datasets kagglehub librosa soundfile scikit-learn pandas numpy matplotlib seaborn psutil tqdm

In [None]:
!pip install kagglehub --quiet
import kagglehub
crema_d_path = kagglehub.dataset_download("ejlok1/cremad")
ravdess_path  = kagglehub.dataset_download("uwrfkaggler/ravdess-emotional-speech-audio")
tess_path = kagglehub.dataset_download("ejlok1/toronto-emotional-speech-set-tess")
shemo_path = kagglehub.dataset_download("mansourehk/shemo-persian-speech-emotion-detection-database")
subesco_path=kagglehub.dataset_download("sushmit0109/subescobangla-speech-emotion-dataset")

print("Path to subesco:", subesco_path)
print(shemo_path)

Using Colab cache for faster access to the 'cremad' dataset.
Using Colab cache for faster access to the 'ravdess-emotional-speech-audio' dataset.
Using Colab cache for faster access to the 'toronto-emotional-speech-set-tess' dataset.
Using Colab cache for faster access to the 'shemo-persian-speech-emotion-detection-database' dataset.
Using Colab cache for faster access to the 'subescobangla-speech-emotion-dataset' dataset.
Path to subesco: /kaggle/input/subescobangla-speech-emotion-dataset
/kaggle/input/shemo-persian-speech-emotion-detection-database


In [None]:
#!/usr/bin/env python3
"""
unified_federated_emotion_recognition.py

Combined federated learning system for text and speech emotion recognition.
Supports multiple datasets across both modalities with federated averaging.
Optimized for A100 GPU with enhanced architectures and regularization.
"""

# ============================================================================
# PACKAGE INSTALLATION (Run this cell first in Jupyter Notebook)
# ============================================================================
print("Installing required packages...")
print("="*80)

import sys
import subprocess

packages = [
    'transformers',
    'datasets',
    'kagglehub',
    'torch',
    'torchvision',
    'torchaudio',
    'librosa',
    'scikit-learn',
    'pandas',
    'numpy',
    'matplotlib',
    'seaborn',
    'psutil',
    'tqdm'
]

def install_packages():
    """Install all required packages"""
    for package in packages:
        try:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])
            print(f"  ✓ {package} installed successfully")
        except subprocess.CalledProcessError:
            print(f"  ✗ Failed to install {package}")
        except Exception as e:
            print(f"  ✗ Error installing {package}: {e}")

# Uncomment the line below to install packages
# install_packages()

print("\n" + "="*80)
print("Package installation complete!")
print("="*80 + "\n")

# ============================================================================
# IMPORTS
# ============================================================================

import os
import random
import hashlib
import gc
import psutil
from copy import deepcopy
from pathlib import Path
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, WeightedRandomSampler
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (accuracy_score, f1_score, recall_score,
                            precision_score, confusion_matrix,
                            precision_recall_fscore_support)
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Text processing
from transformers import (DistilBertTokenizerFast, DistilBertModel,
                         get_linear_schedule_with_warmup)
from torch.optim import AdamW

# Audio processing
import librosa
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Dataset loaders
import kagglehub
from datasets import load_dataset

# ============================================================================
# CONFIGURATION
# ============================================================================

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

# Training hyperparameters
BATCH_SIZE = 32 if torch.cuda.is_available() else 16
ROUNDS = 30
LOCAL_EPOCHS = 3
MAX_LEN = 128  # Text sequence length
SAMPLE_LIMIT = 20000  # Text samples per dataset

# Audio parameters
SAMPLE_RATE = 16000
TARGET_LEN = SAMPLE_RATE * 4
CACHE_DIR = "/tmp/emo_cache_fl"
os.makedirs(CACHE_DIR, exist_ok=True)

# Learning rates
LR_TEXT_CLASSIFIER = 3e-5
LR_TEXT_BERT = 2e-5
LR_AUDIO_NEW = 8e-4
LR_AUDIO_WAV2VEC = 2e-5

# Regularization
WEIGHT_DECAY = 1e-3
GRAD_CLIP = 2.0
LABEL_SMOOTHING = 0.1
WARMUP_RATIO = 0.1
WARMUP_ROUNDS = 2

# Audio augmentation
MIXUP_PROB = 0.65
MIXUP_ALPHA = 0.45
SPEC_AUG_MASK = 0.3
PITCH_SHIFT_PROB = 0.35
TIME_STRETCH_PROB = 0.35
NOISE_PROB = 0.25

# Federated learning
FEDPROX_MU = 1.5e-3
EARLY_STOPPING_PATIENCE = 5
CONFUSION_BOOST_FACTOR = 2.8

NUM_CLASSES = 7  # Unified emotion classes
NUM_WORKERS = 4 if torch.cuda.is_available() else 0

use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler() if use_amp else None

print(f"Batch Size: {BATCH_SIZE}, Mixed Precision: {use_amp}")

# ============================================================================
# UNIFIED EMOTION MAPPING
# ============================================================================

# Unified emotion classes (7 core emotions)
UNIFIED_EMOTIONS = ['anger', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']

# Text emotion mappings
TEXT_EMOTION_MAP = {
    'anger': 'anger', 'fear': 'fear', 'joy': 'happy',
    'sadness': 'sad', 'love': 'happy', 'surprise': 'surprise',
    '0': 'sad', '1': 'happy', '2': 'happy',
    '3': 'anger', '4': 'fear', '5': 'surprise',
    'happiness': 'happy', 'neutral': 'neutral', 'worry': 'fear',
    'hate': 'anger', 'boredom': 'neutral', 'enthusiasm': 'happy',
    'fun': 'happy', 'relief': 'happy', 'empty': 'sad',
    'disgust': 'disgust', 'shame': 'sad', 'guilt': 'sad',
    'confusion': 'neutral', 'desire': 'happy', 'sarcasm': 'neutral'
}

# Audio emotion mappings
AUDIO_LABEL_MAP = {
    "RAVDESS": {"01":"neutral","02":"neutral","03":"happy","04":"sad",
                "05":"anger","06":"fear","07":"disgust","08":"surprise"},
    "CREMA-D": {"ANG":"anger","DIS":"disgust","FEA":"fear","HAP":"happy",
                "NEU":"neutral","SAD":"sad"},
    "TESS": {"angry":"anger","disgust":"disgust","fear":"fear","happy":"happy",
             "neutral":"neutral","pleasant surprise":"surprise","ps":"surprise","sad":"sad"},
    "ShEMO": {"A":"anger","H":"happy","S":"sad","N":"neutral","F":"fear","W":"surprise"},
    "SUBESCO": {"angry":"anger","anger":"anger","happy":"happy","happiness":"happy",
                "sad":"sad","sadness":"sad","neutral":"neutral","surprise":"surprise",
                "surprised":"surprise","fear":"fear","fearful":"fear"}
}

GENERIC_TOKENS = {"angry":"anger","anger":"anger","happy":"happy","joy":"happy",
                  "sad":"sad","sadness":"sad","neutral":"neutral","surprise":"surprise",
                  "fear":"fear","disgust":"disgust"}

def normalize_text_label(label):
    """Normalize text emotion labels to unified format"""
    label_str = str(label).strip().lower()
    if label_str.replace('.', '').isdigit():
        label_str = str(int(float(label_str)))
    return TEXT_EMOTION_MAP.get(label_str, label_str)

def normalize_audio_label(token, dataset_name=None):
    """Normalize audio emotion labels to unified format"""
    if token is None:
        return None
    t = token.lower()

    if dataset_name and dataset_name in AUDIO_LABEL_MAP:
        dmap = AUDIO_LABEL_MAP[dataset_name]
        if token in dmap:
            return dmap[token]
        if t in dmap:
            return dmap[t]
        if dataset_name == "RAVDESS" and token.isdigit():
            return dmap.get(token)

    for k, v in GENERIC_TOKENS.items():
        if k in t:
            return v
    return None

# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def cleanup_model(model):
    """Clean up model from memory"""
    if hasattr(model, 'cpu'):
        try:
            model.cpu()
        except:
            pass
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def aggressive_cleanup():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def print_memory_usage():
    """Print current memory usage"""
    ram = psutil.virtual_memory().percent
    vram = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
    print(f"   RAM: {ram:.1f}% | VRAM: {vram:.2f}GB")

def compute_metrics(trues, preds):
    """Compute evaluation metrics"""
    if len(trues) == 0:
        return 0.0, 0.0, 0.0, 0.0
    acc = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='weighted', zero_division=0)
    recall = recall_score(trues, preds, average='weighted', zero_division=0)
    precision = precision_score(trues, preds, average='weighted', zero_division=0)
    return acc, f1, recall, precision

def get_lr_multiplier(round_num, warmup_rounds=WARMUP_ROUNDS, total_rounds=ROUNDS):
    """Warmup + Cosine annealing schedule"""
    if round_num < warmup_rounds:
        return (round_num + 1) / warmup_rounds
    else:
        progress = (round_num - warmup_rounds) / (total_rounds - warmup_rounds)
        return 0.5 * (1 + np.cos(np.pi * progress))

def mixup_batch_audio(batch, alpha=0.5):
    """Apply mixup augmentation to audio batch"""
    wav, mel, wav2vec_emb, labels = batch
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
    n = wav.size(0)
    idx = torch.randperm(n)

    wav2 = wav[idx]
    mel2 = mel[idx]
    wav2vec2 = None if wav2vec_emb is None else wav2vec_emb[idx]
    labels2 = labels[idx]

    wav_m = lam * wav + (1 - lam) * wav2
    mel_m = lam * mel + (1 - lam) * mel2
    wav2vec_m = None if wav2vec_emb is None else lam * wav2vec_emb + (1 - lam) * wav2vec2

    return wav_m, mel_m, wav2vec_m, labels, labels2, lam

# ============================================================================
# TEXT DATA PREPARATION
# ============================================================================

print("\n" + "="*80)
print("LOADING TEXT DATASETS")
print("="*80)

# Load HuggingFace text datasets
try:
    hf_clients = {
        "dair-ai/emotion": load_dataset("dair-ai/emotion"),
        "boltuix/emotions-dataset": load_dataset("boltuix/emotions-dataset"),
        "mteb/emotion": load_dataset("mteb/emotion"),
        "go_emotions": load_dataset("go_emotions")
    }
except Exception as e:
    print(f"Warning: Some HF datasets failed to load: {e}")
    hf_clients = {}

# Load Kaggle text dataset
try:
    dataset_path = kagglehub.dataset_download("pashupatigupta/emotion-detection-from-text")
    csv_files = [f for f in os.listdir(dataset_path) if f.endswith(".csv")]
    if csv_files:
        csv_path = os.path.join(dataset_path, csv_files[0])
        df_kaggle = pd.read_csv(csv_path)
        kaggle_client = {"train": df_kaggle.to_dict("records")}
        print(f"✓ Loaded Kaggle text dataset: {len(df_kaggle)} samples")
    else:
        kaggle_client = {}
except Exception as e:
    print(f"Warning: Kaggle text dataset not loaded: {e}")
    kaggle_client = {}

text_clients_datasets = {**hf_clients, "pashupatigupta/emotion": kaggle_client}

def prepare_text_client_data(ds, dataset_name):
    """Prepare text data from various dataset formats"""
    if isinstance(ds, dict) and "train" in ds:
        data_iter = ds["train"]
        sample_item = ds["train"][0]
    else:
        return [], []

    # Find text column
    preferred_text_cols = ['text', 'sentence', 'content', 'tweet', 'utterance']
    lower_keys = {k.lower(): k for k in sample_item.keys()}
    text_candidates = [lower_keys[col] for col in preferred_text_cols if col in lower_keys]

    if not text_candidates:
        text_candidates = [k for k in sample_item.keys()
                          if any(tok in k.lower() for tok in preferred_text_cols)]

    text_col = text_candidates[0] if text_candidates else list(sample_item.keys())[0]

    # Find label column
    preferred_label_cols = ['label', 'emotion', 'sentiment', 'category']
    label_candidates = [lower_keys[col] for col in preferred_label_cols if col in lower_keys]

    if not label_candidates:
        label_candidates = [k for k in sample_item.keys()
                           if any(tok in k.lower() for tok in preferred_label_cols)]

    label_col = label_candidates[0] if label_candidates else list(sample_item.keys())[1]

    texts, labels = [], []
    for i, item in enumerate(data_iter):
        if len(texts) >= SAMPLE_LIMIT:
            break

        text = str(item[text_col]).strip()
        if len(text) < 10:
            continue

        lbl = item[label_col]
        if isinstance(lbl, list):
            lbl = lbl[0]

        lbl = normalize_text_label(lbl)
        if lbl in UNIFIED_EMOTIONS:
            texts.append(text)
            labels.append(lbl)

    print(f"  ✓ {dataset_name}: {len(texts)} samples")
    return texts, labels

text_client_texts, text_client_labels = {}, {}
for name, ds in text_clients_datasets.items():
    if ds:
        texts, labels = prepare_text_client_data(ds, name)
        if texts:
            text_client_texts[name] = texts
            text_client_labels[name] = labels

print(f"Loaded {len(text_client_texts)} text datasets")

# ============================================================================
# AUDIO DATA PREPARATION
# ============================================================================

print("\n" + "="*80)
print("LOADING AUDIO DATASETS")
print("="*80)

# Audio dataset paths (update these for your environment)
AUDIO_DATASET_PATHS = {
    "CREMA-D": "/kaggle/input/cremad/AudioWAV",
    "RAVDESS": "/kaggle/input/ravdess-emotional-speech-audio",
    "TESS": "/kaggle/input/toronto-emotional-speech-set-tess/TESS Toronto emotional speech set data",
    "ShEMO-Male": "/kaggle/input/shemo-persian-speech-emotion-detection-database/male",
    "ShEMO-Female": "/kaggle/input/shemo-persian-speech-emotion-detection-database/female",
    "SUBESCO": "/kaggle/input/subescobangla-speech-emotion-dataset/SUBESCO"
}

def gather_ravdess(path):
    """Gather RAVDESS dataset files"""
    files, labels = [], []
    for root, _, fnames in os.walk(path):
        for f in fnames:
            if f.lower().endswith('.wav'):
                parts = f.split('-')
                if len(parts) >= 3:
                    emo = normalize_audio_label(parts[2], "RAVDESS")
                    if emo:
                        files.append(os.path.join(root, f))
                        labels.append(emo)
    return files, labels

def gather_cremad(path):
    """Gather CREMA-D dataset files"""
    files, labels = [], []
    for root, _, fnames in os.walk(path):
        for f in fnames:
            if f.lower().endswith('.wav'):
                parts = f.split('_')
                if len(parts) >= 3:
                    emo = AUDIO_LABEL_MAP["CREMA-D"].get(parts[2][:3].upper())
                    if emo:
                        files.append(os.path.join(root, f))
                        labels.append(emo)
    return files, labels

def gather_tess(path):
    """Gather TESS dataset files"""
    files, labels = [], []
    for root, _, fnames in os.walk(path):
        parent_emo = normalize_audio_label(os.path.basename(root).lower(), "TESS")
        for f in fnames:
            if f.lower().endswith('.wav'):
                emo = normalize_audio_label(f.lower(), "TESS") or parent_emo
                if emo:
                    files.append(os.path.join(root, f))
                    labels.append(emo)
    return files, labels

def gather_shemo(path):
    """Gather ShEMO dataset files"""
    files, labels = [], []
    for root, _, fnames in os.walk(path):
        for f in fnames:
            if f.lower().endswith('.wav') and len(f) >= 4:
                emotion_code = f[3].upper()
                emotion = AUDIO_LABEL_MAP["ShEMO"].get(emotion_code)
                if emotion:
                    files.append(os.path.join(root, f))
                    labels.append(emotion)
    return files, labels

def gather_subesco(path):
    """Gather SUBESCO dataset files"""
    files, labels = [], []
    for root, dirs, fnames in os.walk(path):
        folder_name = os.path.basename(root).lower()
        folder_emotion = normalize_audio_label(folder_name, "SUBESCO")

        for f in fnames:
            if f.lower().endswith('.wav'):
                full_path = os.path.join(root, f)
                fname_lower = f.lower()

                file_emotion = None
                for emo_key in AUDIO_LABEL_MAP["SUBESCO"].keys():
                    if emo_key in fname_lower:
                        file_emotion = AUDIO_LABEL_MAP["SUBESCO"][emo_key]
                        break

                emotion = file_emotion or folder_emotion
                if emotion:
                    files.append(full_path)
                    labels.append(emotion)
    return files, labels

# Gather audio datasets
audio_clients_files, audio_clients_labels = {}, {}
shemo_files_combined, shemo_labels_combined = [], []

for name, path in AUDIO_DATASET_PATHS.items():
    if not os.path.exists(path):
        print(f"  ✗ Missing: {path}")
        continue

    if "RAVDESS" in name.upper():
        f, l = gather_ravdess(path)
    elif "CREMA" in name.upper():
        f, l = gather_cremad(path)
    elif "TESS" in name.upper():
        f, l = gather_tess(path)
    elif "SHEMO" in name.upper():
        f, l = gather_shemo(path)
        shemo_files_combined.extend(f)
        shemo_labels_combined.extend(l)
        print(f"  ✓ {name}: {len(f)} files | Dist: {dict(Counter(l))}")
        continue
    elif "SUBESCO" in name.upper():
        f, l = gather_subesco(path)
    else:
        continue

    if f and len(f) > 100:
        audio_clients_files[name] = f
        audio_clients_labels[name] = l
        print(f"  ✓ {name}: {len(f)} files | Dist: {dict(Counter(l))}")

# Combine ShEMO datasets
if shemo_files_combined and len(shemo_files_combined) > 100:
    audio_clients_files["ShEMO"] = shemo_files_combined
    audio_clients_labels["ShEMO"] = shemo_labels_combined
    print(f"  ✓ ShEMO (Combined): {len(shemo_files_combined)} files | Dist: {dict(Counter(shemo_labels_combined))}")

print(f"Loaded {len(audio_clients_files)} audio datasets")

# ============================================================================
# LABEL ENCODING
# ============================================================================

print("\n" + "="*80)
print("ENCODING LABELS")
print("="*80)

# Collect all labels from both modalities
all_labels_combined = []
all_labels_combined.extend([lbl for labs in text_client_labels.values() for lbl in labs])
all_labels_combined.extend([lbl for labs in audio_clients_labels.values() for lbl in labs])

# Filter to only include unified emotions
all_labels_filtered = [lbl for lbl in all_labels_combined if lbl in UNIFIED_EMOTIONS]

# Create label encoder
label_encoder = LabelEncoder()
label_encoder.fit(UNIFIED_EMOTIONS)
NUM_CLASSES = len(label_encoder.classes_)

print(f"Unified emotion classes ({NUM_CLASSES}): {list(label_encoder.classes_)}")

# Compute class weights
all_label_nums = label_encoder.transform(all_labels_filtered)
class_weights = compute_class_weight('balanced', classes=np.arange(NUM_CLASSES), y=all_label_nums)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

print(f"Class weights: {dict(zip(label_encoder.classes_, class_weights.round(3)))}")

# ============================================================================
# TEXT MODEL ARCHITECTURE
# ============================================================================

class TokenCNN_BiLSTM_Attention(nn.Module):
    """Hybrid CNN-BiLSTM-Attention module for text"""
    def __init__(self, hidden_size=768, conv_ch=64, lstm_hidden=128, lstm_layers=2, out_dim=128):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, conv_ch, (3, 3), padding=1),
            nn.BatchNorm2d(conv_ch),
            nn.ReLU(),
            nn.MaxPool2d((2, 2))
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(conv_ch, conv_ch * 2, (3, 3), padding=1),
            nn.BatchNorm2d(conv_ch * 2),
            nn.ReLU(),
            nn.MaxPool2d((2, 2))
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(conv_ch * 2, conv_ch * 2, (3, 3), padding=1),
            nn.BatchNorm2d(conv_ch * 2),
            nn.ReLU()
        )

        self.lstm = nn.LSTM(
            input_size=conv_ch * 2,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True
        )

        self.attn_fc = nn.Linear(2 * lstm_hidden, 1)
        self.out_proj = nn.Linear(2 * lstm_hidden, out_dim)

    def forward(self, token_embs):
        B, L, H = token_embs.size()
        x = token_embs.permute(0, 2, 1).unsqueeze(1)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        x = x.mean(dim=2).permute(0, 2, 1)
        outputs, _ = self.lstm(x)

        attn_scores = self.attn_fc(outputs).squeeze(-1)
        attn_weights = torch.softmax(attn_scores, dim=1).unsqueeze(-1)
        pooled = torch.sum(outputs * attn_weights, dim=1)

        return self.out_proj(pooled)

class TextEmotionModel(nn.Module):
    """Text emotion recognition model"""
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.text_proj = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.ReLU()
        )
        self.hybrid = TokenCNN_BiLSTM_Attention(out_dim=64)
        self.classifier = nn.Linear(128 + 64, num_classes)

    def forward(self, input_ids, attention_mask):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        text_feat = self.text_proj(bert_out[:, 0, :])
        hybrid_feat = self.hybrid(bert_out)
        combined = torch.cat([text_feat, hybrid_feat], dim=1)
        return self.classifier(combined)

# ============================================================================
# AUDIO MODEL ARCHITECTURE
# ============================================================================

class CRNNBranch(nn.Module):
    """Enhanced CNN-BiLSTM-Attention for audio spectrograms"""
    def __init__(self, in_channels=1, hidden_size=384, out_dim=384):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.1)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.15)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            nn.Dropout2d(0.2)
        )

        self.feature_size = 256 * 16

        self.bilstm = nn.LSTM(
            self.feature_size,
            hidden_size,
            batch_first=True,
            bidirectional=True,
            num_layers=3,
            dropout=0.3
        )

        self.attention = nn.MultiheadAttention(
            hidden_size * 2,
            num_heads=8,
            dropout=0.2,
            batch_first=True
        )

        self.proj = nn.Sequential(
            nn.Linear(hidden_size * 2, out_dim),
            nn.LayerNorm(out_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        b, c, f, t = x.size()
        x = x.permute(0, 3, 1, 2).contiguous().view(b, t, -1)

        self.bilstm.flatten_parameters()
        x, _ = self.bilstm(x)

        x, _ = self.attention(x, x, x)
        x = x.mean(dim=1)

        return self.proj(x)

class FocalLoss(nn.Module):
    """Focal loss for handling class imbalance"""
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class AudioEmotionModel(nn.Module):
    """Audio emotion recognition model"""
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()

        # Branch A: Wav2Vec2
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(DEVICE)
        for p in self.wav2vec.parameters():
            p.requires_grad = False

        # Branch B: Enhanced CRNN
        self.crnn = CRNNBranch(hidden_size=384, out_dim=384)

        # Fusion layers
        self.fusion = nn.Sequential(
            nn.Linear(768 + 384, 768),
            nn.LayerNorm(768),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(768, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.35),
            nn.Linear(512, 384),
            nn.LayerNorm(384),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(384, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.35),
            nn.Linear(128, num_classes)
        )

    def forward(self, wav, mel, wav2vec_emb=None):
        # Process wav2vec features
        if wav2vec_emb is not None and not any(p.requires_grad for p in self.wav2vec.parameters()):
            wav_emb = wav2vec_emb.to(DEVICE)
        else:
            processor_wav2vec = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base",
                                                                   sampling_rate=SAMPLE_RATE)
            input_values = processor_wav2vec([w.cpu().numpy() for w in wav],
                                            sampling_rate=SAMPLE_RATE,
                                            return_tensors="pt",
                                            padding=True).input_values.to(DEVICE)

            if any(p.requires_grad for p in self.wav2vec.parameters()):
                wav_emb = self.wav2vec(input_values).last_hidden_state.mean(dim=1)
            else:
                with torch.no_grad():
                    wav_emb = self.wav2vec(input_values).last_hidden_state.mean(dim=1)

        # Process spectrogram features
        crnn_emb = self.crnn(mel.to(DEVICE))

        # Fusion and classification
        fused = self.fusion(torch.cat([wav_emb, crnn_emb], dim=1))
        return self.classifier(fused)

# ============================================================================
# TEXT DATASET CLASS
# ============================================================================

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

class TextEmotionDataset(Dataset):
    """Dataset for text emotion recognition"""
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = [label_encoder.transform([l])[0] for l in labels]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        toks = tokenizer(text, truncation=True, padding='max_length',
                        max_length=MAX_LEN, return_tensors='pt')

        input_ids = toks['input_ids'].squeeze(0)
        attention_mask = toks['attention_mask'].squeeze(0)
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "label": label}

def text_collate_fn(batch):
    """Collate function for text batches"""
    input_ids = torch.stack([b["input_ids"] for b in batch])
    attention_mask = torch.stack([b["attention_mask"] for b in batch])
    labels = torch.stack([b["label"] for b in batch])
    return {"input_ids": input_ids, "attention_mask": attention_mask, "label": labels}

# ============================================================================
# AUDIO DATASET CLASS
# ============================================================================

processor_audio = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base",
                                                     sampling_rate=SAMPLE_RATE)

def cache_audio_file_features(wav_path, wav2vec_model=None, processor_obj=None):
    """Cache audio features to disk"""
    h = hashlib.sha1(wav_path.encode()).hexdigest()[:10]
    cache_file = os.path.join(CACHE_DIR, f"{Path(wav_path).stem}_{h}.npz")

    if os.path.exists(cache_file):
        try:
            np.load(cache_file)
            return cache_file
        except:
            try:
                os.remove(cache_file)
            except:
                pass

    try:
        wav, _ = librosa.load(wav_path, sr=SAMPLE_RATE)
        wav = librosa.effects.trim(wav, top_db=25)[0]
        wav = librosa.util.normalize(wav)
        wav = np.pad(wav, (0, max(0, TARGET_LEN - len(wav))))[:TARGET_LEN]

        mel = librosa.feature.melspectrogram(y=wav, sr=SAMPLE_RATE, n_mels=128,
                                            n_fft=2048, hop_length=256,
                                            fmin=50, fmax=8000)
        mel = librosa.power_to_db(mel, ref=np.max)

        wav2vec_arr = None
        if wav2vec_model and processor_obj:
            input_values = processor_obj(wav, sampling_rate=SAMPLE_RATE,
                                        return_tensors="pt", padding=True).input_values.to(DEVICE)
            with torch.no_grad():
                out = wav2vec_model(input_values)
                wav2vec_arr = out.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().astype(np.float32)

        if wav2vec_arr is None:
            np.savez_compressed(cache_file, wav=wav.astype(np.float32), mel=mel.astype(np.float32))
        else:
            np.savez_compressed(cache_file, wav=wav.astype(np.float32),
                              mel=mel.astype(np.float32), wav2vec=wav2vec_arr)

        return cache_file
    except Exception as e:
        return None

class AudioEmotionDataset(Dataset):
    """Dataset for audio emotion recognition with augmentation"""
    def __init__(self, files, labels, train=True, dataset_name=None):
        self.files = files
        self.labels = label_encoder.transform(labels)
        self.labels_str = labels
        self.train = train
        self.dataset_name = dataset_name

        # Dataset-specific augmentation multipliers
        if dataset_name == "ShEMO":
            self.aug_multiplier = 2.2
        elif dataset_name == "TESS":
            self.aug_multiplier = 1.8
        elif dataset_name == "RAVDESS":
            self.aug_multiplier = 0.9
        elif dataset_name == "SUBESCO":
            self.aug_multiplier = 1.3
        else:
            self.aug_multiplier = 1.1

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        cache = cache_audio_file_features(self.files[idx])

        if cache:
            data = np.load(cache)
            wav = data["wav"]
            mel = data["mel"]
            wav2vec_arr = data["wav2vec"] if 'wav2vec' in data else None
            data.close()
        else:
            wav = np.zeros(TARGET_LEN, dtype=np.float32)
            mel = np.zeros((128, int(TARGET_LEN / 256) + 1), dtype=np.float32)
            wav2vec_arr = None

        # Apply augmentation during training
        if self.train:
            # Gaussian noise
            if random.random() < NOISE_PROB * self.aug_multiplier:
                wav += 0.015 * self.aug_multiplier * np.random.randn(len(wav))

            # Volume augmentation
            wav *= (0.7 + 0.6 * np.random.random())

            # Pitch shifting
            if random.random() < PITCH_SHIFT_PROB * self.aug_multiplier:
                try:
                    n_steps = random.choice([-2, -1, 1, 2])
                    wav = librosa.effects.pitch_shift(wav, sr=SAMPLE_RATE, n_steps=n_steps)
                    wav = np.pad(wav, (0, max(0, TARGET_LEN - len(wav))))[:TARGET_LEN]
                except:
                    pass

            # Time stretching
            if random.random() < TIME_STRETCH_PROB * self.aug_multiplier:
                try:
                    rate = random.uniform(0.85, 1.15)
                    wav = librosa.effects.time_stretch(wav, rate=rate)
                    wav = np.pad(wav, (0, max(0, TARGET_LEN - len(wav))))[:TARGET_LEN]
                except:
                    pass

            # Spectrogram augmentation
            mel = self.spec_augment(mel, mask_percent=SPEC_AUG_MASK * self.aug_multiplier)

        wav_t = torch.tensor(wav, dtype=torch.float32)
        mel_t = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)
        wav2_t = torch.tensor(wav2vec_arr, dtype=torch.float32) if wav2vec_arr is not None else None
        label = int(self.labels[idx])

        return wav_t, mel_t, wav2_t, label

    def spec_augment(self, spec, mask_percent=SPEC_AUG_MASK):
        """Apply SpecAugment"""
        spec = spec.copy()
        T, F = spec.shape[1], spec.shape[0]

        # Time masking
        if T > 0:
            mask_size = int(T * mask_percent)
            start = random.randint(0, max(0, T - mask_size))
            end = min(T, start + mask_size)
            spec[:, start:end] = spec.min()

        # Frequency masking
        if F > 0:
            mask_size = int(F * mask_percent)
            start = random.randint(0, max(0, F - mask_size))
            end = min(F, start + mask_size)
            spec[start:end, :] = spec.min()

        return spec

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

def train_text_model_one_epoch(model, loader, optimizer, scheduler, criterion):
    """Train text model for one epoch"""
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in loader:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["label"].to(DEVICE)

        with torch.amp.autocast(device_type='cuda', enabled=use_amp):
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

        if use_amp and scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            optimizer.step()

        if scheduler:
            scheduler.step()

        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        total_loss += loss.item() * labels.size(0)

    return total_loss / total_samples, total_correct / total_samples

def train_audio_model_one_epoch(model, loader, optimizer, criterion,
                                global_state_for_prox=None, mu=0.0, use_mixup=True):
    """Train audio model for one epoch"""
    model.train()
    total_loss = 0.0
    preds, trues = [], []
    total_samples = 0

    for batch in loader:
        wav, mel, wav2vec_emb, y = batch
        wav, mel, y = wav.to(DEVICE), mel.to(DEVICE), y.to(DEVICE)
        if wav2vec_emb is not None:
            wav2vec_emb = wav2vec_emb.to(DEVICE)

        optimizer.zero_grad()

        # Apply mixup
        if use_mixup and random.random() < MIXUP_PROB:
            wav, mel, wav2vec_emb, y1, y2, lam = mixup_batch_audio((wav, mel, wav2vec_emb, y))
            y1, y2 = y1.to(DEVICE), y2.to(DEVICE)
            use_mixup_this_batch = True
        else:
            y1 = y2 = None
            lam = None
            use_mixup_this_batch = False

        if use_amp and scaler:
            with torch.cuda.amp.autocast(enabled=True):
                logits = model(wav, mel, wav2vec_emb)
                if lam is None:
                    loss = criterion(logits, y)
                else:
                    loss = lam * criterion(logits, y1) + (1 - lam) * criterion(logits, y2)

                # FedProx regularization
                if mu > 0 and global_state_for_prox is not None:
                    prox = 0.0
                    for (k, p) in model.state_dict().items():
                        prox += torch.sum((p.to(DEVICE) - global_state_for_prox[k].to(DEVICE)) ** 2)
                    loss = loss + (mu / 2.0) * prox

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(wav, mel, wav2vec_emb)
            if lam is None:
                loss = criterion(logits, y)
            else:
                loss = lam * criterion(logits, y1) + (1 - lam) * criterion(logits, y2)

            if mu > 0 and global_state_for_prox is not None:
                prox = 0.0
                for (k, p) in model.state_dict().items():
                    prox += torch.sum((p.to(DEVICE) - global_state_for_prox[k].to(DEVICE)) ** 2)
                loss = loss + (mu / 2.0) * prox

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            optimizer.step()

        batch_n = wav.size(0)
        total_loss += loss.item() * batch_n
        total_samples += batch_n

        preds.extend(logits.argmax(1).cpu().numpy())
        trues.extend(y.cpu().numpy())

    avg_loss = total_loss / max(1, total_samples)
    acc, f1, recall, precision = compute_metrics(trues, preds)
    return avg_loss, acc, f1, recall, precision, trues, preds

def evaluate_text_model(model, loader, criterion):
    """Evaluate text model"""
    model.eval()
    total_correct = 0
    total_samples = 0
    preds_list, labels_list = [], []
    total_loss = 0.0

    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            preds = outputs.argmax(dim=1)

            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
            total_loss += loss.item() * labels.size(0)
            preds_list.extend(preds.cpu().numpy())
            labels_list.extend(labels.cpu().numpy())

    acc = total_correct / total_samples
    loss = total_loss / total_samples
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels_list, preds_list, average='macro', zero_division=0
    )

    return acc, loss, precision, recall, f1, labels_list, preds_list

def evaluate_audio_model(model, loader, criterion, device):
    """Evaluate audio model"""
    model.eval()
    total_loss = 0.0
    preds, trues = [], []
    total_samples = 0

    with torch.no_grad(), torch.cuda.amp.autocast():
        for wav, mel, wav2vec_emb, y in loader:
            wav, mel, y = wav.to(device), mel.to(device), y.to(device)
            if wav2vec_emb is not None:
                wav2vec_emb = wav2vec_emb.to(device)

            logits = model(wav, mel, wav2vec_emb)
            loss = criterion(logits, y)

            batch_n = wav.size(0)
            total_loss += loss.item() * batch_n
            total_samples += batch_n
            preds.extend(logits.argmax(1).cpu().numpy())
            trues.extend(y.cpu().numpy())

    avg_loss = total_loss / max(1, total_samples)
    acc, f1, recall, precision = compute_metrics(trues, preds)
    return avg_loss, acc, f1, recall, precision, trues, preds

# ============================================================================
# PREPARE DATALOADERS
# ============================================================================

print("\n" + "="*80)
print("PREPARING DATALOADERS")
print("="*80)

# Text dataloaders
text_client_loaders = {}
text_val_loaders = {}

for name in text_client_texts.keys():
    ds = TextEmotionDataset(text_client_texts[name], text_client_labels[name])
    val_size = int(0.2 * len(ds))
    train_size = len(ds) - val_size
    train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])

    text_client_loaders[name] = DataLoader(train_ds, batch_size=BATCH_SIZE,
                                          shuffle=True, collate_fn=text_collate_fn)
    text_val_loaders[name] = DataLoader(val_ds, batch_size=BATCH_SIZE * 2,
                                       shuffle=False, collate_fn=text_collate_fn)
    print(f"  ✓ Text - {name}: Train={len(train_ds)}, Val={len(val_ds)}")

# Audio dataloaders
print("\nCaching audio features...")
if audio_clients_files:
    wav2vec_cache = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(DEVICE)
    wav2vec_cache.eval()
    for p in wav2vec_cache.parameters():
        p.requires_grad = False

    all_audio_files = [f for files in audio_clients_files.values() for f in files]

    def cache_file_wrapper(fpath):
        return cache_audio_file_features(fpath, wav2vec_cache, processor_audio)

    max_workers = min(8, os.cpu_count() or 1)
    with tqdm(total=len(all_audio_files), desc="Caching", unit="file") as pbar:
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(cache_file_wrapper, fpath): fpath
                      for fpath in all_audio_files}
            for future in as_completed(futures):
                try:
                    future.result()
                except:
                    pass
                pbar.update(1)

    cleanup_model(wav2vec_cache)
    aggressive_cleanup()

audio_client_loaders = {}
audio_val_loaders = {}

for name in audio_clients_files.keys():
    ds = AudioEmotionDataset(audio_clients_files[name], audio_clients_labels[name],
                            train=True, dataset_name=name)
    val_size = max(15, int(len(ds) * 0.18))
    train_size = len(ds) - val_size
    train_ds, val_ds = random_split(ds, [train_size, val_size])
    train_ds.dataset.train = True
    val_ds.dataset.train = False

    audio_client_loaders[name] = DataLoader(train_ds, batch_size=BATCH_SIZE,
                                           shuffle=True, num_workers=NUM_WORKERS,
                                           pin_memory=True)
    audio_val_loaders[name] = DataLoader(val_ds, batch_size=BATCH_SIZE,
                                        shuffle=False, num_workers=NUM_WORKERS,
                                        pin_memory=True)
    print(f"  ✓ Audio - {name}: Train={len(train_ds)}, Val={len(val_ds)}")

# ============================================================================
# FEDERATED LEARNING MAIN LOOP
# ============================================================================

print("\n" + "="*80)
print("STARTING UNIFIED FEDERATED LEARNING")
print("="*80)
print(f"Text clients: {list(text_client_texts.keys())}")
print(f"Audio clients: {list(audio_clients_files.keys())}")
print(f"Rounds: {ROUNDS}, Local Epochs: {LOCAL_EPOCHS}")
print(f"Batch Size: {BATCH_SIZE}")

# Initialize global models
text_global_model = TextEmotionModel().to(DEVICE) if text_client_texts else None
audio_global_model = AudioEmotionModel().to(DEVICE) if audio_clients_files else None

# Loss functions
text_criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
audio_criterion = FocalLoss(alpha=class_weights_tensor, gamma=2.0)

# Track best models
best_text_acc = 0.0
best_audio_acc = 0.0
best_text_state = None
best_audio_state = None

# History tracking for plots
text_history = {
    'global_acc': [],
    'global_loss': [],
    'global_f1': [],
    'client_acc': {name: [] for name in text_client_texts.keys()},
    'client_loss': {name: [] for name in text_client_texts.keys()},
    'client_f1': {name: [] for name in text_client_texts.keys()}
}

audio_history = {
    'global_acc': [],
    'global_loss': [],
    'global_f1': [],
    'client_acc': {name: [] for name in audio_clients_files.keys()},
    'client_loss': {name: [] for name in audio_clients_files.keys()},
    'client_f1': {name: [] for name in audio_clients_files.keys()}
}

# Training loop
for rnd in range(ROUNDS):
    print(f"\n{'='*80}")
    print(f"ROUND {rnd + 1}/{ROUNDS}")
    print(f"{'='*80}")

    lr_mult = get_lr_multiplier(rnd, WARMUP_ROUNDS, ROUNDS)
    print(f"LR multiplier: {lr_mult:.4f}")

    # ========================= TEXT TRAINING =========================
    if text_global_model and text_client_loaders:
        print(f"\n--- TEXT MODALITY ---")
        text_local_states = []
        text_client_sizes = []

        for cname in text_client_texts.keys():
            print(f"\nTraining text client: {cname}")
            local_model = TextEmotionModel().to(DEVICE)
            local_model.load_state_dict(text_global_model.state_dict())

            optimizer = AdamW(
                local_model.parameters(),
                lr=LR_TEXT_BERT * lr_mult,
                weight_decay=WEIGHT_DECAY
            )

            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=int(WARMUP_RATIO * len(text_client_loaders[cname]) * LOCAL_EPOCHS),
                num_training_steps=len(text_client_loaders[cname]) * LOCAL_EPOCHS
            )

            for ep in range(LOCAL_EPOCHS):
                loss, acc = train_text_model_one_epoch(
                    local_model, text_client_loaders[cname],
                    optimizer, scheduler, text_criterion
                )
                print(f"  Epoch {ep+1}: Loss={loss:.4f}, Acc={acc:.4f}")

            text_local_states.append(deepcopy(local_model.state_dict()))
            text_client_sizes.append(len(text_client_loaders[cname].dataset))
            cleanup_model(local_model)

        # FedAvg for text
        print("\nAggregating text models...")
        text_total = sum(text_client_sizes)
        new_text_state = deepcopy(text_global_model.state_dict())

        for key in new_text_state:
            new_text_state[key] = sum(
                text_local_states[i][key] * (text_client_sizes[i] / text_total)
                for i in range(len(text_client_sizes))
            )

        text_global_model.load_state_dict(new_text_state)

        # Evaluate text model
        print("\nText validation results:")
        text_round_accs = []
        text_round_losses = []
        text_round_f1s = []

        for cname in text_client_texts.keys():
            acc, loss, prec, rec, f1, _, _ = evaluate_text_model(
                text_global_model, text_val_loaders[cname], text_criterion
            )
            print(f"  {cname}: Acc={acc:.4f}, F1={f1:.4f}, Loss={loss:.4f}")

            # Track client history
            text_history['client_acc'][cname].append(acc)
            text_history['client_loss'][cname].append(loss)
            text_history['client_f1'][cname].append(f1)

            text_round_accs.append(acc)
            text_round_losses.append(loss)
            text_round_f1s.append(f1)

            if acc > best_text_acc:
                best_text_acc = acc
                best_text_state = deepcopy(text_global_model.state_dict())

        # Track global (average) history
        text_history['global_acc'].append(np.mean(text_round_accs))
        text_history['global_loss'].append(np.mean(text_round_losses))
        text_history['global_f1'].append(np.mean(text_round_f1s))

        print(f"\n  Global Text Avg: Acc={np.mean(text_round_accs):.4f}, "
              f"F1={np.mean(text_round_f1s):.4f}, Loss={np.mean(text_round_losses):.4f}")

    # ========================= AUDIO TRAINING =========================
    if audio_global_model and audio_client_loaders:
        print(f"\n--- AUDIO MODALITY ---")
        audio_local_states = []
        audio_client_sizes = []
        audio_global_state_for_prox = {k: v.clone().detach().to('cpu')
                                      for k, v in audio_global_model.state_dict().items()}

        for cname in audio_clients_files.keys():
            print(f"\nTraining audio client: {cname}")
            local_model = deepcopy(audio_global_model).to(DEVICE)

            if any(p.requires_grad for p in local_model.wav2vec.parameters()):
                optimizer = torch.optim.AdamW([
                    {"params": local_model.wav2vec.parameters(), "lr": LR_AUDIO_WAV2VEC * lr_mult},
                    {"params": list(local_model.crnn.parameters()) +
                              list(local_model.fusion.parameters()) +
                              list(local_model.classifier.parameters()),
                     "lr": LR_AUDIO_NEW * lr_mult}
                ], weight_decay=WEIGHT_DECAY)
            else:
                optimizer = torch.optim.AdamW(
                    list(local_model.crnn.parameters()) +
                    list(local_model.fusion.parameters()) +
                    list(local_model.classifier.parameters()),
                    lr=LR_AUDIO_NEW * lr_mult,
                    weight_decay=WEIGHT_DECAY
                )

            for ep in range(LOCAL_EPOCHS):
                loss, acc, f1, rec, prec, _, _ = train_audio_model_one_epoch(
                    local_model, audio_client_loaders[cname], optimizer,
                    audio_criterion, audio_global_state_for_prox, FEDPROX_MU
                )
                print(f"  Epoch {ep+1}: Loss={loss:.4f}, Acc={acc:.4f}, F1={f1:.4f}")

            audio_local_states.append({k: v.cpu() for k, v in local_model.state_dict().items()})
            audio_client_sizes.append(len(audio_client_loaders[cname].dataset))
            cleanup_model(local_model)

        # FedAvg for audio
        print("\nAggregating audio models...")
        audio_total = sum(audio_client_sizes)
        new_audio_state = deepcopy(audio_global_model.state_dict())

        for key in new_audio_state:
            acc = None
            for idx in range(len(audio_client_sizes)):
                part = audio_local_states[idx][key].float() * (audio_client_sizes[idx] / audio_total)
                acc = part if acc is None else (acc + part)
            new_audio_state[key] = acc

        audio_global_model.load_state_dict(new_audio_state)

        # Evaluate audio model
        print("\nAudio validation results:")
        audio_round_accs = []
        audio_round_losses = []
        audio_round_f1s = []

        for cname in audio_clients_files.keys():
            loss, acc, f1, rec, prec, _, _ = evaluate_audio_model(
                audio_global_model, audio_val_loaders[cname], audio_criterion, DEVICE
            )
            print(f"  {cname}: Acc={acc:.4f}, F1={f1:.4f}, Loss={loss:.4f}")

            # Track client history
            audio_history['client_acc'][cname].append(acc)
            audio_history['client_loss'][cname].append(loss)
            audio_history['client_f1'][cname].append(f1)

            audio_round_accs.append(acc)
            audio_round_losses.append(loss)
            audio_round_f1s.append(f1)

            if acc > best_audio_acc:
                best_audio_acc = acc
                best_audio_state = deepcopy(audio_global_model.state_dict())

        # Track global (average) history
        audio_history['global_acc'].append(np.mean(audio_round_accs))
        audio_history['global_loss'].append(np.mean(audio_round_losses))
        audio_history['global_f1'].append(np.mean(audio_round_f1s))

        print(f"\n  Global Audio Avg: Acc={np.mean(audio_round_accs):.4f}, "
              f"F1={np.mean(audio_round_f1s):.4f}, Loss={np.mean(audio_round_losses):.4f}")

    aggressive_cleanup()

# ============================================================================
# FINAL RESULTS
# ============================================================================

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)

SAVE_DIR = "/content/saved_models"
os.makedirs(SAVE_DIR, exist_ok=True)

print("\n" + "="*80)
print("SAVING MODELS")
print("="*80)

# Save Text Model (if it exists)
if text_global_model and best_text_state:
    text_global_model.load_state_dict(best_text_state)

    torch.save({
        'model_state_dict': text_global_model.state_dict(),
        'best_accuracy': best_text_acc,
        'label_encoder': label_encoder,
        'num_classes': NUM_CLASSES,
        'emotion_classes': list(label_encoder.classes_),
        'config': {
            'batch_size': BATCH_SIZE,
            'max_len': MAX_LEN,
            'rounds': ROUNDS,
            'num_classes': NUM_CLASSES
        }
    }, os.path.join(SAVE_DIR, 'text_emotion_model.pth'))

    print(f"✓ Text Model Saved!")
    print(f"  Accuracy: {best_text_acc:.4f}")
    print(f"  Path: {SAVE_DIR}/text_emotion_model.pth")

# Save Audio Model (if it exists)
if audio_global_model and best_audio_state:
    audio_global_model.load_state_dict(best_audio_state)

    torch.save({
        'model_state_dict': audio_global_model.state_dict(),
        'best_accuracy': best_audio_acc,
        'label_encoder': label_encoder,
        'num_classes': NUM_CLASSES,
        'emotion_classes': list(label_encoder.classes_),
        'config': {
            'batch_size': BATCH_SIZE,
            'sample_rate': SAMPLE_RATE,
            'target_len': TARGET_LEN,
            'rounds': ROUNDS,
            'num_classes': NUM_CLASSES
        }
    }, os.path.join(SAVE_DIR, 'audio_emotion_model.pth'))

    print(f"✓ Audio Model Saved!")
    print(f"  Accuracy: {best_audio_acc:.4f}")
    print(f"  Path: {SAVE_DIR}/audio_emotion_model.pth")

print(f"\n✓ Models saved to: {SAVE_DIR}")
print("="*80)

print(f"\n{'='*80}")
print("UNIFIED FEDERATED LEARNING SUMMARY")
print(f"{'='*80}")
print(f"Modalities: {'Text' if text_global_model else ''} "
      f"{'Audio' if audio_global_model else ''}")
print(f"Total rounds: {ROUNDS}")
print(f"Unified emotion classes: {list(label_encoder.classes_)}")
print(f"{'='*80}")

# ============================================================================
# PLOT ACCURACY GRAPHS
# ============================================================================

def plot_training_history(history, modality_name, save_dir=CACHE_DIR):
    """Plot training history for server (global) and clients"""

    if not history['global_acc']:
        print(f"No history to plot for {modality_name}")
        return

    rounds = list(range(1, len(history['global_acc']) + 1))

    # Set style
    try:
        plt.style.use('seaborn-v0_8-darkgrid')
    except:
        try:
            plt.style.use('seaborn-darkgrid')
        except:
            pass  # Use default style
    sns.set_palette("husl")

    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))
    fig.suptitle(f'{modality_name} Emotion Recognition - Federated Learning',
                 fontsize=16, fontweight='bold')

    # 1. Accuracy plot
    ax1 = axes[0]
    ax1.plot(rounds, history['global_acc'], 'b-', linewidth=3,
             marker='o', markersize=8, label='Global Server', alpha=0.8)

    for client_name, client_accs in history['client_acc'].items():
        ax1.plot(rounds, client_accs, '--', linewidth=2,
                marker='s', markersize=5, label=f'Client: {client_name}', alpha=0.7)

    ax1.set_xlabel('Round', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax1.set_title('Accuracy Progression', fontsize=13, fontweight='bold')
    ax1.legend(loc='best', fontsize=9)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])

    # 2. Loss plot
    ax2 = axes[1]
    ax2.plot(rounds, history['global_loss'], 'r-', linewidth=3,
             marker='o', markersize=8, label='Global Server', alpha=0.8)

    for client_name, client_losses in history['client_loss'].items():
        ax2.plot(rounds, client_losses, '--', linewidth=2,
                marker='s', markersize=5, label=f'Client: {client_name}', alpha=0.7)

    ax2.set_xlabel('Round', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax2.set_title('Loss Progression', fontsize=13, fontweight='bold')
    ax2.legend(loc='best', fontsize=9)
    ax2.grid(True, alpha=0.3)

    # 3. F1-Score plot
    ax3 = axes[2]
    ax3.plot(rounds, history['global_f1'], 'g-', linewidth=3,
             marker='o', markersize=8, label='Global Server', alpha=0.8)

    for client_name, client_f1s in history['client_f1'].items():
        ax3.plot(rounds, client_f1s, '--', linewidth=2,
                marker='s', markersize=5, label=f'Client: {client_name}', alpha=0.7)

    ax3.set_xlabel('Round', fontsize=12, fontweight='bold')
    ax3.set_ylabel('F1-Score', fontsize=12, fontweight='bold')
    ax3.set_title('F1-Score Progression', fontsize=13, fontweight='bold')
    ax3.legend(loc='best', fontsize=9)
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim([0, 1])

    plt.tight_layout()

    # Save figure
    save_path = os.path.join(save_dir, f'{modality_name.lower()}_training_history.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ {modality_name} training plot saved to: {save_path}")
    plt.show()

    # Create detailed comparison plot
    fig2, ax = plt.subplots(figsize=(12, 7))

    # Plot global server line prominently
    ax.plot(rounds, history['global_acc'], 'b-', linewidth=4,
            marker='o', markersize=10, label='Global Server (FedAvg)',
            alpha=0.9, zorder=10)

    # Plot each client
    colors = plt.cm.tab10(np.linspace(0, 1, len(history['client_acc'])))
    for idx, (client_name, client_accs) in enumerate(history['client_acc'].items()):
        ax.plot(rounds, client_accs, '--', linewidth=2.5,
                marker='D', markersize=6, label=f'{client_name}',
                alpha=0.7, color=colors[idx])

    ax.set_xlabel('Federated Round', fontsize=14, fontweight='bold')
    ax.set_ylabel('Validation Accuracy', fontsize=14, fontweight='bold')
    ax.set_title(f'{modality_name} - Server vs Client Accuracy Comparison',
                 fontsize=15, fontweight='bold')
    ax.legend(loc='lower right', fontsize=11, framealpha=0.9)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_ylim([0, 1])
    ax.axhline(y=0.8, color='red', linestyle=':', linewidth=2,
               label='80% Target', alpha=0.5)

    # Add annotations for best performance
    best_global_idx = np.argmax(history['global_acc'])
    best_global_acc = history['global_acc'][best_global_idx]
    ax.annotate(f'Best: {best_global_acc:.3f}',
                xy=(best_global_idx + 1, best_global_acc),
                xytext=(10, 10), textcoords='offset points',
                bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.7),
                arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
                fontsize=10, fontweight='bold')

    plt.tight_layout()

    # Save comparison figure
    save_path2 = os.path.join(save_dir, f'{modality_name.lower()}_accuracy_comparison.png')
    plt.savefig(save_path2, dpi=300, bbox_inches='tight')
    print(f"✓ {modality_name} comparison plot saved to: {save_path2}")
    plt.show()

print("\n" + "="*80)
print("GENERATING TRAINING VISUALIZATIONS")
print("="*80)

# Plot text results
if text_global_model and text_history['global_acc']:
    plot_training_history(text_history, 'Text')

# Plot audio results
if audio_global_model and audio_history['global_acc']:
    plot_training_history(audio_history, 'Audio')

# Create combined comparison if both modalities exist
if (text_global_model and text_history['global_acc'] and
    audio_global_model and audio_history['global_acc']):

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    fig.suptitle('Multimodal Federated Learning - Text vs Audio',
                 fontsize=16, fontweight='bold')

    rounds_text = list(range(1, len(text_history['global_acc']) + 1))
    rounds_audio = list(range(1, len(audio_history['global_acc']) + 1))

    # Text modality
    ax1.plot(rounds_text, text_history['global_acc'], 'b-', linewidth=3,
             marker='o', markersize=8, label='Server', alpha=0.8)
    for client_name, client_accs in text_history['client_acc'].items():
        ax1.plot(rounds_text, client_accs, '--', linewidth=2,
                marker='s', markersize=5, label=client_name, alpha=0.7)
    ax1.set_xlabel('Round', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax1.set_title('Text Emotion Recognition', fontsize=13, fontweight='bold')
    ax1.legend(loc='best', fontsize=9)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])
    ax1.axhline(y=0.8, color='red', linestyle=':', linewidth=2, alpha=0.5)

    # Audio modality
    ax2.plot(rounds_audio, audio_history['global_acc'], 'b-', linewidth=3,
             marker='o', markersize=8, label='Server', alpha=0.8)
    for client_name, client_accs in audio_history['client_acc'].items():
        ax2.plot(rounds_audio, client_accs, '--', linewidth=2,
                marker='s', markersize=5, label=client_name, alpha=0.7)
    ax2.set_xlabel('Round', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax2.set_title('Audio Emotion Recognition', fontsize=13, fontweight='bold')
    ax2.legend(loc='best', fontsize=9)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim([0, 1])
    ax2.axhline(y=0.8, color='red', linestyle=':', linewidth=2, alpha=0.5)

    plt.tight_layout()
    save_path = os.path.join(CACHE_DIR, 'multimodal_comparison.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Multimodal comparison plot saved to: {save_path}")
    plt.show()

print("\n" + "="*80)
print("ALL VISUALIZATIONS GENERATED SUCCESSFULLY")
print("="*80)

Installing required packages...

Package installation complete!

Device: cuda
Batch Size: 32, Mixed Precision: True

LOADING TEXT DATASETS


  scaler = torch.cuda.amp.GradScaler() if use_amp else None


Using Colab cache for faster access to the 'emotion-detection-from-text' dataset.
✓ Loaded Kaggle text dataset: 40000 samples
  ✓ dair-ai/emotion: 15996 samples
  ✓ boltuix/emotions-dataset: 20000 samples
  ✓ mteb/emotion: 15956 samples
  ✓ go_emotions: 13542 samples
  ✓ pashupatigupta/emotion: 20000 samples
Loaded 5 text datasets

LOADING AUDIO DATASETS
  ✓ CREMA-D: 7442 files | Dist: {'disgust': 1271, 'happy': 1271, 'sad': 1271, 'neutral': 1087, 'fear': 1271, 'anger': 1271}
  ✓ RAVDESS: 2880 files | Dist: {'surprise': 384, 'neutral': 576, 'disgust': 384, 'fear': 384, 'sad': 384, 'happy': 384, 'anger': 384}
  ✓ TESS: 2800 files | Dist: {'fear': 400, 'anger': 400, 'disgust': 400, 'neutral': 400, 'sad': 400, 'surprise': 400, 'happy': 400}
  ✓ ShEMO-Male: 1737 files | Dist: {'neutral': 744, 'anger': 604, 'sad': 178, 'fear': 16, 'surprise': 105, 'happy': 90}
  ✓ ShEMO-Female: 1263 files | Dist: {'anger': 455, 'neutral': 284, 'sad': 271, 'surprise': 120, 'fear': 22, 'happy': 111}
  ✓ SUBES




PREPARING DATALOADERS
  ✓ Text - dair-ai/emotion: Train=12797, Val=3199
  ✓ Text - boltuix/emotions-dataset: Train=16000, Val=4000
  ✓ Text - mteb/emotion: Train=12765, Val=3191
  ✓ Text - go_emotions: Train=10834, Val=2708
  ✓ Text - pashupatigupta/emotion: Train=16000, Val=4000

Caching audio features...




pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

Caching:   0%|          | 0/22122 [00:00<?, ?file/s]

model.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

Caching: 100%|██████████| 22122/22122 [16:50<00:00, 21.89file/s]


  ✓ Audio - CREMA-D: Train=6103, Val=1339
  ✓ Audio - RAVDESS: Train=2362, Val=518
  ✓ Audio - TESS: Train=2296, Val=504
  ✓ Audio - SUBESCO: Train=4920, Val=1080
  ✓ Audio - ShEMO: Train=2460, Val=540

STARTING UNIFIED FEDERATED LEARNING
Text clients: ['dair-ai/emotion', 'boltuix/emotions-dataset', 'mteb/emotion', 'go_emotions', 'pashupatigupta/emotion']
Audio clients: ['CREMA-D', 'RAVDESS', 'TESS', 'SUBESCO', 'ShEMO']
Rounds: 30, Local Epochs: 3
Batch Size: 32





ROUND 1/30
LR multiplier: 0.5000

--- TEXT MODALITY ---

Training text client: dair-ai/emotion
  Epoch 1: Loss=1.2440, Acc=0.6344
  Epoch 2: Loss=0.3799, Acc=0.9203
  Epoch 3: Loss=0.2345, Acc=0.9492

Training text client: boltuix/emotions-dataset
  Epoch 1: Loss=1.5388, Acc=0.3805
  Epoch 2: Loss=1.1403, Acc=0.5897
  Epoch 3: Loss=1.0013, Acc=0.6486

Training text client: mteb/emotion
  Epoch 1: Loss=1.2325, Acc=0.6410
  Epoch 2: Loss=0.3664, Acc=0.9246
  Epoch 3: Loss=0.2266, Acc=0.9528

Training text client: go_emotions
  Epoch 1: Loss=1.4009, Acc=0.4497
  Epoch 2: Loss=0.8550, Acc=0.6963
  Epoch 3: Loss=0.7215, Acc=0.7464

Training text client: pashupatigupta/emotion
  Epoch 1: Loss=1.6678, Acc=0.3064
  Epoch 2: Loss=1.4850, Acc=0.3747
  Epoch 3: Loss=1.4203, Acc=0.4053

Aggregating text models...

Text validation results:
  dair-ai/emotion: Acc=0.7931, F1=0.6745, Loss=0.9156
  boltuix/emotions-dataset: Acc=0.5270, F1=0.3439, Loss=2.1077
  mteb/emotion: Acc=0.7991, F1=0.7061, Loss

  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=1102.7804, Acc=0.1711, F1=0.0563


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=1537.0781, Acc=0.1847, F1=0.0795


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=2166.4069, Acc=0.2027, F1=0.1045

Training audio client: RAVDESS


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=1034.2343, Acc=0.1406, F1=0.0780


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=1295.0557, Acc=0.1363, F1=0.0487


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=1420.1609, Acc=0.1389, F1=0.0534

Training audio client: TESS


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=728.6101, Acc=0.1363, F1=0.0619


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=933.8780, Acc=0.1524, F1=0.0578


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=1048.0684, Acc=0.1468, F1=0.0520

Training audio client: SUBESCO


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=1259.8497, Acc=0.1665, F1=0.0802


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=1585.1594, Acc=0.2083, F1=0.1155


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=2051.6057, Acc=0.2606, F1=0.1387

Training audio client: ShEMO


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=913.5021, Acc=0.3407, F1=0.2819


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=1217.8276, Acc=0.3524, F1=0.2928


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=1293.3707, Acc=0.3476, F1=0.2915

Aggregating audio models...

Audio validation results:


  with torch.no_grad(), torch.cuda.amp.autocast():


  CREMA-D: Acc=0.2651, F1=0.1326, Loss=2.1553


  with torch.no_grad(), torch.cuda.amp.autocast():


  RAVDESS: Acc=0.1486, F1=0.0385, Loss=3.1520


  with torch.no_grad(), torch.cuda.amp.autocast():


  TESS: Acc=0.1587, F1=0.0494, Loss=3.4297


  with torch.no_grad(), torch.cuda.amp.autocast():


  SUBESCO: Acc=0.1759, F1=0.0541, Loss=2.2326


  with torch.no_grad(), torch.cuda.amp.autocast():


  ShEMO: Acc=0.3685, F1=0.2048, Loss=1.4333

  Global Audio Avg: Acc=0.2234, F1=0.0959, Loss=2.4806

ROUND 2/30
LR multiplier: 1.0000

--- TEXT MODALITY ---

Training text client: dair-ai/emotion
  Epoch 1: Loss=0.3837, Acc=0.9104
  Epoch 2: Loss=0.1389, Acc=0.9616
  Epoch 3: Loss=0.0911, Acc=0.9736

Training text client: boltuix/emotions-dataset
  Epoch 1: Loss=1.1706, Acc=0.6102
  Epoch 2: Loss=0.8083, Acc=0.7119
  Epoch 3: Loss=0.6485, Acc=0.7620

Training text client: mteb/emotion
  Epoch 1: Loss=0.3812, Acc=0.9121
  Epoch 2: Loss=0.1405, Acc=0.9593
  Epoch 3: Loss=0.0936, Acc=0.9709

Training text client: go_emotions
  Epoch 1: Loss=1.0411, Acc=0.6036
  Epoch 2: Loss=0.6288, Acc=0.7729
  Epoch 3: Loss=0.4961, Acc=0.8164

Training text client: pashupatigupta/emotion
  Epoch 1: Loss=1.5198, Acc=0.3721
  Epoch 2: Loss=1.3650, Acc=0.4309
  Epoch 3: Loss=1.2416, Acc=0.4853

Aggregating text models...

Text validation results:
  dair-ai/emotion: Acc=0.9237, F1=0.8911, Loss=0.2774
  bolt

  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=62.6772, Acc=0.2107, F1=0.1129


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=393.2032, Acc=0.2106, F1=0.1105


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=1050.5714, Acc=0.2160, F1=0.1159

Training audio client: RAVDESS


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=16.4206, Acc=0.1401, F1=0.0733


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=70.1221, Acc=0.1338, F1=0.0392


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=170.7522, Acc=0.1507, F1=0.0778

Training audio client: TESS


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=35.1719, Acc=0.1424, F1=0.0618


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=89.6600, Acc=0.1481, F1=0.0557


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=189.7083, Acc=0.1555, F1=0.0617

Training audio client: SUBESCO


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=50.1514, Acc=0.2240, F1=0.1257


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=274.3457, Acc=0.2657, F1=0.1386


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=696.9549, Acc=0.2622, F1=0.1375

Training audio client: ShEMO


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 1: Loss=16.3443, Acc=0.3813, F1=0.3242


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 2: Loss=72.7662, Acc=0.4358, F1=0.3770


  with torch.cuda.amp.autocast(enabled=True):


  Epoch 3: Loss=178.5129, Acc=0.4813, F1=0.4211

Aggregating audio models...

Audio validation results:


  with torch.no_grad(), torch.cuda.amp.autocast():


  CREMA-D: Acc=0.2808, F1=0.1460, Loss=2.1390


  with torch.no_grad(), torch.cuda.amp.autocast():


  RAVDESS: Acc=0.1988, F1=0.1153, Loss=3.1695


  with torch.no_grad(), torch.cuda.amp.autocast():


  TESS: Acc=0.1964, F1=0.0995, Loss=3.2935


  with torch.no_grad(), torch.cuda.amp.autocast():


  SUBESCO: Acc=0.1889, F1=0.0784, Loss=2.0640


  with torch.no_grad(), torch.cuda.amp.autocast():


  ShEMO: Acc=0.4759, F1=0.3865, Loss=1.4483

  Global Audio Avg: Acc=0.2682, F1=0.1652, Loss=2.4229

ROUND 3/30
LR multiplier: 1.0000

--- TEXT MODALITY ---

Training text client: dair-ai/emotion
  Epoch 1: Loss=0.1789, Acc=0.9526
  Epoch 2: Loss=0.1028, Acc=0.9687
  Epoch 3: Loss=0.0733, Acc=0.9777

Training text client: boltuix/emotions-dataset
  Epoch 1: Loss=1.0249, Acc=0.6623
  Epoch 2: Loss=0.6870, Acc=0.7471
  Epoch 3: Loss=0.5292, Acc=0.7974

Training text client: mteb/emotion
  Epoch 1: Loss=0.1875, Acc=0.9486
  Epoch 2: Loss=0.1061, Acc=0.9679
  Epoch 3: Loss=0.0685, Acc=0.9798

Training text client: go_emotions
  Epoch 1: Loss=0.9565, Acc=0.6429
  Epoch 2: Loss=0.5674, Acc=0.7878


KeyboardInterrupt: 

In [None]:
import os
import uuid
import faiss
import torch
import librosa
import soundfile as sf
import numpy as np
import networkx as nx
import google.generativeai as genai
import nltk
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModel, Wav2Vec2Processor, Wav2Vec2Model
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv

load_dotenv()

nltk.download('punkt')

app = FastAPI()

"""
Trained model .pth must be saved in
saved_models/text_emotion_model.pth
saved_models/audio_emotion_model.pth
"""
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)



# Configure Gemini
genai.configure(api_key="AIzaSyBVySqAtcZTq5XjLdnV1ryitgCluqvHzts")

# Paths
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
curriculum_graph = nx.Graph()

class TextEmotionModel(torch.nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", num_labels=6):
        super(TextEmotionModel, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.classifier = torch.nn.Linear(self.model.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled_output)
        return logits
class AudioEmotionModel(torch.nn.Module):
    def __init__(self, model_name="facebook/wav2vec2-base", num_labels=6):
        super(AudioEmotionModel, self).__init__()
        self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        self.model = Wav2Vec2Model.from_pretrained(model_name)
        self.classifier = torch.nn.Linear(self.model.config.hidden_size, num_labels)

    def forward(self, input_values):
        outputs = self.model(input_values)
        hidden_states = outputs.last_hidden_state
        pooled_output = torch.mean(hidden_states, dim=1)
        logits = self.classifier(pooled_output)
        return logits
def load_text_model(path):
    model = TextEmotionModel()
    model.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
    model.eval()
    return model

def load_audio_model(path):
    model = AudioEmotionModel()
    model.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
    model.eval()
    return model
def predict_text(model, text):
    tokenizer = model.tokenizer
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        logits = model(inputs["input_ids"], inputs["attention_mask"])
    emotion_id = torch.argmax(logits, dim=1).item()
    emotions = ["happy", "sad", "angry", "neutral", "confused", "excited"]
    return emotions[emotion_id]

def predict_audio(model, file_path):
    processor = model.processor
    speech, sr = librosa.load(file_path, sr=16000)
    inputs = processor(speech, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs["input_values"])
    emotion_id = torch.argmax(logits, dim=1).item()
    emotions = ["happy", "sad", "angry", "neutral", "confused", "excited"]
    return emotions[emotion_id]
curriculum_text = ""
study_materials_texts = []
study_materials_embeddings = []
topics_graph = nx.Graph()
roadmap = []

# ==== Load your trained emotion models ====
text_model_path = "saved_models/text_emotion_model.pth"   # update if needed
audio_model_path = "saved_models/audio_emotion_model.pth" # update if needed

text_model = load_text_model(text_model_path)
audio_model = load_audio_model(audio_model_path)

print("✅ Emotion models loaded successfully!")

# Load embedding model
embedder_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
embedder_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

def embed_text(text):
    inputs = embedder_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        embeddings = embedder_model(**inputs).last_hidden_state.mean(dim=1)
    return embeddings[0].numpy()

dimension = 384
index = faiss.IndexFlatL2(dimension)
@app.post("/upload_curriculum/")
async def upload_curriculum(file: UploadFile = File(...)):
    global curriculum_text
    file_path = os.path.join(UPLOAD_DIR, file.filename)
    with open(file_path, "wb") as f:
        f.write(await file.read())
    curriculum_text = extract_text(file_path)
    return {"message": "Curriculum uploaded and processed."}
from fastapi import HTTPException
from typing import Optional

# helper: extract text from pdf (used earlier)
from PyPDF2 import PdfReader

def extract_text(path: str) -> str:
    reader = PdfReader(path)
    out = []
    for p in reader.pages:
        t = p.extract_text()
        if t:
            out.append(t)
    return "\n".join(out)

# chunk helper (split large context for both embeddings and for sending to LLM)
def chunk_text_for_index(text: str, max_chars: int = 1000):
    chunks = []
    for i in range(0, len(text), max_chars):
        chunks.append(text[i:i+max_chars])
    return chunks

# store mapping from index -> chunk_text
# (If you used study_materials_texts earlier, ensure it's the same list; we'll use `study_materials_texts`)
# study_materials_texts and index were defined in PART 3

@app.post("/upload_study_material/")
async def upload_study_material(file: UploadFile = File(...)):
    """
    Accept a single PDF (or text) file, extract chunks, embed them, and add to FAISS index.
    """
    global study_materials_texts, index

    file_path = os.path.join(UPLOAD_DIR, file.filename)
    with open(file_path, "wb") as f:
        f.write(await file.read())

    text = extract_text(file_path)
    chunks = chunk_text_for_index(text, max_chars=800)

    for ch in chunks:
        emb = embed_text(ch)
        emb = emb.astype("float32")
        index.add(np.expand_dims(emb, axis=0))  # add to faiss
        study_materials_texts.append(ch)

    return {"message": "Study material uploaded", "chunks_added": len(chunks)}


@app.post("/ask_question/")
async def ask_question(
    question: Optional[str] = Form(None),
    audio: Optional[UploadFile] = File(None),
    top_k: int = Form(3)
):
    """
    Accepts text (question) OR audio file (wav/mp3). Predicts emotion and answers using Gemini.
    Provide question if you want semantic search context; audio-only allowed (question can be empty).
    """
    # --- check data ---
    if not question and not audio:
        raise HTTPException(status_code=400, detail="Provide either 'question' text or an 'audio' file.")

    # --- Step 1: emotion detection ---
    detected_emotion = "neutral"
    confidence = None
    try:
        if audio:
            # save temp audio
            temp_path = os.path.join(UPLOAD_DIR, f"temp_{uuid.uuid4().hex}_{audio.filename}")
            with open(temp_path, "wb") as f:
                f.write(await audio.read())
            # predict_audio returns emotion (and optionally confidence in your implementation)
            # Here we assume predict_audio(model, path) -> emotion (string) OR (emotion, conf, probs)
            try:
                # attempt tuple signature like earlier notebook (emotion, confidence, probs)
                res = predict_audio(audio_model, temp_path)
                if isinstance(res, tuple):
                    detected_emotion, confidence, _ = res
                else:
                    detected_emotion = res
            except Exception:
                # fallback to simple function (the lighter variant in PART 2)
                detected_emotion = predict_audio(audio_model, temp_path)
            # remove temp file
            try:
                os.remove(temp_path)
            except:
                pass

        elif question:
            try:
                res = predict_text(text_model, question)
                # If your predict_text returns tuple (emotion, conf, probs) as in the long loader,
                # handle that; otherwise assume it returns emotion string.
                if isinstance(res, tuple):
                    detected_emotion, confidence, _ = res
                else:
                    detected_emotion = res
            except Exception:
                detected_emotion = predict_text(text_model, question)
    except Exception as e:
        # ignore model failure but log
        print("Emotion detection failed:", e)
        detected_emotion = "neutral"

    # --- Step 2: semantic retrieval (if question present) ---
    retrieved_context = ""
    if question and len(study_materials_texts) > 0 and index.ntotal > 0:
        q_emb = embed_text(question).astype("float32")
        q_emb = np.expand_dims(q_emb, axis=0)
        k = min(top_k, index.ntotal)
        D, I = index.search(q_emb, k)
        hits = [study_materials_texts[idx] for idx in I[0] if idx < len(study_materials_texts)]
        retrieved_context = "\n\n".join(hits)
    elif not question:
        retrieved_context = ""  # question may be spoken — in that case, we could transcribe audio externally

    # --- Step 3: curriculum roadmap context ---
    # build a readable roadmap string from the topics_graph or the curriculum you built earlier
    # For simplicity, if you maintained a `roadmap` list from earlier, convert to text.
    roadmap_text = ""
    try:
        # If curriculum graph exists and has method generate_graph_roadmap
        if hasattr(curriculum_graph, "generate_graph_roadmap"):
            rm = curriculum_graph.generate_graph_roadmap()
            lines = []
            for lvl, topics in rm.items():
                lines.append(f"{lvl}: {', '.join(topics)}")
            roadmap_text = "\n".join(lines)
    except Exception:
        roadmap_text = ""

    # --- Step 4: build emotion-aware prompt ---
    emotion_instruction_map = {
        "happy": "Be encouraging and build on the user's enthusiasm.",
        "sad": "Be gentle, supportive, and concise. Offer small, clear steps.",
        "angry": "Be calm, clear, and avoid confrontational phrasing.",
        "confused": "Use simple examples, step-by-step explanations, and check understanding.",
        "excited": "Be encouraging and give engaging examples.",
        "neutral": "Be clear, concise, and instructional."
    }
    tone_instr = emotion_instruction_map.get(detected_emotion.lower(), emotion_instruction_map["neutral"])

    prompt = f"""
You are an educational assistant and tutor. The student currently feels: {detected_emotion}.
Tone instruction: {tone_instr}

Curriculum Roadmap:
{roadmap_text}

Relevant Study Material (from uploaded files):
{retrieved_context}

Question:
{question or '[no explicit text question; audio used]'}
"""

    # --- Step 5: call Gemini ---
    try:
        response = genai.GenerativeModel("models/gemini-2.5-flash").generate_content(
            prompt,
            generation_config={"temperature": 0.2, "max_output_tokens": 500}
        )
        answer_text = response.text or ""
    except Exception as e:
        print("Gemini call failed:", e)
        answer_text = "Sorry — failed to generate answer right now."

    return {
        "emotion_detected": detected_emotion,
        "confidence": confidence,
        "answer": answer_text
    }
@app.get("/topics/")
def list_topics():
    # If you built topics into a graph or list, return them. Example:
    try:
        return {"total_topics": len(curriculum_graph.get_all_topics()), "topics": curriculum_graph.get_all_topics()}
    except Exception:
        return {"total_topics": 0, "topics": []}


@app.get("/roadmap/")
def get_roadmap():
    try:
        rm = curriculum_graph.generate_graph_roadmap()
        return {"roadmap": rm}
    except Exception:
        return {"roadmap": {}}


@app.get("/")
def root():
    return {"message": "Emotion-aware Curriculum QA running."}
