In [None]:
import os
import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [None]:
AUDIO_DIRS = {
    "stress": "/content/drive/MyDrive/Dataset/Cleaned_Vocals/stress",
    "nonstress": "/content/drive/MyDrive/Dataset/Cleaned_Vocals/nonstress"
}
VISUAL_DIR = "/content/drive/MyDrive/Dataset/Visual_Features"

MAX_LEN = 200
SAMPLE_RATE = 16000
N_MFCC = 40

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================
# Utilities
# ==============================
def extract_digits(name: str):
    return ''.join(filter(str.isdigit, name))

def resample_sequence(seq, target_len):
    """Resample a sequence (T, D) to new length target_len."""
    T, D = seq.shape
    x_old = np.linspace(0, 1, T)
    x_new = np.linspace(0, 1, target_len)
    seq_resampled = np.zeros((target_len, D))
    for d in range(D):
        seq_resampled[:, d] = np.interp(x_new, x_old, seq[:, d])
    return seq_resampled.astype(np.float32)

# ==============================
# Feature loaders
# ==============================
def load_audio_features(audio_path, max_len=MAX_LEN, target_len=None):
    y, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=N_MFCC).T  # [T, 40]

    if target_len is not None:
        mfcc = resample_sequence(mfcc, target_len)

    if mfcc.shape[0] < max_len:
        mfcc = np.pad(mfcc, ((0, max_len - mfcc.shape[0]), (0, 0)))
    else:
        mfcc = mfcc[:max_len, :]

    return mfcc.astype(np.float32)

def load_visual_features(visual_path, max_len=MAX_LEN):
    df = pd.read_csv(visual_path, on_bad_lines="skip", encoding="ISO-8859-1")
    df = df.select_dtypes(include=[np.number])
    feats = df.values

    if feats.shape[0] < max_len:
        feats = np.pad(feats, ((0, max_len - feats.shape[0]), (0, 0)))
    else:
        feats = feats[:max_len, :]

    return feats.astype(np.float32)

# ==============================
# Dataset
# ==============================
class MultimodalDataset(Dataset):
    def __init__(self, audio_dirs, visual_dir):
        self.pairs = []
        all_visual_files = [f for f in os.listdir(visual_dir) if f.endswith(".csv")]

        audio_files = []
        for label, folder in audio_dirs.items():
            for f in os.listdir(folder):
                if f.endswith(".wav"):
                    audio_files.append((os.path.join(folder, f), label))

        for af, label in audio_files:
            base = os.path.splitext(os.path.basename(af))[0].lower()
            base_digits = extract_digits(base)
            vf_candidates = [vf for vf in all_visual_files if base_digits in vf]

            if vf_candidates:
                vf = os.path.join(visual_dir, vf_candidates[0])
                self.pairs.append((af, vf, 1 if label == "stress" else 0))

        print(f"‚úÖ Found {len(self.pairs)} matching pairs")

        # detect max visual dim
        max_dim = 0
        for _, vpath, _ in self.pairs:
            df = pd.read_csv(vpath, on_bad_lines="skip", encoding="ISO-8859-1")
            df = df.select_dtypes(include=[np.number])
            max_dim = max(max_dim, df.shape[1])
        self.visual_dim = max_dim
        print(f"üìä Standardizing visual features to {self.visual_dim} dims")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        af, vf, label = self.pairs[idx]

        visual = load_visual_features(vf, max_len=MAX_LEN)
        target_len = visual.shape[0]
        audio = load_audio_features(af, max_len=MAX_LEN, target_len=target_len)

        if visual.shape[1] < self.visual_dim:
            visual = np.pad(visual, ((0, 0), (0, self.visual_dim - visual.shape[1])))
        elif visual.shape[1] > self.visual_dim:
            visual = visual[:, :self.visual_dim]

        return torch.tensor(audio), torch.tensor(visual), torch.tensor(label)

In [None]:

# ==============================
# Audio-Visual Fusion Model
# ==============================
class AudioVisualFusionModel(nn.Module):
    def __init__(self, audio_dim=40, video_dim=2054, d_model=256, nhead=4, num_layers=2, num_classes=2, dropout=0.2):
        super().__init__()

        # Project audio & video features to the same dimension
        self.audio_proj = nn.Linear(audio_dim, d_model)
        self.video_proj = nn.Linear(video_dim, d_model)

        # LayerNorm for stability
        self.audio_norm = nn.LayerNorm(d_model)
        self.video_norm = nn.LayerNorm(d_model)

        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # Transformer encoders for audio & video
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True, dropout=dropout
        )
        self.audio_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.video_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Cross-attention blocks
        self.cross_attn_audio_to_video = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.cross_attn_video_to_audio = nn.MultiheadAttention(d_model, nhead, batch_first=True)

        # Fusion transformer (operates on concatenated sequence with CLS)
        fusion_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True, dropout=dropout
        )
        self.fusion_transformer = nn.TransformerEncoder(fusion_layer, num_layers=2)

        # Final classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_classes)
        )

    def forward(self, audio, video):
        # Project & normalize inputs
        audio = self.audio_norm(self.audio_proj(audio))
        video = self.video_norm(self.video_proj(video))

        # Encode separately
        a_enc = self.audio_encoder(audio)
        v_enc = self.video_encoder(video)

        # Cross-attention (bi-directional)
        a2v, _ = self.cross_attn_audio_to_video(a_enc, v_enc, v_enc)
        v2a, _ = self.cross_attn_video_to_audio(v_enc, a_enc, a_enc)

        # Fuse cross-attended features
        fused_seq = torch.cat([a2v, v2a], dim=1)

        # Add [CLS] token at start
        B = fused_seq.size(0)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        fused_seq = torch.cat([cls_tokens, fused_seq], dim=1)

        # Fusion transformer
        fused_out = self.fusion_transformer(fused_seq)

        # Take CLS token output
        cls_out = fused_out[:, 0, :]

        # Classify
        logits = self.classifier(cls_out)
        return logits

