In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, roc_curve
import matplotlib.pyplot as plt
from collections import defaultdict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define paths to the data files
# Make sure these paths are correct for your system
TRAIN_DATA_PATH = "/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/prosodic_features_train.csv"
DEV_DATA_PATH = "/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/prosodic_features_dev.csv"

# Define the feature columns to use (these are already static in your CSV)
FEATURE_COLS = ['mean_f0', 'std_f0', 'jitter', 'shimmer', 'mean_hnr', 'std_hnr']

# Define hyperparameters for the simple model
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 50 # Increased epochs for early stopping to have more room
HIDDEN_SIZE = 64 # A single hidden layer size for the FFNN
EARLY_STOPPING_PATIENCE = 5 # Number of epochs to wait for improvement before stopping
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class StaticFeatureDataset(Dataset):
    def __init__(self, csv_file, feature_cols=FEATURE_COLS, transform=None):
        """
        Args:
            csv_file (str): Path to the CSV file with static audio features.
            feature_cols (list): List of feature column names to use.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.feature_cols = feature_cols
        self.transform = transform

        # For static features, each row is already a sample
        self.file_ids = self.data['filename'].tolist()

        # Create label mapping (assuming 'label' column exists and is 0/1 or 'bonafide'/'spoof')
        if self.data['label'].dtype == 'object': # Check if labels are strings
            self.labels = {file_id: 1 if label == 'bonafide' else 0
                           for file_id, label in self.data[['filename', 'label']].values}
        else: # Assume labels are already integers (0 or 1)
            self.labels = {file_id: label
                           for file_id, label in self.data[['filename', 'label']].values}

    def __len__(self):
        return len(self.data) # Number of rows is the number of samples

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_id = row['filename']
        
        features = row[self.feature_cols].values.astype(np.float32) # Ensure float32 for PyTorch

        # Handle any potential NaN values in features
        # StandardScaler expects non-NaN values, so it's good to handle them before transformation.
        features[np.isnan(features)] = 0.0 # Simple imputation; consider mean/median imputation if NaNs are common

        # Apply transformation if provided
        if self.transform:
            # StandardScaler expects a 2D array (n_samples, n_features)
            features = self.transform(features.reshape(1, -1)).flatten()

        # Convert to tensor
        features_tensor = torch.FloatTensor(features)
        label = self.labels[file_id]
        label_tensor = torch.FloatTensor([label]) # Keep as float tensor for BCELoss

        return features_tensor, label_tensor

class SimpleFFNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(SimpleFFNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        predictions = self.sigmoid(x)
        return predictions

def calculate_eer(labels, scores):
    """
    Calculate the Equal Error Rate (EER) and the threshold at which it occurs.
    Labels should be 0 for bona fide (negative) and 1 for spoof (positive).
    Scores are the raw model outputs (probabilities).
    """
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
    
    # Find the threshold where FPR is approximately equal to (1 - TPR)
    # i.e., where FAR is approximately equal to FRR
    eer = 1.0
    eer_threshold = 0.0
    for i, _ in enumerate(thresholds):
        far = fpr[i]
        frr = 1 - tpr[i]
        if far >= frr:
            eer = (far + frr) / 2
            eer_threshold = thresholds[i]
            break
    return eer, eer_threshold

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, patience):
    """
    Train the model with early stopping and record metrics per epoch.
    """
    train_losses = []
    val_losses = []
    val_accuracies = []
    val_eers = []
    val_roc_aucs = []

    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_epoch = 0

    print("\nStarting training with early stopping...")
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch_idx, (features, labels) in enumerate(train_loader):
            features = features.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if (batch_idx + 1) % 100 == 0: # Print less frequently for cleaner output
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}')

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        all_scores = [] # Store raw scores for EER and ROC AUC

        with torch.no_grad():
            for features, labels in val_loader:
                features = features.to(device)
                labels = labels.to(device)

                outputs = model(features)
                loss = criterion(outputs, labels)

                val_loss += loss.item()

                preds = (outputs > 0.5).float() # Binary predictions
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_scores.extend(outputs.cpu().numpy()) # Raw probabilities/scores

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # Calculate validation metrics
        val_accuracy = accuracy_score(all_labels, all_preds)
        val_accuracies.append(val_accuracy)
        
        val_eer, _ = calculate_eer(all_labels, all_scores)
        val_eers.append(val_eer)

        val_roc_auc = roc_auc_score(all_labels, all_scores)
        val_roc_aucs.append(val_roc_auc)

        print(f'Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val EER: {val_eer:.4f} | Val ROC AUC: {val_roc_auc:.4f}')

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_epoch = epoch + 1 # Store the epoch number
            # Save the best model state
            torch.save(model.state_dict(), 'best_simple_ffnn_static_spoofing_detector.pth')
            print(f"--- Improved validation loss. Saving model from epoch {best_epoch}. ---")
        else:
            epochs_no_improve += 1
            print(f"--- No improvement for {epochs_no_improve}/{patience} epochs. ---")
            if epochs_no_improve == patience:
                print(f"Early stopping triggered at epoch {epoch+1}!")
                break

    print(f"\nTraining finished. Best model saved from epoch {best_epoch} with Validation Loss: {best_val_loss:.4f}")
    return train_losses, val_losses, val_accuracies, val_eers, val_roc_aucs

def evaluate_model(model, test_loader, device, model_path=None):
    """
    Evaluate the model on the test set, optionally loading a saved model.
    """
    if model_path:
        print(f"Loading best model from {model_path} for final evaluation...")
        model.load_state_dict(torch.load(model_path))
    
    model.eval()
    all_preds = []
    all_labels = []
    all_scores = [] # Store raw scores for EER and ROC AUC

    with torch.no_grad():
        for features, labels in test_loader:
            features = features.to(device)

            outputs = model(features)
            preds = (outputs > 0.5).float() # Binary predictions

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_scores.extend(outputs.cpu().numpy()) # Raw probabilities/scores

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    roc_auc = roc_auc_score(all_labels, all_scores)
    eer, eer_threshold = calculate_eer(all_labels, all_scores)

    print("\n===== Model Evaluation Results =====")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Equal Error Rate (EER): {eer:.4f}")
    print(f"EER Threshold: {eer_threshold:.4f}")
    print(f"ROC AUC: {roc_auc:.4f}")
    print("Confusion Matrix:")
    print(conf_matrix)

    return accuracy, precision, recall, f1, eer, eer_threshold, roc_auc, conf_matrix

def plot_metrics(train_losses, val_losses, val_accuracies, val_eers, val_roc_aucs, num_epochs_ran, filename='simple_ffnn_static_features_metrics.png'):
    """
    Plots training and validation metrics over epochs.
    """
    epochs = range(1, num_epochs_ran + 1) # Adjust epochs range for early stopping

    plt.figure(figsize=(15, 10))

    # Plot Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    # Plot Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(epochs, val_accuracies, label='Validation Accuracy', color='green')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()
    plt.grid(True)

    # Plot EER
    plt.subplot(2, 2, 3)
    plt.plot(epochs, val_eers, label='Validation EER', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('EER')
    plt.title('Validation Equal Error Rate (EER)')
    plt.legend()
    plt.grid(True)
    plt.axhline(y=min(val_eers), color='gray', linestyle='--', label=f'Min EER: {min(val_eers):.4f}')


    # Plot ROC AUC
    plt.subplot(2, 2, 4)
    plt.plot(epochs, val_roc_aucs, label='Validation ROC AUC', color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('ROC AUC')
    plt.title('Validation ROC AUC')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(filename)
    plt.show()


def main():
    print(f"Using device: {DEVICE}")

    print("Loading and preprocessing static features...")
    train_data_for_scaler = pd.read_csv(TRAIN_DATA_PATH)
    
    scaler = StandardScaler()
    scaler.fit(train_data_for_scaler[FEATURE_COLS].values)

    transform = lambda x: scaler.transform(x)

    train_dataset = StaticFeatureDataset(TRAIN_DATA_PATH, transform=transform)
    val_dataset = StaticFeatureDataset(DEV_DATA_PATH, transform=transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    input_size = len(FEATURE_COLS)
    print(f"Input size for FFNN: {input_size} features")

    model = SimpleFFNN(
        input_size=input_size,
        hidden_size=HIDDEN_SIZE
    ).to(DEVICE)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train the model with early stopping
    train_losses, val_losses, val_accuracies, val_eers, val_roc_aucs = train_model(
        model,
        train_loader,
        val_loader,
        criterion,
        optimizer,
        NUM_EPOCHS,
        DEVICE,
        EARLY_STOPPING_PATIENCE
    )
    
    # Get the number of epochs actually run (due to early stopping)
    num_epochs_ran = len(val_losses)

    # Plot results after training
    plot_metrics(train_losses, val_losses, val_accuracies, val_eers, val_roc_aucs, num_epochs_ran)

    # Evaluate the model on the validation set (which now acts as our 'test' set as per your request)
    # We load the best saved model for final evaluation.
    print("\nFinal Evaluation on the Validation Set (using best saved model):")
    accuracy, precision, recall, f1, eer, eer_threshold, roc_auc, conf_matrix = evaluate_model(
        model, val_loader, DEVICE, model_path='best_simple_ffnn_static_spoofing_detector.pth'
    )

    # You can also save the final model state (the one at the last epoch run) if you wish
    # torch.save(model.state_dict(), 'simple_ffnn_static_spoofing_detector_final_epoch.pth')
    # print("Model state at last epoch saved to simple_ffnn_static_spoofing_detector_final_epoch.pth")

if __name__ == "__main__":
    main()

Using device: cuda
Loading and preprocessing static features...
Input size for FFNN: 6 features

Starting training with early stopping...
Epoch [1/50], Batch [100/397], Train Loss: 0.3824
Epoch [1/50], Batch [200/397], Train Loss: 0.3709
Epoch [1/50], Batch [300/397], Train Loss: 0.3164
Epoch [1/50] | Train Loss: 0.3286 | Val Loss: 0.2937 | Val Acc: 0.8978 | Val EER: 0.3043 | Val ROC AUC: 0.7561
--- Improved validation loss. Saving model from epoch 1. ---
Epoch [2/50], Batch [100/397], Train Loss: 0.1794
Epoch [2/50], Batch [200/397], Train Loss: 0.2604
Epoch [2/50], Batch [300/397], Train Loss: 0.2522
Epoch [2/50] | Train Loss: 0.2568 | Val Loss: 0.2893 | Val Acc: 0.8969 | Val EER: 0.2926 | Val ROC AUC: 0.7730
--- Improved validation loss. Saving model from epoch 2. ---
Epoch [3/50], Batch [100/397], Train Loss: 0.3752
Epoch [3/50], Batch [200/397], Train Loss: 0.2470
Epoch [3/50], Batch [300/397], Train Loss: 0.2336
Epoch [3/50] | Train Loss: 0.2531 | Val Loss: 0.2883 | Val Acc: 0.89