In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- 1. CONFIGURATION ---

# --- Paths ---
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "lstm_cross_attention_model_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64 # Smaller batch size might be needed for more complex models
EPOCHS = 10
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128 # Dimension for embeddings and attention

# --- 2. DATASET CLASS ---

class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset for the fusion model."""
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- 3. LSTM CROSS-ATTENTION MODEL DEFINITION ---

class LSTMCrossAttentionFusion(nn.Module):
    """
    Fuses CQCC features (processed by a CNN) with eGeMAPS LLDs (processed by an LSTM)
    using cross-modal attention.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, embedding_dim):
        super(LSTMCrossAttentionFusion, self).__init__()
        
        # --- CQCC Branch (Key and Value) ---
        # A simple 1D CNN to process CQCCs as a sequence
        self.cqcc_cnn = nn.Sequential(
            nn.Conv1d(cqcc_features, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64, embedding_dim, kernel_size=3, padding=1),
        )

        # --- eGeMAPS LLD Branch (Query) ---
        self.lstm = nn.LSTM(
            input_size=egmaps_features,
            hidden_size=embedding_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True # Bidirectional LSTM is often more powerful
        )
        # The output of the bidirectional LSTM is 2 * embedding_dim
        self.lstm_fc = nn.Linear(embedding_dim * 2, embedding_dim)
        
        # --- Cross-Attention Mechanism ---
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=embedding_dim,
            num_heads=4,
            batch_first=True
        )

        # --- Final Classifier ---
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        # Reshape for sequence processing: (batch, features, time) -> (batch, time, features)
        cqcc_x_seq = cqcc_x.transpose(1, 2)
        egmaps_x_seq = egmaps_x.transpose(1, 2)
        
        # 1. Get CQCC sequence from CNN branch
        # Input to Conv1d is (batch, features, time)
        cqcc_out_cnn = self.cqcc_cnn(cqcc_x).transpose(1, 2)

        # 2. Get prosodic sequence from LSTM branch
        lstm_out, _ = self.lstm(egmaps_x_seq)
        prosody_query = torch.tanh(self.lstm_fc(lstm_out))
        
        # 3. Apply Cross-Attention
        # Query: prosody_query, Key/Value: cqcc_out_cnn
        attended_output, _ = self.cross_attention(
            query=prosody_query,
            key=cqcc_out_cnn,
            value=cqcc_out_cnn
        )
        
        # 4. Classify the attended output
        # We can average the features over the time dimension before classifying
        pooled_output = attended_output.mean(dim=1)
        output = self.classifier(pooled_output)
        
        return torch.sigmoid(output)

# --- (The main execution script remains the same as the previous one) ---
# --- You just need to initialize this new model instead. ---

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    # Load Data (same as before)
    try:
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    # Scale Features (same as before)
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    # Create Datasets and DataLoaders (same as before)
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # --- Initialize the NEW Model ---
    model = LSTMCrossAttentionFusion(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        embedding_dim=EMBEDDING_DIM
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # --- Training Loop (same as before) ---
    best_val_eer = float('inf')
    print(f"\n--- Starting Training: LSTM Cross-Attention Model ---")
    for epoch in range(EPOCHS):
        model.train()
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()

        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in val_loader:
                cqcc_batch, lld_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())
        
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        try:
            eer = brentq(lambda x: 1. - x - interp1d(*roc_curve(all_labels, all_scores, pos_label=1)[:2])(x), 0., 1.) * 100
        except Exception: eer = -1.0

        print(f"Epoch {epoch+1} | Validation EER: {eer:.2f}%")
        
        if eer > 0 and eer < best_val_eer:
            best_val_eer = eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_lstm_cross_attention_model.pth"))
            print(f"  -> New best model saved with EER: {best_val_eer:.2f}%")

Using device: cuda

--- Starting Training: LSTM Cross-Attention Model ---


Epoch 1/10: 100%|██████████| 720/720 [00:43<00:00, 16.46it/s]


Epoch 1 | Validation EER: 6.36%
  -> New best model saved with EER: 6.36%


Epoch 2/10: 100%|██████████| 720/720 [00:42<00:00, 17.01it/s]


Epoch 2 | Validation EER: 6.58%


Epoch 3/10: 100%|██████████| 720/720 [00:42<00:00, 17.02it/s]


Epoch 3 | Validation EER: 5.86%
  -> New best model saved with EER: 5.86%


Epoch 4/10: 100%|██████████| 720/720 [00:43<00:00, 16.62it/s]


Epoch 4 | Validation EER: 5.18%
  -> New best model saved with EER: 5.18%


Epoch 5/10: 100%|██████████| 720/720 [00:41<00:00, 17.43it/s]


Epoch 5 | Validation EER: 8.36%


Epoch 6/10: 100%|██████████| 720/720 [00:41<00:00, 17.33it/s]


Epoch 6 | Validation EER: 9.69%


Epoch 7/10: 100%|██████████| 720/720 [00:42<00:00, 17.02it/s]


Epoch 7 | Validation EER: 4.63%
  -> New best model saved with EER: 4.63%


Epoch 8/10: 100%|██████████| 720/720 [00:43<00:00, 16.67it/s]


Epoch 8 | Validation EER: 5.42%


Epoch 9/10: 100%|██████████| 720/720 [00:42<00:00, 16.75it/s]


Epoch 9 | Validation EER: 4.63%
  -> New best model saved with EER: 4.63%


Epoch 10/10: 100%|██████████| 720/720 [00:43<00:00, 16.71it/s]


Epoch 10 | Validation EER: 4.71%