# ==============================
# Training + Testing
# ==============================
def evaluate_model(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for audio, visual, labels in loader:
            audio, visual, labels = audio.to(device), visual.to(device), labels.to(device)
            outputs = model(audio.float(), visual.float())
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, zero_division=0)
    rec = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    return acc, prec, rec, f1

def train_model(audio_dirs, visual_dir, epochs=30, save_dir="/content/models"):
    dataset = MultimodalDataset(audio_dirs, visual_dir)
    if len(dataset) == 0:
        raise ValueError("‚ùå No matching audio-visual pairs found!")

    os.makedirs(save_dir, exist_ok=True)

    # 60% train, 20% val, 20% test
    train_idx, temp_idx = train_test_split(range(len(dataset)), test_size=0.4, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    train_loader = DataLoader(torch.utils.data.Subset(dataset, train_idx), batch_size=8, shuffle=True)
    val_loader = DataLoader(torch.utils.data.Subset(dataset, val_idx), batch_size=8)
    test_loader = DataLoader(torch.utils.data.Subset(dataset, test_idx), batch_size=8)

    sample_audio, sample_visual, _ = dataset[0]
    model = AudioVisualFusionModel(audio_dim=sample_audio.shape[1], video_dim=sample_visual.shape[1]).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for audio, visual, labels in train_loader:
            audio, visual, labels = audio.to(device), visual.to(device), labels.to(device)

            # Debug check for NaNs
            if torch.isnan(audio).any() or torch.isnan(visual).any():
                raise ValueError("NaN detected in input features!")

            outputs = model(audio.float(), visual.float())
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_acc = 100 * correct / total

        val_acc, val_prec, val_rec, val_f1 = evaluate_model(model, val_loader)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_loader):.4f} "
              f"| Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f} "
              f"| Prec: {val_prec:.2f} | Recall: {val_rec:.2f} | F1: {val_f1:.2f}")

        # save checkpoint
        torch.save(model.state_dict(), f"{save_dir}/model_epoch{epoch+1}.pt")

    # ‚úÖ Final Testing
    test_acc, test_prec, test_rec, test_f1 = evaluate_model(model, test_loader)
    print("\nüéØ Final Test Results:")
    print(f"Accuracy: {test_acc:.2f} | Precision: {test_prec:.2f} | Recall: {test_rec:.2f} | F1: {test_f1:.2f}")

    return model

# ==============================
# Run
# ==============================
if __name__ == "__main__":
    model = train_model(AUDIO_DIRS, VISUAL_DIR, epochs=30)


‚úÖ Found 814 matching pairs
üìä Standardizing visual features to 2054 dims
Epoch 1/30 | Loss: 0.7072 | Train Acc: 52.05% | Val Acc: 0.53 | Prec: 0.00 | Recall: 0.00 | F1: 0.00
Epoch 2/30 | Loss: 0.7014 | Train Acc: 52.46% | Val Acc: 0.53 | Prec: 0.00 | Recall: 0.00 | F1: 0.00
Epoch 3/30 | Loss: 0.7030 | Train Acc: 54.10% | Val Acc: 0.55 | Prec: 0.51 | Recall: 0.79 | F1: 0.62
Epoch 4/30 | Loss: 0.6865 | Train Acc: 57.58% | Val Acc: 0.61 | Prec: 0.84 | Recall: 0.21 | F1: 0.34
Epoch 5/30 | Loss: 0.6000 | Train Acc: 67.62% | Val Acc: 0.74 | Prec: 0.75 | Recall: 0.66 | F1: 0.70
Epoch 6/30 | Loss: 0.5771 | Train Acc: 72.95% | Val Acc: 0.71 | Prec: 0.97 | Recall: 0.38 | F1: 0.55
Epoch 7/30 | Loss: 0.5183 | Train Acc: 78.28% | Val Acc: 0.74 | Prec: 0.74 | Recall: 0.70 | F1: 0.72
Epoch 8/30 | Loss: 0.4882 | Train Acc: 78.07% | Val Acc: 0.75 | Prec: 0.93 | Recall: 0.51 | F1: 0.66
Epoch 9/30 | Loss: 0.4977 | Train Acc: 79.30% | Val Acc: 0.72 | Prec: 0.97 | Recall: 0.41 | F1: 0.57
Epoch 10/30 | 