In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- 1. CONFIGURATION ---

# --- Paths ---
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "lstm_cross_attention_model_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 10 # Reduced for a quick test, you can set it back to 50
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128

# --- 2. DATASET CLASS & MODEL DEFINITION (Unchanged) ---

class AudioFeatureDataset(Dataset):
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

class LSTMCrossAttentionFusion(nn.Module):
    def __init__(self, cqcc_features, egmaps_features, time_steps, embedding_dim):
        super(LSTMCrossAttentionFusion, self).__init__()
        self.cqcc_cnn = nn.Sequential(
            nn.Conv1d(cqcc_features, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64, embedding_dim, kernel_size=3, padding=1),
        )
        self.lstm = nn.LSTM(
            input_size=egmaps_features, hidden_size=embedding_dim,
            num_layers=2, batch_first=True, bidirectional=True
        )
        self.lstm_fc = nn.Linear(embedding_dim * 2, embedding_dim)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=embedding_dim, num_heads=4, batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 64), nn.BatchNorm1d(64),
            nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, 1)
        )
    def forward(self, cqcc_x, egmaps_x):
        cqcc_x_seq = cqcc_x.transpose(1, 2)
        egmaps_x_seq = egmaps_x.transpose(1, 2)
        cqcc_out_cnn = self.cqcc_cnn(cqcc_x).transpose(1, 2)
        lstm_out, _ = self.lstm(egmaps_x_seq)
        prosody_query = torch.tanh(self.lstm_fc(lstm_out))
        attended_output, _ = self.cross_attention(
            query=prosody_query, key=cqcc_out_cnn, value=cqcc_out_cnn
        )
        pooled_output = attended_output.mean(dim=1)
        output = self.classifier(pooled_output)
        return torch.sigmoid(output)

# --- 3. MAIN EXECUTION SCRIPT ---

if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    # Load Data
    try:
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    # Scale Features
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    # Create Datasets and DataLoaders
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize Model, Loss, and Optimizer
    model = LSTMCrossAttentionFusion(
        cqcc_features=CQCC_SHAPE[0], egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1], embedding_dim=EMBEDDING_DIM
    ).to(DEVICE)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # --- Training Loop with Enhanced Evaluation ---
    best_val_eer = float('inf')
    print(f"\n--- Starting Training ---")
    for epoch in range(EPOCHS):
        model.train()
        # Training pass (no changes here)
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()

        # --- Validation pass with more metrics ---
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in val_loader:
                cqcc_batch, lld_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        
        # --- NEW: Calculate more metrics ---
        all_preds = (all_scores > 0.5).astype(int)
        val_accuracy = accuracy_score(all_labels, all_preds)
        val_f1 = f1_score(all_labels, all_preds, pos_label=1) # F1 for the 'bonafide' class
        cm = confusion_matrix(all_labels, all_preds)
        
        try:
            eer = brentq(lambda x: 1. - x - interp1d(*roc_curve(all_labels, all_scores, pos_label=1)[:2])(x), 0., 1.) * 100
        except Exception: 
            eer = -1.0

        # --- NEW: Print all metrics ---
        print(f"\n--- Epoch {epoch+1}/{EPOCHS} Results ---")
        print(f"Accuracy: {val_accuracy:.4f} | F1-Score (Bonafide): {val_f1:.4f} | EER: {eer:.2f}%")
        print("Confusion Matrix:")
        print("         Predicted Spoof | Predicted Bonafide")
        print(f"True Spoof | {cm[0][0]:<15} | {cm[0][1]:<18}")
        print(f"True Bonafide| {cm[1][0]:<15} | {cm[1][1]:<18}")
        print("------------------------------------------")

        if eer > 0 and eer < best_val_eer:
            best_val_eer = eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_lstm_cross_attention_model.pth"))
            print(f"  -> New best model saved with EER: {best_val_eer:.2f}%")

Using device: cuda

--- Starting Training ---


Epoch 1/10 [Train]: 100%|██████████| 720/720 [00:38<00:00, 18.85it/s]



--- Epoch 1/10 Results ---
Accuracy: 0.9653 | F1-Score (Bonafide): 0.8154 | EER: 6.30%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22081           | 215               
True Bonafide| 646             | 1902              
------------------------------------------
  -> New best model saved with EER: 6.30%


Epoch 2/10 [Train]: 100%|██████████| 720/720 [00:42<00:00, 17.07it/s]



--- Epoch 2/10 Results ---
Accuracy: 0.9712 | F1-Score (Bonafide): 0.8425 | EER: 5.53%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22217           | 79                
True Bonafide| 636             | 1912              
------------------------------------------
  -> New best model saved with EER: 5.53%


Epoch 3/10 [Train]: 100%|██████████| 720/720 [00:41<00:00, 17.40it/s]



--- Epoch 3/10 Results ---
Accuracy: 0.9681 | F1-Score (Bonafide): 0.8186 | EER: 6.44%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22265           | 31                
True Bonafide| 761             | 1787              
------------------------------------------


Epoch 4/10 [Train]: 100%|██████████| 720/720 [00:41<00:00, 17.30it/s]



--- Epoch 4/10 Results ---
Accuracy: 0.9787 | F1-Score (Bonafide): 0.8888 | EER: 5.85%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22196           | 100               
True Bonafide| 430             | 2118              
------------------------------------------


Epoch 5/10 [Train]: 100%|██████████| 720/720 [00:42<00:00, 16.85it/s]



--- Epoch 5/10 Results ---
Accuracy: 0.9611 | F1-Score (Bonafide): 0.7668 | EER: 8.13%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22290           | 6                 
True Bonafide| 960             | 1588              
------------------------------------------


Epoch 6/10 [Train]: 100%|██████████| 720/720 [00:40<00:00, 17.76it/s]



--- Epoch 6/10 Results ---
Accuracy: 0.9705 | F1-Score (Bonafide): 0.8346 | EER: 9.57%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22261           | 35                
True Bonafide| 698             | 1850              
------------------------------------------


Epoch 7/10 [Train]: 100%|██████████| 720/720 [00:43<00:00, 16.71it/s]



--- Epoch 7/10 Results ---
Accuracy: 0.9746 | F1-Score (Bonafide): 0.8610 | EER: 5.30%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22254           | 42                
True Bonafide| 590             | 1958              
------------------------------------------
  -> New best model saved with EER: 5.30%


Epoch 8/10 [Train]: 100%|██████████| 720/720 [00:42<00:00, 16.93it/s]



--- Epoch 8/10 Results ---
Accuracy: 0.9681 | F1-Score (Bonafide): 0.8171 | EER: 5.77%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22280           | 16                
True Bonafide| 777             | 1771              
------------------------------------------


Epoch 9/10 [Train]: 100%|██████████| 720/720 [00:41<00:00, 17.16it/s]



--- Epoch 9/10 Results ---
Accuracy: 0.9730 | F1-Score (Bonafide): 0.8498 | EER: 5.53%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22275           | 21                
True Bonafide| 650             | 1898              
------------------------------------------


Epoch 10/10 [Train]: 100%|██████████| 720/720 [00:41<00:00, 17.15it/s]



--- Epoch 10/10 Results ---
Accuracy: 0.9684 | F1-Score (Bonafide): 0.8197 | EER: 6.08%
Confusion Matrix:
         Predicted Spoof | Predicted Bonafide
True Spoof | 22278           | 18                
True Bonafide| 766             | 1782              
------------------------------------------
