In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt


# --- 1. CONFIGURATION ---


# --- Paths ---
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "cross_attention_model_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)


# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 64 # The dimension of the feature embeddings


# --- 2. DATASET CLASS ---


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset for the fusion model."""
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)


    def __len__(self):
        return len(self.labels)


    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]


# --- 3. CROSS-ATTENTION MODEL DEFINITION ---


class CrossAttentionFusionCNN(nn.Module):
    """Fusion model using Cross-Modal Attention."""
    def __init__(self, cqcc_shape, egmaps_shape, embedding_dim):
        super(CrossAttentionFusionCNN, self).__init__()
       
        # --- CQCC Branch (Processes spectral features) ---
        self.cqcc_branch = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d((2, 2)),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d((2, 2)),
            nn.Flatten(),
        )
        with torch.no_grad():
            self.cqcc_flat_size = self.cqcc_branch(torch.zeros(1, 1, *cqcc_shape)).numel()
        self.cqcc_fc = nn.Linear(self.cqcc_flat_size, embedding_dim)


        # --- eGeMAPS LLD Branch (Processes prosodic features over time) ---
        self.egmaps_branch = nn.Sequential(
            nn.Conv1d(in_channels=egmaps_shape[0], out_channels=16, kernel_size=3, padding=1),
            nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2),
            nn.Flatten(),
        )
        with torch.no_grad():
            self.egmaps_flat_size = self.egmaps_branch(torch.zeros(1, *egmaps_shape)).numel()
        self.egmaps_fc = nn.Linear(self.egmaps_flat_size, embedding_dim)
       
        # --- Cross-Attention Mechanism ---
        self.cross_attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True)


        # --- Final Classifier ---
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 1)
        )


    def forward(self, cqcc_x, egmaps_x):
        # Get embeddings from both branches
        cqcc_embedding = self.cqcc_fc(self.cqcc_branch(cqcc_x.unsqueeze(1)))
       
        # --- FIX: The LLD data is already in the correct (batch, features, time) format. ---
        # --- The transpose operation was incorrect and has been removed. ---
        egmaps_embedding = self.egmaps_fc(self.egmaps_branch(egmaps_x))
       
        # Apply Cross-Attention
        cqcc_seq = cqcc_embedding.unsqueeze(1)
        egmaps_seq = egmaps_embedding.unsqueeze(1)
       
        attended_output, _ = self.cross_attention(query=egmaps_seq, key=cqcc_seq, value=cqcc_seq)
       
        # Classify the attended output
        output = self.classifier(attended_output.squeeze(1))
       
        return torch.sigmoid(output)


def plot_training_history(history, save_path):
    # This function is for plotting and saving the training graph
    fig, ax1 = plt.subplots(figsize=(10, 7))
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color='tab:red')
    ax1.plot(history['train_loss'], color='tab:red', linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color='tab:red', linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.legend(loc='upper left')
    ax2 = ax1.twinx()
    ax2.set_ylabel('EER (%)', color='tab:blue')
    ax2.plot(history['eer'], color='tab:blue', linestyle='-', label='Val EER (%)')
    ax2.tick_params(axis='y', labelcolor='tab:blue')
    ax2.legend(loc='upper right')
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    plt.close()


# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")


    # Load Data
    try:
        print("\n--- Loading Preprocessed Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))


        # --- FIX: Changed filenames to consistently use '_dev' for the validation set ---
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}\nPlease run the feature extraction script first.")
        exit()


    # Scale Features
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(X_lld_train.shape[0], -1))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(X_lld_train.shape[0], -1)).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(X_lld_val.shape[0], -1)).reshape(X_lld_val.shape)


    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(X_cqcc_train.shape[0], -1))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(X_cqcc_train.shape[0], -1)).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(X_cqcc_val.shape[0], -1)).reshape(X_cqcc_val.shape)


    # Create Datasets and DataLoaders
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)


    # Initialize Model, Loss, and Optimizer
    model = CrossAttentionFusionCNN(
        cqcc_shape=CQCC_SHAPE,
        egmaps_shape=EGMAPS_LLD_SHAPE,
        embedding_dim=EMBEDDING_DIM
    ).to(DEVICE)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
   
    # Training Loop
    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'eer': []}
    print(f"\n--- Starting Training for Cross-Attention Model ---")
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
    # Also move labels_batch to the correct device
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()


        # Validation Loop
        model.eval()
        val_loss, all_labels, all_scores = 0.0, [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in val_loader:
        # Also move labels_batch to the correct device
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
        # ...
                outputs = model(cqcc_batch, lld_batch)
                val_loss += criterion(outputs, labels_batch.unsqueeze(1)).item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())
       
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
       
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        try:
            eer = brentq(lambda x: 1. - x - interp1d(*roc_curve(all_labels, all_scores, pos_label=1)[:2])(x), 0., 1.) * 100
        except Exception:
            eer = -1.0 # Handle case where EER cannot be calculated
        history['eer'].append(eer)


        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val EER: {eer:.2f}%")
       
        if eer > 0 and eer < best_val_eer:
            best_val_eer = eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_cross_attention_model.pth"))
            print(f"  -> New best model saved with EER: {best_val_eer:.2f}%")


    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_metrics.png"))

Using device: cuda

--- Loading Preprocessed Data ---

--- Starting Training for Cross-Attention Model ---


Epoch 1/10: 100%|██████████| 360/360 [00:18<00:00, 19.86it/s]


Epoch 1 | Train Loss: 0.2166 | Val Loss: 0.1552 | Val EER: 7.16%
  -> New best model saved with EER: 7.16%


Epoch 2/10: 100%|██████████| 360/360 [00:18<00:00, 19.78it/s]


Epoch 2 | Train Loss: 0.0676 | Val Loss: 0.1292 | Val EER: 5.44%
  -> New best model saved with EER: 5.44%


Epoch 3/10: 100%|██████████| 360/360 [00:18<00:00, 19.71it/s]


Epoch 3 | Train Loss: 0.0337 | Val Loss: 0.1049 | Val EER: 6.36%


Epoch 4/10: 100%|██████████| 360/360 [00:18<00:00, 19.62it/s]


Epoch 4 | Train Loss: 0.0194 | Val Loss: 0.0818 | Val EER: 5.85%


Epoch 5/10: 100%|██████████| 360/360 [00:18<00:00, 19.79it/s]


Epoch 5 | Train Loss: 0.0126 | Val Loss: 0.0800 | Val EER: 6.01%


Epoch 6/10: 100%|██████████| 360/360 [00:18<00:00, 19.20it/s]


Epoch 6 | Train Loss: 0.0087 | Val Loss: 0.0998 | Val EER: 5.73%


Epoch 7/10: 100%|██████████| 360/360 [00:18<00:00, 19.27it/s]


Epoch 7 | Train Loss: 0.0063 | Val Loss: 0.1250 | Val EER: 6.87%


Epoch 8/10: 100%|██████████| 360/360 [00:18<00:00, 19.03it/s]


Epoch 8 | Train Loss: 0.0048 | Val Loss: 0.1134 | Val EER: 7.30%


Epoch 9/10: 100%|██████████| 360/360 [00:19<00:00, 18.86it/s]


Epoch 9 | Train Loss: 0.0036 | Val Loss: 0.1278 | Val EER: 6.48%


Epoch 10/10: 100%|██████████| 360/360 [00:18<00:00, 19.09it/s]


Epoch 10 | Train Loss: 0.0029 | Val Loss: 0.0989 | Val EER: 6.61%

--- Training Complete ---
