In [1]:
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import torchaudio
#import torchaudio.functional as F
import torchaudio.transforms as T
import librosa
import torch.nn.functional as F 
import matplotlib.pyplot as plt
from scipy.fftpack import dct
from tqdm import tqdm
from sklearn.metrics import roc_curve
from tqdm import tqdm
# print(torch.__version__)
# print(torchaudio.__version__)
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
import seaborn as sns
import pandas as pd
import os
import librosa # For audio loading


In [2]:
def pad_or_truncate_spectrogram(spectrogram, max_time_steps, padding_value=-80.0):
    """
    Pads or truncates a 2D spectrogram to a specified number of time steps (width).

    Padding is applied to the end of the time axis (axis 1).
    Truncation is performed by taking the initial part of the time axis.

    Args:
        spectrogram (np.ndarray): The input 2D spectrogram (num_features, num_time_frames).
        max_time_steps (int): The target number of time frames (width) for the spectrogram.
        padding_value (float, optional): The value to use for padding.
                                         Defaults to -80.0 (suitable for dB-scaled log-Mel spectrograms).

    Returns:
        np.ndarray: The padded or truncated spectrogram of shape (num_features, max_time_steps).
    """
    if not isinstance(spectrogram, np.ndarray) or spectrogram.ndim != 2:
        raise ValueError("Input spectrogram must be a 2D NumPy array.")
    if not isinstance(max_time_steps, int) or max_time_steps <= 0:
        raise ValueError("max_time_steps must be a positive integer.")
    
    #print(spectrogram.shape)

    num_features, current_time_steps = spectrogram.shape
    

    max_time_steps = 157


    if current_time_steps == max_time_steps:
        # Already the correct length
        return spectrogram
    elif current_time_steps < max_time_steps:
        # Pad if shorter
        pad_width = max_time_steps - current_time_steps
        # np.pad takes a list of tuples for padding widths: ((before_axis0, after_axis0), (before_axis1, after_axis1), ...)
        # We only want to pad axis 1 (time steps) at the end.
        padded_spectrogram = np.pad(
            spectrogram,
            pad_width=((0, 0), (0, pad_width)), # No padding on feature axis, pad_width at the end of time axis
            mode='constant',
            constant_values=padding_value
        )
        return padded_spectrogram
    else: # current_time_steps > max_time_steps
        # Truncate if longer (take from the beginning of the time axis)
        truncated_spectrogram = spectrogram[:, :max_time_steps]
        return truncated_spectrogram

In [3]:

def process_flac_files_in_folder(feature_vector,folder_path, target_sample_rate=16000, duration=10):
    """
    Reads all .flac files from a given folder, loads them using librosa,
    and prints basic information.

    Args:
        folder_path (str): The path to the folder containing .flac files.
        target_sample_rate (int, optional): The target sampling rate to resample audio to.
                                            If None, loads at original sampling rate.
                                            Defaults to None.
        duration (float, optional): Maximum duration (in seconds) to load for each audio file.
                                    If None, loads the entire file. Defaults to None.
    """
    print(f"Processing FLAC files in folder: {folder_path}\n")
    found_flac_files = False
    count = 0
    # Check if the folder exists
    if not os.path.isdir(folder_path):
        print(f"Error: Folder not found at '{folder_path}'")
        return
    freq = {}
    # Iterate over all entries in the folder
    for item_name in os.listdir(folder_path):
        # Construct the full path to the item
        item_path = os.path.join(folder_path, item_name)

        # Check if it's a file and ends with .flac
        if os.path.isfile(item_path) and item_name.lower().endswith(".flac"):
            found_flac_files = True
            count = count + 1
            print(f"--- Found FLAC file: {item_name} ---")
            try:
                # Load the audio file
                # audio_data is a NumPy array containing the waveform
                # sr is the sampling rate of the loaded audio
                audio_data, sr = librosa.load(item_path, sr=target_sample_rate, duration=duration)

                print(f"  Successfully loaded.")
                print(f"  Shape of audio data: {audio_data.shape}")
                print(f"  Sampling rate: {sr} Hz")
                print(f"  Duration: {librosa.get_duration(y=audio_data, sr=sr):.2f} seconds")
                # You can add more processing here, e.g., feature extraction
                mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sr)
                print(f"  Mel spectrogram shape: {mel_spectrogram.shape}")
                padded_mel_spectrogram = pad_or_truncate_spectrogram(mel_spectrogram,157,-80)
                feature_vector.append(padded_mel_spectrogram)
                print(f"Added to the list - Shape of MFCC features: {padded_mel_spectrogram.shape}")
                if padded_mel_spectrogram.shape[1] not in freq:
                    freq[padded_mel_spectrogram.shape[1]] = 1
                else:
                    freq[padded_mel_spectrogram.shape[1]] += 1

            except Exception as e:
                print(f"  Error loading or processing file {item_name}: {e}")
            print("-" * 30) # Separator
    # print(count)
    # print(freq)
    if not found_flac_files:
        print(f"No .flac files found in '{folder_path}'.")
    return feature_vector

In [4]:
def feature_extraction_cqcc(y, sr, hop_length=512, fmin=None, n_bins=90, bins_per_octave=12, n_cqcc = 128):
    
    if fmin is None:
        # Set a default fmin similar to librosa's default if not provided
        # This corresponds to C1
        fmin = librosa.note_to_hz('C1')
    try:
        # 1. Calculate Constant Q Transform (CQT)
        # We use the magnitude CQT directly
        cqt_result = librosa.cqt(y=y, sr=sr,
                                hop_length=hop_length,
                                fmin=fmin,
                                n_bins=n_bins,
                                bins_per_octave=bins_per_octave)

        # Get the magnitude (absolute value) - CQT returns complex values
        cqt_mag = np.abs(cqt_result)

        # Handle potential zero values before log
        cqt_mag[cqt_mag == 0] = np.finfo(float).eps

        # 2. Calculate Log Power/Magnitude
        log_cqt_mag = np.log(cqt_mag)

        # 3. Apply Discrete Cosine Transform (DCT) - Type II is common
        # Apply DCT along the frequency axis (axis=0)
        cqcc = dct(log_cqt_mag, type=2, axis=0, norm='ortho')

        # 4. Keep only the first n_cqcc coefficients
        cqcc_truncated = cqcc[:n_cqcc, :]

        return cqcc_truncated
    except Exception as e:
        print(f"Error calculating CQT or CQCC: {e}")
        # Handle cases where CQT might fail (e.g., very short audio)
        return None

In [5]:
def create_label_array(audio_folder_path , label_file_path):
    count = 0
    audio_folder_path = audio_folder_path
    label_file_path = label_file_path
    audio_name_list = []
    label_dict = {}

    for audio_name in os.listdir(audio_folder_path):
        # Construct the full path to the item
        item_path = os.path.join(audio_folder_path, audio_name)

        # Check if it's a file and ends with .flac
        if os.path.isfile(item_path) and audio_name.lower().endswith(".flac"):
            found_flac_files = True
            count = count + 1
            #print(f"--- Found FLAC file: {audio_name.strip('.flac')} ---")
            audio_name_list.append(audio_name.strip('.flac'))

    with open(label_file_path, 'r') as f:
        lines = f.readlines()
        #print(lines)
    for line_num, line in enumerate(lines):
        #print(line_num)
        parts = line.strip().split()
        # Check if line has at least 4 parts (SpeakerID, Filename, Key) - adjust if format differs
        if len(parts) >= 4:
            file_name = parts[1]  # Assuming filename is the second element
            key = parts[-1]     # Assuming label key is the last element
            #print(key)

            if key == "bonafide" and file_name in audio_name_list:
                label = 1
                label_dict[file_name] = label
                
            elif key == "spoof" and file_name in audio_name_list:
                label = 0
                label_dict[file_name] = label
        else:
            print(f"Warning: Skipping malformed line {line_num + 1}: {line.strip()}")
    return label_dict,audio_name_list

In [6]:
# import soundfile as sf
# from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, OneOf, SomeOf, Gain, TimeMask

# # --- 1. Setup paths and augmentation pipeline ---
# original_bonafide_path = "bonafide_audio_train"
# spoof_path = "spoof_audio_train"
# augmented_bonafide_path = "augmented_bonafide"

# # Create a directory to save augmented files
# os.makedirs(augmented_bonafide_path, exist_ok=True)

# # Define the augmentation pipeline
# augment = Compose([
#     # Apply one of the following "major" transformations to 80% of the samples
#     OneOf([
#         TimeStretch(min_rate=0.85, max_rate=1.15),
#         PitchShift(min_semitones=-1.5, max_semitones=1.5),
#     ], p=0.8),

#     # Apply a small amount of noise to 40% of the samples
#     AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.008, p=0.4),

#     # Apply a small gain adjustment to 40% of the samples
#     Gain(min_gain_in_db=-4, max_gain_in_db=4, p=0.4),

#     # Apply a time mask to 30% of the samples
#     TimeMask(min_band_part=0.02, max_band_part=0.08, p=0.3),
# ])

# # --- 2. Calculate how many new files to generate ---
# bonafide_files = os.listdir(original_bonafide_path)
# spoof_files_count = 22799 # Your majority class count
# bonafide_files_count = len(bonafide_files)
# num_augmentations_needed = spoof_files_count - bonafide_files_count
# augmentations_per_file = round(num_augmentations_needed / bonafide_files_count)

# print(f"Generating {augmentations_per_file} augmented samples per Bonafide file.")

# # --- 3. Apply augmentation ---
# for filename in bonafide_files:
#     # Load the audio file
#     audio, sr = librosa.load(os.path.join(original_bonafide_path, filename), sr=16000)
    
#     # Create multiple augmented versions
#     for i in range(augmentations_per_file):
#         # Apply the augmentation pipeline
#         augmented_audio = augment(samples=audio, sample_rate=sr)
        
#         # Define a new filename and save the file
#         new_filename = f"aug_{i}_{filename}"
#         sf.write(os.path.join(augmented_bonafide_path, new_filename), augmented_audio, sr)

# print("Data augmentation complete.")
# # Now your training data can be loaded from the spoof, original bonafide, and augmented bonafide folders.


In [7]:
# sorted_labels = sorted(label_list.items())
# labels = []


In [8]:
# sorted_labels = sorted(label_list.items())
# labels = []
# for i in range(0,len(sorted_labels)):
#     labels.append(sorted_labels[i][1])
# labels = np.array(labels)
# feature_vector = np.array(feature_vector)

In [9]:
# labels = np.array(labels)

In [10]:
# labels

In [11]:
# feature_vector = np.array(feature_vector)

In [12]:
# feature_vector.shape

In [13]:
# list(label_list_val.keys())

In [14]:
# sorted_labels_val = sorted(label_list_val.items())
# labels_val= []
# for i in range(0,len(sorted_labels_val)):
#     labels_val.append(sorted_labels_val[i][1])
# labels_val = np.array(labels_val)
# #labels_val
# feature_vector_val = np.array(feature_vector_val)
# print(feature_vector_val.shape, labels_val.shape)

In [15]:
# ACHTUNG FOR Pre-Processing for EVALUATION ONLY ---
# your_folder_with_flac_files_val = "ASVSpoof19/LA/ASVspoof2019_LA_eval/flac"
# audio_folder_path_val = "ASVSpoof19/LA/ASVspoof2019_LA_eval/flac"
# label_file_val = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt" 
# folder_path_val = your_folder_with_flac_files_val
# val_file_sorted_path= "val_file_list.txt"
# print(f"Processing FLAC files in folder: {folder_path_val}\n")
# label_list_val, audio_name_list_val = create_label_array(audio_folder_path=audio_folder_path_val, label_file_path=label_file_val)
# sorted_labels_val = sorted(label_list_val.items())
# labels_val= []
# for i in range(0,len(sorted_labels_val)):
#     labels_val.append(sorted_labels_val[i][1])
# labels_val = np.array(labels_val)
# # print(len(label_list_val), len(audio_name_list_val),len(feature_vector_val))
# feature_vector_val = []
# found_flac_files_val = False
# count_val = 0
# target_sample_rate_val=16000
# duration_val=5.0
# # Check if the folder exists
# if not os.path.isdir(folder_path_val):
#     print(f"Error: Folder not found at '{folder_path_val}'")
# freq_val = {}
# with open(val_file_sorted_path, "r") as f:
#     lines = f.readlines()
# # Iterate over all entries in the folder
# print("Hey User!!, You are now creating the Evaluation set and their corresponding labels")
# for item_name in lines:
#     # Construct the full path to the item
#     #print(item_name)
#     item_path = str(os.path.join(folder_path_val, item_name)).strip('\n')
#     #print(item_path)
#     # Check if it's a file and ends with .flac
#     # if os.path.isfile(item_path) and item_name.lower().endswith(".flac"):
#     #     found_flac_files = True
#     # print(item_path.split('/'))
#     # print(item_name.lower())
#     #print(item_name.strip('.flac\n') in str(list(label_list_val.keys())))
#     if item_name.strip('.flac\n') in str(list(label_list_val.keys())):
#         # audio_data_val, sr_val = librosa.load(item_path, sr=target_sample_rate_val, duration=duration_val)
#         #print("My Name is Mitukk")
#         # print(f"--- Found FLAC file: {item_name} ---")
#         audio_data_val, sr_val = librosa.load(item_path, sr=target_sample_rate_val, duration=duration_val)
#         cqcc_features = feature_extraction_cqcc(audio_data_val, sr_val, n_cqcc=128)
#         if cqcc_features is not None:
#             count_val = count_val + 1
#             padded_cqcc = pad_or_truncate_spectrogram(cqcc_features,157,-80)
#             feature_vector_val.append(padded_cqcc)
#             print(f"Added to the list - Shape of CQCC features: {padded_cqcc.shape}")


# # label_list_val, audio_name_list_val = create_label_array(audio_folder_path=audio_folder_path_val, label_file_path=label_file_val)

# print(len(label_list_val), len(audio_name_list_val),len(feature_vector_val))

In [16]:
label_array

NameError: name 'label_array' is not defined

In [None]:
import os
import librosa
import numpy as np
from tqdm import tqdm
import soundfile as sf

# This script now requires the 'praat-parselmouth' library.
# Install it using: pip install praat-parselmouth
import parselmouth
from parselmouth.praat import call

# --- Feature Extraction Functions ---

def feature_extraction_cqcc(audio_data, sr, n_bins=90):
    """
    Placeholder for the CQCC feature extraction function.
    In a real scenario, this would compute actual CQCC features.
    """
    # Using librosa's CQT and cepstral mean subtraction as a stand-in for CQCC
    cqt = librosa.cqt(y=audio_data, sr=sr, n_bins=n_bins, bins_per_octave=12)
    cqt_mag = np.abs(cqt)

        # Handle potential zero values before log
    cqt_mag[cqt_mag == 0] = np.finfo(float).eps

        # 2. Calculate Log Power/Magnitude
    log_cqt_mag = np.log(cqt_mag)

        # 3. Apply Discrete Cosine Transform (DCT) - Type II is common
        # Apply DCT along the frequency axis (axis=0)
    cqcc = dct(log_cqt_mag, type=2, axis=0, norm='ortho')

        # 4. Keep only the first n_cqcc coefficients
    cqcc_truncated = cqcc[:128, :]

    return cqcc_truncated

def extract_prosodic_features(audio_path):
    """
    Extracts fundamental prosodic features using Parselmouth (Praat).
    Features include F0, jitter, shimmer, and HNR.
    
    Args:
        audio_path (str): Path to the audio file.
    
    Returns:
        np.array: A vector of the specified prosodic features.
    """
    try:
        # Load sound with parselmouth for Praat features
        sound = parselmouth.Sound(audio_path)
        
        # Create a pitch object from the sound
        pitch = call(sound, "To Pitch", 0.0, 75, 600) # time_step, min_pitch, max_pitch
        
        # Create a point process from the pitch object for jitter/shimmer calculation
        point_process = call(pitch, "To PointProcess")
        
        # Create a harmonicity object for HNR calculation
        harmonicity = call(sound, "To Harmonicity (cc)", 0.01, 75, 0.1, 1.0)
        
        # Extract the requested features
        mean_f0 = call(pitch, "Get mean", 0, 0, "Hertz")
        std_f0 = call(pitch, "Get standard deviation", 0, 0, "Hertz")
        local_jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
        local_shimmer = call([sound, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)
        mean_hnr = call(harmonicity, "Get mean", 0, 0)
        std_hnr = call(harmonicity, "Get standard deviation", 0, 0)

    except Exception:
        # Praat can fail on silent or very noisy audio. Provide default values for all features.
        mean_f0, std_f0, local_jitter, local_shimmer, mean_hnr, std_hnr = 0, 0, 0, 0, 0, 0

    # Combine all features into a single vector
    prosodic_vector = np.array([
        mean_f0, std_f0, 
        local_jitter, local_shimmer, 
        mean_hnr, std_hnr
    ], dtype=np.float32)
    
    return prosodic_vector


def pad_or_truncate_spectrogram(spectrogram, max_time_steps, padding_value=-80.0):
    """
    Pads or truncates a 2D spectrogram to a specified number of time steps (width).

    Padding is applied to the end of the time axis (axis 1).
    Truncation is performed by taking the initial part of the time axis.

    Args:
        spectrogram (np.ndarray): The input 2D spectrogram (num_features, num_time_frames).
        max_time_steps (int): The target number of time frames (width) for the spectrogram.
        padding_value (float, optional): The value to use for padding.
                                         Defaults to -80.0 (suitable for dB-scaled log-Mel spectrograms).

    Returns:
        np.ndarray: The padded or truncated spectrogram of shape (num_features, max_time_steps).
    """
    if not isinstance(spectrogram, np.ndarray) or spectrogram.ndim != 2:
        raise ValueError("Input spectrogram must be a 2D NumPy array.")
    if not isinstance(max_time_steps, int) or max_time_steps <= 0:
        raise ValueError("max_time_steps must be a positive integer.")
    
    #print(spectrogram.shape)

    num_features, current_time_steps = spectrogram.shape
    

    max_time_steps = 157


    if current_time_steps == max_time_steps:
        # Already the correct length
        return spectrogram
    elif current_time_steps < max_time_steps:
        # Pad if shorter
        pad_width = max_time_steps - current_time_steps
        # np.pad takes a list of tuples for padding widths: ((before_axis0, after_axis0), (before_axis1, after_axis1), ...)
        # We only want to pad axis 1 (time steps) at the end.
        padded_spectrogram = np.pad(
            spectrogram,
            pad_width=((0, 0), (0, pad_width)), # No padding on feature axis, pad_width at the end of time axis
            mode='constant',
            constant_values=padding_value
        )
        return padded_spectrogram
    else: # current_time_steps > max_time_steps
        # Truncate if longer (take from the beginning of the time axis)
        truncated_spectrogram = spectrogram[:, :max_time_steps]
        return truncated_spectrogram

def create_label_dictionary(label_file_path):
    """
    Reads the ASVspoof protocol file and creates a dictionary mapping
    a file's base name to its label (1 for bonafide, 0 for spoof).
    """
    label_dict = {}
    with open(label_file_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 4:
                file_basename = parts[1]
                label = 0 if parts[-1] == 'bonafide' else 1
                label_dict[file_basename] = label
    return label_dict

# --- Main Script ---

# Configuration
cqcc_output_file = "cqcc_features_aligned.npy"
prosody_output_file = "prosody_features_aligned.npy"
labels_output_file = "labels_aligned.npy"

audio_folder_path = "ASVSpoof19/LA/ASVspoof2019_LA_train/flac"
label_file = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"
trn_file_sorted_path = "trn_file_list.txt"

target_sample_rate = 16000
duration = 5.0

# # --- Setup Dummy Files and Folders for Demonstration ---
# print("--- Setting up dummy files and folders for demonstration ---")
# os.makedirs(audio_folder_path, exist_ok=True)
# os.makedirs(os.path.dirname(label_file), exist_ok=True)

# dummy_labels_content = [
#     "LA_0079 LA_T_1138215 - A01 bonafide",
#     "LA_0079 LA_T_1234567 - A02 spoof",
#     "LA_0079 LA_T_7654321 - A03 bonafide",
# ]
# with open(label_file, "w") as f:
#     f.write("\n".join(dummy_labels_content))

# dummy_flac_files = ["LA_T_1138215.flac", "LA_T_1234567.flac", "LA_T_7654321.flac"]
# with open(trn_file_sorted_path, "w") as f:
#     f.write("\n".join(dummy_flac_files))

# for flac_file in dummy_flac_files:
#     dummy_audio = np.sin(np.linspace(0, 440 * 2 * np.pi, int(target_sample_rate * 2)))
#     sf.write(os.path.join(audio_folder_path, flac_file), dummy_audio, target_sample_rate)
# print("--- Dummy setup complete ---\n")
# # --- End of Setup ---


# Check if the folder exists
if not os.path.isdir(audio_folder_path):
    print(f"Error: Folder not found at '{audio_folder_path}'")
else:
    print("Step 1: Creating a lookup dictionary for labels.")
    label_dictionary = create_label_dictionary(label_file)
    print(f"Found {len(label_dictionary)} labels in the protocol file.")

    print(f"\nStep 2: Processing FLAC files in folder: {audio_folder_path}")
    cqcc_feature_vector = []
    prosody_feature_vector = []
    label_vector = []

    try:
        with open(trn_file_sorted_path, "r") as f:
            lines = [line.strip() for line in f if line.strip()]
    except FileNotFoundError:
        print(f"Error: Sorted file list not found at '{trn_file_sorted_path}'")
        lines = []

    print("Extracting features and aligning with labels...")
    for item_name in tqdm(lines, desc="Extracting Features", unit="file"):
        base_name = os.path.splitext(item_name)[0]

        if base_name in label_dictionary:
            item_path = os.path.join(audio_folder_path, item_name)

            if os.path.isfile(item_path):
                try:
                    # Load audio data once for librosa features
                    audio_data, sr = librosa.load(item_path, sr=target_sample_rate, duration=duration)
                    
                    # Extract CQCC features from the loaded audio data
                    cqcc_features = feature_extraction_cqcc(audio_data, sr, n_bins=90)
                    
                    # Extract prosodic features using the file path (for Parselmouth)
                    prosody_features = extract_prosodic_features(item_path)

                    if cqcc_features is not None and prosody_features is not None:
                        # Standardize CQCC shape
                        padded_cqcc = pad_or_truncate_spectrogram(cqcc_features, 157, -80)
                        
                        # Append all features and the label to their respective lists
                        cqcc_feature_vector.append(padded_cqcc)
                        prosody_feature_vector.append(prosody_features)
                        label_vector.append(label_dictionary[base_name])

                except Exception as e:
                    print(f"\nWarning: Could not process file {item_path}. Error: {e}")

    print("\n--- Processing Complete ---")
    print(f"Number of CQCC feature vectors created: {len(cqcc_feature_vector)}")
    print(f"Number of Prosody feature vectors created: {len(prosody_feature_vector)}")
    print(f"Number of corresponding labels found: {len(label_vector)}")

    # --- Convert lists to NumPy arrays and save to .npy files ---
    if cqcc_feature_vector and label_vector:
        cqcc_array = np.array(cqcc_feature_vector, dtype=np.float32)
        prosody_array = np.array(prosody_feature_vector, dtype=np.float32)
        label_array = np.array(label_vector, dtype=np.int64)

        # Save the arrays to .npy files
        np.save(cqcc_output_file, cqcc_array)
        np.save(prosody_output_file, prosody_array)
        np.save(labels_output_file, label_array)

        print(f"\nSuccessfully saved ALIGNED data:")
        print(f"CQCC Features saved to '{cqcc_output_file}' with shape: {cqcc_array.shape}")
        print(f"Prosody Features saved to '{prosody_output_file}' with shape: {prosody_array.shape}")
        print(f"Labels saved to '{labels_output_file}' with shape: {label_array.shape}")
    else:
        print("\nNo features were extracted or labels created. Nothing to save.")


Step 1: Creating a lookup dictionary for labels.
Found 25380 labels in the protocol file.

Step 2: Processing FLAC files in folder: ASVSpoof19/LA/ASVspoof2019_LA_train/flac
Extracting features and aligning with labels...


  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
  return f(*args, **kwargs)
Extracting Features:


--- Processing Complete ---
Number of CQCC feature vectors created: 25379
Number of Prosody feature vectors created: 25379
Number of corresponding labels found: 25379

Successfully saved ALIGNED data:
CQCC Features saved to 'cqcc_features_aligned.npy' with shape: (25379, 90, 157)
Prosody Features saved to 'prosody_features_aligned.npy' with shape: (25379, 6)
Labels saved to 'labels_aligned.npy' with shape: (25379,)


In [None]:
if feature_vector and labels:
    # Convert the feature list to a NumPy array
    feature_array = np.array(feature_vector)
    # Convert the label list to a NumPy array
    label_array = np.array(labels)

    # Save the arrays to .npy files
    np.save(cqcc_output_file, feature_array)
    np.save(labels_output_file, label_array)

    print(f"\nSuccessfully saved data:")
    print(f"Features saved to '{cqcc_output_file}' with shape: {feature_array.shape}")
    print(f"Labels saved to '{labels_output_file}' with shape: {label_array.shape}")
else:
    print("\nNo features were extracted or labels created. Nothing to save.")
print(len(label_list), len(audio_name_list),len(feature_vector))

In [None]:
for i in range(0,10):
    print(feature_vector[i])

In [None]:
# import os
# import numpy as np
# import librosa
# # import cqcc # You may need to install this: pip install cqcc
# from tqdm import tqdm
# import random
# import warnings
# import parselmouth
# from parselmouth.praat import call

# # Suppress warnings from librosa and other libraries for a cleaner output
# def calculate_prosodic_features(audio_path):
#     """Calculates prosodic features for a given audio file."""
#     try:
#         sound = parselmouth.Sound(audio_path)
#         pitch = sound.to_pitch()
#         point_process = call(pitch, "To PointProcess")
#         harmonicity = call(sound, "To Harmonicity (cc)", 0.01, 75, 0.1, 1.0)
        
#         features = {
#             'mean_f0': call(pitch, "Get mean", 0, 0, "Hertz"),
#             'std_f0': call(pitch, "Get standard deviation", 0, 0, "Hertz"),
#             'jitter': call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3),
#             'shimmer': call([sound, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6),
#             'mean_hnr': call(harmonicity, "Get mean", 0, 0),
#             'std_hnr': call(harmonicity, "Get standard deviation", 0, 0)
#         }
        
#         # Replace NaN values with 0.0 for model compatibility
#         for key, value in features.items():
#             if isinstance(value, float) and np.isnan(value):
#                 features[key] = 0.0
#         return features
#     except Exception:
#         return None

# def extract_cqcc_features(audio, sr, n_cqcc=20):
#     """Extracts Constant Q-Cepstral Coefficients (CQCC)."""
#     try:
#         features = cqcc.cqcc(audio, sr, n_cqcc=n_cqcc)
#         # Return the features averaged over the time axis for a fixed-size vector
#         return np.mean(features, axis=1)
#     except Exception:
#         return None

# def process_all_features(bonafide_dirs, spoof_dir, protocol_file, output_path):
#     """
#     Processes all audio files to extract both CQCC and prosodic features in a single loop,
#     ensuring perfect alignment.
#     """
#     print("--- Starting Unified Feature Extraction Process ---")
#     os.makedirs(output_path, exist_ok=True)

#     # 1. Load protocol file to map filenames to attack IDs
#     try:
#         protocol_df = pd.read_csv(protocol_file, sep=" ", header=None, names=['speaker_id', 'filename', 'attack_type', 'system_id', 'label'])
#         attack_id_map = pd.Series(protocol_df.system_id.values, index=protocol_df.filename).to_dict()
#         print("Successfully loaded protocol file for Attack ID mapping.")
#     except FileNotFoundError:
#         print(f"Error: Protocol file not found at {protocol_file}. Cannot map attack IDs.")
#         return

#     # 2. Gather all file paths and assign preliminary labels
#     all_files_to_process = []
#     for directory in bonafide_dirs:
#         try:
#             files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('.flac', '.wav'))]
#             for f_path in files:
#                 all_files_to_process.append({'filepath': f_path, 'label': 0}) # 0 for bonafide
#         except FileNotFoundError:
#             print(f"Warning: Directory not found: {directory}. Skipping.")
    
#     try:
#         files = [os.path.join(spoof_dir, f) for f in os.listdir(spoof_dir) if f.endswith(('.flac', '.wav'))]
#         for f_path in files:
#             all_files_to_process.append({'filepath': f_path, 'label': 1}) # 1 for spoof
#     except FileNotFoundError:
#         print(f"Warning: Directory not found: {spoof_dir}. Skipping.")

#     print(f"Found {len(all_files_to_process)} total audio files to process.")

#     # 3. Main processing loop
#     cqcc_list = []
#     prosodic_list = []

#     for file_info in tqdm(all_files_to_process, desc="Extracting All Features"):
#         filepath = file_info['filepath']
#         label = file_info['label']
        
#         # Load audio once
#         try:
#             audio, sr = librosa.load(filepath, sr=16000, duration = 5.0) # Use a fixed sample rate
#         except Exception as e:
#             print(f"\nError loading {filepath}: {e}. Skipping.")
#             continue

#         # Extract features
#         cqcc_feats = feature_extraction_cqcc(audio, sr, n_cqcc=128)
#         prosodic_feats = calculate_prosodic_features(filepath)

#         # Only add the sample if both feature extractions were successful
#         if cqcc_feats is not None and prosodic_feats is not None:
#             # Get metadata
#             base_filename = os.path.basename(filepath)
#             if base_filename.startswith('aug_'):
#                 original_filename_key = '_'.join(base_filename.split('_')[2:]).replace('.flac', '')
#             else:
#                 original_filename_key = base_filename.replace('.flac', '')
            
#             attack_id = attack_id_map.get(original_filename_key, '-')
#             padded_cqcc = pad_or_truncate_spectrogram(cqcc_feats, 157, -80)
            
#             # Append data to lists
#             cqcc_list.append(padded_cqcc)

            
#             prosodic_data_row = {
#                 'filename': base_filename,
#                 'label': label,
#                 'attack_id': attack_id,
#                 **prosodic_feats
#             }
#             prosodic_list.append(prosodic_data_row)

#     # 4. Save the processed and aligned data
#     if not prosodic_list:
#         print("No features were extracted. Please check your directories and files.")
#         return

#     # Convert to final formats
#     X_cqcc = np.array(cqcc_list)
#     prosody_df = pd.DataFrame(prosodic_list)

#     # Define output paths
#     cqcc_save_path = os.path.join(output_path, "cqcc_features.npy")
#     prosody_csv_save_path = os.path.join(output_path, "prosodic_features_and_labels.csv")

#     # Save files
#     np.save(cqcc_save_path, X_cqcc)
#     prosody_df.to_csv(prosody_csv_save_path, index=False)

#     print(f"\n--- Unified Feature Extraction Complete ---")
#     print(f"Processed and aligned {len(prosody_df)} files successfully.")
#     print(f"CQCC features saved to: {cqcc_save_path} with shape {X_cqcc.shape}")
#     print(f"Prosodic features and labels saved to: {prosody_csv_save_path}")

In [None]:
# BONAFIDE_DIRS = [
#         "bonafide_audio_train",            # Original bonafide files
#         "augmented_bonafide"   # Augmented bonafide files
#     ]
# SPOOF_AUDIO_DIR = "spoof_audio_train"
    
#     # This is where the final .npy files will be saved.
# OUTPUT_DIR = "processed_data"

#     # Set 'max_files_per_class' to a small number (e.g., 100) for a quick test run.
#     # Set to 'None' to process all files.


# PROTOCOL_FILE = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"

# process_all_features(
#     bonafide_dirs=BONAFIDE_DIRS,
#     spoof_dir=SPOOF_AUDIO_DIR,
#     protocol_file=PROTOCOL_FILE,
#     output_path=OUTPUT_DIR
#     )

In [None]:
# import os
# import numpy as np
# import librosa
# # import cqcc # You may need to install this: pip install cqcc
# from tqdm import tqdm
# import random
# import warnings
# import parselmouth
# from parselmouth.praat import call

# # Suppress warnings from librosa and other libraries for a cleaner output
# def calculate_prosodic_features(audio_path):
#     """Calculates prosodic features for a given audio file."""
#     try:
#         sound = parselmouth.Sound(audio_path)
#         pitch = sound.to_pitch()
#         point_process = call(pitch, "To PointProcess")
#         harmonicity = call(sound, "To Harmonicity (cc)", 0.01, 75, 0.1, 1.0)
        
#         features = {
#             'mean_f0': call(pitch, "Get mean", 0, 0, "Hertz"),
#             'std_f0': call(pitch, "Get standard deviation", 0, 0, "Hertz"),
#             'jitter': call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3),
#             'shimmer': call([sound, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6),
#             'mean_hnr': call(harmonicity, "Get mean", 0, 0),
#             'std_hnr': call(harmonicity, "Get standard deviation", 0, 0)
#         }
        
#         # Replace NaN values with 0.0 for model compatibility
#         for key, value in features.items():
#             if isinstance(value, float) and np.isnan(value):
#                 features[key] = 0.0
#         return features
#     except Exception:
#         return None

# def extract_cqcc_features(audio, sr, n_cqcc=20):
#     """Extracts Constant Q-Cepstral Coefficients (CQCC)."""
#     try:
#         features = cqcc.cqcc(audio, sr, n_cqcc=n_cqcc)
#         # Return the features averaged over the time axis for a fixed-size vector
#         return np.mean(features, axis=1)
#     except Exception:
#         return None

# def process_all_features(bonafide_dirs, spoof_dir, protocol_file, output_path):
#     """
#     Processes all audio files to extract both CQCC and prosodic features in a single loop,
#     ensuring perfect alignment.
#     """
#     print("--- Starting Unified Feature Extraction Process ---")
#     os.makedirs(output_path, exist_ok=True)

#     # 1. Load protocol file to map filenames to attack IDs
#     try:
#         protocol_df = pd.read_csv(protocol_file, sep=" ", header=None, names=['speaker_id', 'filename', 'attack_type', 'system_id', 'label'])
#         attack_id_map = pd.Series(protocol_df.system_id.values, index=protocol_df.filename).to_dict()
#         print("Successfully loaded protocol file for Attack ID mapping.")
#     except FileNotFoundError:
#         print(f"Error: Protocol file not found at {protocol_file}. Cannot map attack IDs.")
#         return

#     # 2. Gather all file paths and assign preliminary labels
#     all_files_to_process = []
#     for directory in bonafide_dirs:
#         try:
#             files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('.flac', '.wav'))]
#             for f_path in files:
#                 all_files_to_process.append({'filepath': f_path, 'label': 0}) # 0 for bonafide
#         except FileNotFoundError:
#             print(f"Warning: Directory not found: {directory}. Skipping.")
    
#     try:
#         files = [os.path.join(spoof_dir, f) for f in os.listdir(spoof_dir) if f.endswith(('.flac', '.wav'))]
#         for f_path in files:
#             all_files_to_process.append({'filepath': f_path, 'label': 1}) # 1 for spoof
#     except FileNotFoundError:
#         print(f"Warning: Directory not found: {spoof_dir}. Skipping.")

#     print(f"Found {len(all_files_to_process)} total audio files to process.")

#     # 3. Main processing loop
#     cqcc_list = []
#     prosodic_list = []

#     for file_info in tqdm(all_files_to_process, desc="Extracting All Features"):
#         filepath = file_info['filepath']
#         label = file_info['label']
        
#         # Load audio once
#         try:
#             audio, sr = librosa.load(filepath, sr=16000, duration = 5.0) # Use a fixed sample rate
#         except Exception as e:
#             print(f"\nError loading {filepath}: {e}. Skipping.")
#             continue

#         # Extract features
#         cqcc_feats = feature_extraction_cqcc(audio, sr, n_cqcc=128)
#         prosodic_feats = calculate_prosodic_features(filepath)

#         # Only add the sample if both feature extractions were successful
#         if cqcc_feats is not None and prosodic_feats is not None:
#             # Get metadata
#             base_filename = os.path.basename(filepath)
#             if base_filename.startswith('aug_'):
#                 original_filename_key = '_'.join(base_filename.split('_')[2:]).replace('.flac', '')
#             else:
#                 original_filename_key = base_filename.replace('.flac', '')
            
#             attack_id = attack_id_map.get(original_filename_key, '-')
#             padded_cqcc = pad_or_truncate_spectrogram(cqcc_feats, 157, -80)
            
#             # Append data to lists
#             cqcc_list.append(padded_cqcc)

            
#             prosodic_data_row = {
#                 'filename': base_filename,
#                 'label': label,
#                 'attack_id': attack_id,
#                 **prosodic_feats
#             }
#             prosodic_list.append(prosodic_data_row)

#     # 4. Save the processed and aligned data
#     if not prosodic_list:
#         print("No features were extracted. Please check your directories and files.")
#         return

#     # Convert to final formats
#     X_cqcc = np.array(cqcc_list)
#     prosody_df = pd.DataFrame(prosodic_list)

#     # Define output paths
#     cqcc_save_path = os.path.join(output_path, "cqcc_features_test.npy")
#     prosody_csv_save_path = os.path.join(output_path, "prosodic_features_and_labels_test.csv")

#     # Save files
#     np.save(cqcc_save_path, X_cqcc)
#     prosody_df.to_csv(prosody_csv_save_path, index=False)

#     print(f"\n--- Unified Feature Extraction Complete ---")
#     print(f"Processed and aligned {len(prosody_df)} files successfully.")
#     print(f"CQCC features saved to: {cqcc_save_path} with shape {X_cqcc.shape}")
#     print(f"Prosodic features and labels saved to: {prosody_csv_save_path}")

# BONAFIDE_DIRS = [
#         "bonafide_audio_test",            # Original bonafide files 
#     ]
# SPOOF_AUDIO_DIR = "spoof_audio_test"
    
#     # This is where the final .npy files will be saved.
# OUTPUT_DIR = "processed_data"

#     # Set 'max_files_per_class' to a small number (e.g., 100) for a quick test run.
#     # Set to 'None' to process all files.


# PROTOCOL_FILE = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt"

# process_all_features(
#     bonafide_dirs=BONAFIDE_DIRS,
#     spoof_dir=SPOOF_AUDIO_DIR,
#     protocol_file=PROTOCOL_FILE,
#     output_path=OUTPUT_DIR
#     )

In [None]:
# # ACHTUNG FOR Pre-Processing for Development ONLY ---
# your_folder_with_flac_files_dev = "ASVSpoof19/LA/ASVspoof2019_LA_dev/flac"
# audio_folder_path_dev = "ASVSpoof19/LA/ASVspoof2019_LA_dev/flac"
# label_file_dev = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt" 
# folder_path_dev = your_folder_with_flac_files_dev
# dev_file_sorted_path= "dev_file_list.txt"
# print(f"Processing FLAC files in folder: {folder_path_dev}\n")
# label_list_dev, audio_name_list_dev = create_label_array(audio_folder_path=audio_folder_path_dev, label_file_path=label_file_dev)
# sorted_labels_dev = sorted(label_list_dev.items())
# labels_dev= []
# for i in range(0,len(sorted_labels_dev)):
#     labels_dev.append(sorted_labels_dev[i][1])
# labels_dev = np.array(labels_dev)
# # print(len(label_list_val), len(audio_name_list_val),len(feature_vector_val))
# feature_vector_dev = []
# found_flac_files_dev = False
# count_dev = 0
# target_sample_rate_dev=16000
# duration_dev=5.0
# # Check if the folder exists
# if not os.path.isdir(folder_path_dev):
#     print(f"Error: Folder not found at '{folder_path_dev}'")
# freq_dev = {}
# with open(dev_file_sorted_path, "r") as f:
#     lines = f.readlines()
# # Iterate over all entries in the folder
# print("Hey User!!, You are now creating the Development set and their corresponding labels")
# for item_name in lines:
#     # Construct the full path to the item
#     #print(item_name)
#     item_path = str(os.path.join(folder_path_dev, item_name)).strip('\n')
#     #print(item_path)
#     # Check if it's a file and ends with .flac
#     # if os.path.isfile(item_path) and item_name.lower().endswith(".flac"):
#     #     found_flac_files = True
#     # print(item_path.split('/'))
#     # print(item_name.lower())
#     #print(item_name.strip('.flac\n') in str(list(label_list_val.keys())))
#     if item_name.strip('.flac\n') in str(list(label_list_dev.keys())):
#         # audio_data_val, sr_val = librosa.load(item_path, sr=target_sample_rate_val, duration=duration_val)
#         #print("My Name is Mitukk")
#         # print(f"--- Found FLAC file: {item_name} ---")
#         audio_data_dev, sr_dev = librosa.load(item_path, sr=target_sample_rate_dev, duration=duration_dev)
#         cqcc_features = feature_extraction_cqcc(audio_data_dev, sr_dev, n_cqcc=128)
#         if cqcc_features is not None:
#             count_dev = count_dev + 1
#             padded_cqcc = pad_or_truncate_spectrogram(cqcc_features,157,-80)
#             feature_vector_dev.append(padded_cqcc)
#             print(f"Added to the list - Shape of CQCC features: {padded_cqcc.shape}")


# # label_list_val, audio_name_list_val = create_label_array(audio_folder_path=audio_folder_path_val, label_file_path=label_file_val)

# print(len(label_list_dev), len(audio_name_list_dev),len(feature_vector_dev))

In [None]:
# import os
# import librosa
# import numpy as np
# from tqdm import tqdm

# # --- ACHTUNG FOR Pre-Processing for Development ONLY ---

# # Assuming feature_extraction_cqcc, pad_or_truncate_spectrogram, and create_label_array are defined elsewhere

# # --- Configuration ---
# your_folder_with_flac_files_dev = "ASVSpoof19/LA/ASVspoof2019_LA_dev/flac"
# audio_folder_path_dev = "ASVSpoof19/LA/ASVspoof2019_LA_dev/flac"
# label_file_dev = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt"
# folder_path_dev = your_folder_with_flac_files_dev
# dev_file_sorted_path = "dev_file_list.txt"

# # --- Feature Extraction Parameters ---
# target_sample_rate_dev = 16000
# duration_dev = 5.0

# print(f"Processing FLAC files in folder: {folder_path_dev}\n")

# # --- Label Loading and Sorting ---
# label_list_dev, audio_name_list_dev = create_label_array(audio_folder_path=audio_folder_path_dev, label_file_path=label_file_dev)
# sorted_labels_dev = sorted(label_list_dev.items())
# labels_dev = np.array([label for _, label in sorted_labels_dev])
# label_keys_dev = set(label_list_dev.keys()) # Use a set for faster lookups

# # --- Feature Extraction ---
# feature_vector_dev = []

# # Check if the folder exists
# if not os.path.isdir(folder_path_dev):
#     print(f"Error: Folder not found at '{folder_path_dev}'")
# else:
#     with open(dev_file_sorted_path, "r") as f:
#         lines = f.readlines()

#     print("Hey User!!, You are now creating the Development set and their corresponding labels")
#     # Use tqdm for a progress bar over the file list
#     for item_name in tqdm(lines, desc="Processing Dev Files", unit="file"):
#         item_name_stripped = item_name.strip()
#         item_path = os.path.join(folder_path_dev, item_name_stripped)
        
#         # Check if the file's key exists in the loaded labels
#         if item_name_stripped.replace('.flac', '') in label_keys_dev:
#             # Load audio file
#             audio_data_dev, sr_dev = librosa.load(item_path, sr=target_sample_rate_dev, duration=duration_dev)
            
#             # Extract features
#             cqcc_features = feature_extraction_cqcc(audio_data_dev, sr_dev, n_cqcc=128)
            
#             if cqcc_features is not None:
#                 # Pad or truncate features to a fixed size
#                 padded_cqcc = pad_or_truncate_spectrogram(cqcc_features, 157, -80)
#                 feature_vector_dev.append(padded_cqcc)

# # --- Final Summary ---
# print("\n--- Development Set Processing Complete ---")
# print(f"Labels found: {len(label_list_dev)}")
# print(f"Audio files listed in protocol: {len(audio_name_list_dev)}")
# print(f"Feature vectors extracted: {len(feature_vector_dev)}")

# feature_vector_dev = np.array(feature_vector_dev, dtype=np.float32)
# feature_vector_dev.shape
# labels_dev = np.array(labels_dev, dtype = np.int64 )
# labels_dev.shape

In [None]:
# np.save('feature_vector_dev', feature_vector_dev)
# np.save('labels_dev', labels_dev)

In [None]:
# if __name__ == "__main__":
#     your_folder_with_flac_files = "ASVSpoof19/LA/ASVspoof2019_LA_train/flac"
#     audio_folder_path = "ASVSpoof19/LA/ASVspoof2019_LA_train/flac"
#     label_file = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt" 
#     folder_path = your_folder_with_flac_files
#     trn_file_sorted_path= "trn_file_list.txt"
#     print(f"Processing FLAC files in folder: {folder_path}\n")
#     feature_vector = []
#     found_flac_files = False
#     count = 0
#     target_sample_rate=16000
#     duration=5.0
#     # Example usage:
#     # 1. Load files at their original sampling rate and full duration
#     #process_flac_files_in_folder(your_folder_with_flac_files)

#     # 2. Load files, resample to 16000 Hz, and load max 5 seconds
#     #process_flac_files_in_folder(your_folder_with_flac_files, target_sample_rate=16000, duration=5.0)

#     # --- For the code to run, uncomment one of the above examples and set the path ---
#     # --- For demonstration, let's create a dummy folder and file if the path is not set

#     # ACHTUNG !!!!!!!!  BELOW CODE IS FOR MFCC ONLY

#     if your_folder_with_flac_files == "path/to/your/flac_files_folder":
#         print("INFO: 'your_folder_with_flac_files' is not set to a real path.")
#         print("Please update the 'your_folder_with_flac_files' variable in the script.")
#         # Create a dummy folder and a dummy flac file for demonstration if it doesn't exist
#         dummy_folder = "dummy_flac_folder"
#         if not os.path.exists(dummy_folder):
#             os.makedirs(dummy_folder)
#             # Create a very short, silent dummy FLAC file (requires soundfile and numpy)
#             try:
#                 import soundfile as sf
#                 import numpy as np
#                 dummy_file_path = os.path.join(dummy_folder, "dummy_audio.flac")
#                 sf.write(dummy_file_path, np.zeros(16000), 16000, format='FLAC') # 1 sec silence
#                 print(f"Created a dummy folder '{dummy_folder}' with 'dummy_audio.flac' for testing.")
#                 print("Running with the dummy folder:")
#                 process_flac_files_in_folder(dummy_folder, target_sample_rate=16000)
#             except ImportError:
#                 print("Skipping dummy file creation: soundfile or numpy not installed.")
#             except Exception as e:
#                 print(f"Error creating dummy file: {e}")
#     else:
#         # If the user has set a path, use it
#         feature_vector_padded = process_flac_files_in_folder(feature_vector, your_folder_with_flac_files, target_sample_rate=16000, duration=5.0)
#         # print(count)
#         # print(freq)
#         widths = sorted(freq.keys())
#         counts = [freq[w] for w in widths]
#         # print(widths)
#         # print(counts)
#         print(feature_vector_padded.shape)

#     # ACHTUNG !!!!! COMMON PLOT FOR ALL THE FEATURES
#     # Create the bar plot
#     # plt.figure(figsize=(12, 6)) # Adjust figure size as needed
#     # bars = plt.bar(widths, counts, color='skyblue', edgecolor='black', width=1.0) # width=1.0 makes bars touch for histogram feel
#     # # Add labels and title
#     # plt.xlabel("Spectrogram Width (Number of Time Frames)")
#     # plt.ylabel("Frequency (Count)")
#     # plt.title("Time Frame Distribution")
#     # plt.grid(axis='y', linestyle='--', alpha=0.7)
#     # plt.savefig("Freq Distribution.png")
#     plt.show()

In [None]:
# feature_vector_padded= np.array(feature_vector_padded)
# feature_vector_padded.shape

In [29]:
# class AttentionFusion(nn.Module):
#     def __init__(self, cqcc_dim=128, prosodic_dim=32, hidden_dim=64):
#         super().__init__()
#         # Layer to project CQCC features to the hidden dimension
#         self.cqcc_projection = nn.Linear(cqcc_dim, hidden_dim)
#         # Layer to project prosodic features (the context) to the hidden dimension
#         self.prosodic_projection = nn.Linear(prosodic_dim, hidden_dim)
#         # Final layer to compute attention scores
#         self.attention_scores = nn.Linear(hidden_dim, 1)

#     def forward(self, cqcc_features, prosodic_features):
#         # Project both feature sets
#         cqcc_proj = torch.tanh(self.cqcc_projection(cqcc_features))
#         prosodic_proj = torch.tanh(self.prosodic_projection(prosodic_features))

#         # Calculate attention weights based on the combined projected features
#         # The prosodic features act as a query on the cqcc features
#         attention_input = cqcc_proj * prosodic_proj
#         attention_weights = F.softmax(self.attention_scores(attention_input), dim=1)

#         # Apply the attention weights to the original CQCC features
#         refined_cqcc = cqcc_features * attention_weights
        
#         # Fuse by concatenation, now with an attention-refined CQCC vector
#         fused_vector = torch.cat([refined_cqcc, prosodic_features], dim=1)
#         return fused_vector
    

# class AttentionFusionCNN(nn.Module):
#     def __init__(self, input_features, input_time_steps, num_classes=2):
#         super().__init__()
#         # --- CNN Branch for CQCC features (unchanged) ---
#         self.CNN = nn.Sequential(
#             nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3,3), stride=(1,1)),
#             nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2)),
#             nn.Conv2d(16, 32, kernel_size=(3,3), stride=(1,1)),
#             nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2)),
#             nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1,1)),
#             nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2))
#         )
#         # Dynamic shape calculation (unchanged)
#         with torch.no_grad():
#             dummy_input = torch.randn(1, 1, input_features, input_time_steps)
#             dummy_output = self.CNN(dummy_input)
#             self.flattened_size = dummy_output.view(1, -1).size(1)
        
#         self.flatten = nn.Flatten()
        
#         # --- Processing Layers for each branch (unchanged) ---
#         self.LinearLayer = nn.Sequential(
#             nn.Linear(self.flattened_size, 128),
#             nn.ReLU(),
#             nn.Dropout(0.5), # Increased dropout for regularization
#             nn.Linear(128, 128)
#         )
#         self.ProsodicLayer = nn.Sequential(
#             nn.Linear(6, 32),
#             nn.ReLU(),
#             nn.Linear(32, 32)
#         )

#         # --- NEW: Instantiate the Attention Fusion layer ---
#         self.attention_fusion = AttentionFusion()

#         # --- Final Classifier (fusionLayer) ---
#         # The input size is still 128 + 32 = 160, because the attention
#         # layer refines the CQCC vector but doesn't change its dimension.
#         self.fusionLayer = nn.Sequential(
#             nn.Linear(160, 64),
#             nn.ReLU(),
#             nn.Linear(64, num_classes)
#         )

#     def forward(self, x, y):
#         # 1. Process CQCC features through the CNN branch
#         x = x.unsqueeze(1)
#         x = self.CNN(x)
#         flattened_x = self.flatten(x)
#         cqcc_embedding = self.LinearLayer(flattened_x)

#         # 2. Process prosodic features through the MLP branch
#         prosodic_embedding = self.ProsodicLayer(y)

#         # 3. Fuse the embeddings using the Attention mechanism
#         #    This replaces the simple torch.cat()
#         fused_vector = self.attention_fusion(cqcc_embedding, prosodic_embedding)

#         # 4. Pass the fused vector to the final classifier
#         logits = self.fusionLayer(fused_vector)
        
#         return logits
    

#CNN Network Architecture
class CNNnetwork(nn.Module):
    def __init__(self, input_features, input_time_steps, num_classes):
        super().__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3,3), stride=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(16,32, kernel_size=(3,3), stride=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), stride=(1,1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        )
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, input_features, input_time_steps)
            dummy_output = self.CNN(dummy_input)
            self.flattened_size = dummy_output.view(1, -1).size(1)
        self.flatten = nn.Flatten()
        ##USE CROSS ATTENTION TO FUSE DIFFERENT DOMAINS(Prosodic and Highlevel CQCCs)
        self.LinearLayer = nn.Sequential(
            nn.Linear(self.flattened_size,128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 128),
            nn.Linear(128, num_classes)
        )

        self.ProsodicLayer = nn.Sequential(
            nn.Linear(6, 32),
            nn.ReLU(),
            nn.Linear(32,32)
        )

        self.fusionLayer = nn.Sequential(
            nn.Linear(160,64),
            nn.ReLU(),
            nn.Linear(64,num_classes)
        )

    def forward(self, x, y):
        x = x.unsqueeze(1)
        #print(x.shape)
        x = self.CNN(x)
        flattened_x = self.flatten(x)
        output_x = self.LinearLayer(flattened_x)
        y = self.ProsodicLayer(y)
        logits = self.fusionLayer(torch.cat([output_x, y], dim=1))
        return logits
    

In [30]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class AttentionFusion(nn.Module):
#     """
#     A layer that uses attention to fuse two feature vectors (e.g., from an LSTM and an MLP).
#     """
#     def __init__(self, primary_dim, context_dim, hidden_dim=64):
#         super().__init__()
#         # Layers to project the features into a common space
#         self.primary_projection = nn.Linear(primary_dim, hidden_dim)
#         self.context_projection = nn.Linear(context_dim, hidden_dim)

#         # Final layer to compute the attention scores from the combined projection
#         self.attention_scores = nn.Linear(hidden_dim, primary_dim)

#     def forward(self, primary_features, context_features):
#         # primary_features (e.g., CQCC embedding): [batch_size, primary_dim]
#         # context_features (e.g., Prosodic embedding): [batch_size, context_dim]

#         # Project both features into the hidden dimension
#         primary_proj = torch.tanh(self.primary_projection(primary_features))
#         context_proj = torch.tanh(self.context_projection(context_features))

#         # Element-wise product of the projections to model their interaction
#         interaction = primary_proj * context_proj

#         # Compute attention weights from the interaction
#         attention_weights = torch.sigmoid(self.attention_scores(interaction))

#         # Apply the learned attention weights to the original primary features
#         refined_primary = primary_features * attention_weights

#         # Concatenate the attention-refined features with the context features
#         fused_vector = torch.cat([refined_primary, context_features], dim=1)
        
#         return fused_vector
    
# class LSTMAttentionFusionNetwork(nn.Module):
#     def __init__(self, input_features, num_classes, lstm_hidden_size=128, lstm_layers=2):
#         super().__init__()
#         # --- LSTM Branch for CQCC features (Unchanged) ---
#         self.lstm = nn.LSTM(
#             input_size=input_features, # 90
#             hidden_size=lstm_hidden_size,
#             num_layers=lstm_layers,
#             batch_first=True,
#             bidirectional=True,
#             dropout=0.3 if lstm_layers > 1 else 0 # Add dropout between LSTM layers
#         )
#         lstm_output_dim = lstm_hidden_size * 2

#         #  MLP Branch for Prosodic features
#         self.ProsodicLayer = nn.Sequential(
#             nn.Linear(6, 32),
#             nn.ReLU(),
#             nn.Linear(32, 32)
#         )
        
#         # Attention Fusion layer ---
#         self.attention_fusion = AttentionFusion(
#             primary_dim=lstm_output_dim, # 256
#             context_dim=32
#         )
        
       
#         self.fusionLayer = nn.Sequential(
#             nn.Linear(lstm_output_dim + 32, 128),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(128, num_classes)
#         )

#     def forward(self, x, y):
#         # 1. Permute CQCC data for LSTM: (batch, features, time) -> (batch, time, features)
#         x = x.permute(0, 2, 1)

#         # 2. Process CQCC features through LSTM
#         _, (hidden, _) = self.lstm(x)

#         # 3. Create CQCC embedding from the final hidden states of the BiLSTM
#         cqcc_embedding = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)

#         # 4. Process prosodic features
#         prosodic_embedding = self.ProsodicLayer(y)

#         # 5. Fuse the embeddings using the Attention mechanism
#         #    This is the only change in the forward pass!
#         fused_vector = self.attention_fusion(
#             primary_features=cqcc_embedding,
#             context_features=prosodic_embedding
#         )

#         # 6. Pass the attention-fused vector to the final classifier
#         logits = self.fusionLayer(fused_vector)
#         return logits

In [31]:
from scipy.optimize import brentq
from scipy.interpolate import interp1d
def calculate_eer(y_true, y_scores):
    """
    Calculates the Equal Error Rate (EER).

    Args:
        y_true (np.ndarray): True binary labels (0 for spoof, 1 for bona fide).
        y_scores (np.ndarray): Scores for the positive class (bona fide).
                               Higher scores indicate a higher likelihood of being bona fide.
    Returns:
        float: The EER value. Returns NaN if EER cannot be computed.
    """
    if len(np.unique(y_true)) < 2:
        print("Warning: EER calculation requires at least two classes in y_true.")
        return float('nan')
        
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    fnr = 1 - tpr

    # Find the point where abs(fpr - fnr) is minimal
    # This is a common way to approximate EER from ROC curve points
    # eer_index = np.nanargmin(np.abs(fpr - fnr))
    # eer = (fpr[eer_index] + fnr[eer_index]) / 2.0
    
    # More precise EER can be found using interpolation if needed, e.g., with scipy.optimize.brentq
    # from scipy.optimize import brentq
    # from scipy.interpolate import interp1d
    try:
        interp_fn = interp1d(fpr,fnr)
        eer = brentq(lambda x : x  - interp_fn(x), 0.0, 1.0)
    except Exception as e: # Fallback if interpolation fails (e.g. NaNs in fpr/tpr, or no intersection)
        print(f"Warning: brentq EER calculation failed ({e}), using approximation.")
        eer_index = np.nanargmin(np.abs(fpr - fnr))
        eer = (fpr[eer_index] + fnr[eer_index]) / 2.0


    return eer

In [32]:
# prosodicFeatures = pd.read_csv('prosodic_features_train.csv')
# prosodicFeatures = pd.DataFrame(prosodicFeatures)
# prosodicFeatures.tail
# prosodicFeatures = prosodicFeatures.drop(["attack_id","filename"], axis="columns") 
# prosodicFeaturesLabels = torch.Tensor(prosodicFeatures["label"].to_numpy(dtype=np.float32))
# prosodicFeatures = prosodicFeatures.drop(["label"], axis="columns")
# print(prosodicFeatures.head(10))
# prosodicFeatures = prosodicFeatures.to_numpy(dtype=np.float32)
# prosodicFeatures = torch.Tensor(prosodicFeatures)
# print(prosodicFeatures.shape)
# print(prosodicFeaturesLabels.shape)

In [None]:
prosodicFeatures.shape
X_train_tensor.shape

In [33]:
prosodicFeatures_val = pd.read_csv('processed_data/prosodic_features_and_labels_val.csv')
prosodicFeatures_val = pd.DataFrame(prosodicFeatures_val)
prosodicFeatures_val.tail
prosodicFeatures_val = prosodicFeatures_val.drop(["attack_id", "filename"], axis="columns")
prosodicFeaturesLabels_val = torch.Tensor(prosodicFeatures_val["label"].to_numpy(dtype=np.float32))
prosodicFeatures_val = prosodicFeatures_val.drop(["label"], axis="columns")
prosodicFeatures_val = prosodicFeatures_val.to_numpy(dtype=np.float32)
prosodicFeatures_val = torch.Tensor(prosodicFeatures_val)
print(prosodicFeatures_val.shape)
print(prosodicFeaturesLabels_val.shape)

torch.Size([24844, 6])
torch.Size([24844])


In [34]:
prosodicFeatures_test = pd.read_csv('processed_data/prosodic_features_and_labels_test.csv')
prosodicFeatures_test = pd.DataFrame(prosodicFeatures_test)
prosodicFeatures_test.tail
prosodicFeatures_test = prosodicFeatures_test.drop(["attack_id","filename"], axis="columns") 
prosodicFeaturesLabels_test = torch.Tensor(prosodicFeatures_test["label"].to_numpy(dtype=np.float32))
prosodicFeatures_test = prosodicFeatures_test.drop(["label"], axis="columns")
prosodicFeatures_test= prosodicFeatures_test.to_numpy(dtype=np.float32)
prosodicFeatures_test = torch.Tensor(prosodicFeatures_test)
print(prosodicFeatures_test.shape)
print(prosodicFeaturesLabels_test.shape)

torch.Size([71237, 6])
torch.Size([71237])


In [35]:
# import parselmouth
# from parselmouth.praat import call

# def calculate_prosodic_features(audio_path):
#     """
#     Calculates various prosodic features for a given audio file using Parselmouth.

#     Args:
#         audio_path (str): The full path to the audio file.

#     Returns:
#         dict: A dictionary containing the calculated features, or None if processing fails.
#     """
#     try:
#         # Parselmouth can load the file directly
#         sound = parselmouth.Sound(audio_path)

#         # --- F0 (Fundamental Frequency) ---
#         pitch = sound.to_pitch()
#         mean_f0 = call(pitch, "Get mean", 0, 0, "Hertz")
#         std_f0 = call(pitch, "Get standard deviation", 0, 0, "Hertz")

#         # --- Jitter and Shimmer ---
#         point_process = call(pitch, "To PointProcess")
#         jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
#         shimmer = call([sound, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)

#         # --- HNR (Harmonics-to-Noise Ratio) ---
#         harmonicity = call(sound, "To Harmonicity (cc)", 0.01, 75, 0.1, 1.0)
#         mean_hnr = call(harmonicity, "Get mean", 0, 0)
#         std_hnr = call(harmonicity, "Get standard deviation", 0, 0)

#         # Replace NaN values (which Praat returns as '--undefined--') with 0.0
#         features = {
#             'mean_f0': mean_f0,
#             'std_f0': std_f0,
#             'jitter': jitter,
#             'shimmer': shimmer,
#             'mean_hnr': mean_hnr,
#             'std_hnr': std_hnr
#         }
#         for key, value in features.items():
#             if isinstance(value, float) and np.isnan(value):
#                 features[key] = 0.0

#         return features

#     except Exception:
#         # This can happen if Praat fails to process a file (e.g., too short, silent)
#         return None

# def process_and_save_features(bonafide_dirs, spoof_dir, protocol_file, output_csv_path):
#     """
#     Processes audio from bonafide and spoof directories, extracts prosodic features,
#     and saves the results to a single CSV file.

#     Args:
#         bonafide_dirs (list): A list of directories containing all bonafide audio files.
#         spoof_dir (str): The directory containing spoofed audio files.
#         protocol_file (str): Path to the original ASVspoof protocol file to map attack IDs.
#         output_csv_path (str): Path to save the final CSV file.
#     """
#     print("--- Starting Prosodic Feature Extraction ---")

#     # 1. Load protocol file to create a mapping from filename to attack_id
#     try:
#         protocol_df = pd.read_csv(protocol_file, sep=" ", header=None, names=['speaker_id', 'filename', 'attack_type', 'system_id', 'label'])
#         attack_id_map = pd.Series(protocol_df.system_id.values, index=protocol_df.filename).to_dict()
#         print(f"Successfully loaded protocol file for Attack ID mapping.")
#     except FileNotFoundError:
#         print(f"Error: Protocol file not found at {protocol_file}. Cannot map attack IDs.")
#         return

#     # 2. Gather all file paths from the directories
#     all_files_to_process = []
    
#     # Add bonafide files (label = 'bonafide')
#     for directory in bonafide_dirs:
#         try:
#             files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('.flac', '.wav'))]
#             for f_path in files:
#                 all_files_to_process.append({'filepath': f_path, 'label': 'bonafide'})
#         except FileNotFoundError:
#             print(f"Warning: Directory not found: {directory}. Skipping.")

#     # Add spoof files (label = 'spoof')
#     try:
#         files = [os.path.join(spoof_dir, f) for f in os.listdir(spoof_dir) if f.endswith(('.flac', '.wav'))]
#         for f_path in files:
#             all_files_to_process.append({'filepath': f_path, 'label': 'spoof'})
#     except FileNotFoundError:
#         print(f"Warning: Directory not found: {spoof_dir}. Skipping.")

#     print(f"Found {len(all_files_to_process)} total audio files to process.")

#     # 3. Iterate, extract features, and build the results list
#     results_list = []
#     for file_info in tqdm(all_files_to_process, desc="Processing Files"):
#         filepath = file_info['filepath']
#         label = file_info['label']
        
#         prosodic_features = calculate_prosodic_features(filepath)
        
#         if prosodic_features:
#             base_filename = os.path.basename(filepath)
            
#             # Determine the original filename to look up the attack ID
#             if base_filename.startswith('aug_'):
#                 # Extracts original name from augmented files, e.g., 'aug_0_LA_T_12345.flac' -> 'LA_T_12345'
#                 original_filename_key = '_'.join(base_filename.split('_')[2:]).replace('.flac', '')
#             else:
#                 # For original files
#                 original_filename_key = base_filename.replace('.flac', '')
            
#             attack_id = attack_id_map.get(original_filename_key, '-') # Default to '-' if not found
            
#             # Combine all information into one dictionary
#             data_row = {
#                 'filename': base_filename,
#                 'label': label,
#                 'attack_id': attack_id,
#                 **prosodic_features
#             }
#             results_list.append(data_row)

#     # 4. Convert the list of results into a DataFrame and save
#     if not results_list:
#         print("No features were extracted. Please check your audio directories.")
#         return
        
#     final_df = pd.DataFrame(results_list)
#     os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
#     final_df.to_csv(output_csv_path, index=False)
    
#     print(f"\n--- Feature Extraction Complete ---")
#     print(f"Processed {len(final_df)} files successfully.")
#     print(f"Prosodic features saved to: {output_csv_path}")

# # --- HOW TO RUN THE SCRIPT ---

# # ==> IMPORTANT: Update these paths to match your folder structure!

# # Path to the original protocol file (used for mapping attack IDs).
# PROTOCOL_FILE = "ASVSpoof19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"

# # A list of all directories containing genuine (bonafide) audio.
# BONAFIDE_DIRS = [
#     "bonafide_audio_train",            # Original bonafide files
#     "augmented_bonafide"   # Augmented bonafide files
# ]

# # The directory containing all spoofed audio files.
# SPOOF_AUDIO_DIR = "spoof_audio_train"

# # Path for the final output CSV file.
# OUTPUT_CSV = "processed_data/prosodic_features_train_augmented.csv"

# process_and_save_features(
#     bonafide_dirs=BONAFIDE_DIRS,
#     spoof_dir=SPOOF_AUDIO_DIR,
#     protocol_file=PROTOCOL_FILE,
#     output_csv_path=OUTPUT_CSV
# )

In [36]:
def calculate_eer(y_true, y_score):
    """
    Calculates the Equal Error Rate (EER).

    Args:
        y_true (np.array): True labels (0 or 1).
        y_score (np.array): Prediction scores or probabilities for the positive class.

    Returns:
        float: The EER value.
    """
    # Ensure y_true and y_score are numpy arrays
    y_true = np.asarray(y_true)
    y_score = np.asarray(y_score)

    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    
    # The EER is the point on the ROC curve where the false acceptance rate (FAR)
    # and false rejection rate (FRR) are equal. FAR is just FPR. FRR is 1 - TPR.
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    
    return eer

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, num_classes):
    """
    Trains a PyTorch model and calculates EER on the validation set.
    """
    print(f"Training on {device}...")
    model.to(device)

    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': [],
        'val_eer': []
    }

    best_val_loss = float('inf')
    best_val_eer = float('inf')
    best_f1 = 0
    best_epoch = 0

    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train()
        running_train_loss = 0.0
        correct_train_predictions = 0
        total_train_samples = 0
        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)

        for cqcc, prosodic, labels in train_progress_bar:
            cqcc = cqcc.to(device)
            prosodic = prosodic.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(cqcc, prosodic)  # Logits
            loss = criterion(outputs, labels) ##for Cross Entropy
            # loss = criterion(outputs, labels.unsqueeze(1).float()) ## for BCE
            
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() * cqcc.size(0)
            
            if num_classes == 1:  # Binary with BCEWithLogitsLoss
                predicted = (torch.sigmoid(outputs) > 0.3).squeeze().long()
            else:  # Multi-class or binary with CrossEntropyLoss
                _, predicted = torch.max(outputs.data, 1)
                
            total_train_samples += labels.size(0)
            correct_train_predictions += (predicted == labels.squeeze()).sum().item()
            train_progress_bar.set_postfix(loss=loss.item())

        epoch_train_loss = running_train_loss / total_train_samples if total_train_samples > 0 else 0
        epoch_train_acc = correct_train_predictions / total_train_samples if total_train_samples > 0 else 0

        # --- Validation Phase ---
        model.eval()
        running_val_loss = 0.0
        correct_val_predictions = 0
        total_val_samples = 0
        all_val_labels = []
        all_val_scores = []  # Scores for the positive class (bona fide)
        all_val_preds = []
        val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)

        with torch.no_grad():
            for cqcc, prosodic, labels in val_progress_bar:
                cqcc = cqcc.to(device)
                prosodic = prosodic.to(device)
                labels = labels.to(device)
                
                outputs = model(cqcc, prosodic)  # Logits
                loss = criterion(outputs, labels) # for CrossEntropy
                # loss = criterion(outputs, labels.unsqueeze(1).float()) ##for BCE
                running_val_loss += loss.item() * cqcc.size(0)

                if num_classes == 1:  # Binary with BCEWithLogitsLoss
                    scores = torch.sigmoid(outputs).squeeze() # Probabilities for bona fide
                    predicted = (scores > 0.3).long()
                else:  # Multi-class or binary with CrossEntropyLoss
                    # Assuming class 1 is 'bona fide' (positive class for EER)
                    # If your 'bona fide' class is 0, use outputs_probs[:, 0]
                    scores = F.softmax(outputs, dim=1)[:, 1] # Probabilities for class 1
                    _, predicted = torch.max(outputs.data, 1)
                
                total_val_samples += labels.size(0)
                correct_val_predictions += (predicted == labels.squeeze()).sum().item()

                all_val_preds.extend(predicted.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
                all_val_scores.extend(scores.cpu().numpy())

                val_progress_bar.set_postfix(loss=loss.item())

        epoch_val_loss = running_val_loss / total_val_samples if total_val_samples > 0 else 0
        epoch_val_acc = correct_val_predictions / total_val_samples if total_val_samples > 0 else 0
        
        # Calculate EER
        val_eer = float('nan')
        # EER can only be calculated if there are both positive and negative samples
        if len(np.unique(all_val_labels)) > 1:
            val_eer = calculate_eer(all_val_labels, all_val_scores)
        
        # --- Epoch End ---
        
        # Log metrics
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_acc'].append(epoch_val_acc)
        history['val_eer'].append(val_eer)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} EER: {val_eer:.4f}"
              )
        
        # Print confusion matrix for the epoch
        cm = confusion_matrix(all_val_labels, all_val_preds)
        tn, fp, fn, tp= cm.ravel()
        f1score = f1_score(all_val_labels, all_val_preds)
        print(f"F1 Score: {f1score}")

        print(f"Confusion Matrix for Epoch {epoch+1}:\n{cm}", tp,tn,fp,fn)

        # Save the best model based on validation loss
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_epoch = epoch + 1
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Epoch {epoch+1}: New best model saved with Val Loss: {best_val_loss:.4f}\n")

        # Step the scheduler
        if scheduler:
            scheduler.step(epoch_val_loss)

    print(f"\nFinished Training. Best model was from Epoch {best_epoch} with Val Loss: {best_val_loss:.4f}")
    
    return pd.DataFrame(history)

In [37]:
import matplotlib.pyplot as plt

def plot_metrics(history):
    """
    Plots the training and validation loss, accuracy, and EER.

    Args:
        history (pd.DataFrame): A DataFrame containing the metrics from train_model.
    """
    fig, ax1 = plt.subplots(figsize=(12, 8))

    # --- Plotting Loss on the primary y-axis (ax1) ---
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color='tab:blue')
    ax1.plot(history.index + 1, history['train_loss'], 'b-', label='Train Loss')
    ax1.plot(history.index + 1, history['val_loss'], 'b--', label='Validation Loss')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    # --- Creating a secondary y-axis for EER and Accuracy ---
    ax2 = ax1.twinx()
    ax2.set_ylabel('EER / Accuracy', color='tab:red')
    # Plot EER
    ax2.plot(history.index + 1, history['val_eer'], 'r-', label='Validation EER')
    # Plot Accuracy
    ax2.plot(history.index + 1, history['val_acc'], 'g--', label='Validation Accuracy')
    ax2.tick_params(axis='y', labelcolor='tab:red')
    # Set EER/Accuracy y-axis limits, e.g., from 0 to 1 if they are rates
    ax2.set_ylim(0, 1.05)


    # --- Final Touches ---
    fig.tight_layout()
    plt.title('Training and Validation Metrics per Epoch')
    
    # Combine legends from both axes
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='best')

    plt.show()

In [38]:
import torch
from torch.utils.data import Dataset, DataLoader

class FusionDataset(Dataset):
    """Custom Dataset for loading CQCC, Prosodic features, and labels."""
    def __init__(self, cqcc_data, prosodic_data, labels):
        # Ensure data is in torch.Tensor format
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosodic_data = torch.tensor(prosodic_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        # Return the total number of samples
        return len(self.labels)

    def __getitem__(self, idx):
        # Retrieve one sample from the dataset
        cqcc_sample = self.cqcc_data[idx]
        prosodic_sample = self.prosodic_data[idx]
        label_sample = self.labels[idx]
        
        # The dataloader will stack these into a batch
        return cqcc_sample, prosodic_sample, label_sample

In [39]:
torch.tensor(np.load("processed_data/cqcc_features_val.npy")).shape

torch.Size([24844, 90, 157])

In [40]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# X_train_tensor = torch.tensor(np.load("processed_data/cqcc_features.npy"))
# y_train_tensor = prosodicFeaturesLabels
# X_val_tensor = torch.tensor(np.load("processed_data/cqcc_features_val.npy"))
# y_val_tensor = prosodicFeaturesLabels_val
# print(X_train_tensor.shape, y_train_tensor.shape)

# y_train_tensor.unique(return_counts=True)

In [44]:
import torch
import torch.nn as nn

class CNNnetwork(nn.Module):
    def __init__(self, input_features, input_time_steps, num_classes):
        super().__init__()
        # --- CNN Branch for CQCC Features ---
        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3,3), stride=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(16,32, kernel_size=(3,3), stride=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), stride=(1,1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # --- Dynamically calculate the flattened size ---
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, input_features, input_time_steps)
            dummy_output = self.CNN(dummy_input)
            self.flattened_size = dummy_output.view(1, -1).size(1)
        self.flatten = nn.Flatten()
        
        # --- Embedding Layer for the CNN Branch ---
        # This layer now correctly outputs a 128-dimensional embedding.
        # The final classification layer was removed from here.
        self.CQCC_embedding_layer = nn.Sequential(
            nn.Linear(self.flattened_size, 128),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # --- Embedding Layer for the Prosodic Branch ---
        self.ProsodicLayer = nn.Sequential(
            nn.Linear(6, 32),
            nn.ReLU(),
            nn.Linear(32,32)
        )

        # --- Fusion and Classification Layer ---
        # This layer takes the concatenated embeddings (128 + 32 = 160)
        # and performs the final classification.
        self.fusionLayer = nn.Sequential(
            nn.Linear(128 + 32, 64), # Input is 160
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, cqcc_input, prosodic_input):
        # Process the CQCC features through the CNN branch
        # Input shape: (batch_size, features, time_steps)
        x = cqcc_input.unsqueeze(1) # Add channel dimension -> (batch_size, 1, features, time_steps)
        x = self.CNN(x)
        flattened_x = self.flatten(x)
        
        # Get the 128-dimensional embedding for CQCC
        cqcc_embedding = self.CQCC_embedding_layer(flattened_x) # Shape: (batch_size, 128)
        
        # Process the prosodic features through its branch
        # Get the 32-dimensional embedding for prosodic features
        prosodic_embedding = self.ProsodicLayer(prosodic_input) # Shape: (batch_size, 32)
        
        # Concatenate the embeddings from both branches
        concatenated_features = torch.cat([cqcc_embedding, prosodic_embedding], dim=1) # Shape: (batch_size, 160)
        
        # Pass the fused features to the final classification layer
        logits = self.fusionLayer(concatenated_features)
        
        return logits

In [45]:
prosodicFeatures_train

tensor([[1.2970e+02, 1.9315e+01, 1.9433e-02, 1.1098e-01, 9.4594e+00, 6.0118e+00],
        [1.5117e+02, 1.8884e+01, 1.3721e-02, 7.7546e-02, 1.5424e+01, 8.8525e+00],
        [2.4093e+02, 1.4894e+01, 7.4816e-03, 1.1253e-01, 1.3760e+01, 6.7782e+00],
        ...,
        [1.0943e+02, 2.1005e+01, 2.0272e-02, 1.6705e-01, 5.6863e+00, 5.1124e+00],
        [1.8079e+02, 1.7326e+01, 8.5521e-03, 9.7119e-02, 1.5088e+01, 9.3377e+00],
        [1.2182e+02, 1.3714e+01, 1.4913e-02, 1.4035e-01, 1.3382e+01, 7.0352e+00]])

In [46]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train_tensor = torch.tensor(np.load("cqcc_features_aligned.npy"))
y_train_tensor = torch.tensor(np.load("labels_aligned.npy"))#prosodicFeaturesLabels
X_val_tensor = torch.tensor(np.load("processed_data/cqcc_features_val.npy"))
y_val_tensor = prosodicFeaturesLabels_val
X_test_tensor = torch.tensor(np.load("processed_data/cqcc_features_test.npy"))
y_test_tensor = prosodicFeaturesLabels_test
print(X_train_tensor.shape, y_train_tensor.shape)
prosodicFeatures_train = torch.tensor(np.load("prosody_features_aligned.npy"))
train_dataset = FusionDataset(X_train_tensor, prosodicFeatures_train, y_train_tensor)
val_dataset = FusionDataset(X_val_tensor, prosodicFeatures_val, y_val_tensor)
test_dataset = FusionDataset(X_test_tensor, prosodicFeatures_test, y_test_tensor)
# --- 4. Create TensorDatasets and DataLoaders ---
# train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
# prosodic_train_set = TensorDataset(prosodicFeatures, y_train_tensor)
# val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
# prosodic_val_set = TensorDataset(prosodicFeatures_val, y_train_tensor)



train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128
, shuffle=False) # No need to shuffle validation
test_loader = DataLoader(test_dataset, batch_size=32
, shuffle=False) # No need to shuffle validation

# --- 5. Initialize Model, Criterion (Loss Function), and Optimizer ---
# model = LSTMAttentionFusionNetwork(input_features=90, num_classes=2)
# model =AttentionFusionCNN(input_features=90,input_time_steps=157, num_classes=2)
model =CNNnetwork(input_features=90,input_time_steps=157, num_classes=2)
# 1. Calculate the number of samples in each class from the training data
# Ensure your labels are integers (0 and 1) for bincount
# class_counts = torch.bincount(y_train_tensor.long())
# count_class_0 = class_counts[0]
# count_class_1 = class_counts[1]

# print(f"Training data contains {count_class_0} samples of class 0 (negative)")
# print(f"Training data contains {count_class_1} samples of class 1 (positive)")

# 2. Calculate the pos_weight
# Handle the case where a class might be missing to avoid division by zero
# if count_class_1 > 0:
#     pos_weight_value = count_class_0.float() / count_class_1.float()
# else:
#     pos_weight_value = 1.0 # Default value if no positive samples

# print(f"Calculated pos_weight for BCEWithLogitsLoss: {pos_weight_value:.4f}")

# # 3. Create the weight tensor and move it to the correct device
# pos_weight = torch.tensor([pos_weight_value], device=device)


# num_samples = len(y_train_tensor)
# num_classes = 2
# class_counts = torch.bincount(y_train_tensor)

# weights = num_samples / (num_classes * class_counts.float())
# weights = weights.to(device) # Move to the correct device

# For multi-class (or binary if NUM_CLASSES=2 and labels are 0, 1), CrossEntropyLoss is common.
# It expects raw logits from the model and integer labels.
criterion = nn.CrossEntropyLoss()

# If NUM_CLASSES = 1 (binary with sigmoid output from model), you would use:
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) # if model outputs raw logits
#criterion = nn.BCELoss() # if model outputs sigmoid probabilities
# And your y_labels_np should be float32 and possibly reshaped for BCE losses.
# For this example, we stick to NUM_CLASSES=2 and CrossEntropyLoss.

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

scheduler = ReduceLROnPlateau(
    optimizer,
    mode = 'min',
    factor = 0.1,
    patience = 3,
    verbose = True
)

# --- 6. Start Training ---
history_df = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 50, device, num_classes=2)



if not history_df.empty:
    plot_metrics(history_df)

torch.Size([25379, 90, 157]) torch.Size([25379])


  self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
  self.prosodic_data = torch.tensor(prosodic_data, dtype=torch.float32)
  self.labels = torch.tensor(labels, dtype=torch.long)


Training on cuda...


                                                                                   

Epoch 1/50 | Train Loss: 0.2765 Acc: 0.8942 | Val Loss: 4.0315 Acc: 0.0865 EER: 0.8862
F1 Score: 0.0069137531177525914
Confusion Matrix for Epoch 1:
[[ 2070   478]
 [22217    79]] 79 2070 478 22217
Epoch 1: New best model saved with Val Loss: 4.0315



                                                                                   

Epoch 2/50 | Train Loss: 0.1415 Acc: 0.9413 | Val Loss: 6.7068 Acc: 0.0581 EER: 0.9245
F1 Score: 0.0067908832392513056
Confusion Matrix for Epoch 2:
[[ 1363  1185]
 [22216    80]] 80 1363 1185 22216


                                                                                   

Epoch 3/50 | Train Loss: 0.0967 Acc: 0.9636 | Val Loss: 8.2029 Acc: 0.0330 EER: 0.9439
F1 Score: 0.010135970333745366
Confusion Matrix for Epoch 3:
[[  697  1851]
 [22173   123]] 123 697 1851 22173


                                                                                    

Epoch 4/50 | Train Loss: 0.0749 Acc: 0.9723 | Val Loss: 9.0178 Acc: 0.0378 EER: 0.9533
F1 Score: 0.003584827011254689
Confusion Matrix for Epoch 4:
[[  897  1651]
 [22253    43]] 43 897 1651 22253


                                                                                    

Epoch 5/50 | Train Loss: 0.0578 Acc: 0.9801 | Val Loss: 9.2888 Acc: 0.0291 EER: 0.9541
F1 Score: 0.004868388480897763
Confusion Matrix for Epoch 5:
[[  665  1883]
 [22237    59]] 59 665 1883 22237


                                                                                    

Epoch 6/50 | Train Loss: 0.0386 Acc: 0.9876 | Val Loss: 9.5558 Acc: 0.0222 EER: 0.9557
F1 Score: 0.0072742133224356345
Confusion Matrix for Epoch 6:
[[  463  2085]
 [22207    89]] 89 463 2085 22207


                                                                                    

Epoch 7/50 | Train Loss: 0.0361 Acc: 0.9891 | Val Loss: 9.7196 Acc: 0.0226 EER: 0.9571
F1 Score: 0.007195715278629543
Confusion Matrix for Epoch 7:
[[  473  2075]
 [22208    88]] 88 473 2075 22208


                                                                                    

Epoch 8/50 | Train Loss: 0.0338 Acc: 0.9904 | Val Loss: 10.2573 Acc: 0.0246 EER: 0.9560
F1 Score: 0.0046005339905524755
Confusion Matrix for Epoch 8:
[[  555  1993]
 [22240    56]] 56 555 1993 22240


                                                                                    

Epoch 9/50 | Train Loss: 0.0325 Acc: 0.9903 | Val Loss: 9.8572 Acc: 0.0216 EER: 0.9586
F1 Score: 0.007917721002367154
Confusion Matrix for Epoch 9:
[[  439  2109]
 [22199    97]] 97 439 2109 22199


                                                                                     

Epoch 10/50 | Train Loss: 0.0297 Acc: 0.9910 | Val Loss: 10.3319 Acc: 0.0241 EER: 0.9578
F1 Score: 0.005006565988181221
Confusion Matrix for Epoch 10:
[[  537  2011]
 [22235    61]] 61 537 2011 22235


                                                                                     

KeyboardInterrupt: 

In [None]:
def test_model(model, test_loader, criterion, device):
    print(f"\nTesting on {device}...")
    model.to(device)
    model.eval() # Set the model to evaluation mode

    running_test_loss = 0.0
    correct_test_predictions = 0
    total_test_samples = 0
    all_test_labels = []
    all_test_scores = []
    all_test_preds = []
    test_progress_bar = tqdm(test_loader, desc="[Test]", leave=True)

    with torch.no_grad():
        for cqcc, prosodic, labels in test_progress_bar:
            cqcc, prosodic, labels = cqcc.to(device), prosodic.to(device), labels.to(device)
            
            outputs = model(cqcc, prosodic)
            loss = criterion(outputs, labels)
            running_test_loss += loss.item() * cqcc.size(0)

            scores = F.softmax(outputs, dim=1)[:, 1] # Probability for class 1
            _, predicted = torch.max(outputs.data, 1)
            
            total_test_samples += labels.size(0)
            correct_test_predictions += (predicted == labels).sum().item()

            all_test_preds.extend(predicted.cpu().numpy())
            all_test_labels.extend(labels.cpu().numpy())
            all_test_scores.extend(scores.cpu().numpy())

    # --- Calculate final metrics ---
    final_test_loss = running_test_loss / total_test_samples if total_test_samples > 0 else 0
    final_test_acc = correct_test_predictions / total_test_samples if total_test_samples > 0 else 0
    final_f1 = f1_score(all_test_labels, all_test_preds)
    final_eer = float('nan')
    if len(np.unique(all_test_labels)) > 1:
        final_eer = calculate_eer(all_test_labels, all_test_scores)
    
    cm = confusion_matrix(all_test_labels, all_test_preds)

    print("\n--- Test Results ---")
    print(f"Test Loss: {final_test_loss:.4f}")
    print(f"Test Accuracy: {final_test_acc:.4f}")
    print(f"Test F1 Score: {final_f1:.4f}")
    print(f"Test EER: {final_eer:.4f}")
    print(f"Confusion Matrix:\n{cm}")
    print("--------------------")

test_model_instance = CNNnetwork(input_features=90,input_time_steps=157, num_classes=2)
criterion = nn.CrossEntropyLoss()       
# Load the saved model weights
test_model_instance.load_state_dict(torch.load('best_model.pth', map_location=device))
test_model(test_model_instance, test_loader, criterion, device)

In [None]:
print(X_train_tensor[0], y_train_tensor[0])

In [None]:
def plot_confusion_matrix(cm, class_names, epoch, is_final=False):
    """
    Renders a confusion matrix using Seaborn's heatmap.

    Args:
        cm (numpy.ndarray): The confusion matrix.
        class_names (list): A list of class names for the labels.
        epoch (int): The current epoch number for the title.
        is_final (bool): If True, displays the plot. Otherwise, saves it.
    """
    figure = plt.figure(figsize=(8, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap=plt.cm.Blues, xticklabels=class_names, yticklabels=class_names)
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    
    if is_final:
        plt.title(f"Final Confusion Matrix (from Best Epoch {epoch})")
        print("Displaying final confusion matrix plot...")
        plt.show()
    else:
        # Avoid showing plots for every epoch during training
        # Instead, save them to a file if needed
        filename = f"confusion_matrix_epoch_{epoch}.png"
        plt.title(f"Confusion Matrix (Epoch {epoch})")
        plt.savefig(filename)
        print(f"Saved confusion matrix for epoch {epoch} to {filename}")
    plt.close()

In [None]:
# import torch
# import torch.nn as nn
# import math

# # The PositionalEncoding class remains the same as before
# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
#         super().__init__()
#         self.dropout = nn.Dropout(p=dropout)
#         position = torch.arange(max_len).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
#         pe = torch.zeros(max_len, 1, d_model)
#         pe[:, 0, 0::2] = torch.sin(position * div_term)
#         pe[:, 0, 1::2] = torch.cos(position * div_term)
#         self.register_buffer('pe', pe)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         x = x + self.pe[:x.size(0)]
#         return self.dropout(x)
# def calculate_eer_from_scores(labels, scores):
#     """Calculates EER from ground truth labels and classification scores."""
#     if not isinstance(labels, np.ndarray): labels = np.array(labels)
#     if not isinstance(scores, np.ndarray): scores = np.array(scores)

#     # Scores for bona fide (positive) and spoof (negative) trials
#     bona_fide_scores = scores[labels == 1]
#     spoof_scores = scores[labels == 0]

#     if len(bona_fide_scores) == 0 or len(spoof_scores) == 0:
#         return float('nan')

#     # Calculate FAR and FRR at different thresholds
#     min_score = min(np.min(bona_fide_scores), np.min(spoof_scores))
#     max_score = max(np.max(bona_fide_scores), np.max(spoof_scores))
#     thresholds = np.linspace(min_score, max_score, 1000)
    
#     far, frr = [], [] # False Acceptance Rate, False Rejection Rate
#     for t in thresholds:
#         far.append(np.sum(spoof_scores > t) / len(spoof_scores))
#         frr.append(np.sum(bona_fide_scores <= t) / len(bona_fide_scores))
    
#     far, frr = np.array(far), np.array(frr)

#     # Find the EER using interpolation
#     try:
#         eer_threshold = brentq(lambda x: interp1d(thresholds, far - frr)(x), min_score, max_score)
#         eer = interp1d(thresholds, far)(eer_threshold)
#     except (ValueError, RuntimeError):
#         eer_index = np.nanargmin(np.abs(far - frr))
#         eer = (far[eer_index] + frr[eer_index]) / 2.0
#     return eer

# # New model that handles both CQCC and Prosodic features
# class MultiModalClassifierTransformer(nn.Module):
#     # Model definition from previous response...
#     def __init__(self, cqcc_input_dim: int, prosodic_input_dim: int, num_classes: int,
#                  embedding_dim: int = 256, d_model: int = 256, nhead: int = 8, 
#                  num_encoder_layers: int = 6, dim_feedforward: int = 1024, dropout: float = 0.3):
#         super().__init__()
#         self.d_model = d_model
#         self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
#         self.cqcc_input_projection = nn.Linear(cqcc_input_dim, d_model)
#         self.pos_encoder = PositionalEncoding(d_model, dropout)
#         encoder_layer = nn.TransformerEncoderLayer(
#             d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
#             dropout=dropout, batch_first=True
#         )
#         self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
#         prosody_embedding_dim = 64
#         self.prosody_encoder = nn.Sequential(
#             nn.Linear(prosodic_input_dim, 128), nn.ReLU(),
#             nn.Linear(128, prosody_embedding_dim)
#         )
#         self.fusion_head = nn.Sequential(
#             nn.LayerNorm(d_model + prosody_embedding_dim),
#             nn.Linear(d_model + prosody_embedding_dim, embedding_dim)
#         )
#         self.classifier_head = nn.Linear(embedding_dim, num_classes)
        
#     def forward(self, cqcc_src: torch.Tensor, prosodic_src: torch.Tensor, return_embedding=False):
#         # *** FIX: Automatically correct CQCC input shape if necessary ***
#         # The model expects [batch, seq_len, features]. If the input is
#         # [batch, features, seq_len], we transpose it.
#         if cqcc_src.shape[-1] != self.cqcc_input_projection.in_features:
#             cqcc_src = cqcc_src.transpose(1, 2)

#         cqcc_proj = self.cqcc_input_projection(cqcc_src) * math.sqrt(self.d_model)
#         batch_size = cqcc_src.shape[0]
#         cls_tokens = self.cls_token.expand(batch_size, -1, -1)
#         cqcc_with_cls = torch.cat([cls_tokens, cqcc_proj], dim=1)
#         cqcc_with_cls = self.pos_encoder(cqcc_with_cls.permute(1, 0, 2)).permute(1, 0, 2)
#         transformer_output = self.transformer_encoder(cqcc_with_cls)
#         cqcc_embedding = transformer_output[:, 0]
#         prosody_embedding = self.prosody_encoder(prosodic_src)
#         combined_embedding = torch.cat([cqcc_embedding, prosody_embedding], dim=1)
#         final_embedding = self.fusion_head(combined_embedding)
#         logits = self.classifier_head(final_embedding)
#         if return_embedding:
#             return logits, final_embedding
#         return logits


# # def train_classifier_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, num_classes):
# #     # Training loop definition from previous response...
# #     print(f"Training on {device}...")
# #     model.to(device)
# #     history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_eer': []}
# #     best_val_loss = float('inf')
# #     best_epoch = 0
# #     for epoch in range(num_epochs):
# #         model.train()
# #         running_train_loss, correct_train_predictions, total_train_samples = 0.0, 0, 0
# #         train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
# #         for cqcc, prosodic, labels in train_progress_bar:
# #             cqcc, prosodic, labels = cqcc.to(device), prosodic.to(device), labels.to(device)
# #             optimizer.zero_grad()
# #             outputs = model(cqcc, prosodic)
# #             if num_classes == 1: outputs = outputs.squeeze(1)
# #             loss = criterion(outputs, labels.float() if num_classes == 1 else labels)
# #             loss.backward()
# #             optimizer.step()
# #             running_train_loss += loss.item() * cqcc.size(0)
# #             if num_classes == 1: predicted = (torch.sigmoid(outputs) > 0.5).long()
# #             else: _, predicted = torch.max(outputs.data, 1)
# #             total_train_samples += labels.size(0)
# #             correct_train_predictions += (predicted == labels).sum().item()
# #             train_progress_bar.set_postfix(loss=f"{loss.item():.4f}")
# #         epoch_train_loss = running_train_loss / total_train_samples
# #         epoch_train_acc = correct_train_predictions / total_train_samples
# #         history['train_loss'].append(epoch_train_loss)
# #         history['train_acc'].append(epoch_train_acc)
# #         model.eval()
# #         running_val_loss, correct_val_predictions, total_val_samples = 0.0, 0, 0
# #         all_val_labels, all_val_scores = [], []
# #         val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
# #         with torch.no_grad():
# #             for cqcc, prosodic, labels in val_progress_bar:
# #                 cqcc, prosodic, labels = cqcc.to(device), prosodic.to(device), labels.to(device)
# #                 outputs = model(cqcc, prosodic)
# #                 if num_classes == 1: outputs = outputs.squeeze(1)
# #                 loss = criterion(outputs, labels.float() if num_classes == 1 else labels)
# #                 running_val_loss += loss.item() * cqcc.size(0)
# #                 if num_classes == 1:
# #                     scores = torch.sigmoid(outputs)
# #                     predicted = (scores > 0.5).long()
# #                 else:
# #                     scores = F.softmax(outputs, dim=1)[:, 1]
# #                     _, predicted = torch.max(outputs.data, 1)
# #                 total_val_samples += labels.size(0)
# #                 correct_val_predictions += (predicted == labels).sum().item()
# #                 all_val_labels.extend(labels.cpu().numpy())
# #                 all_val_scores.extend(scores.cpu().numpy())
# #         epoch_val_loss = running_val_loss / total_val_samples
# #         epoch_val_acc = correct_val_predictions / total_val_samples
# #         val_eer = calculate_eer_from_scores(all_val_labels, all_val_scores)
# #         history['val_loss'].append(epoch_val_loss)
# #         history['val_acc'].append(epoch_val_acc)
# #         history['val_eer'].append(val_eer)
# #         print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} EER: {val_eer:.4f}")
# #         if epoch_val_loss < best_val_loss:
# #             best_val_loss = epoch_val_loss
# #             best_epoch = epoch + 1
# #             torch.save(model.state_dict(), 'best_classifier_model.pth')
# #             print(f"Epoch {epoch+1}: New best model saved with Val Loss: {best_val_loss:.4f}")
# #         if scheduler:
# #             scheduler.step(epoch_val_loss)
# #     print(f"\nFinished Training. Best model from Epoch {best_epoch} with Val Loss: {best_val_loss:.4f}")
# #     return pd.DataFrame(history)

# def train_classifier_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, num_classes):
#     """
#     Trains a classifier model and evaluates it, including confusion matrix calculation.
#     """
#     print(f"Training on {device}...")
#     model.to(device)
#     history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_eer': []}
#     best_val_loss = float('inf')
#     best_epoch = 0

#     # Define class names for the confusion matrix plot
#     if num_classes == 1:
#         class_names = ['bona fide', 'spoof'] # Assuming binary classification for anti-spoofing
#     else:
#         class_names = [str(i) for i in range(num_classes)]

#     for epoch in range(num_epochs):
#         model.train()
#         running_train_loss, correct_train_predictions, total_train_samples = 0.0, 0, 0
#         train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
#         for cqcc, prosodic, labels in train_progress_bar:
#             cqcc, prosodic, labels = cqcc.to(device), prosodic.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(cqcc, prosodic)
#             if num_classes == 1: outputs = outputs.squeeze(1)
#             loss = criterion(outputs, labels.float() if num_classes == 1 else labels)
#             loss.backward()
#             optimizer.step()
#             running_train_loss += loss.item() * cqcc.size(0)
#             if num_classes == 1: predicted = (torch.sigmoid(outputs) > 0.5).long()
#             else: _, predicted = torch.max(outputs.data, 1)
#             total_train_samples += labels.size(0)
#             correct_train_predictions += (predicted == labels).sum().item()
#             train_progress_bar.set_postfix(loss=f"{loss.item():.4f}")
            
#         epoch_train_loss = running_train_loss / total_train_samples
#         epoch_train_acc = correct_train_predictions / total_train_samples
#         history['train_loss'].append(epoch_train_loss)
#         history['train_acc'].append(epoch_train_acc)

#         # --- Validation Phase ---
#         model.eval()
#         running_val_loss, correct_val_predictions, total_val_samples = 0.0, 0, 0
#         all_val_labels, all_val_scores = [], []
        
#         # --- ADDED: List to store all validation predictions for the confusion matrix ---
#         all_val_preds = []

#         val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
#         with torch.no_grad():
#             for cqcc, prosodic, labels in val_progress_bar:
#                 cqcc, prosodic, labels = cqcc.to(device), prosodic.to(device), labels.to(device)
#                 outputs = model(cqcc, prosodic)
#                 if num_classes == 1: outputs = outputs.squeeze(1)
#                 loss = criterion(outputs, labels.float() if num_classes == 1 else labels)
#                 running_val_loss += loss.item() * cqcc.size(0)

#                 if num_classes == 1:
#                     scores = torch.sigmoid(outputs)
#                     predicted = (scores > 0.5).long()
#                 else:
#                     scores = F.softmax(outputs, dim=1)[:, 1] # Note: This assumes class 1 is the positive class for EER
#                     _, predicted = torch.max(outputs.data, 1)

#                 total_val_samples += labels.size(0)
#                 correct_val_predictions += (predicted == labels).sum().item()
#                 all_val_labels.extend(labels.cpu().numpy())
#                 all_val_scores.extend(scores.cpu().numpy())
                
#                 # --- ADDED: Store predictions for the confusion matrix ---
#                 all_val_preds.extend(predicted.cpu().numpy())
                
#         epoch_val_loss = running_val_loss / total_val_samples
#         epoch_val_acc = correct_val_predictions / total_val_samples
#         val_eer = calculate_eer_from_scores(all_val_labels, all_val_scores)
#         history['val_loss'].append(epoch_val_loss)
#         history['val_acc'].append(epoch_val_acc)
#         history['val_eer'].append(val_eer)

#         # --- ADDED: Calculate and print the confusion matrix for the current epoch ---
#         cm = confusion_matrix(all_val_labels, all_val_preds)
#         print(f"\nEpoch {epoch+1}/{num_epochs} | Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} EER: {val_eer:.4f}")
#         print(f"Confusion Matrix for Epoch {epoch+1}:\n{cm}")

#         if epoch_val_loss < best_val_loss:
#             best_val_loss = epoch_val_loss
#             best_epoch = epoch + 1
#             torch.save(model.state_dict(), 'best_classifier_model.pth')
#             print(f"Epoch {epoch+1}: New best model saved with Val Loss: {best_val_loss:.4f}")
            
#         if scheduler:
#             scheduler.step(epoch_val_loss)

#     print(f"\nFinished Training. Best model from Epoch {best_epoch} with Val Loss: {best_val_loss:.4f}")

#     # --- ADDED: Final Evaluation and Confusion Matrix Plot for the Best Model ---
#     print("\n--- Generating Final Report for the Best Model ---")
#     best_model = model
#     best_model.load_state_dict(torch.load('best_classifier_model.pth'))
#     best_model.to(device)
#     best_model.eval()

#     final_labels, final_preds = [], []
#     with torch.no_grad():
#         for cqcc, prosodic, labels in val_loader:
#             cqcc, prosodic, labels = cqcc.to(device), prosodic.to(device), labels.to(device)
#             outputs = best_model(cqcc, prosodic)
#             if num_classes == 1:
#                 predicted = (torch.sigmoid(outputs.squeeze(1)) > 0.5).long()
#             else:
#                 _, predicted = torch.max(outputs.data, 1)
#             final_labels.extend(labels.cpu().numpy())
#             final_preds.extend(predicted.cpu().numpy())

#     final_cm = confusion_matrix(final_labels, final_preds)
#     plot_confusion_matrix(final_cm, class_names=class_names, epoch=best_epoch, is_final=True)
    
#     return pd.DataFrame(history)

# #==============================================================================
# #  MAIN EXECUTION SCRIPT (HELPER FUNCTION TO RUN)
# #==============================================================================
# def run_training():
#     """
#     A helper function to demonstrate how to set up and run the training process.
#     """
#     # --- Configuration ---
#     # Model Hyperparameters
#     CQCC_DIM = 90
#     PROSODIC_DIM = 6
#     SEQ_LEN = 157
#     EMBEDDING_DIM = 128
#     NUM_CLASSES = 2 # Assuming binary classification: bona fide (1) vs. spoof (0)
#     D_MODEL = 256
#     N_HEAD = 8
#     N_LAYERS = 4
    
#     # Training Hyperparameters
#     NUM_EPOCHS = 50 # Keep it short for a demo run
#     LEARNING_RATE = 1e-4
#     BATCH_SIZE = 32
#     DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
#     # --- Setup for Strategy 2: Classification ---
#     print("--- Running Training with Classification Strategy ---")
    
#     # 1. Instantiate Model
#     # Use the classifier version of the model
#     model = MultiModalClassifierTransformer(
#         cqcc_input_dim=CQCC_DIM,
#         prosodic_input_dim=PROSODIC_DIM,
#         num_classes=NUM_CLASSES if NUM_CLASSES > 1 else 1, # Output 1 logit for binary case
#         embedding_dim=EMBEDDING_DIM,
#         d_model=D_MODEL,
#         nhead=N_HEAD,
#         num_encoder_layers=N_LAYERS
#     )
    
#     # 2. Instantiate Loss Function
#     # Use BCEWithLogitsLoss for binary classification, CrossEntropy for multi-class
#     if NUM_CLASSES == 1 or NUM_CLASSES == 2:
#         criterion = nn.BCEWithLogitsLoss() if NUM_CLASSES == 1 else nn.CrossEntropyLoss()
#     else:
#         criterion = nn.CrossEntropyLoss()

#     # 3. Instantiate Optimizer and Scheduler
#     optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True)

#     # 4. Run Training
#     history_df = train_classifier_model(
#         model=model,
#         train_loader=train_loader,
#         val_loader=val_loader,
#         criterion=criterion,
#         optimizer=optimizer,
#         scheduler=scheduler,
#         num_epochs=NUM_EPOCHS,
#         device=DEVICE,
#         num_classes=NUM_CLASSES if NUM_CLASSES > 1 else 1
#     )
    
#     print("\nTraining History:")
#     print(history_df)

# if __name__ == '__main__':
#     run_training()

In [None]:
if not history_df.empty:
    plot_metrics(history_df)

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data/cqcc_features.npy"
PROSODIC_FEATURES_TRAIN_CSV_PATH = "processed_data/prosodic_features_and_labels.csv"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data/cqcc_features_val.npy"
PROSODIC_FEATURES_VAL_CSV_PATH = "processed_data/prosodic_features_and_labels_val.csv"

# --- Model and Training Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-4

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """
    Calculates the Equal Error Rate (EER) from the true labels and scores.
    """
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset for handling 2D CQCC and 1D prosodic features."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionCNN(nn.Module):
    """PyTorch implementation using Conv2D for CQCC features."""
    def __init__(self, cqcc_input_shape, prosodic_features):
        super(AttentionFusionCNN, self).__init__()
        
        # --- CQCC Branch (Conv2D) ---
        self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.cqcc_bn1 = nn.BatchNorm2d(16)
        self.cqcc_pool1 = nn.MaxPool2d((2, 2))
        
        self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.cqcc_bn2 = nn.BatchNorm2d(32)
        self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
        with torch.no_grad():
            dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
            dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
            cqcc_flat_size = dummy_out.numel()
            
        self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)

        # --- Prosodic Branch ---
        self.prosody_fc1 = nn.Linear(prosodic_features, 32)
        self.prosody_bn1 = nn.BatchNorm1d(32)
        self.prosody_dropout = nn.Dropout(0.4)
        self.prosody_fc2 = nn.Linear(32, 64)

        # --- Fusion and Classifier ---
        concatenated_size = 64 + 64
        self.attention = nn.Linear(concatenated_size, concatenated_size)
        self.classifier_fc1 = nn.Linear(concatenated_size, 64)
        self.classifier_bn = nn.BatchNorm1d(64)
        self.classifier_dropout = nn.Dropout(0.5)
        self.output_fc = nn.Linear(64, 1)

    def forward(self, cqcc_x, prosody_x):
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
        cqcc_out = self.cqcc_pool1(cqcc_out)
        cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
        cqcc_out = self.cqcc_pool2(cqcc_out)
        cqcc_out = torch.flatten(cqcc_out, 1)
        cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

        prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
        prosody_out = self.prosody_dropout(prosody_out)
        prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

        concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        attention_weights = torch.softmax(self.attention(concatenated), dim=1)
        fused = concatenated * attention_weights

        x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
        x = self.classifier_dropout(x)
        output = torch.sigmoid(self.output_fc(x))
        
        return output

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        
        prosody_df_train = pd.read_csv(PROSODIC_FEATURES_TRAIN_CSV_PATH)
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        
        prosody_df_val = pd.read_csv(PROSODIC_FEATURES_VAL_CSV_PATH)
        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        
        if len(prosody_df_train) != len(X_cqcc_train) or len(prosody_df_val) != len(X_cqcc_val):
            raise ValueError("Sample count mismatch between CSV and .npy files.")

        feature_columns = ['mean_f0', 'std_f0', 'jitter', 'shimmer', 'mean_hnr', 'std_hnr']
        
        X_prosody_train = prosody_df_train[feature_columns].values
        y_train = prosody_df_train['label'].values
        
        X_prosody_val = prosody_df_val[feature_columns].values
        y_val = prosody_df_val['label'].values
        
        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
        
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("--- Scaling Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)

    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, prosody_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = [] # Collect raw scores for EER
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                
                outputs = model(cqcc_batch, prosody_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        # Convert lists to numpy arrays for metric calculations
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:")
        print(cm)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")


In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data/cqcc_features.npy"
PROSODIC_FEATURES_TRAIN_CSV_PATH = "processed_data/prosodic_features_and_labels.csv"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data/cqcc_features_val.npy"
PROSODIC_FEATURES_VAL_CSV_PATH = "processed_data/prosodic_features_and_labels_val.csv"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_CSV_PATH = "processed_data/prosodic_features_and_labels_test.csv"

# --- Model and Training Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 256
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))

    # Plotting losses on the primary y-axis
    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')

    # Create a second y-axis for EER
    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionCNN(nn.Module):
    """PyTorch implementation using Conv2D for CQCC features."""
    def __init__(self, cqcc_input_shape, prosodic_features):
        super(AttentionFusionCNN, self).__init__()
        
        self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.cqcc_bn1 = nn.BatchNorm2d(16)
        self.cqcc_pool1 = nn.MaxPool2d((2, 2))
        self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.cqcc_bn2 = nn.BatchNorm2d(32)
        self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
        with torch.no_grad():
            dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
            dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
            cqcc_flat_size = dummy_out.numel()
            
        self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)
        self.prosody_fc1 = nn.Linear(prosodic_features, 32)
        self.prosody_bn1 = nn.BatchNorm1d(32)
        self.prosody_dropout = nn.Dropout(0.4)
        self.prosody_fc2 = nn.Linear(32, 64)
        concatenated_size = 64 + 64
        self.attention = nn.Linear(concatenated_size, concatenated_size)
        self.classifier_fc1 = nn.Linear(concatenated_size, 64)
        self.classifier_bn = nn.BatchNorm1d(64)
        self.classifier_dropout = nn.Dropout(0.5)
        self.output_fc = nn.Linear(64, 1)

    def forward(self, cqcc_x, prosody_x):
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
        cqcc_out = self.cqcc_pool1(cqcc_out)
        cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
        cqcc_out = self.cqcc_pool2(cqcc_out)
        cqcc_out = torch.flatten(cqcc_out, 1)
        cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

        prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
        prosody_out = self.prosody_dropout(prosody_out)
        prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

        concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        attention_weights = torch.softmax(self.attention(concatenated), dim=1)
        fused = concatenated * attention_weights

        x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
        x = self.classifier_dropout(x)
        output = torch.sigmoid(self.output_fc(x))
        
        return output

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        prosody_df_train = pd.read_csv(PROSODIC_FEATURES_TRAIN_CSV_PATH)
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        prosody_df_val = pd.read_csv(PROSODIC_FEATURES_VAL_CSV_PATH)
        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        
        if len(prosody_df_train) != len(X_cqcc_train) or len(prosody_df_val) != len(X_cqcc_val):
            raise ValueError("Sample count mismatch between CSV and .npy files.")

        feature_columns = ['mean_f0', 'std_f0', 'jitter', 'shimmer', 'mean_hnr', 'std_hnr']
        X_prosody_train = prosody_df_train[feature_columns].values
        y_train = prosody_df_train['label'].values
        X_prosody_val = prosody_df_val[feature_columns].values
        y_val = prosody_df_val['label'].values
        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("--- Scaling Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(cqcc_batch, prosody_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, prosody_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:\n", cm)
        
        # Store history for plotting
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['f1'].append(f1)
        history['eer'].append(eer)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

    plot_training_history(history, PLOT_SAVE_PATH)

    # Plot and save the training history
CQCC_FEATURES_TEST_PATH = "processed_data/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_CSV_PATH = "processed_data/prosodic_features_and_labels_test.csv"


print("\n--- Starting Final Testing ---")
try:
    # 1. Load the full test dataset
    print("Loading full test data...")
    prosody_df_test_full = pd.read_csv(PROSODIC_FEATURES_TEST_CSV_PATH)
    X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
    X_prosody_test_full = prosody_df_test_full[feature_columns].values
    y_test_full = prosody_df_test_full['label'].values
    print(f"Loaded {len(y_test_full)} total test samples.")

    # 2. Create a balanced 70,000-sample subset using stratified sampling
    NUM_SAMPLES_TO_SELECT = 71200
    print(f"Creating a balanced subset of {NUM_SAMPLES_TO_SELECT} samples...")

    # Use train_test_split to perform a single stratified split to get indices
    _, _, _, _, _, selected_indices = train_test_split(
        X_cqcc_test_full,
        y_test_full,
        np.arange(len(y_test_full)), # Pass indices to get a split of indices
        test_size=NUM_SAMPLES_TO_SELECT,
        stratify=y_test_full,
        random_state=42 # Ensures reproducibility
    )

    # Use the selected indices to create the subset
    X_cqcc_test_subset = X_cqcc_test_full[selected_indices]
    X_prosody_test_subset = X_prosody_test_full[selected_indices]
    y_test_subset = y_test_full[selected_indices]
    
    print(f"Balanced subset created. Class distribution: {np.bincount(y_test_subset) / len(y_test_subset)}")

    # 3. Scale the test subset using the *already-fitted* scalers
    X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_subset)
    nsamples_test, nx_test, ny_test = X_cqcc_test_subset.shape
    X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_subset.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
    
    # 4. Create Test Dataset and DataLoader
    test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test_subset)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # 5. Load the best model and evaluate
    print("Loading best model for testing...")
    test_model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    test_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    test_model.eval()

    all_test_labels = []
    all_test_scores = []
    with torch.no_grad():
        for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Testing"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            outputs = test_model(cqcc_batch, prosody_batch)
            all_test_scores.extend(outputs.cpu().numpy())
            all_test_labels.extend(labels_batch.cpu().numpy())
    
    # 6. Calculate and display final metrics
    all_test_labels = np.array(all_test_labels)
    all_test_scores = np.array(all_test_scores).flatten()
    all_test_preds = (all_test_scores > 0.5).astype(int)

    test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
    test_f1 = f1_score(all_test_labels, all_test_preds)
    test_eer = calculate_eer(all_test_labels, all_test_scores)
    test_cm = confusion_matrix(all_test_labels, all_test_preds)

    print("\n--- Final Test Results ---")
    print(f"Accuracy: {test_accuracy:.2f}%")
    print(f"F1-Score: {test_f1:.4f}")
    print(f"EER: {test_eer:.2f}%")
    print("Confusion Matrix:\n", test_cm)

except (FileNotFoundError, ValueError) as e:
    print(f"Error during testing: {e}")
    print("Please ensure your test data files are in the correct paths and format.")
    
    

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import math

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data/cqcc_features.npy"
PROSODIC_FEATURES_TRAIN_CSV_PATH = "processed_data/prosodic_features_and_labels.csv"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data/cqcc_features_val.npy"
PROSODIC_FEATURES_VAL_CSV_PATH = "processed_data/prosodic_features_and_labels_val.csv"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_CSV_PATH = "processed_data/prosodic_features_and_labels_test.csv"

# --- Model and Training Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionTransformer_Best.pth" # Updated model name
PLOT_SAVE_PATH = "saved_models/training_metrics_transformer.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64 # Transformers can be memory-intensive, a smaller batch size might be needed
EPOCHS = 40
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# --- NEW: Transformer Hyperparameters ---
D_MODEL = 128       # Embedding dimension
N_HEAD = 8          # Number of attention heads (must be a divisor of D_MODEL)
NUM_ENCODER_LAYERS = 4 # Number of stacked transformer layers
DIM_FEEDFORWARD = 512 # Hidden dimension in the feed-forward network
TRANSFORMER_DROPOUT = 0.1

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))

    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionTransformer(nn.Module):
    """
    A Transformer-based model to classify audio features.
    It fuses prosodic features (as a CLS token) with a sequence of CQCC features.
    """
    def __init__(self, cqcc_input_shape, prosodic_features, d_model, nhead, num_encoder_layers, dim_feedforward, dropout):
        super(AttentionFusionTransformer, self).__init__()
        
        num_cqcc_coeffs, num_frames = cqcc_input_shape
        
        # 1. Projection layers to create embeddings
        self.cqcc_projection = nn.Linear(num_cqcc_coeffs, d_model)
        self.prosody_projection = nn.Linear(prosodic_features, d_model)
        
        # 2. Positional Encoding
        self.pos_encoder = nn.Parameter(torch.zeros(1, num_frames + 1, d_model))
        
        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout,
            batch_first=True  # Important!
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # 4. Classifier Head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC Input shape: (batch, coeffs, frames)
        # Prosody Input shape: (batch, num_features)
        
        # Permute CQCC to treat frames as sequence: (batch, frames, coeffs)
        cqcc_x = cqcc_x.permute(0, 2, 1)

        # Project features to the embedding dimension (d_model)
        cqcc_embed = self.cqcc_projection(cqcc_x)  # -> (batch, frames, d_model)
        prosody_embed = self.prosody_projection(prosody_x).unsqueeze(1) # -> (batch, 1, d_model)
        
        # Prepend the prosody embedding as the [CLS] token
        # This token will act as the aggregate representation for classification
        full_sequence = torch.cat([prosody_embed, cqcc_embed], dim=1) # -> (batch, frames+1, d_model)
        
        # Add positional encoding
        full_sequence_with_pos = full_sequence + self.pos_encoder
        
        # Feed into the Transformer Encoder
        transformer_out = self.transformer_encoder(full_sequence_with_pos) # -> (batch, frames+1, d_model)
        
        # We only use the output of the [CLS] token (the first token) for classification
        cls_token_out = transformer_out[:, 0, :] # -> (batch, d_model)
        
        # Final classification
        logits = self.classifier(cls_token_out)
        output = torch.sigmoid(logits)
        
        return output


if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        prosody_df_train = pd.read_csv(PROSODIC_FEATURES_TRAIN_CSV_PATH)
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        prosody_df_val = pd.read_csv(PROSODIC_FEATURES_VAL_CSV_PATH)
        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        
        if len(prosody_df_train) != len(X_cqcc_train) or len(prosody_df_val) != len(X_cqcc_val):
            raise ValueError("Sample count mismatch between CSV and .npy files.")

        feature_columns = ['mean_f0', 'std_f0', 'jitter', 'shimmer', 'mean_hnr', 'std_hnr']
        X_prosody_train = prosody_df_train[feature_columns].values
        y_train = prosody_df_train['label'].values
        X_prosody_val = prosody_df_val[feature_columns].values
        y_val = prosody_df_val['label'].values
        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("--- Scaling Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # --- MODEL INSTANTIATION CHANGED ---
    # model = AttentionFusionCNN(...) # Old model commented out
    model = AttentionFusionTransformer(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1],
        d_model=D_MODEL,
        nhead=N_HEAD,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        dropout=TRANSFORMER_DROPOUT
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(cqcc_batch, prosody_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, prosody_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:\n", cm)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['f1'].append(f1)
        history['eer'].append(eer)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")
    
    plot_training_history(history, PLOT_SAVE_PATH)

    # --- Testing Loop ---
    print("\n--- Starting Final Testing ---")
    try:
        print("Loading full test data...")
        prosody_df_test_full = pd.read_csv(PROSODIC_FEATURES_TEST_CSV_PATH)
        X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_full = prosody_df_test_full[feature_columns].values
        y_test_full = prosody_df_test_full['label'].values
        print(f"Loaded {len(y_test_full)} total test samples.")

        NUM_SAMPLES_TO_SELECT = 70000
        print(f"Creating a balanced subset of {NUM_SAMPLES_TO_SELECT} samples...")
        
        _, _, _, _, _, selected_indices = train_test_split(
            X_cqcc_test_full, y_test_full, np.arange(len(y_test_full)),
            test_size=NUM_SAMPLES_TO_SELECT,
            stratify=y_test_full,
            random_state=42
        )

        X_cqcc_test_subset = X_cqcc_test_full[selected_indices]
        X_prosody_test_subset = X_prosody_test_full[selected_indices]
        y_test_subset = y_test_full[selected_indices]
        
        print(f"Balanced subset created. Class distribution: {np.bincount(y_test_subset) / len(y_test_subset)}")

        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_subset)
        nsamples_test, nx_test, ny_test = X_cqcc_test_subset.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_subset.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
        
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test_subset)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        print("Loading best model for testing...")
        test_model = AttentionFusionTransformer(
            cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
            prosodic_features=X_prosody_train.shape[1],
            d_model=D_MODEL,
            nhead=N_HEAD,
            num_encoder_layers=NUM_ENCODER_LAYERS,
            dim_feedforward=DIM_FEEDFORWARD,
            dropout=TRANSFORMER_DROPOUT
        ).to(DEVICE)
        test_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        test_model.eval()

        all_test_labels = []
        all_test_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Testing"):
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = test_model(cqcc_batch, prosody_batch)
                all_test_scores.extend(outputs.cpu().numpy())
                all_test_labels.extend(labels_batch.cpu().numpy())
        
        all_test_labels = np.array(all_test_labels)
        all_test_scores = np.array(all_test_scores).flatten()
        all_test_preds = (all_test_scores > 0.5).astype(int)

        test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
        test_f1 = f1_score(all_test_labels, all_test_preds)
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        test_cm = confusion_matrix(all_test_labels, all_test_preds)

        print("\n--- Final Test Results ---")
        print(f"Accuracy: {test_accuracy:.2f}%")
        print(f"F1-Score: {test_f1:.4f}")
        print(f"EER: {test_eer:.2f}%")
        print("Confusion Matrix:\n", test_cm)

    except (FileNotFoundError, ValueError) as e:
        print(f"Error during testing: {e}")
        print("Please ensure your test data files are in the correct paths and format.")

In [None]:
# # --- Testing Loop ---
# print("\n--- Starting Final Testing ---")
# try:
#     print("Loading full test data...")
#     prosody_df_test_full = pd.read_csv(PROSODIC_FEATURES_TEST_CSV_PATH)
#     X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
#     X_prosody_test_full = prosody_df_test_full[feature_columns].values
#     y_test_full = prosody_df_test_full['label'].values
#     print(f"Loaded {len(y_test_full)} total test samples.")

#     NUM_SAMPLES_TO_SELECT = 70000
#     print(f"Creating a balanced subset of {NUM_SAMPLES_TO_SELECT} samples...")
    
#     _, _, _, _, _, selected_indices = train_test_split(
#         X_cqcc_test_full, y_test_full, np.arange(len(y_test_full)),
#         test_size=NUM_SAMPLES_TO_SELECT,
#         stratify=y_test_full,
#         random_state=42
#     )

#     X_cqcc_test_subset = X_cqcc_test_full[selected_indices]
#     X_prosody_test_subset = X_prosody_test_full[selected_indices]
#     y_test_subset = y_test_full[selected_indices]
    
#     print(f"Balanced subset created. Class distribution: {np.bincount(y_test_subset) / len(y_test_subset)}")

#     X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_subset)
#     nsamples_test, nx_test, ny_test = X_cqcc_test_subset.shape
#     X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_subset.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
    
#     test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test_subset)
#     test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
#     print("Loading best model for testing...")
#     test_model = AttentionFusionTransformer(
#         cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
#         prosodic_features=X_prosody_train.shape[1],
#         d_model=D_MODEL,
#         nhead=N_HEAD,
#         num_encoder_layers=NUM_ENCODER_LAYERS,
#         dim_feedforward=DIM_FEEDFORWARD,
#         dropout=TRANSFORMER_DROPOUT
#     ).to(DEVICE)
#     test_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
#     test_model.eval()

#     all_test_labels = []
#     all_test_scores = []
#     with torch.no_grad():
#         for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Testing"):
#             cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
#             outputs = test_model(cqcc_batch, prosody_batch)
#             all_test_scores.extend(outputs.cpu().numpy())
#             all_test_labels.extend(labels_batch.cpu().numpy())
    
#     all_test_labels = np.array(all_test_labels)
#     all_test_scores = np.array(all_test_scores).flatten()
#     all_test_preds = (all_test_scores > 0.5).astype(int)

#     test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
#     test_f1 = f1_score(all_test_labels, all_test_preds)
#     test_eer = calculate_eer(all_test_labels, all_test_scores)
#     test_cm = confusion_matrix(all_test_labels, all_test_preds)

#     print("\n--- Final Test Results ---")
#     print(f"Accuracy: {test_accuracy:.2f}%")
#     print(f"F1-Score: {test_f1:.4f}")
#     print(f"EER: {test_eer:.2f}%")
#     print("Confusion Matrix:\n", test_cm)

# except (FileNotFoundError, ValueError) as e:
#     print(f"Error during testing: {e}")
#     print("Please ensure your test data files are in the correct paths and format.")

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- 1. CONFIGURATION ---

# --- Paths ---
# Ensure this path is correct for your environment
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "lstm_cross_attention_model_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 20 # Increased for better convergence visualization
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128

# --- 2. UTILITY FUNCTIONS & DATASET CLASS ---

def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER). Returns -1 if calculation fails."""
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        return eer * 100
    except (ValueError, ZeroDivisionError):
        return -1.0

def plot_training_history(history, save_path):
    """Plots and saves a comprehensive training history graph."""
    fig, ax1 = plt.subplots(figsize=(14, 8))
    epochs_range = range(1, len(history['train_loss']) + 1)

    # Plotting losses on the primary y-axis (left)
    color = 'tab:red'
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Loss', color=color, fontsize=14)
    ax1.plot(epochs_range, history['train_loss'], color=color, linestyle='--', marker='o', label='Train Loss')
    ax1.plot(epochs_range, history['val_loss'], color=color, linestyle='-', marker='o', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    # Create a second y-axis for Accuracy (right)
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy (%)', color=color, fontsize=14)
    ax2.plot(epochs_range, history['train_acc'], color=color, linestyle='--', marker='s', label='Train Accuracy')
    ax2.plot(epochs_range, history['val_acc'], color=color, linestyle='-', marker='s', label='Val Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)

    # Create a third y-axis for EER (right, further out)
    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 60)) # Offset the third axis
    color = 'tab:green'
    ax3.set_ylabel('EER (%)', color=color, fontsize=14)
    ax3.plot(epochs_range, history['val_eer'], color=color, linestyle=':', marker='^', label='Val EER')
    ax3.tick_params(axis='y', labelcolor=color)
    
    # Combine legends from all axes
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines3, labels3 = ax3.get_legend_handles_labels()
    ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3, loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, shadow=True, ncol=5)

    fig.suptitle('Training and Validation Metrics', fontsize=16)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96]) # Adjust layout to make room for legend
    plt.savefig(save_path)
    print(f"\n📈 Training plot saved to {save_path}")
    plt.close()

class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset for the fusion model."""
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- 3. LSTM CROSS-ATTENTION MODEL DEFINITION ---

class LSTMCrossAttentionFusion(nn.Module):
    """
    Fuses CQCC features (processed by a CNN) with eGeMAPS LLDs (processed by an LSTM)
    using cross-modal attention.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, embedding_dim):
        super(LSTMCrossAttentionFusion, self).__init__()
        
        self.cqcc_cnn = nn.Sequential(
            nn.Conv1d(cqcc_features, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64, embedding_dim, kernel_size=3, padding=1),
        )

        self.lstm = nn.LSTM(
            input_size=egmaps_features,
            hidden_size=embedding_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 # Added dropout to LSTM
        )
        self.lstm_fc = nn.Linear(embedding_dim * 2, embedding_dim)
        
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=embedding_dim,
            num_heads=4,
            batch_first=True,
            dropout=0.2 # Added dropout to attention
        )

        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        # Input shapes are (batch, features, time)
        cqcc_out_cnn = self.cqcc_cnn(cqcc_x).transpose(1, 2) # -> (batch, time, embed_dim)
        
        egmaps_x_seq = egmaps_x.transpose(1, 2) # -> (batch, time, features)
        lstm_out, _ = self.lstm(egmaps_x_seq)
        prosody_query = torch.tanh(self.lstm_fc(lstm_out)) # -> (batch, time, embed_dim)
        
        attended_output, _ = self.cross_attention(
            query=prosody_query, key=cqcc_out_cnn, value=cqcc_out_cnn
        )
        
        pooled_output = attended_output.mean(dim=1)
        output = self.classifier(pooled_output)
        
        return torch.sigmoid(output)

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    print("--- Scaling Features ---")
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = LSTMCrossAttentionFusion(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        embedding_dim=EMBEDDING_DIM
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_eer': []}
    
    print(f"\n--- Starting Training: LSTM Cross-Attention Model ---")
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        train_labels, train_preds = [], []
        
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_labels.extend(labels_batch.cpu().numpy())
            train_preds.extend(outputs.detach().cpu().numpy())

        # --- Validation Phase ---
        model.eval()
        total_val_loss = 0
        val_labels, val_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  "):
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                
                total_val_loss += loss.item()
                val_scores.extend(outputs.cpu().numpy())
                val_labels.extend(labels_batch.cpu().numpy())
        
        # --- Calculate and Store Metrics ---
        # Training metrics
        avg_train_loss = total_train_loss / len(train_loader)
        train_labels = np.array(train_labels)
        train_preds_binary = (np.array(train_preds) > 0.5).astype(int).flatten()
        train_acc = accuracy_score(train_labels, train_preds_binary) * 100

        # Validation metrics
        avg_val_loss = total_val_loss / len(val_loader)
        val_labels = np.array(val_labels)
        val_scores = np.array(val_scores).flatten()
        val_preds_binary = (val_scores > 0.5).astype(int)
        val_acc = accuracy_score(val_labels, val_preds_binary) * 100
        val_f1 = f1_score(val_labels, val_preds_binary)
        val_eer = calculate_eer(val_labels, val_scores)
        cm = confusion_matrix(val_labels, val_preds_binary)
        
        # Log history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['val_eer'].append(val_eer)

        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}% | Val F1: {val_f1:.4f} | Val EER: {val_eer:.2f}%")
        print("  Validation Confusion Matrix:\n", cm)

        scheduler.step(avg_val_loss) # Step scheduler on validation loss
        
        if val_eer > 0 and val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_lstm_cross_attention_model1.pth"))
            print(f"  -> ✅ New best model saved with EER: {best_val_eer:.2f}%")

    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_history.png"))



In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- 1. CONFIGURATION ---

# --- Paths ---
# Ensure this path is correct for your environment
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "transformer_encoder_decoder_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64 # Transformers can be memory intensive
EPOCHS = 25
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 512 # d_model for the transformer
NUM_HEADS = 8       # Number of attention heads
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4
DROPOUT = 0.3

# --- 2. UTILITY FUNCTIONS & DATASET CLASS ---

def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER). Returns -1 if calculation fails."""
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        return eer * 100
    except (ValueError, ZeroDivisionError):
        return -1.0

def plot_training_history(history, save_path):
    """Plots and saves a comprehensive training history graph."""
    fig, ax1 = plt.subplots(figsize=(14, 8))
    epochs_range = range(1, len(history['train_loss']) + 1)

    color = 'tab:red'
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Loss', color=color, fontsize=14)
    ax1.plot(epochs_range, history['train_loss'], color=color, linestyle='--', marker='o', label='Train Loss')
    ax1.plot(epochs_range, history['val_loss'], color=color, linestyle='-', marker='o', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy (%)', color=color, fontsize=14)
    ax2.plot(epochs_range, history['train_acc'], color=color, linestyle='--', marker='s', label='Train Accuracy')
    ax2.plot(epochs_range, history['val_acc'], color=color, linestyle='-', marker='s', label='Val Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)

    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 60))
    color = 'tab:green'
    ax3.set_ylabel('EER (%)', color=color, fontsize=14)
    ax3.plot(epochs_range, history['val_eer'], color=color, linestyle=':', marker='^', label='Val EER')
    ax3.tick_params(axis='y', labelcolor=color)
    
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines3, labels3 = ax3.get_legend_handles_labels()
    ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3, loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, shadow=True, ncol=5)

    fig.suptitle('Training and Validation Metrics', fontsize=16)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96])
    plt.savefig(save_path)
    print(f"\n📈 Training plot saved to {save_path}")
    plt.close()

class AudioFeatureDataset(Dataset):
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- 3. TRANSFORMER ENCODER-DECODER MODEL ---

class PositionalEncoding(nn.Module):
    """Adds positional information to the input embeddings."""
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerEncoderDecoder(nn.Module):
    """
    Fuses CQCC and eGeMAPS LLDs using a Transformer Encoder-Decoder architecture.
    Encoder processes CQCCs. Decoder processes eGeMAPS and attends to the encoder output.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, d_model, nhead, num_encoder_layers, num_decoder_layers, dropout):
        super(TransformerEncoderDecoder, self).__init__()
        
        self.d_model = d_model
        
        # --- Feature Projection ---
        self.cqcc_projection = nn.Linear(cqcc_features, d_model)
        self.egmaps_projection = nn.Linear(egmaps_features, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        # --- Transformer ---
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )

        # --- Classifier ---
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        # Input shapes are (batch, features, time)
        # Transpose to (batch, time, features) for sequence processing
        cqcc_x = cqcc_x.transpose(1, 2)
        egmaps_x = egmaps_x.transpose(1, 2)
        
        # 1. Project features to the embedding dimension (d_model)
        cqcc_embed = self.cqcc_projection(cqcc_x)
        egmaps_embed = self.egmaps_projection(egmaps_x)

        # 2. Add positional encoding
        # PyTorch Transformer expects (seq_len, batch, features), so we transpose
        cqcc_embed = self.pos_encoder(cqcc_embed.transpose(0, 1))
        egmaps_embed = self.pos_encoder(egmaps_embed.transpose(0, 1))
        
        # 3. Pass through the Transformer
        # Encoder gets CQCCs (src), Decoder gets eGeMAPS (tgt) and attends to CQCCs (memory)
        transformer_out = self.transformer(src=cqcc_embed, tgt=egmaps_embed)
        
        # Transpose back to (batch, time, features)
        transformer_out = transformer_out.transpose(0, 1)

        # 4. Pool and Classify
        pooled_output = transformer_out.mean(dim=1)
        output = self.classifier(pooled_output)
        
        return torch.sigmoid(output)

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    print("--- Scaling Features ---")
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = TransformerEncoderDecoder(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_eer': []}
    
    print(f"\n--- Starting Training: Transformer Encoder-Decoder Model ---")
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        train_labels, train_preds = [], []
        
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_labels.extend(labels_batch.cpu().numpy())
            train_preds.extend(outputs.detach().cpu().numpy())

        model.eval()
        total_val_loss = 0
        val_labels, val_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  "):
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                
                total_val_loss += loss.item()
                val_scores.extend(outputs.cpu().numpy())
                val_labels.extend(labels_batch.cpu().numpy())
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_labels = np.array(train_labels)
        train_preds_binary = (np.array(train_preds) > 0.5).astype(int).flatten()
        train_acc = accuracy_score(train_labels, train_preds_binary) * 100

        avg_val_loss = total_val_loss / len(val_loader)
        val_labels = np.array(val_labels)
        val_scores = np.array(val_scores).flatten()
        val_preds_binary = (val_scores > 0.5).astype(int)
        val_acc = accuracy_score(val_labels, val_preds_binary) * 100
        val_f1 = f1_score(val_labels, val_preds_binary)
        val_eer = calculate_eer(val_labels, val_scores)
        cm = confusion_matrix(val_labels, val_preds_binary)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['val_eer'].append(val_eer)

        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}% | Val F1: {val_f1:.4f} | Val EER: {val_eer:.2f}%")
        print("  Validation Confusion Matrix:\n", cm)

        scheduler.step(avg_val_loss)
        
        if val_eer > 0 and val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_transformer_enc_dec_model.pth"))
            print(f"  -> ✅ New best model saved with EER: {best_val_eer:.2f}%")

    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_history_transformer.png"))


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import joblib

# --- 1. CONFIGURATION ---

# --- Paths ---
# Ensure this path is correct for your environment
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "single_stream_transformer_output")
MODEL_OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "single_stream_transformer_output") 
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32 # Transformers can be memory intensive
EPOCHS = 25
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128 # d_model for the transformer
NUM_HEADS = 8       # Number of attention heads
NUM_ENCODER_LAYERS = 6 # Can use a deeper single encoder
DROPOUT = 0.2

# --- 2. UTILITY FUNCTIONS & DATASET CLASS ---

def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER). Returns -1 if calculation fails."""
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        return eer * 100
    except (ValueError, ZeroDivisionError):
        return -1.0

def plot_training_history(history, save_path):
    """Plots and saves a comprehensive training history graph."""
    fig, ax1 = plt.subplots(figsize=(14, 8))
    epochs_range = range(1, len(history['train_loss']) + 1)

    color = 'tab:red'
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Loss', color=color, fontsize=14)
    ax1.plot(epochs_range, history['train_loss'], color=color, linestyle='--', marker='o', label='Train Loss')
    ax1.plot(epochs_range, history['val_loss'], color=color, linestyle='-', marker='o', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy (%)', color=color, fontsize=14)
    ax2.plot(epochs_range, history['train_acc'], color=color, linestyle='--', marker='s', label='Train Accuracy')
    ax2.plot(epochs_range, history['val_acc'], color=color, linestyle='-', marker='s', label='Val Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)

    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 60))
    color = 'tab:green'
    ax3.set_ylabel('EER (%)', color=color, fontsize=14)
    ax3.plot(epochs_range, history['val_eer'], color=color, linestyle=':', marker='^', label='Val EER')
    ax3.tick_params(axis='y', labelcolor=color)
    
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines3, labels3 = ax3.get_legend_handles_labels()
    ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3, loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, shadow=True, ncol=5)

    fig.suptitle('Training and Validation Metrics', fontsize=16)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96])
    plt.savefig(save_path)
    print(f"\n📈 Training plot saved to {save_path}")
    plt.close()

class AudioFeatureDataset(Dataset):
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- 3. SINGLE-STREAM FUSION TRANSFORMER MODEL ---

class SingleStreamFusionTransformer(nn.Module):
    """
    Fuses CQCC and eGeMAPS by concatenating them into a single sequence
    and feeding them to a single Transformer Encoder. Uses token-type embeddings
    to distinguish between the two modalities.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, d_model, nhead, num_encoder_layers, dropout):
        super(SingleStreamFusionTransformer, self).__init__()
        
        self.d_model = d_model
        
        # --- Feature Projection ---
        self.cqcc_projection = nn.Linear(cqcc_features, d_model)
        self.egmaps_projection = nn.Linear(egmaps_features, d_model)
        
        # --- Special Tokens and Embeddings ---
        # CLS token will be learned
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        # Token type embeddings to distinguish CQCC from eGeMAPS
        self.token_type_embeddings = nn.Embedding(num_embeddings=2, embedding_dim=d_model) # 0 for CQCC, 1 for eGeMAPS
        
        # Positional encodings will be added to the combined sequence
        # Max length is 1 (CLS) + time_steps (CQCC) + time_steps (eGeMAPS)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 1 + time_steps * 2, d_model))

        # --- Transformer Encoder ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # --- Classifier ---
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        # Input shapes are (batch, features, time)
        # Transpose to (batch, time, features) for sequence processing
        cqcc_x = cqcc_x.transpose(1, 2)
        egmaps_x = egmaps_x.transpose(1, 2)
        
        batch_size = cqcc_x.size(0)
        time_steps = cqcc_x.size(1) # Get sequence length
        
        # 1. Project features to the embedding dimension (d_model)
        cqcc_embed = self.cqcc_projection(cqcc_x)      # (batch, time, d_model)
        egmaps_embed = self.egmaps_projection(egmaps_x)  # (batch, time, d_model)
        
        # 2. Prepare token type embeddings (CORRECTED)
        # Create IDs with shape (batch_size, time_steps)
        cqcc_type_ids = torch.zeros(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        egmaps_type_ids = torch.ones(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        
        # Get embeddings from IDs. Shape will be (batch_size, time_steps, d_model)
        cqcc_type_embed = self.token_type_embeddings(cqcc_type_ids)
        egmaps_type_embed = self.token_type_embeddings(egmaps_type_ids)

        # Add token type embeddings to feature embeddings
        cqcc_embed += cqcc_type_embed
        egmaps_embed += egmaps_type_embed
        
        # 3. Create the full sequence: [CLS] + CQCC + eGeMAPS
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        full_sequence = torch.cat([cls_tokens, cqcc_embed, egmaps_embed], dim=1)
        
        # 4. Add positional encoding
        full_sequence += self.positional_encoding
        
        # 5. Pass through the Transformer Encoder
        transformer_out = self.transformer_encoder(full_sequence)
        
        # 6. Use the output of the [CLS] token for classification
        cls_output = transformer_out[:, 0, :]
        output = self.classifier(cls_output)
        
        return torch.sigmoid(output)

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    print("--- Scaling Features ---")
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_eer': []}
    
    print(f"\n--- Starting Training: Single-Stream Fusion Transformer Model ---")
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        train_labels, train_preds = [], []
        
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_labels.extend(labels_batch.cpu().numpy())
            train_preds.extend(outputs.detach().cpu().numpy())

        model.eval()
        total_val_loss = 0
        val_labels, val_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  "):
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                
                total_val_loss += loss.item()
                val_scores.extend(outputs.cpu().numpy())
                val_labels.extend(labels_batch.cpu().numpy())
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_labels = np.array(train_labels)
        train_preds_binary = (np.array(train_preds) > 0.5).astype(int).flatten()
        train_acc = accuracy_score(train_labels, train_preds_binary) * 100

        avg_val_loss = total_val_loss / len(val_loader)
        val_labels = np.array(val_labels)
        val_scores = np.array(val_scores).flatten()
        val_preds_binary = (val_scores > 0.5).astype(int)
        val_acc = accuracy_score(val_labels, val_preds_binary) * 100
        val_f1 = f1_score(val_labels, val_preds_binary)
        val_eer = calculate_eer(val_labels, val_scores)
        cm = confusion_matrix(val_labels, val_preds_binary)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['val_eer'].append(val_eer)

        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}% | Val F1: {val_f1:.4f} | Val EER: {val_eer:.2f}%")
        print("  Validation Confusion Matrix:\n", cm)

        scheduler.step(avg_val_loss)
        
        if val_eer > 0 and val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_single_stream_transformer_model.pth"))
            print(f"  -> ✅ New best model saved with EER: {best_val_eer:.2f}%")

    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_history_single_stream.png"))


def evaluate_on_test_set(model, test_loader, device):
    print("--- Running Inference on Test Set ---")
    model.eval()
    
    test_labels, test_scores = [], []
    
    with torch.no_grad():
        for cqcc_batch, lld_batch, labels_batch in tqdm(test_loader, desc="Testing"):
            cqcc_batch, lld_batch = cqcc_batch.to(device), lld_batch.to(device)
            outputs = model(cqcc_batch, lld_batch)
            test_scores.extend(outputs.cpu().numpy())
            test_labels.extend(labels_batch.cpu().numpy())
    
    # Calculate Final Metrics
    test_labels = np.array(test_labels)
    test_scores = np.array(test_scores).flatten()
    test_preds_binary = (test_scores > 0.5).astype(int)
    
    test_acc = accuracy_score(test_labels, test_preds_binary) * 100
    test_f1 = f1_score(test_labels, test_preds_binary)
    test_eer = calculate_eer(test_labels, test_scores)
    test_cm = confusion_matrix(test_labels, test_preds_binary)
    
    # Print Results
    print("\n" + "="*40)
    print("--- Final Test Results ---")
    print(f"  Accuracy: {test_acc:.2f}%")
    print(f"  F1-Score: {test_f1:.4f}")
    print(f"  EER:      {test_eer:.2f}%")
    print("  Confusion Matrix:")
    print(test_cm)
    print("="*40)
    
    # Return Metrics
    results = {
        'accuracy': test_acc,
        'f1_score': test_f1,
        'eer': test_eer,
        'confusion_matrix': test_cm
    }
    return results

# --- 4. MAIN TEST EXECUTION ---
if __name__ == '__main__':
    print(f"--- Starting Model Evaluation on Test Set ---")
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Test Data ---")
        X_cqcc_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_test.npy"))
        X_lld_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_test.npy"))
        y_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_test.npy"))
        print(f"✅ Loaded {len(y_test)} test samples.")
    except FileNotFoundError as e:
        print(f"❌ Error loading test data files: {e}")
        exit()

    try:
        print("--- Loading Scalers ---")
        scaler_cqcc = joblib.load(os.path.join(MODEL_OUTPUT_DIR, "scaler_cqcc.joblib"))
        scaler_lld = joblib.load(os.path.join(MODEL_OUTPUT_DIR, "scaler_lld.joblib"))
        print("✅ Scalers loaded successfully.")
    except FileNotFoundError as e:
        print(f"❌ Error loading scaler files: {e}")
        exit()
        
    print("--- Scaling Test Features ---")
    X_lld_test_scaled = scaler_lld.transform(X_lld_test.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_test.shape)
    X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_test.shape)
    
    test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_lld_test_scaled, y_test)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print("--- Loading Best Trained Model ---")
    model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    model_path = os.path.join(MODEL_OUTPUT_DIR, "best_single_stream_transformer_model.pth")
    try:
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        print("✅ Model weights loaded successfully.")
    except FileNotFoundError:
        print(f"❌ Model file not found at {model_path}")
        exit()
    
    # --- Call the evaluation function ---
    test_results = evaluate_on_test_set(model=model, test_loader=test_loader, device=DEVICE)
    
    print("\nEvaluation complete.")
    # You can optionally do something with the results dictionary here
    # print("Returned metrics dictionary:", test_results)


In [None]:
import os
import numpy as np
import librosa
from scipy.fftpack import dct
from tqdm import tqdm
import soundfile as sf
import opensmile

# --- 1. CONFIGURATION ---

# --- Paths ---
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
os.makedirs(OUTPUT_DIR, exist_ok=True)
TEMP_AUDIO_PATH = os.path.join(OUTPUT_DIR, "temp_5s_audio.wav") # For temporary audio clips

# --- Feature Parameters ---
TARGET_SHAPE_CQCC = (128, 157)
TARGET_SHAPE_LLD = (23, 157) # eGeMAPS LLDs have 23 features
SAMPLE_RATE = 16000
DURATION = 5.0

# --- 2. HELPER FUNCTIONS ---

def extract_cqcc(y, sr, n_bins=90, n_cqcc=128):
    """Extracts CQCC features."""
    try:
        cqt = np.abs(librosa.cqt(y=y, sr=sr, n_bins=n_bins, fmin=librosa.note_to_hz('C1')))
        log_cqt = np.log(cqt + 1e-6)
        cqcc = dct(log_cqt, type=2, axis=0, norm='ortho')
        return cqcc[:n_cqcc, :]
    except Exception:
        return None

def pad_or_truncate(array, target_shape):
    """Pads or truncates a 2D array to a target shape."""
    padded_array = np.full(target_shape, 0.0, dtype=np.float32) # Pad with 0
    copy_shape = tuple(min(c, t) for c, t in zip(array.shape, target_shape))
    padded_array[:copy_shape[0], :copy_shape[1]] = array[:copy_shape[0], :copy_shape[1]]
    return padded_array

def process_data_aligned(directories, label, smile_instance):
    """
    Extracts aligned CQCC and eGeMAPS LLD features.
    """
    cqcc_list, lld_list, labels_list = [], [], []

    for directory in directories:
        full_dir_path = os.path.join(TEAMMATE_DATA_PATH, directory)
        print(f"\nProcessing directory: {full_dir_path}")
        if not os.path.isdir(full_dir_path):
            continue

        files = [f for f in os.listdir(full_dir_path) if f.endswith(('.flac', '.wav'))]
        for filename in tqdm(files, desc=f"Extracting from {directory}"):
            filepath = os.path.join(full_dir_path, filename)
            try:
                # 1. Load the 5-second audio clip once
                audio_5s, sr = librosa.load(filepath, sr=SAMPLE_RATE, duration=DURATION)

                # 2. Extract CQCC from the 5s clip
                cqcc_feats = extract_cqcc(audio_5s, sr, n_cqcc=TARGET_SHAPE_CQCC[0])
                if cqcc_feats is None: continue

                # 3. Extract LLDs from the same 5s clip
                # We need to save the clip to a temporary file for openSMILE to process
                sf.write(TEMP_AUDIO_PATH, audio_5s, sr)
                lld_df = smile_instance.process_file(TEMP_AUDIO_PATH)
                lld_feats = lld_df.values.T # Transpose to get (features, time)

                # 4. Pad both feature sets to the target shape
                padded_cqcc = pad_or_truncate(cqcc_feats, TARGET_SHAPE_CQCC)
                padded_lld = pad_or_truncate(lld_feats, TARGET_SHAPE_LLD)

                # 5. Append to lists
                cqcc_list.append(padded_cqcc)
                lld_list.append(padded_lld)
                labels_list.append(label)

            except Exception as e:
                print(f"\nError processing {filepath}: {e}")

    # Clean up the temporary audio file
    if os.path.exists(TEMP_AUDIO_PATH):
        os.remove(TEMP_AUDIO_PATH)

    return np.array(cqcc_list), np.array(lld_list), np.array(labels_list)


# --- 3. MAIN EXECUTION SCRIPT ---

if __name__ == '__main__':
    # --- Initialize openSMILE for LLDs ---
    smile = opensmile.Smile(
        feature_set=opensmile.FeatureSet.eGeMAPS,
        feature_level=opensmile.FeatureLevel.LowLevelDescriptors, # Use the full name, # Set to LLD
    )

    # # --- Process Training Data ---
    # print("--- Processing Training Set ---")
    # cqcc_bf_train, lld_bf_train, labels_bf_train = process_data_aligned(['bonafide_audio_train', 'augmented_bonafide'], 1, smile)
    # cqcc_spf_train, lld_spf_train, labels_spf_train = process_data_aligned(['spoof_audio_train'], 0, smile)

    # # --- Process Validation Data ---
    # print("\n--- Processing Validation Set ---")
    # cqcc_bf_val, lld_bf_val, labels_bf_val = process_data_aligned(['bonafide_audio_val'], 1, smile)
    # cqcc_spf_val, lld_spf_val, labels_spf_val = process_data_aligned(['spoof_audio_val'], 0, smile)

    # # --- Combine and Save Training Data ---
    # X_cqcc_train = np.concatenate((cqcc_bf_train, cqcc_spf_train), axis=0)
    # X_lld_train = np.concatenate((lld_bf_train, lld_spf_train), axis=0)
    # y_train = np.concatenate((labels_bf_train, labels_spf_train), axis=0)
    # np.save(os.path.join(OUTPUT_DIR, "cqcc_features_train.npy"), X_cqcc_train)
    # np.save(os.path.join(OUTPUT_DIR, "egmaps_lld_features_train.npy"), X_lld_train)
    # np.save(os.path.join(OUTPUT_DIR, "labels_train.npy"), y_train)
    # print(f"\n✅ Training data saved. Shapes: CQCC={X_cqcc_train.shape}, LLD={X_lld_train.shape}")

    print("\n--- Processing Test Set ---")
    cqcc_bf_test, lld_bf_test, labels_bf_test = process_data_aligned(['bonafide_audio_test'], 1, smile)
    cqcc_spf_test, lld_spf_test, labels_spf_test = process_data_aligned(['spoof_audio_test'], 0, smile)

    # --- Combine and Save Validation Data ---
    # X_cqcc_val = np.concatenate((cqcc_bf_val, cqcc_spf_val), axis=0)
    # X_lld_val = np.concatenate((lld_bf_val, lld_spf_val), axis=0)
    # y_val = np.concatenate((labels_bf_val, labels_spf_val), axis=0)
    # np.save(os.path.join(OUTPUT_DIR, "cqcc_features_val.npy"), X_cqcc_val)
    # np.save(os.path.join(OUTPUT_DIR, "egmaps_lld_features_val.npy"), X_lld_val)
    # np.save(os.path.join(OUTPUT_DIR, "labels_dev.npy"), y_val)
    # print(f"✅ Validation data saved. Shapes: CQCC={X_cqcc_val.shape}, LLD={X_lld_val.shape}")

    X_cqcc_test = np.concatenate((cqcc_bf_test, cqcc_spf_test), axis=0)
    X_lld_test = np.concatenate((lld_bf_test, lld_spf_test), axis=0)
    y_test = np.concatenate((labels_bf_test, labels_spf_test), axis=0)
    np.save(os.path.join(OUTPUT_DIR, "cqcc_features_test.npy"), X_cqcc_test)
    np.save(os.path.join(OUTPUT_DIR, "egmaps_lld_features_test.npy"), X_lld_test)
    np.save(os.path.join(OUTPUT_DIR, "labels_test.npy"), y_test)
    print(f"✅ Test data saved. Shapes: CQCC={X_cqcc_test.shape}, LLD={X_lld_test.shape}")

In [None]:
def evaluate_on_test_set(model, test_loader, device):
    """
    Evaluates the model on the provided test dataset.

    Args:
        model (nn.Module): The trained model to evaluate.
        test_loader (DataLoader): The DataLoader for the test set.
        device (str): The device to run evaluation on ('cuda' or 'cpu').

    Returns:
        dict: A dictionary containing the final test metrics.
    """
    print("--- Running Inference on Test Set ---")
    model.eval()
    
    test_labels = []
    test_scores = []
    
    with torch.no_grad():
        for cqcc_batch, lld_batch, labels_batch in tqdm(test_loader, desc="Testing"):
            cqcc_batch, lld_batch = cqcc_batch.to(device), lld_batch.to(device)
            outputs = model(cqcc_batch, lld_batch)
            test_scores.extend(outputs.cpu().numpy())
            test_labels.extend(labels_batch.cpu().numpy())
    
    # --- Calculate Final Metrics ---
    test_labels = np.array(test_labels)
    test_scores = np.array(test_scores).flatten()
    test_preds_binary = (test_scores > 0.5).astype(int)
    
    test_acc = accuracy_score(test_labels, test_preds_binary) * 100
    test_f1 = f1_score(test_labels, test_preds_binary)
    test_eer = calculate_eer(test_labels, test_scores)
    test_cm = confusion_matrix(test_labels, test_preds_binary)
    
    # --- Print Results ---
    print("\n" + "="*40)
    print("--- Final Test Results ---")
    print(f"  Accuracy: {test_acc:.2f}%")
    print(f"  F1-Score: {test_f1:.4f}")
    print(f"  EER:      {test_eer:.2f}%")
    print("  Confusion Matrix:")
    print(test_cm)
    print("="*40)
    
    # --- Return Metrics ---
    results = {
        'accuracy': test_acc,
        'f1_score': test_f1,
        'eer': test_eer,
        'confusion_matrix': test_cm
    }
    return results

# --- HOW TO USE IT IN YOUR SCRIPT ---
# In your `if __name__ == '__main__':` block, after loading the model:
# ... (code to load data, scalers, and model) ...

# model.eval() # This is now handled inside the function

test_results = evaluate_on_test_set(model=model, test_loader=test_loader, device=DEVICE)
print("\nEvaluation complete. Results dictionary:", test_results)

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import joblib

# --- 1. CONFIGURATION ---

# --- Paths ---
# Ensure this path is correct for your environment
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "single_stream_transformer_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32 # Transformers can be memory intensive
EPOCHS = 25
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128 # d_model for the transformer
NUM_HEADS = 8       # Number of attention heads
NUM_ENCODER_LAYERS = 6 # Can use a deeper single encoder
DROPOUT = 0.2

# --- 2. UTILITY FUNCTIONS & DATASET CLASS ---

def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER). Returns -1 if calculation fails."""
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        return eer * 100
    except (ValueError, ZeroDivisionError):
        return -1.0

def plot_training_history(history, save_path):
    """Plots and saves a comprehensive training history graph."""
    fig, ax1 = plt.subplots(figsize=(14, 8))
    epochs_range = range(1, len(history['train_loss']) + 1)

    color = 'tab:red'
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Loss', color=color, fontsize=14)
    ax1.plot(epochs_range, history['train_loss'], color=color, linestyle='--', marker='o', label='Train Loss')
    ax1.plot(epochs_range, history['val_loss'], color=color, linestyle='-', marker='o', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy (%)', color=color, fontsize=14)
    ax2.plot(epochs_range, history['train_acc'], color=color, linestyle='--', marker='s', label='Train Accuracy')
    ax2.plot(epochs_range, history['val_acc'], color=color, linestyle='-', marker='s', label='Val Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)

    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 60))
    color = 'tab:green'
    ax3.set_ylabel('EER (%)', color=color, fontsize=14)
    ax3.plot(epochs_range, history['val_eer'], color=color, linestyle=':', marker='^', label='Val EER')
    ax3.tick_params(axis='y', labelcolor=color)
    
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines3, labels3 = ax3.get_legend_handles_labels()
    ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3, loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, shadow=True, ncol=5)

    fig.suptitle('Training and Validation Metrics', fontsize=16)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96])
    plt.savefig(save_path)
    print(f"\n📈 Training plot saved to {save_path}")
    plt.close()

class AudioFeatureDataset(Dataset):
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- 3. SINGLE-STREAM FUSION TRANSFORMER MODEL ---

class SingleStreamFusionTransformer(nn.Module):
    """
    Fuses CQCC and eGeMAPS by concatenating them into a single sequence
    and feeding them to a single Transformer Encoder. Uses token-type embeddings
    to distinguish between the two modalities.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, d_model, nhead, num_encoder_layers, dropout):
        super(SingleStreamFusionTransformer, self).__init__()
        
        self.d_model = d_model
        
        self.cqcc_projection = nn.Linear(cqcc_features, d_model)
        self.egmaps_projection = nn.Linear(egmaps_features, d_model)
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.token_type_embeddings = nn.Embedding(num_embeddings=2, embedding_dim=d_model)
        
        self.positional_encoding = nn.Parameter(torch.zeros(1, 1 + time_steps * 2, d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        cqcc_x = cqcc_x.transpose(1, 2)
        egmaps_x = egmaps_x.transpose(1, 2)
        
        batch_size = cqcc_x.size(0)
        time_steps = cqcc_x.size(1)
        
        cqcc_embed = self.cqcc_projection(cqcc_x)
        egmaps_embed = self.egmaps_projection(egmaps_x)
        
        cqcc_type_ids = torch.zeros(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        egmaps_type_ids = torch.ones(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        
        cqcc_type_embed = self.token_type_embeddings(cqcc_type_ids)
        egmaps_type_embed = self.token_type_embeddings(egmaps_type_ids)

        cqcc_embed += cqcc_type_embed
        egmaps_embed += egmaps_type_embed
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        full_sequence = torch.cat([cls_tokens, cqcc_embed, egmaps_embed], dim=1)
        
        full_sequence += self.positional_encoding
        
        transformer_out = self.transformer_encoder(full_sequence)
        
        cls_output = transformer_out[:, 0, :]
        output = self.classifier(cls_output)
        
        return torch.sigmoid(output)

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    print("--- Scaling Features ---")
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    # --- ADDED: Save the fitted scalers for the test script ---
    print("--- Saving Scalers ---")
    joblib.dump(scaler_cqcc, os.path.join(OUTPUT_DIR, "scaler_cqcc.joblib"))
    joblib.dump(scaler_lld, os.path.join(OUTPUT_DIR, "scaler_lld.joblib"))
    print(f"✅ Scalers saved to {OUTPUT_DIR}")
    
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_eer': []}
    
    print(f"\n--- Starting Training: Single-Stream Fusion Transformer Model ---")
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        train_labels, train_preds = [], []
        
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_labels.extend(labels_batch.cpu().numpy())
            train_preds.extend(outputs.detach().cpu().numpy())

        model.eval()
        total_val_loss = 0
        val_labels, val_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  "):
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                
                total_val_loss += loss.item()
                val_scores.extend(outputs.cpu().numpy())
                val_labels.extend(labels_batch.cpu().numpy())
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_labels = np.array(train_labels)
        train_preds_binary = (np.array(train_preds) > 0.5).astype(int).flatten()
        train_acc = accuracy_score(train_labels, train_preds_binary) * 100

        avg_val_loss = total_val_loss / len(val_loader)
        val_labels = np.array(val_labels)
        val_scores = np.array(val_scores).flatten()
        val_preds_binary = (val_scores > 0.5).astype(int)
        val_acc = accuracy_score(val_labels, val_preds_binary) * 100
        val_f1 = f1_score(val_labels, val_preds_binary)
        val_eer = calculate_eer(val_labels, val_scores)
        cm = confusion_matrix(val_labels, val_preds_binary)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['val_eer'].append(val_eer)

        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}% | Val F1: {val_f1:.4f} | Val EER: {val_eer:.2f}%")
        print("  Validation Confusion Matrix:\n", cm)

        scheduler.step(avg_val_loss)
        
        if val_eer > 0 and val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_single_stream_transformer_model.pth"))
            print(f"  -> ✅ New best model saved with EER: {best_val_eer:.2f}%")

    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_history_single_stream.png"))


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import joblib

# --- 1. CONFIGURATION ---

# --- Paths ---
# Ensure this path is correct for your environment
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "single_stream_transformer_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32 # Transformers can be memory intensive
EPOCHS = 25
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128 # d_model for the transformer
NUM_HEADS = 8       # Number of attention heads
NUM_ENCODER_LAYERS = 6 # Can use a deeper single encoder
DROPOUT = 0.2

# --- 2. UTILITY FUNCTIONS & DATASET CLASS ---

def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER). Returns -1 if calculation fails."""
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        return eer * 100
    except (ValueError, ZeroDivisionError):
        return -1.0

def plot_training_history(history, save_path):
    """Plots and saves a comprehensive training history graph."""
    fig, ax1 = plt.subplots(figsize=(14, 8))
    epochs_range = range(1, len(history['train_loss']) + 1)

    color = 'tab:red'
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Loss', color=color, fontsize=14)
    ax1.plot(epochs_range, history['train_loss'], color=color, linestyle='--', marker='o', label='Train Loss')
    ax1.plot(epochs_range, history['val_loss'], color=color, linestyle='-', marker='o', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy (%)', color=color, fontsize=14)
    ax2.plot(epochs_range, history['train_acc'], color=color, linestyle='--', marker='s', label='Train Accuracy')
    ax2.plot(epochs_range, history['val_acc'], color=color, linestyle='-', marker='s', label='Val Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)

    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 60))
    color = 'tab:green'
    ax3.set_ylabel('EER (%)', color=color, fontsize=14)
    ax3.plot(epochs_range, history['val_eer'], color=color, linestyle=':', marker='^', label='Val EER')
    ax3.tick_params(axis='y', labelcolor=color)
    
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines3, labels3 = ax3.get_legend_handles_labels()
    ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3, loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, shadow=True, ncol=5)

    fig.suptitle('Training and Validation Metrics', fontsize=16)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96])
    plt.savefig(save_path)
    print(f"\n📈 Training plot saved to {save_path}")
    plt.close()

class AudioFeatureDataset(Dataset):
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- 3. SINGLE-STREAM FUSION TRANSFORMER MODEL ---

class SingleStreamFusionTransformer(nn.Module):
    """
    Fuses CQCC and eGeMAPS by concatenating them into a single sequence
    and feeding them to a single Transformer Encoder. Uses token-type embeddings
    to distinguish between the two modalities.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, d_model, nhead, num_encoder_layers, dropout):
        super(SingleStreamFusionTransformer, self).__init__()
        
        self.d_model = d_model
        
        self.cqcc_projection = nn.Linear(cqcc_features, d_model)
        self.egmaps_projection = nn.Linear(egmaps_features, d_model)
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.token_type_embeddings = nn.Embedding(num_embeddings=2, embedding_dim=d_model)
        
        self.positional_encoding = nn.Parameter(torch.zeros(1, 1 + time_steps * 2, d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        cqcc_x = cqcc_x.transpose(1, 2)
        egmaps_x = egmaps_x.transpose(1, 2)
        
        batch_size = cqcc_x.size(0)
        time_steps = cqcc_x.size(1)
        
        cqcc_embed = self.cqcc_projection(cqcc_x)
        egmaps_embed = self.egmaps_projection(egmaps_x)
        
        cqcc_type_ids = torch.zeros(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        egmaps_type_ids = torch.ones(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        
        cqcc_type_embed = self.token_type_embeddings(cqcc_type_ids)
        egmaps_type_embed = self.token_type_embeddings(egmaps_type_ids)

        cqcc_embed += cqcc_type_embed
        egmaps_embed += egmaps_type_embed
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        full_sequence = torch.cat([cls_tokens, cqcc_embed, egmaps_embed], dim=1)
        
        full_sequence += self.positional_encoding
        
        transformer_out = self.transformer_encoder(full_sequence)
        
        cls_output = transformer_out[:, 0, :]
        output = self.classifier(cls_output)
        
        return torch.sigmoid(output)

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    print("--- Scaling Features ---")
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    print("--- Saving Scalers ---")
    joblib.dump(scaler_cqcc, os.path.join(OUTPUT_DIR, "scaler_cqcc.joblib"))
    joblib.dump(scaler_lld, os.path.join(OUTPUT_DIR, "scaler_lld.joblib"))
    print(f"✅ Scalers saved to {OUTPUT_DIR}")
    
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_eer': []}
    
    print(f"\n--- Starting Training: Single-Stream Fusion Transformer Model ---")
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        train_labels, train_preds = [], []
        
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_labels.extend(labels_batch.cpu().numpy())
            train_preds.extend(outputs.detach().cpu().numpy())

        model.eval()
        total_val_loss = 0
        val_labels, val_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  "):
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                
                total_val_loss += loss.item()
                val_scores.extend(outputs.cpu().numpy())
                val_labels.extend(labels_batch.cpu().numpy())
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_labels = np.array(train_labels)
        train_preds_binary = (np.array(train_preds) > 0.5).astype(int).flatten()
        train_acc = accuracy_score(train_labels, train_preds_binary) * 100

        avg_val_loss = total_val_loss / len(val_loader)
        val_labels = np.array(val_labels)
        val_scores = np.array(val_scores).flatten()
        val_preds_binary = (val_scores > 0.5).astype(int)
        val_acc = accuracy_score(val_labels, val_preds_binary) * 100
        val_f1 = f1_score(val_labels, val_preds_binary)
        val_eer = calculate_eer(val_labels, val_scores)
        cm = confusion_matrix(val_labels, val_preds_binary)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['val_eer'].append(val_eer)

        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}% | Val F1: {val_f1:.4f} | Val EER: {val_eer:.2f}%")
        print("  Validation Confusion Matrix:\n", cm)

        scheduler.step(avg_val_loss)
        
        if val_eer > 0 and val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_single_stream_transformer_model.pth"))
            print(f"  -> ✅ New best model saved with EER: {best_val_eer:.2f}%")

    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_history_single_stream.png"))
    
    # --- ADDED: TESTING PORTION ---
    print("\n" + "="*50)
    print("--- Starting Final Evaluation on Test Set ---")
    
    try:
        print("--- Loading Test Data ---")
        X_cqcc_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_test.npy"))
        X_lld_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_test.npy"))
        y_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_test.npy"))
        print(f"✅ Loaded {len(y_test)} test samples.")
    except FileNotFoundError as e:
        print(f"❌ Error loading test data files: {e}")
        exit()

    print("--- Scaling Test Features ---")
    X_lld_test_scaled = scaler_lld.transform(X_lld_test.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_test.shape)
    X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_test.shape)
    
    test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_lld_test_scaled, y_test)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print("--- Loading Best Trained Model for Testing ---")
    # Re-initialize the model structure
    test_model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    model_path = os.path.join(OUTPUT_DIR, "best_single_stream_transformer_model.pth")
    try:
        test_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        print("✅ Model weights loaded successfully.")
    except FileNotFoundError:
        print(f"❌ Model file not found at {model_path}")
        exit()

    test_model.eval()
    
    test_labels, test_scores = [], []
    with torch.no_grad():
        for cqcc_batch, lld_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
            cqcc_batch, lld_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE)
            outputs = test_model(cqcc_batch, lld_batch)
            test_scores.extend(outputs.cpu().numpy())
            test_labels.extend(labels_batch.cpu().numpy())
    
    test_labels = np.array(test_labels)
    test_scores = np.array(test_scores).flatten()
    test_preds_binary = (test_scores > 0.5).astype(int)
    
    test_acc = accuracy_score(test_labels, test_preds_binary) * 100
    test_f1 = f1_score(test_labels, test_preds_binary)
    test_eer = calculate_eer(test_labels, test_scores)
    test_cm = confusion_matrix(test_labels, test_preds_binary)
    
    print("\n" + "="*40)
    print("--- Final Test Results ---")
    print(f"  Accuracy: {test_acc:.2f}%")
    print(f"  F1-Score: {test_f1:.4f}")
    print(f"  EER:      {test_eer:.2f}%")
    print("  Confusion Matrix:")
    print(test_cm)
    print("="*40)


In [None]:
# import os
# import numpy as np
# import pandas as pd
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader
# from sklearn.preprocessing import StandardScaler
# from sklearn.metrics import confusion_matrix, f1_score, roc_curve
# from scipy.optimize import brentq
# from scipy.interpolate import interp1d
# from tqdm import tqdm
# from sklearn.model_selection import train_test_split
# import matplotlib.pyplot as plt

# # --- Configuration ---
# # Paths for TRAINING data
# CQCC_FEATURES_TRAIN_PATH = "processed_data/cqcc_features.npy"
# PROSODIC_FEATURES_TRAIN_CSV_PATH = "processed_data/prosodic_features_and_labels.csv"

# # Paths for VALIDATION data
# CQCC_FEATURES_VAL_PATH = "processed_data/cqcc_features_val.npy"
# PROSODIC_FEATURES_VAL_CSV_PATH = "processed_data/prosodic_features_and_labels_val.csv"

# # Paths for TEST data
# CQCC_FEATURES_TEST_PATH = "processed_data/cqcc_features_test.npy"
# PROSODIC_FEATURES_TEST_CSV_PATH = "processed_data/prosodic_features_and_labels_test.csv"

# # --- Model and Analysis Configuration ---
# MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best.pth"
# PLOT_SAVE_PATH = "saved_models/training_metrics.png"
# ATTENTION_PLOT_PATH = "saved_models/attention_importance.png"
# ABLATION_PLOT_PATH = "saved_models/ablation_importance.png"
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# BATCH_SIZE = 64
# EPOCHS = 40
# LEARNING_RATE = 1e-4 
# WEIGHT_DECAY = 1e-5

# os.makedirs("saved_models", exist_ok=True)
# print(f"Using device: {DEVICE}")


# def calculate_eer(y_true, y_score):
#     """Calculates the Equal Error Rate (EER)."""
#     fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
#     eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
#     return eer * 100

# def plot_training_history(history, save_path):
#     """Plots and saves the training history graph."""
#     fig, ax1 = plt.subplots(figsize=(12, 8))

#     color = 'tab:red'
#     ax1.set_xlabel('Epochs')
#     ax1.set_ylabel('Loss', color=color)
#     ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
#     ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
#     ax1.tick_params(axis='y', labelcolor=color)
#     ax1.legend(loc='upper left')

#     ax2 = ax1.twinx()  
#     color = 'tab:blue'
#     ax2.set_ylabel('EER (%)', color=color)
#     ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
#     ax2.tick_params(axis='y', labelcolor=color)
#     ax2.legend(loc='upper right')

#     fig.tight_layout()
#     plt.title('Training and Validation Metrics')
#     plt.savefig(save_path)
#     print(f"\nTraining plot saved to {save_path}")
#     plt.close()


# class AudioFeatureDataset(Dataset):
#     """Custom PyTorch Dataset."""
#     def __init__(self, cqcc_data, prosody_data, labels):
#         self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
#         self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
#         self.labels = torch.tensor(labels, dtype=torch.float32)

#     def __len__(self):
#         return len(self.labels)

#     def __getitem__(self, idx):
#         return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

# class AttentionFusionCNN(nn.Module):
#     """PyTorch implementation using Conv2D for CQCC features."""
#     def __init__(self, cqcc_input_shape, prosodic_features):
#         super(AttentionFusionCNN, self).__init__()
        
#         self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
#         self.cqcc_bn1 = nn.BatchNorm2d(16)
#         self.cqcc_pool1 = nn.MaxPool2d((2, 2))
#         self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
#         self.cqcc_bn2 = nn.BatchNorm2d(32)
#         self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
#         with torch.no_grad():
#             dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
#             dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
#             cqcc_flat_size = dummy_out.numel()
            
#         self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)
#         self.prosody_fc1 = nn.Linear(prosodic_features, 32)
#         self.prosody_bn1 = nn.BatchNorm1d(32)
#         self.prosody_dropout = nn.Dropout(0.4)
#         self.prosody_fc2 = nn.Linear(32, 64)
#         concatenated_size = 64 + 64
#         self.attention = nn.Linear(concatenated_size, concatenated_size)
#         self.classifier_fc1 = nn.Linear(concatenated_size, 64)
#         self.classifier_bn = nn.BatchNorm1d(64)
#         self.classifier_dropout = nn.Dropout(0.5)
#         self.output_fc = nn.Linear(64, 1)

#     def forward(self, cqcc_x, prosody_x):
#         # IMPORTANT: This forward pass now returns attention weights
#         cqcc_x = cqcc_x.unsqueeze(1)
#         cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
#         cqcc_out = self.cqcc_pool1(cqcc_out)
#         cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
#         cqcc_out = self.cqcc_pool2(cqcc_out)
#         cqcc_out = torch.flatten(cqcc_out, 1)
#         cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

#         prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
#         prosody_out = self.prosody_dropout(prosody_out)
#         prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

#         concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        
#         # The attention mechanism here is self-attention on the combined features.
#         # The weights will show the importance of each part of the concatenated vector.
#         attention_weights = torch.softmax(self.attention(concatenated), dim=1)
#         fused = concatenated * attention_weights

#         x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
#         x = self.classifier_dropout(x)
#         output = torch.sigmoid(self.output_fc(x))
        
#         # Return both the final prediction and the attention weights
#         return output, attention_weights

# # ==============================================================================
# # ANALYSIS FUNCTIONS
# # ==============================================================================

# def analyze_attention_weights(model, dataloader, feature_names, device, save_path):
#     """
#     Analyzes and visualizes aggregated attention weights for prosodic features.
#     """
#     print("\n--- Running Attention Weight Analysis ---")
#     model.eval()
#     # The attention is on the concatenated vector (64 CQCC + N prosodic features expanded to 64)
#     # We are interested in the second half of the attention weights
#     num_prosodic_features = len(feature_names)
#     attention_scores = np.zeros(64) # Attention weights for the prosody branch

#     with torch.no_grad():
#         for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
#             cqcc, prosody = cqcc.to(device), prosody.to(device)
#             _, weights = model(cqcc, prosody)
#             # weights shape: [batch_size, 128 (64+64)]
#             # We only care about the weights applied to the prosodic part of the vector
#             prosody_attention_weights = weights[:, 64:]
#             attention_scores += prosody_attention_weights.sum(dim=0).cpu().numpy()

#     # Since the prosody branch is an MLP, we can't directly map these 64 weights back
#     # to the original N features. This analysis shows the importance of the learned
#     # prosodic representation, but not individual input features.
#     # For a more direct analysis, feature ablation is better suited for this model architecture.
#     print("NOTE: Attention analysis for this model shows importance of the *learned prosodic representation*.")
#     print("Feature ablation is recommended for analyzing original input feature importance.")
    
#     # We can still plot the importance of the learned 64 prosodic dimensions
#     plt.figure(figsize=(12, 8))
#     plt.bar(range(64), attention_scores, color='purple')
#     plt.xlabel('Dimension of Learned Prosodic Representation')
#     plt.ylabel('Aggregated Attention Score')
#     plt.title('Importance of Learned Prosodic Feature Dimensions')
#     plt.tight_layout()
#     plt.savefig(save_path.replace(".png", "_learned_dims.png"))
#     plt.close()


# def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
#     """
#     Performs feature ablation to measure EER increase.
#     """
#     print("\n--- Running Feature Ablation Analysis ---")
    
#     def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
#         model.eval()
#         all_labels, all_scores = [], []
#         with torch.no_grad():
#             for cqcc, prosody, labels in dataloader:
#                 cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
#                 if feature_to_ablate is not None:
#                     prosody[:, feature_to_ablate] = 0.0 # Zero out the feature
                
#                 outputs, _ = model(cqcc, prosody)
#                 all_scores.extend(outputs.cpu().numpy())
#                 all_labels.extend(labels.cpu().numpy())
#         return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

#     baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
#     print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    
#     eer_increases = {}
#     for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
#         ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
#         eer_increases[name] = ablated_eer - baseline_eer
        
#     sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
#     print("\nFeature Importance based on EER Increase:")
#     for feature, increase in sorted_features:
#         print(f"- {feature}: EER increases by {increase:.2f}%")

#     names = [item[0] for item in sorted_features]
#     increases = [item[1] for item in sorted_features]
#     plt.figure(figsize=(12, 8))
#     plt.barh(names, increases, color='salmon')
#     plt.xlabel('EER Increase (%)')
#     plt.title('Prosodic Feature Importance based on Feature Ablation')
#     plt.gca().invert_yaxis()
#     plt.tight_layout()
#     plt.savefig(save_path)
#     print(f"\nAblation plot saved to {save_path}")
#     plt.close()


# # ==============================================================================
# # MAIN EXECUTION BLOCK
# # ==============================================================================

# if __name__ == '__main__':
#     try:
#         print("--- Loading and Preparing Data ---")
#         prosody_df_train = pd.read_csv(PROSODIC_FEATURES_TRAIN_CSV_PATH)
#         X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
#         prosody_df_val = pd.read_csv(PROSODIC_FEATURES_VAL_CSV_PATH)
#         X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        
#         if len(prosody_df_train) != len(X_cqcc_train) or len(prosody_df_val) != len(X_cqcc_val):
#             raise ValueError("Sample count mismatch between CSV and .npy files.")

#         feature_columns = ['mean_f0', 'std_f0', 'jitter', 'shimmer', 'mean_hnr', 'std_hnr']
#         X_prosody_train = prosody_df_train[feature_columns].values
#         y_train = prosody_df_train['label'].values
#         X_prosody_val = prosody_df_val[feature_columns].values
#         y_val = prosody_df_val['label'].values
#         print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
#     except (FileNotFoundError, ValueError) as e:
#         print(f"Error loading data: {e}")
#         exit()

#     print("--- Scaling Data ---")
#     scaler_prosody = StandardScaler()
#     X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
#     X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

#     scaler_cqcc = StandardScaler()
#     nsamples, nx, ny = X_cqcc_train.shape
#     X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
#     nsamples_val, nx_val, ny_val = X_cqcc_val.shape
#     X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
#     print("Scaling complete.")

#     train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
#     val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
#     train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
#     val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

#     model = AttentionFusionCNN(
#         cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
#         prosodic_features=X_prosody_train.shape[1]
#     ).to(DEVICE)
    
#     criterion = nn.BCELoss()
#     optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
#     print(model)
    
#     best_val_loss = float('inf')
#     history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
#     print("\n--- Starting Model Training ---")

#     for epoch in range(EPOCHS):
#         model.train()
#         running_loss = 0.0
        
#         for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
#             cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
#             optimizer.zero_grad()
#             outputs, _ = model(cqcc_batch, prosody_batch) # Ignore weights during training
#             loss = criterion(outputs, labels_batch.unsqueeze(1))
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()

#         model.eval()
#         val_loss = 0.0
#         all_labels = []
#         all_scores = []
#         with torch.no_grad():
#             for cqcc_batch, prosody_batch, labels_batch in val_loader:
#                 cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
#                 outputs, _ = model(cqcc_batch, prosody_batch) # Ignore weights during validation
#                 loss = criterion(outputs, labels_batch.unsqueeze(1))
#                 val_loss += loss.item()
#                 all_scores.extend(outputs.cpu().numpy())
#                 all_labels.extend(labels_batch.cpu().numpy())

#         avg_train_loss = running_loss / len(train_loader)
#         avg_val_loss = val_loss / len(val_loader)
        
#         all_labels = np.array(all_labels)
#         all_scores = np.array(all_scores).flatten()
#         all_preds = (all_scores > 0.5).astype(int)

#         val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
#         f1 = f1_score(all_labels, all_preds)
#         eer = calculate_eer(all_labels, all_scores)
#         cm = confusion_matrix(all_labels, all_preds)

#         print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
#         print("Validation Confusion Matrix:\n", cm)
        
#         history['train_loss'].append(avg_train_loss)
#         history['val_loss'].append(avg_val_loss)
#         history['val_acc'].append(val_accuracy)
#         history['f1'].append(f1)
#         history['eer'].append(eer)
        
#         scheduler.step(avg_val_loss)

#         if avg_val_loss < best_val_loss:
#             best_val_loss = avg_val_loss
#             torch.save(model.state_dict(), MODEL_SAVE_PATH)
#             print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

#     print("\n--- Training Complete ---")
#     print(f"Best validation loss achieved: {best_val_loss:.4f}")

#     plot_training_history(history, PLOT_SAVE_PATH)

#     # ==============================================================================
#     # FINAL TESTING AND ANALYSIS
#     # ==============================================================================
#     print("\n--- Starting Final Testing and Analysis ---")
#     try:
#         print("Loading test data...")
#         prosody_df_test_full = pd.read_csv(PROSODIC_FEATURES_TEST_CSV_PATH)
#         X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
#         X_prosody_test_full = prosody_df_test_full[feature_columns].values
#         y_test_full = prosody_df_test_full['label'].values
        
#         NUM_SAMPLES_TO_SELECT = min(70000, len(y_test_full))
#         print(f"Creating a balanced subset of {NUM_SAMPLES_TO_SELECT} samples for testing and analysis...")

#         _, _, _, _, _, selected_indices = train_test_split(
#             X_cqcc_test_full, y_test_full, np.arange(len(y_test_full)),
#             test_size=NUM_SAMPLES_TO_SELECT, stratify=y_test_full, random_state=42
#         )

#         X_cqcc_test_subset = X_cqcc_test_full[selected_indices]
#         X_prosody_test_subset = X_prosody_test_full[selected_indices]
#         y_test_subset = y_test_full[selected_indices]
        
#         X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_subset)
#         nsamples_test, nx_test, ny_test = X_cqcc_test_subset.shape
#         X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_subset.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
        
#         test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test_subset)
#         test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
#         print("Loading best model for testing and analysis...")
#         analysis_model = AttentionFusionCNN(
#             cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
#             prosodic_features=X_prosody_train.shape[1]
#         ).to(DEVICE)
#         analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
#         analysis_model.eval()

#         # --- First, get final test metrics ---
#         all_test_labels = []
#         all_test_scores = []
#         with torch.no_grad():
#             for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
#                 cqcc_batch, prosody_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE)
#                 outputs, _ = analysis_model(cqcc_batch, prosody_batch)
#                 all_test_scores.extend(outputs.cpu().numpy())
#                 all_test_labels.extend(labels_batch.cpu().numpy())
        
#         all_test_labels = np.array(all_test_labels)
#         all_test_scores = np.array(all_test_scores).flatten()
#         all_test_preds = (all_test_scores > 0.5).astype(int)

#         test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
#         test_f1 = f1_score(all_test_labels, all_test_preds)
#         test_eer = calculate_eer(all_test_labels, all_test_scores)
#         test_cm = confusion_matrix(all_test_labels, all_test_preds)

#         print("\n--- Final Test Results ---")
#         print(f"Accuracy: {test_accuracy:.2f}%")
#         print(f"F1-Score: {test_f1:.4f}")
#         print(f"EER: {test_eer:.2f}%")
#         print("Confusion Matrix:\n", test_cm)

#         # --- Now, run analysis functions ---
#         analyze_attention_weights(analysis_model, test_loader, feature_columns, DEVICE, ATTENTION_PLOT_PATH)
#         perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)

#     except (FileNotFoundError, ValueError) as e:
#         print(f"Error during testing/analysis: {e}")
#         print("Please ensure your test data files are in the correct paths and format.")


In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import shap

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data/cqcc_features.npy"
PROSODIC_FEATURES_TRAIN_CSV_PATH = "processed_data/prosodic_features_and_labels.csv"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data/cqcc_features_val.npy"
PROSODIC_FEATURES_VAL_CSV_PATH = "processed_data/prosodic_features_and_labels_val.csv"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_CSV_PATH = "processed_data/prosodic_features_and_labels_test.csv"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance.png"
SHAP_PLOT_PATH = "saved_models/shap_importance.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))

    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')

    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionCNN(nn.Module):
    """PyTorch implementation using Conv2D for CQCC features."""
    def __init__(self, cqcc_input_shape, prosodic_features):
        super(AttentionFusionCNN, self).__init__()
        
        self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.cqcc_bn1 = nn.BatchNorm2d(16)
        self.cqcc_pool1 = nn.MaxPool2d((2, 2))
        self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.cqcc_bn2 = nn.BatchNorm2d(32)
        self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
        with torch.no_grad():
            dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
            dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
            cqcc_flat_size = dummy_out.numel()
            
        self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)
        self.prosody_fc1 = nn.Linear(prosodic_features, 32)
        self.prosody_bn1 = nn.BatchNorm1d(32)
        self.prosody_dropout = nn.Dropout(0.4)
        self.prosody_fc2 = nn.Linear(32, 64)
        concatenated_size = 64 + 64
        self.attention = nn.Linear(concatenated_size, concatenated_size)
        self.classifier_fc1 = nn.Linear(concatenated_size, 64)
        self.classifier_bn = nn.BatchNorm1d(64)
        self.classifier_dropout = nn.Dropout(0.5)
        self.output_fc = nn.Linear(64, 1)

    def forward(self, cqcc_x, prosody_x):
        # IMPORTANT: This forward pass now returns attention weights
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
        cqcc_out = self.cqcc_pool1(cqcc_out)
        cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
        cqcc_out = self.cqcc_pool2(cqcc_out)
        cqcc_out = torch.flatten(cqcc_out, 1)
        cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

        prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
        prosody_out = self.prosody_dropout(prosody_out)
        prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

        concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        
        attention_weights = torch.softmax(self.attention(concatenated), dim=1)
        fused = concatenated * attention_weights

        x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
        x = self.classifier_dropout(x)
        output = torch.sigmoid(self.output_fc(x))
        
        return output, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, feature_names, device, save_path):
    """
    Analyzes and visualizes aggregated attention weights for prosodic features.
    """
    print("\n--- Running Attention Weight Analysis ---")
    model.eval()
    attention_scores = np.zeros(64) # Attention weights for the prosody branch

    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            prosody_attention_weights = weights[:, 64:]
            attention_scores += prosody_attention_weights.sum(dim=0).cpu().numpy()

    print("NOTE: Attention analysis for this model shows importance of the *learned prosodic representation*.")
    print("Feature ablation is recommended for analyzing original input feature importance.")
    
    plt.figure(figsize=(12, 8))
    plt.bar(range(64), attention_scores, color='purple')
    plt.xlabel('Dimension of Learned Prosodic Representation')
    plt.ylabel('Aggregated Attention Score')
    plt.title('Importance of Learned Prosodic Feature Dimensions')
    plt.tight_layout()
    plt.savefig(save_path.replace(".png", "_learned_dims.png"))
    plt.close()


def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    """
    Performs feature ablation to measure EER increase.
    """
    print("\n--- Running Feature Ablation Analysis ---")
    
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody[:, feature_to_ablate] = 0.0 # Zero out the feature
                
                outputs, _ = model(cqcc, prosody)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
        
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")

    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 8))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    """
    Method 3: Uses SHAP to explain model predictions for prosodic features.
    This is computationally intensive and is run on a subset of data.
    """
    print("\n--- Running SHAP Analysis ---")
    model.eval()

    # Get a small batch of data for SHAP analysis
    # SHAP needs a background dataset to integrate over, and a test set to explain
    background_cqcc, background_prosody, _ = next(iter(dataloader))
    
    # We'll explain the predictions for another batch
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    # SHAP's KernelExplainer needs a function that takes a numpy array.
    # We create a wrapper that takes only the prosodic features (as a numpy array),
    # combines them with a fixed background CQCC sample, and returns the model's prediction.
    def model_wrapper(prosodic_features_numpy):
        # Number of samples SHAP is currently testing
        num_samples = prosodic_features_numpy.shape[0]
        
        # Convert prosodic numpy array to a tensor
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        
        # Use a single CQCC sample as a fixed background for all predictions
        # and expand it to match the batch size of the prosody samples
        cqcc_background_sample = background_cqcc[0:1].to(device) # Take the first sample
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        
        with torch.no_grad():
            output, _ = model(cqcc_tensor, prosody_tensor)
        
        return output.cpu().numpy()

    # Create the explainer
    # We use the background prosody data to initialize the explainer
    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    
    print("Calculating SHAP values (this may take a while)...")
    # Calculate SHAP values for the test prosody data
    # Using a small number of samples for demonstration purposes
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    
    print("Plotting SHAP summary...")
    # The output of shap_values for a single-output model might be a list
    # We take the first element if it is
    if isinstance(shap_values, list):
        shap_values = shap_values[0]

    # Create the summary plot
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()


# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        prosody_df_train = pd.read_csv(PROSODIC_FEATURES_TRAIN_CSV_PATH)
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        prosody_df_val = pd.read_csv(PROSODIC_FEATURES_VAL_CSV_PATH)
        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        
        if len(prosody_df_train) != len(X_cqcc_train) or len(prosody_df_val) != len(X_cqcc_val):
            raise ValueError("Sample count mismatch between CSV and .npy files.")

        feature_columns = ['mean_f0', 'std_f0', 'jitter', 'shimmer', 'mean_hnr', 'std_hnr']
        X_prosody_train = prosody_df_train[feature_columns].values
        y_train = prosody_df_train['label'].values
        X_prosody_val = prosody_df_val[feature_columns].values
        y_val = prosody_df_val['label'].values
        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("--- Scaling Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = model(cqcc_batch, prosody_batch) # Ignore weights during training
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs, _ = model(cqcc_batch, prosody_batch) # Ignore weights during validation
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:\n", cm)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['f1'].append(f1)
        history['eer'].append(eer)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

    plot_training_history(history, PLOT_SAVE_PATH)

    # ==============================================================================
    # FINAL TESTING AND ANALYSIS
    # ==============================================================================
    print("\n--- Starting Final Testing and Analysis ---")
    try:
        print("Loading test data...")
        prosody_df_test_full = pd.read_csv(PROSODIC_FEATURES_TEST_CSV_PATH)
        X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_full = prosody_df_test_full[feature_columns].values
        y_test_full = prosody_df_test_full['label'].values
        
        NUM_SAMPLES_TO_SELECT = min(70000, len(y_test_full))
        print(f"Creating a balanced subset of {NUM_SAMPLES_TO_SELECT} samples for testing and analysis...")

        _, _, _, _, _, selected_indices = train_test_split(
            X_cqcc_test_full, y_test_full, np.arange(len(y_test_full)),
            test_size=NUM_SAMPLES_TO_SELECT, stratify=y_test_full, random_state=42
        )

        X_cqcc_test_subset = X_cqcc_test_full[selected_indices]
        X_prosody_test_subset = X_prosody_test_full[selected_indices]
        y_test_subset = y_test_full[selected_indices]
        
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_subset)
        nsamples_test, nx_test, ny_test = X_cqcc_test_subset.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_subset.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
        
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test_subset)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        print("Loading best model for testing and analysis...")
        analysis_model = AttentionFusionCNN(
            cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
            prosodic_features=X_prosody_train.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()

        # --- First, get final test metrics ---
        all_test_labels = []
        all_test_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
                cqcc_batch, prosody_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE)
                outputs, _ = analysis_model(cqcc_batch, prosody_batch)
                all_test_scores.extend(outputs.cpu().numpy())
                all_test_labels.extend(labels_batch.cpu().numpy())
        
        all_test_labels = np.array(all_test_labels)
        all_test_scores = np.array(all_test_scores).flatten()
        all_test_preds = (all_test_scores > 0.5).astype(int)

        test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
        test_f1 = f1_score(all_test_labels, all_test_preds)
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        test_cm = confusion_matrix(all_test_labels, all_test_preds)

        print("\n--- Final Test Results ---")
        print(f"Accuracy: {test_accuracy:.2f}%")
        print(f"F1-Score: {test_f1:.4f}")
        print(f"EER: {test_eer:.2f}%")
        print("Confusion Matrix:\n", test_cm)

        # --- Now, run analysis functions ---
        analyze_attention_weights(analysis_model, test_loader, feature_columns, DEVICE, ATTENTION_PLOT_PATH)
        perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except (FileNotFoundError, ValueError) as e:
        print(f"Error during testing/analysis: {e}")
        print("Please ensure your test data files are in the correct paths and format.")


In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import shap

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance.png"
SHAP_PLOT_PATH = "saved_models/shap_importance.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))

    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')

    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionCNN(nn.Module):
    """PyTorch implementation using Conv2D for CQCC features."""
    def __init__(self, cqcc_input_shape, prosodic_features):
        super(AttentionFusionCNN, self).__init__()
        
        self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.cqcc_bn1 = nn.BatchNorm2d(16)
        self.cqcc_pool1 = nn.MaxPool2d((2, 2))
        self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.cqcc_bn2 = nn.BatchNorm2d(32)
        self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
        with torch.no_grad():
            dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
            dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
            cqcc_flat_size = dummy_out.numel()
            
        self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)
        self.prosody_fc1 = nn.Linear(prosodic_features, 32)
        self.prosody_bn1 = nn.BatchNorm1d(32)
        self.prosody_dropout = nn.Dropout(0.4)
        self.prosody_fc2 = nn.Linear(32, 64)
        concatenated_size = 64 + 64
        self.attention = nn.Linear(concatenated_size, concatenated_size)
        self.classifier_fc1 = nn.Linear(concatenated_size, 64)
        self.classifier_bn = nn.BatchNorm1d(64)
        self.classifier_dropout = nn.Dropout(0.5)
        self.output_fc = nn.Linear(64, 1)

    def forward(self, cqcc_x, prosody_x):
        # IMPORTANT: This forward pass now returns attention weights
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
        cqcc_out = self.cqcc_pool1(cqcc_out)
        cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
        cqcc_out = self.cqcc_pool2(cqcc_out)
        cqcc_out = torch.flatten(cqcc_out, 1)
        cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

        prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
        prosody_out = self.prosody_dropout(prosody_out)
        prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

        concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        
        attention_weights = torch.softmax(self.attention(concatenated), dim=1)
        fused = concatenated * attention_weights

        x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
        x = self.classifier_dropout(x)
        output = torch.sigmoid(self.output_fc(x))
        
        return output, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, feature_names, device, save_path):
    """
    Analyzes and visualizes aggregated attention weights for prosodic features.
    """
    print("\n--- Running Attention Weight Analysis ---")
    model.eval()
    attention_scores = np.zeros(64) # Attention weights for the prosody branch

    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            prosody_attention_weights = weights[:, 64:]
            attention_scores += prosody_attention_weights.sum(dim=0).cpu().numpy()

    print("NOTE: Attention analysis for this model shows importance of the *learned prosodic representation*.")
    print("Feature ablation is recommended for analyzing original input feature importance.")
    
    plt.figure(figsize=(12, 8))
    plt.bar(range(64), attention_scores, color='purple')
    plt.xlabel('Dimension of Learned Prosodic Representation')
    plt.ylabel('Aggregated Attention Score')
    plt.title('Importance of Learned Prosodic Feature Dimensions')
    plt.tight_layout()
    plt.savefig(save_path.replace(".png", "_learned_dims.png"))
    plt.close()


def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    """
    Performs feature ablation to measure EER increase.
    """
    print("\n--- Running Feature Ablation Analysis ---")
    
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody[:, feature_to_ablate] = 0.0 # Zero out the feature
                
                outputs, _ = model(cqcc, prosody)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
        
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")

    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10)) # Increased figure height for more features
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    """
    Method 3: Uses SHAP to explain model predictions for prosodic features.
    This is computationally intensive and is run on a subset of data.
    """
    print("\n--- Running SHAP Analysis ---")
    model.eval()

    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        
        with torch.no_grad():
            output, _ = model(cqcc_tensor, prosody_tensor)
        
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]

    plt.figure(figsize=(12, 10)) # Ensure plot is large enough
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()


# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        # Load all data from .npy files
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)

        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)
        
        # --- FIX for ValueError ---
        # The LLD prosodic features are 3D (samples, features, time).
        # This model expects 2D summary statistics for prosody.
        # We convert the 3D data to 2D by taking the mean across the time axis.
        print("Converting 3D LLD prosodic features to 2D summary statistics (mean)...")
        # Assumes shape is (samples, features, time), so we take mean over axis 2
        X_prosody_train = np.mean(X_prosody_train_3d, axis=2)
        X_prosody_val = np.mean(X_prosody_val_3d, axis=2)
        
        if not (len(X_cqcc_train) == len(X_prosody_train) == len(y_train)):
            raise ValueError("Sample count mismatch in training files.")
        if not (len(X_cqcc_val) == len(X_prosody_val) == len(y_val)):
            raise ValueError("Sample count mismatch in validation files.")

        # Generate generic feature names since they are not in the .npy file
        num_prosodic_features = X_prosody_train.shape[1]
        feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]
        
        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
        print(f"Using {num_prosodic_features} prosodic features.")

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure your .npy data files are in the correct paths.")
        exit()

    print("--- Scaling Data ---")
    # This now works because X_prosody_train is 2D
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    # This scaling method flattens feature and time dimensions together.
    # It's a form of instance-level normalization.
    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = model(cqcc_batch, prosody_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs, _ = model(cqcc_batch, prosody_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:\n", cm)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['f1'].append(f1)
        history['eer'].append(eer)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

    plot_training_history(history, PLOT_SAVE_PATH)

    # ==============================================================================
    # FINAL TESTING AND ANALYSIS
    # ==============================================================================
    print("\n--- Starting Final Testing and Analysis ---")
    try:
        print("Loading test data...")
        X_cqcc_test = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        
        # --- FIX for ValueError ---
        # Convert 3D test prosodic features to 2D to match training
        X_prosody_test = np.mean(X_prosody_test_3d, axis=2)
        
        print(f"Loaded {len(y_test)} test samples.")
        
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test)
        nsamples_test, nx_test, ny_test = X_cqcc_test.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
        
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        print("Loading best model for testing and analysis...")
        analysis_model = AttentionFusionCNN(
            cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
            prosodic_features=X_prosody_train.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()

        all_test_labels = []
        all_test_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
                cqcc_batch, prosody_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE)
                outputs, _ = analysis_model(cqcc_batch, prosody_batch)
                all_test_scores.extend(outputs.cpu().numpy())
                all_test_labels.extend(labels_batch.cpu().numpy())
        
        all_test_labels = np.array(all_test_labels)
        all_test_scores = np.array(all_test_scores).flatten()
        all_test_preds = (all_test_scores > 0.5).astype(int)

        test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
        test_f1 = f1_score(all_test_labels, all_test_preds)
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        test_cm = confusion_matrix(all_test_labels, all_test_preds)

        print("\n--- Final Test Results ---")
        print(f"Accuracy: {test_accuracy:.2f}%")
        print(f"F1-Score: {test_f1:.4f}")
        print(f"EER: {test_eer:.2f}%")
        print("Confusion Matrix:\n", test_cm)

        analyze_attention_weights(analysis_model, test_loader, feature_columns, DEVICE, ATTENTION_PLOT_PATH)
        perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except (FileNotFoundError, ValueError) as e:
        print(f"Error during testing/analysis: {e}")
        print("Please ensure your .npy data files are in the correct paths and format.")


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, f1_score, confusion_matrix, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import joblib
import shap

# --- 1. CONFIGURATION ---

# --- Paths ---
# Ensure this path is correct for your environment
TEAMMATE_DATA_PATH = '/mount/studenten/arbeitsdaten-studenten1/team-lab-phonetics/2025/student_directories/AuFa/'
PREPROCESSED_DATA_DIR = os.path.join(TEAMMATE_DATA_PATH, "processed_data_aligned_lld")
OUTPUT_DIR = os.path.join(TEAMMATE_DATA_PATH, "single_stream_transformer_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Model & Training Parameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32 # Transformers can be memory intensive
EPOCHS = 25
LEARNING_RATE = 1e-4
CQCC_SHAPE = (128, 157)
EGMAPS_LLD_SHAPE = (23, 157)
EMBEDDING_DIM = 128 # d_model for the transformer
NUM_HEADS = 8      # Number of attention heads
NUM_ENCODER_LAYERS = 6 # Can use a deeper single encoder
DROPOUT = 0.2

# --- Analysis Configuration ---
ABLATION_PLOT_PATH = os.path.join(OUTPUT_DIR, "ablation_importance.png")
SHAP_PLOT_PATH = os.path.join(OUTPUT_DIR, "shap_importance.png")


# --- 2. UTILITY FUNCTIONS & DATASET CLASS ---

def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER). Returns -1 if calculation fails."""
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        return eer * 100
    except (ValueError, ZeroDivisionError):
        return -1.0

def plot_training_history(history, save_path):
    """Plots and saves a comprehensive training history graph."""
    fig, ax1 = plt.subplots(figsize=(14, 8))
    epochs_range = range(1, len(history['train_loss']) + 1)

    color = 'tab:red'
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Loss', color=color, fontsize=14)
    ax1.plot(epochs_range, history['train_loss'], color=color, linestyle='--', marker='o', label='Train Loss')
    ax1.plot(epochs_range, history['val_loss'], color=color, linestyle='-', marker='o', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy (%)', color=color, fontsize=14)
    ax2.plot(epochs_range, history['train_acc'], color=color, linestyle='--', marker='s', label='Train Accuracy')
    ax2.plot(epochs_range, history['val_acc'], color=color, linestyle='-', marker='s', label='Val Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)

    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 60))
    color = 'tab:green'
    ax3.set_ylabel('EER (%)', color=color, fontsize=14)
    ax3.plot(epochs_range, history['val_eer'], color=color, linestyle=':', marker='^', label='Val EER')
    ax3.tick_params(axis='y', labelcolor=color)
    
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines3, labels3 = ax3.get_legend_handles_labels()
    ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3, loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, shadow=True, ncol=5)

    fig.suptitle('Training and Validation Metrics', fontsize=16)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96])
    plt.savefig(save_path)
    print(f"\n📈 Training plot saved to {save_path}")
    plt.close()

class AudioFeatureDataset(Dataset):
    def __init__(self, cqcc_data, egmaps_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.egmaps_data = torch.tensor(egmaps_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.egmaps_data[idx], self.labels[idx]

# --- ADDED: ANALYSIS FUNCTIONS ---

def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    """
    Performs feature ablation on the 3D LLD data to measure EER increase.
    """
    print("\n--- Running Feature Ablation Analysis ---")
    
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    # Zero out the specified feature channel across all time steps
                    # Input shape is (batch, features, time)
                    prosody[:, feature_to_ablate, :] = 0.0 
                
                outputs = model(cqcc, prosody)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
        
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")

    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()

def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    """
    Uses SHAP to explain model predictions based on the mean of LLD features.
    """
    print("\n--- Running SHAP Analysis ---")
    model.eval()

    background_cqcc, background_prosody_3d, _ = next(iter(dataloader))
    test_cqcc, test_prosody_3d, _ = next(iter(dataloader))
    
    # SHAP works best with 2D data, so we'll analyze the mean of the LLDs
    background_prosody_2d = np.mean(background_prosody_3d.numpy(), axis=2)
    test_prosody_2d = np.mean(test_prosody_3d.numpy(), axis=2)

    def model_wrapper(prosodic_features_2d_numpy):
        num_samples = prosodic_features_2d_numpy.shape[0]
        
        # Expand the 2D summary stats back to a 3D sequence by repeating
        # This is an approximation to make the data compatible with the model
        time_steps = EGMAPS_LLD_SHAPE[1]
        prosody_3d_numpy = np.repeat(prosodic_features_2d_numpy[:, :, np.newaxis], time_steps, axis=2)
        prosody_tensor = torch.from_numpy(prosody_3d_numpy).float().to(device)
        
        # Use a fixed CQCC background for all SHAP predictions
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        
        with torch.no_grad():
            output = model(cqcc_tensor, prosody_tensor)
        
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody_2d)
    
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody_2d, nsamples=100)
    
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]

    plt.figure(figsize=(12, 10))
    shap.summary_plot(shap_values, test_prosody_2d, feature_names=feature_names, show=False)
    
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()

# --- 3. SINGLE-STREAM FUSION TRANSFORMER MODEL ---

class SingleStreamFusionTransformer(nn.Module):
    """
    Fuses CQCC and eGeMAPS by concatenating them into a single sequence
    and feeding them to a single Transformer Encoder. Uses token-type embeddings
    to distinguish between the two modalities.
    """
    def __init__(self, cqcc_features, egmaps_features, time_steps, d_model, nhead, num_encoder_layers, dropout):
        super(SingleStreamFusionTransformer, self).__init__()
        
        self.d_model = d_model
        
        self.cqcc_projection = nn.Linear(cqcc_features, d_model)
        self.egmaps_projection = nn.Linear(egmaps_features, d_model)
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.token_type_embeddings = nn.Embedding(num_embeddings=2, embedding_dim=d_model)
        
        self.positional_encoding = nn.Parameter(torch.zeros(1, 1 + time_steps * 2, d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, egmaps_x):
        # Input shapes are expected to be (batch, features, time)
        # Transpose to (batch, time, features) for linear projection
        cqcc_x = cqcc_x.transpose(1, 2)
        egmaps_x = egmaps_x.transpose(1, 2)
        
        batch_size = cqcc_x.size(0)
        time_steps = cqcc_x.size(1)
        
        cqcc_embed = self.cqcc_projection(cqcc_x)
        egmaps_embed = self.egmaps_projection(egmaps_x)
        
        cqcc_type_ids = torch.zeros(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        egmaps_type_ids = torch.ones(batch_size, time_steps, dtype=torch.long, device=DEVICE)
        
        cqcc_type_embed = self.token_type_embeddings(cqcc_type_ids)
        egmaps_type_embed = self.token_type_embeddings(egmaps_type_ids)

        cqcc_embed += cqcc_type_embed
        egmaps_embed += egmaps_type_embed
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        full_sequence = torch.cat([cls_tokens, cqcc_embed, egmaps_embed], dim=1)
        
        full_sequence += self.positional_encoding
        
        transformer_out = self.transformer_encoder(full_sequence)
        
        cls_output = transformer_out[:, 0, :]
        output = self.classifier(cls_output)
        
        return torch.sigmoid(output)

# --- 4. MAIN EXECUTION SCRIPT ---
if __name__ == '__main__':
    print(f"Using device: {DEVICE}")

    try:
        print("--- Loading Data ---")
        X_cqcc_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_train.npy"))
        X_lld_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_train.npy"))
        y_train = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_train.npy"))
        X_cqcc_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_dev.npy"))
        X_lld_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_dev.npy"))
        y_val = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_dev.npy"))
        
        # Define the 23 eGeMAPS LLD feature names
        feature_columns = [
            'F0semitoneFrom27.5Hz_sma3nz', 'jitterLocal_sma3nz', 'shimmerLocaldB_sma3nz',
            'Loudness_sma3', 'HNRdBACF_sma3nz', 'logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz', 'F1frequency_sma3nz', 'F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz', 'F2frequency_sma3nz', 'F2bandwidth_sma3nz',
            'F2amplitudeLogRelF0_sma3nz', 'F3frequency_sma3nz', 'F3bandwidth_sma3nz',
            'F3amplitudeLogRelF0_sma3nz', 'alphaRatio_sma3', 'hammarbergIndex_sma3',
            'slope0-500_sma3', 'slope500-1500_sma3', 'spectralFlux_sma3',
            'mfcc1_sma3', 'mfcc2_sma3'
        ]
        
    except FileNotFoundError as e:
        print(f"❌ Error loading data files: {e}")
        exit()

    print("--- Scaling Features ---")
    # Reshape to 2D for scaler: (samples * time, features)
    scaler_lld = StandardScaler().fit(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0]))
    X_lld_train_scaled = scaler_lld.transform(X_lld_train.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_train.shape)
    X_lld_val_scaled = scaler_lld.transform(X_lld_val.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_val.shape)
    
    scaler_cqcc = StandardScaler().fit(X_cqcc_train.reshape(-1, CQCC_SHAPE[0]))
    X_cqcc_train_scaled = scaler_cqcc.transform(X_cqcc_train.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_train.shape)
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_val.shape)
    
    print("--- Saving Scalers ---")
    joblib.dump(scaler_cqcc, os.path.join(OUTPUT_DIR, "scaler_cqcc.joblib"))
    joblib.dump(scaler_lld, os.path.join(OUTPUT_DIR, "scaler_lld.joblib"))
    print(f"✅ Scalers saved to {OUTPUT_DIR}")
    
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_lld_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_lld_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    best_val_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_eer': []}
    
    print(f"\n--- Starting Training: Single-Stream Fusion Transformer Model ---")
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        train_labels, train_preds = [], []
        
        for cqcc_batch, lld_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(cqcc_batch, lld_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_labels.extend(labels_batch.cpu().numpy())
            train_preds.extend(outputs.detach().cpu().numpy())

        model.eval()
        total_val_loss = 0
        val_labels, val_scores = [], []
        with torch.no_grad():
            for cqcc_batch, lld_batch, labels_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  "):
                cqcc_batch, lld_batch, labels_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs = model(cqcc_batch, lld_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                
                total_val_loss += loss.item()
                val_scores.extend(outputs.cpu().numpy())
                val_labels.extend(labels_batch.cpu().numpy())
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_labels = np.array(train_labels)
        train_preds_binary = (np.array(train_preds) > 0.5).astype(int).flatten()
        train_acc = accuracy_score(train_labels, train_preds_binary) * 100

        avg_val_loss = total_val_loss / len(val_loader)
        val_labels = np.array(val_labels)
        val_scores = np.array(val_scores).flatten()
        val_preds_binary = (val_scores > 0.5).astype(int)
        val_acc = accuracy_score(val_labels, val_preds_binary) * 100
        val_f1 = f1_score(val_labels, val_preds_binary)
        val_eer = calculate_eer(val_labels, val_scores)
        cm = confusion_matrix(val_labels, val_preds_binary)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['val_eer'].append(val_eer)

        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}% | Val F1: {val_f1:.4f} | Val EER: {val_eer:.2f}%")
        print("  Validation Confusion Matrix:\n", cm)

        scheduler.step(avg_val_loss)
        
        if val_eer > 0 and val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_single_stream_transformer_model.pth"))
            print(f"  -> ✅ New best model saved with EER: {best_val_eer:.2f}%")

    print("\n--- Training Complete ---")
    plot_training_history(history, os.path.join(OUTPUT_DIR, "training_history_single_stream.png"))
    
    # --- FINAL EVALUATION AND ANALYSIS ---
    print("\n" + "="*50)
    print("--- Starting Final Evaluation and Analysis on Test Set ---")
    
    try:
        print("--- Loading Test Data ---")
        X_cqcc_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "cqcc_features_test.npy"))
        X_lld_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "egmaps_lld_features_test.npy"))
        y_test = np.load(os.path.join(PREPROCESSED_DATA_DIR, "labels_test.npy"))
        print(f"✅ Loaded {len(y_test)} test samples.")
    except FileNotFoundError as e:
        print(f"❌ Error loading test data files: {e}")
        exit()

    print("--- Scaling Test Features ---")
    X_lld_test_scaled = scaler_lld.transform(X_lld_test.reshape(-1, EGMAPS_LLD_SHAPE[0])).reshape(X_lld_test.shape)
    X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(-1, CQCC_SHAPE[0])).reshape(X_cqcc_test.shape)
    
    test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_lld_test_scaled, y_test)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print("--- Loading Best Trained Model for Testing ---")
    analysis_model = SingleStreamFusionTransformer(
        cqcc_features=CQCC_SHAPE[0],
        egmaps_features=EGMAPS_LLD_SHAPE[0],
        time_steps=CQCC_SHAPE[1],
        d_model=EMBEDDING_DIM,
        nhead=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    model_path = os.path.join(OUTPUT_DIR, "best_single_stream_transformer_model.pth")
    try:
        analysis_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        print("✅ Model weights loaded successfully.")
    except FileNotFoundError:
        print(f"❌ Model file not found at {model_path}")
        exit()

    analysis_model.eval()
    
    test_labels, test_scores = [], []
    with torch.no_grad():
        for cqcc_batch, lld_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
            cqcc_batch, lld_batch = cqcc_batch.to(DEVICE), lld_batch.to(DEVICE)
            outputs = analysis_model(cqcc_batch, lld_batch)
            test_scores.extend(outputs.cpu().numpy())
            test_labels.extend(labels_batch.cpu().numpy())
    
    test_labels = np.array(test_labels)
    test_scores = np.array(test_scores).flatten()
    test_preds_binary = (test_scores > 0.5).astype(int)
    
    test_acc = accuracy_score(test_labels, test_preds_binary) * 100
    test_f1 = f1_score(test_labels, test_preds_binary)
    test_eer = calculate_eer(test_labels, test_scores)
    test_cm = confusion_matrix(test_labels, test_preds_binary)
    
    print("\n" + "="*40)
    print("--- Final Test Results ---")
    print(f"  Accuracy: {test_acc:.2f}%")
    print(f"  F1-Score: {test_f1:.4f}")
    print(f"  EER:        {test_eer:.2f}%")
    print("  Confusion Matrix:")
    print(test_cm)
    print("="*40)

    # --- Run Analysis ---
    perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
    analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)


In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import shap

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance.png"
SHAP_PLOT_PATH = "saved_models/shap_importance.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))

    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')

    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionCNN(nn.Module):
    """PyTorch implementation using Conv2D for CQCC features."""
    def __init__(self, cqcc_input_shape, prosodic_features):
        super(AttentionFusionCNN, self).__init__()
        
        self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.cqcc_bn1 = nn.BatchNorm2d(16)
        self.cqcc_pool1 = nn.MaxPool2d((2, 2))
        self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.cqcc_bn2 = nn.BatchNorm2d(32)
        self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
        with torch.no_grad():
            dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
            dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
            cqcc_flat_size = dummy_out.numel()
            
        self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)
        self.prosody_fc1 = nn.Linear(prosodic_features, 32)
        self.prosody_bn1 = nn.BatchNorm1d(32)
        self.prosody_dropout = nn.Dropout(0.4)
        self.prosody_fc2 = nn.Linear(32, 64)
        concatenated_size = 64 + 64
        self.attention = nn.Linear(concatenated_size, concatenated_size)
        self.classifier_fc1 = nn.Linear(concatenated_size, 64)
        self.classifier_bn = nn.BatchNorm1d(64)
        self.classifier_dropout = nn.Dropout(0.5)
        self.output_fc = nn.Linear(64, 1)

    def forward(self, cqcc_x, prosody_x):
        # IMPORTANT: This forward pass now returns attention weights
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
        cqcc_out = self.cqcc_pool1(cqcc_out)
        cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
        cqcc_out = self.cqcc_pool2(cqcc_out)
        cqcc_out = torch.flatten(cqcc_out, 1)
        cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

        prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
        prosody_out = self.prosody_dropout(prosody_out)
        prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

        concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        
        attention_weights = torch.softmax(self.attention(concatenated), dim=1)
        fused = concatenated * attention_weights

        x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
        x = self.classifier_dropout(x)
        output = torch.sigmoid(self.output_fc(x))
        
        return output, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, feature_names, device, save_path):
    """
    Analyzes and visualizes aggregated attention weights for prosodic features.
    """
    print("\n--- Running Attention Weight Analysis ---")
    model.eval()
    attention_scores = np.zeros(64) # Attention weights for the prosody branch

    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            prosody_attention_weights = weights[:, 64:]
            attention_scores += prosody_attention_weights.sum(dim=0).cpu().numpy()

    print("NOTE: Attention analysis for this model shows importance of the *learned prosodic representation*.")
    print("Feature ablation is recommended for analyzing original input feature importance.")
    
    plt.figure(figsize=(12, 8))
    plt.bar(range(64), attention_scores, color='purple')
    plt.xlabel('Dimension of Learned Prosodic Representation')
    plt.ylabel('Aggregated Attention Score')
    plt.title('Importance of Learned Prosodic Feature Dimensions')
    plt.tight_layout()
    plt.savefig(save_path.replace(".png", "_learned_dims.png"))
    plt.close()


def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    """
    Performs feature ablation to measure EER increase.
    """
    print("\n--- Running Feature Ablation Analysis ---")
    
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody[:, feature_to_ablate] = 0.0 # Zero out the feature
                
                outputs, _ = model(cqcc, prosody)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
        
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")

    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10)) # Increased figure height for more features
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    """
    Method 3: Uses SHAP to explain model predictions for prosodic features.
    This is computationally intensive and is run on a subset of data.
    """
    print("\n--- Running SHAP Analysis ---")
    model.eval()

    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        
        with torch.no_grad():
            output, _ = model(cqcc_tensor, prosody_tensor)
        
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]

    plt.figure(figsize=(12, 10)) # Ensure plot is large enough
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()


# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        # Load all data from .npy files
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)

        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)
        
        # --- FIX for ValueError ---
        # The LLD prosodic features are 3D (samples, features, time).
        # This model expects 2D summary statistics for prosody.
        # We convert the 3D data to 2D by taking the mean across the time axis.
        print("Converting 3D LLD prosodic features to 2D summary statistics (mean)...")
        # Assumes shape is (samples, features, time), so we take mean over axis 2
        X_prosody_train = np.mean(X_prosody_train_3d, axis=2)
        X_prosody_val = np.mean(X_prosody_val_3d, axis=2)
        
        if not (len(X_cqcc_train) == len(X_prosody_train) == len(y_train)):
            raise ValueError("Sample count mismatch in training files.")
        if not (len(X_cqcc_val) == len(X_prosody_val) == len(y_val)):
            raise ValueError("Sample count mismatch in validation files.")

        # --- UPDATED: Define the 23 eGeMAPS LLD feature names ---
        # This list corresponds to the standard eGeMAPS v01b LLD set.
        # The script calculates the mean of these, so these are "mean LLDs".
        # IMPORTANT: Please verify this order matches your feature extraction script.
        feature_columns = [
            'F0semitoneFrom27.5Hz_sma3nz_amean',
            'jitterLocal_sma3nz_amean',
            'shimmerLocaldB_sma3nz_amean',
            'Loudness_sma3_amean',
            'HNRdBACF_sma3nz_amean',
            'logRelF0-H1-H2_sma3nz_amean',
            'logRelF0-H1-A3_sma3nz_amean',
            'F1frequency_sma3nz_amean',
            'F1bandwidth_sma3nz_amean',
            'F1amplitudeLogRelF0_sma3nz_amean',
            'F2frequency_sma3nz_amean',
            'F2bandwidth_sma3nz_amean',
            'F2amplitudeLogRelF0_sma3nz_amean',
            'F3frequency_sma3nz_amean',
            'F3bandwidth_sma3nz_amean',
            'F3amplitudeLogRelF0_sma3nz_amean',
            'alphaRatio_sma3_amean',
            'hammarbergIndex_sma3_amean',
            'slope0-500_sma3_amean',
            'slope500-1500_sma3_amean',
            'spectralFlux_sma3_amean',
            'mfcc1_sma3_amean',
            'mfcc2_sma3_amean'
        ]
        
        num_prosodic_features = X_prosody_train.shape[1]
        if len(feature_columns) != num_prosodic_features:
            print(f"Warning: Provided feature name count ({len(feature_columns)}) does not match data ({num_prosodic_features}). Using generic names.")
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
        print(f"Using {num_prosodic_features} prosodic features.")

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure your .npy data files are in the correct paths.")
        exit()

    print("--- Scaling Data ---")
    # This now works because X_prosody_train is 2D
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    # This scaling method flattens feature and time dimensions together.
    # It's a form of instance-level normalization.
    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = model(cqcc_batch, prosody_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs, _ = model(cqcc_batch, prosody_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:\n", cm)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['f1'].append(f1)
        history['eer'].append(eer)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

    plot_training_history(history, PLOT_SAVE_PATH)

    # ==============================================================================
    # FINAL TESTING AND ANALYSIS
    # ==============================================================================
    print("\n--- Starting Final Testing and Analysis ---")
    try:
        print("Loading test data...")
        X_cqcc_test = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        
        # --- FIX for ValueError ---
        # Convert 3D test prosodic features to 2D to match training
        X_prosody_test = np.mean(X_prosody_test_3d, axis=2)
        
        print(f"Loaded {len(y_test)} test samples.")
        
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test)
        nsamples_test, nx_test, ny_test = X_cqcc_test.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
        
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        print("Loading best model for testing and analysis...")
        analysis_model = AttentionFusionCNN(
            cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
            prosodic_features=X_prosody_train.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()

        all_test_labels = []
        all_test_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
                cqcc_batch, prosody_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE)
                outputs, _ = analysis_model(cqcc_batch, prosody_batch)
                all_test_scores.extend(outputs.cpu().numpy())
                all_test_labels.extend(labels_batch.cpu().numpy())
        
        all_test_labels = np.array(all_test_labels)
        all_test_scores = np.array(all_test_scores).flatten()
        all_test_preds = (all_test_scores > 0.5).astype(int)

        test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
        test_f1 = f1_score(all_test_labels, all_test_preds)
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        test_cm = confusion_matrix(all_test_labels, all_test_preds)

        print("\n--- Final Test Results ---")
        print(f"Accuracy: {test_accuracy:.2f}%")
        print(f"F1-Score: {test_f1:.4f}")
        print(f"EER: {test_eer:.2f}%")
        print("Confusion Matrix:\n", test_cm)

        analyze_attention_weights(analysis_model, test_loader, feature_columns, DEVICE, ATTENTION_PLOT_PATH)
        perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except (FileNotFoundError, ValueError) as e:
        print(f"Error during testing/analysis: {e}")
        print("Please ensure your .npy data files are in the correct paths and format.")


In [None]:
# import os
# import numpy as np
# import librosa
# from scipy.fftpack import dct
# from tqdm import tqdm
# import soundfile as sf
# import opensmile

# if __name__ == '__main__':
#     # --- Initialize openSMILE for LLDs ---
#     smile = opensmile.Smile(
#         feature_set=opensmile.FeatureSet.eGeMAPS,
#         feature_level=opensmile.FeatureLevel.LowLevelDescriptors, # Use the full name, # Set to LLD
#     )


In [None]:
print(smile)

In [None]:
# for i in smile.feature_names:
#     print(i)

In [None]:
for i in smile.feature_set:
    print(i)

In [None]:
# --- Verification Snippet ---

# 1. Initialize openSMILE (as you already do)
smile = opensmile.Smile(
    feature_set=opensmile.FeatureSet.eGeMAPS,
    feature_level=opensmile.FeatureLevel.LowLevelDescriptors,
)

# 2. Get the list of feature names directly from the instance
feature_names_from_smile = smile.feature_names

# 3. Process a SINGLE audio file to get the DataFrame
#    Replace with a real path to one of your audio files.
try:
    single_audio_file_path = "audio_test/LA_E_1002903.flac" 
    lld_df = smile.process_file(single_audio_file_path)

    print("--- Verifying Feature Names and Values ---")
    
    # 4. Iterate through the feature names and get the values
    for feature_name in feature_names_from_smile:
        # Get the corresponding column (a pandas Series) from the DataFrame
        feature_values = lld_df[feature_name]
        
        # Print the name and the first 3 values for that feature
        print(f"Feature: {feature_name}")
        print(f"  First 3 values: {feature_values.head(3).values}")
        print("-" * 20)

except FileNotFoundError:
    print(f"Verification skipped: Could not find the example file at '{single_audio_file_path}'")
except Exception as e:
    print(f"An error occurred during verification: {e}")

# --- End of Verification Snippet ---

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import shap

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/AttentionFusionCNN_2D_PyTorch_Best_23feat.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics_23feat.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance_23feat.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance_23feat.png"
SHAP_PLOT_PATH = "saved_models/shap_importance_23feat.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))

    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')

    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

class AttentionFusionCNN(nn.Module):
    """PyTorch implementation using Conv2D for CQCC features."""
    def __init__(self, cqcc_input_shape, prosodic_features):
        super(AttentionFusionCNN, self).__init__()
        
        self.cqcc_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.cqcc_bn1 = nn.BatchNorm2d(16)
        self.cqcc_pool1 = nn.MaxPool2d((2, 2))
        self.cqcc_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.cqcc_bn2 = nn.BatchNorm2d(32)
        self.cqcc_pool2 = nn.MaxPool2d((2, 2))
        
        with torch.no_grad():
            dummy_cqcc = torch.zeros(1, 1, *cqcc_input_shape)
            dummy_out = self.cqcc_pool2(self.cqcc_bn2(self.cqcc_conv2(self.cqcc_pool1(self.cqcc_bn1(self.cqcc_conv1(dummy_cqcc))))))
            cqcc_flat_size = dummy_out.numel()
            
        self.cqcc_fc = nn.Linear(cqcc_flat_size, 64)
        self.prosody_fc1 = nn.Linear(prosodic_features, 32)
        self.prosody_bn1 = nn.BatchNorm1d(32)
        self.prosody_dropout = nn.Dropout(0.4)
        self.prosody_fc2 = nn.Linear(32, 64)
        concatenated_size = 64 + 64
        self.attention = nn.Linear(concatenated_size, concatenated_size)
        self.classifier_fc1 = nn.Linear(concatenated_size, 64)
        self.classifier_bn = nn.BatchNorm1d(64)
        self.classifier_dropout = nn.Dropout(0.5)
        self.output_fc = nn.Linear(64, 1)

    def forward(self, cqcc_x, prosody_x):
        # IMPORTANT: This forward pass now returns attention weights
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_out = torch.relu(self.cqcc_bn1(self.cqcc_conv1(cqcc_x)))
        cqcc_out = self.cqcc_pool1(cqcc_out)
        cqcc_out = torch.relu(self.cqcc_bn2(self.cqcc_conv2(cqcc_out)))
        cqcc_out = self.cqcc_pool2(cqcc_out)
        cqcc_out = torch.flatten(cqcc_out, 1)
        cqcc_branch_out = torch.relu(self.cqcc_fc(cqcc_out))

        prosody_out = torch.relu(self.prosody_bn1(self.prosody_fc1(prosody_x)))
        prosody_out = self.prosody_dropout(prosody_out)
        prosody_branch_out = torch.relu(self.prosody_fc2(prosody_out))

        concatenated = torch.cat([cqcc_branch_out, prosody_branch_out], dim=1)
        
        attention_weights = torch.softmax(self.attention(concatenated), dim=1)
        fused = concatenated * attention_weights

        x = torch.relu(self.classifier_bn(self.classifier_fc1(fused)))
        x = self.classifier_dropout(x)
        output = torch.sigmoid(self.output_fc(x))
        
        return output, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, feature_names, device, save_path):
    """
    Analyzes and visualizes aggregated attention weights for prosodic features.
    """
    print("\n--- Running Attention Weight Analysis ---")
    model.eval()
    attention_scores = np.zeros(64) # Attention weights for the prosody branch

    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            prosody_attention_weights = weights[:, 64:]
            attention_scores += prosody_attention_weights.sum(dim=0).cpu().numpy()

    print("NOTE: Attention analysis for this model shows importance of the *learned prosodic representation*.")
    print("Feature ablation is recommended for analyzing original input feature importance.")
    
    plt.figure(figsize=(12, 8))
    plt.bar(range(64), attention_scores, color='purple')
    plt.xlabel('Dimension of Learned Prosodic Representation')
    plt.ylabel('Aggregated Attention Score')
    plt.title('Importance of Learned Prosodic Feature Dimensions')
    plt.tight_layout()
    plt.savefig(save_path.replace(".png", "_learned_dims.png"))
    plt.close()


def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    """
    Performs feature ablation to measure EER increase.
    """
    print("\n--- Running Feature Ablation Analysis ---")
    
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody[:, feature_to_ablate] = 0.0 # Zero out the feature
                
                outputs, _ = model(cqcc, prosody)
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
        
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")

    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10)) # Increased figure height for more features
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    """
    Method 3: Uses SHAP to explain model predictions for prosodic features.
    This is computationally intensive and is run on a subset of data.
    """
    print("\n--- Running SHAP Analysis ---")
    model.eval()

    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        
        with torch.no_grad():
            output, _ = model(cqcc_tensor, prosody_tensor)
        
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]

    plt.figure(figsize=(12, 10)) # Ensure plot is large enough
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()


# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        # Load all data from .npy files
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)

        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)
        
        # --- FIX for ValueError ---
        # The LLD prosodic features are 3D (samples, features, time).
        # This model expects 2D summary statistics for prosody.
        # We convert the 3D data to 2D by taking the mean across the time axis.
        print("Converting 3D LLD prosodic features to 2D summary statistics (mean)...")
        # Assumes shape is (samples, features, time), so we take mean over axis 2
        X_prosody_train = np.mean(X_prosody_train_3d, axis=2)
        X_prosody_val = np.mean(X_prosody_val_3d, axis=2)
        
        if not (len(X_cqcc_train) == len(X_prosody_train) == len(y_train)):
            raise ValueError("Sample count mismatch in training files.")
        if not (len(X_cqcc_val) == len(X_prosody_val) == len(y_val)):
            raise ValueError("Sample count mismatch in validation files.")

        # --- UPDATED: Define the 23 feature names based on user verification ---
        feature_columns = [
            'Loudness_sma3',
            'alphaRatio_sma3',
            'hammarbergIndex_sma3',
            'slope0-500_sma3',
            'slope500-1500_sma3',
            'spectralFlux_sma3',
            'mfcc1_sma3',
            'mfcc2_sma3',
            'mfcc3_sma3',
            'mfcc4_sma3',
            'F0semitoneFrom27.5Hz_sma3nz',
            'jitterLocal_sma3nz',
            'shimmerLocaldB_sma3nz',
            'HNRdBACF_sma3nz',
            'logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz',
            'F1frequency_sma3nz',
            'F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz',
            'F2frequency_sma3nz',
            'F2amplitudeLogRelF0_sma3nz',
            'F3frequency_sma3nz',
            'F3amplitudeLogRelF0_sma3nz'
        ]
        
        num_prosodic_features = X_prosody_train.shape[1]
        if len(feature_columns) != num_prosodic_features:
            print(f"Warning: Provided feature name count ({len(feature_columns)}) does not match data ({num_prosodic_features}). Using generic names.")
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

        print(f"Training samples: {len(y_train)}, Validation samples: {len(y_val)}")
        print(f"Using {num_prosodic_features} prosodic features.")

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure your .npy data files are in the correct paths.")
        exit()

    print("--- Scaling Data ---")
    # This now works because X_prosody_train is 2D
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)

    # This scaling method flattens feature and time dimensions together.
    # It's a form of instance-level normalization.
    scaler_cqcc = StandardScaler()
    nsamples, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(nsamples, -1)).reshape(nsamples, nx, ny)
    nsamples_val, nx_val, ny_val = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsamples_val, -1)).reshape(nsamples_val, nx_val, ny_val)
    print("Scaling complete.")

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AttentionFusionCNN(
        cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
        prosodic_features=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for cqcc_batch, prosody_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = model(cqcc_batch, prosody_batch)
            loss = criterion(outputs, labels_batch.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels = []
        all_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in val_loader:
                cqcc_batch, prosody_batch, labels_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE), labels_batch.to(DEVICE)
                outputs, _ = model(cqcc_batch, prosody_batch)
                loss = criterion(outputs, labels_batch.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(outputs.cpu().numpy())
                all_labels.extend(labels_batch.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        all_labels = np.array(all_labels)
        all_scores = np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)

        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        print("Validation Confusion Matrix:\n", cm)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['f1'].append(f1)
        history['eer'].append(eer)
        
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    print("\n--- Training Complete ---")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

    plot_training_history(history, PLOT_SAVE_PATH)

    # ==============================================================================
    # FINAL TESTING AND ANALYSIS
    # ==============================================================================
    print("\n--- Starting Final Testing and Analysis ---")
    try:
        print("Loading test data...")
        X_cqcc_test = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        
        # --- FIX for ValueError ---
        # Convert 3D test prosodic features to 2D to match training
        X_prosody_test = np.mean(X_prosody_test_3d, axis=2)
        
        print(f"Loaded {len(y_test)} test samples.")
        
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test)
        nsamples_test, nx_test, ny_test = X_cqcc_test.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(nsamples_test, -1)).reshape(nsamples_test, nx_test, ny_test)
        
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        print("Loading best model for testing and analysis...")
        analysis_model = AttentionFusionCNN(
            cqcc_input_shape=(X_cqcc_train.shape[1], X_cqcc_train.shape[2]),
            prosodic_features=X_prosody_train.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()

        all_test_labels = []
        all_test_scores = []
        with torch.no_grad():
            for cqcc_batch, prosody_batch, labels_batch in tqdm(test_loader, desc="Final Testing"):
                cqcc_batch, prosody_batch = cqcc_batch.to(DEVICE), prosody_batch.to(DEVICE)
                outputs, _ = analysis_model(cqcc_batch, prosody_batch)
                all_test_scores.extend(outputs.cpu().numpy())
                all_test_labels.extend(labels_batch.cpu().numpy())
        
        all_test_labels = np.array(all_test_labels)
        all_test_scores = np.array(all_test_scores).flatten()
        all_test_preds = (all_test_scores > 0.5).astype(int)

        test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
        test_f1 = f1_score(all_test_labels, all_test_preds)
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        test_cm = confusion_matrix(all_test_labels, all_test_preds)

        print("\n--- Final Test Results ---")
        print(f"Accuracy: {test_accuracy:.2f}%")
        print(f"F1-Score: {test_f1:.4f}")
        print(f"EER: {test_eer:.2f}%")
        print("Confusion Matrix:\n", test_cm)

        analyze_attention_weights(analysis_model, test_loader, feature_columns, DEVICE, ATTENTION_PLOT_PATH)
        perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except (FileNotFoundError, ValueError) as e:
        print(f"Error during testing/analysis: {e}")
        print("Please ensure your .npy data files are in the correct paths and format.")


In [None]:
#CNN + MLP with Attention
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os

# --- Paste the CrossAttentionModel class definition here ---
class CrossAttentionModel(nn.Module):
    """
    A model that uses cross-attention to fuse prosodic features and CQCC features.
    
    1. A CNN processes the CQCC spectrogram to extract spatial features.
    2. An MLP processes the 1D prosodic feature vector.
    3. The processed prosodic features act as a query to attend to the CQCC feature map.
    4. The attended CQCC features are combined with the prosody features for final classification.
    """
    def __init__(self, cqcc_shape, prosody_feature_dim, embed_dim=128):
        """
        Args:
            cqcc_shape (tuple): The shape of the input CQCC (channels, features, frames), e.g., (1, 90, 157).
            prosody_feature_dim (int): The number of prosodic features, e.g., 6.
            embed_dim (int): The dimensionality of the common embedding space for attention.
        """
        super(CrossAttentionModel, self).__init__()
        
        self.embed_dim = embed_dim
        
        # --- 1. CNN for CQCC Feature Extraction ---
        self.cnn_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=32, out_channels=embed_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(embed_dim),
        )
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 64),
            nn.ReLU(),
            nn.Linear(64, embed_dim),
            nn.Tanh()
        )
        
        # --- 3. Classifier Head ---
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        cqcc_x = cqcc_x.unsqueeze(1)
        cqcc_map = self.cnn_extractor(cqcc_x)
        prosody_query = self.prosody_mlp(prosody_x)
        batch_size, _, H, W = cqcc_map.shape
        cqcc_seq = cqcc_map.view(batch_size, self.embed_dim, -1)
        cqcc_seq = cqcc_seq.permute(0, 2, 1)
        query = prosody_query.unsqueeze(1)
        attention_scores = torch.bmm(query, cqcc_seq.transpose(1, 2))
        attention_scores = attention_scores / (self.embed_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_context = torch.bmm(attention_weights, cqcc_seq)
        attention_context = attention_context.squeeze(1)
        fused_features = torch.cat([attention_context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        # Load the combined CSV file using pandas
        combined_data = pd.read_csv(combined_csv_file)
        
        # --- FIX STARTS HERE ---
        # A robust way to extract labels and features from the dataframe
        
        # 1. Extract labels and convert to a NumPy array
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        # 2. Identify all columns that are NOT features
        metadata_cols = ['label'] # Start with the label column
        # Add other potential metadata columns if they exist in the dataframe
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        # 3. Drop all metadata columns to get the prosodic features, then convert to a NumPy array
        self.prosody_data = combined_data.drop(columns=metadata_cols).values
        # --- FIX ENDS HERE ---

        # The assertion check is crucial for catching data generation errors
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        # .iloc is not needed here because self.prosody_data is now a NumPy array
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files (all .npy)
    TRAIN_CQCC_FILE = 'cqcc_features_aligned.npy'
    TRAIN_PROSODY_FILE = 'prosody_features_aligned.npy'
    TRAIN_LABELS_FILE = 'labels_aligned.npy'
    
    # Validation data files (1 .npy, 1 combined .csv)
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files (1 .npy, 1 combined .csv)
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 20
    LEARNING_RATE = 0.0001
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetNPY(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_shape = (1, cqcc_sample.shape[0], cqcc_sample.shape[1])
    prosody_dim = prosody_sample.shape[0]

    model = CrossAttentionModel(
        cqcc_shape=cqcc_shape,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            # Use the same dataset class for test set as for validation set
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 25379 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---
Epoch 1/20 | Train Loss: 0.2859, Train Acc: 0.8764 | Val Loss: 0.1732, Val EER: 0.1205, Val F1: 0.9556
-> New best model saved with EER: 0.1205
Epoch 2/20 | Train Loss: 0.1424, Train Acc: 0.9406 | Val Loss: 0.1453, Val EER: 0.0922, Val F1: 0.9673
-> New best model saved with EER: 0.0922
Epoch 3/20 | Train Loss: 0.1012, Train Acc: 0.9607 | Val Loss: 0.1181, Val EER: 0.0824, Val F1: 0.9748
-> New best model saved with EER: 0.0824
Epoch 4/20 | Train Loss: 0.0872, Train Acc: 0.9679 | Val Loss: 0.1073, Val EER: 0.0742, Val F1: 0.9772
-> New best model saved with EER: 0.0742
Epoch 5/20 | Train Loss: 0.0714, Train Acc: 0.9749 | Val Loss: 0.0888, Val EER: 0.0648, Val F1: 0.9816
-> New best model saved with EER: 0.0648
Epoch 6/20 | Train Loss: 0.0601, Train Acc: 0.9803 | Val Loss: 0.1081, Val EER: 0.0649, Val F1: 0.9808
Epoch 7/20 | Train Loss: 0.0547, Train Acc: 0.9803 | Va

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.4989
Test EER: 0.1180
Test F1-Score: 0.9186


In [20]:
#CNN + MLP (No Attention No Augmentation)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os

# --- CNN + MLP Fusion Model (No Attention) ---
class CNNMLPFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using a CNN and an MLP, without attention.
    
    1. A CNN processes the CQCC spectrogram.
    2. The output of the CNN is flattened to a 1D vector.
    3. An MLP processes the 1D prosodic feature vector.
    4. The two resulting vectors are concatenated.
    5. A final classifier head makes the prediction.
    """
    def __init__(self, cqcc_shape, prosody_feature_dim, cnn_embed_dim=128, prosody_embed_dim=64):
        """
        Args:
            cqcc_shape (tuple): The shape of the input CQCC (channels, features, frames).
            prosody_feature_dim (int): The number of prosodic features.
            cnn_embed_dim (int): The output channels of the last CNN layer.
            prosody_embed_dim (int): The output dimension of the prosody MLP.
        """
        super(CNNMLPFusionModel, self).__init__()
        
        # --- 1. CNN for CQCC Feature Extraction ---
        self.cnn_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=32, out_channels=cnn_embed_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(cnn_embed_dim),
        )
        
        # Use Adaptive Pooling to handle variable input sizes and create a fixed-size output
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, prosody_embed_dim),
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input dimension is the sum of the flattened CNN output and the prosody MLP output
        classifier_input_dim = cnn_embed_dim + prosody_embed_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # Add a channel dimension for the CNN
        cqcc_x = cqcc_x.unsqueeze(1)
        
        # Process CQCC features
        cqcc_out = self.cnn_extractor(cqcc_x)
        cqcc_pooled = self.adaptive_pool(cqcc_out)
        cqcc_flat = self.flatten(cqcc_pooled)
        
        # Process prosodic features
        prosody_out = self.prosody_mlp(prosody_x)
        
        # Concatenate the feature vectors
        fused_features = torch.cat([cqcc_flat, prosody_out], dim=1)
        
        # Pass through the final classifier
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files (all .npy)
    TRAIN_CQCC_FILE = 'cqcc_features_aligned.npy'
    TRAIN_PROSODY_FILE = 'prosody_features_aligned.npy'
    TRAIN_LABELS_FILE = 'labels_aligned.npy'
    
    # Validation data files (1 .npy, 1 combined .csv)
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files (1 .npy, 1 combined .csv)
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 50
    LEARNING_RATE = 0.0001
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetNPY(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_shape = (1, cqcc_sample.shape[0], cqcc_sample.shape[1])
    prosody_dim = prosody_sample.shape[0]

    model = CNNMLPFusionModel(
        cqcc_shape=cqcc_shape,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 25379 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---
Epoch 1/50 | Train Loss: 0.4109, Train Acc: 0.8736 | Val Loss: 0.3418, Val EER: 0.4134, Val F1: 0.9459
-> New best model saved with EER: 0.4134
Epoch 2/50 | Train Loss: 0.3500, Train Acc: 0.8979 | Val Loss: 0.2936, Val EER: 0.2117, Val F1: 0.9459
-> New best model saved with EER: 0.2117
Epoch 3/50 | Train Loss: 0.2784, Train Acc: 0.8969 | Val Loss: 0.2378, Val EER: 0.1468, Val F1: 0.9459
-> New best model saved with EER: 0.1468
Epoch 4/50 | Train Loss: 0.2190, Train Acc: 0.8938 | Val Loss: 0.1803, Val EER: 0.1329, Val F1: 0.9459
-> New best model saved with EER: 0.1329
Epoch 5/50 | Train Loss: 0.1902, Train Acc: 0.8959 | Val Loss: 0.1721, Val EER: 0.1240, Val F1: 0.9459
-> New best model saved with EER: 0.1240
Epoch 6/50 | Train Loss: 0.1713, Train Acc: 0.9017 | Val Loss: 0.1955, Val EER: 0.1056, Val F1: 0.9584
-> New best model saved with EER: 0.1056
Epoch 7/50 | T

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.7309
Test EER: 0.1127
Test F1-Score: 0.9157


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os

# --- LSTM + MLP Fusion Model (No Attention) ---
class LSTMMMLPFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using an LSTM and an MLP.
    
    1. An LSTM processes the CQCC spectrogram as a sequence of feature vectors.
    2. An MLP processes the 1D prosodic feature vector.
    3. The final hidden state of the LSTM and the output of the MLP are concatenated.
    4. A final classifier head makes the prediction.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, lstm_hidden_dim=128, lstm_layers=2, prosody_embed_dim=64):
        """
        Args:
            cqcc_feature_dim (int): The number of CQCC coefficients (input feature size for LSTM).
            prosody_feature_dim (int): The number of prosodic features.
            lstm_hidden_dim (int): The hidden dimension size of the LSTM.
            lstm_layers (int): The number of LSTM layers.
            prosody_embed_dim (int): The output dimension of the prosody MLP.
        """
        super(LSTMMMLPFusionModel, self).__init__()
        
        # --- 1. LSTM for CQCC Feature Extraction ---
        self.lstm = nn.LSTM(
            input_size=cqcc_feature_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,  # Crucial for (batch, seq, feature) input format
            bidirectional=True, # Use a bidirectional LSTM to capture context from both directions
            dropout=0.2 if lstm_layers > 1 else 0 # Add dropout between LSTM layers
        )
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, prosody_embed_dim),
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input dimension is the sum of the LSTM output and the prosody MLP output
        # LSTM hidden_dim * 2 because it's bidirectional
        classifier_input_dim = (lstm_hidden_dim * 2) + prosody_embed_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames)
        # Permute for LSTM: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # Process CQCC features through LSTM
        # We only need the final hidden state, not the full output sequence
        # The `_` holds the output of all time steps
        _, (h_n, _) = self.lstm(cqcc_x)
        
        # h_n shape is (num_layers * num_directions, batch, hidden_dim)
        # We concatenate the final hidden states of the forward and backward LSTMs
        # The last forward layer is at index -2, the last backward layer is at index -1
        lstm_out = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)
        
        # Process prosodic features
        prosody_out = self.prosody_mlp(prosody_x)
        
        # Concatenate the feature vectors
        fused_features = torch.cat([lstm_out, prosody_out], dim=1)
        
        # Pass through the final classifier
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files (all .npy)
    TRAIN_CQCC_FILE = 'cqcc_features_aligned.npy'
    TRAIN_PROSODY_FILE = 'prosody_features_aligned.npy'
    TRAIN_LABELS_FILE = 'labels_aligned.npy'
    
    # Validation data files (1 .npy, 1 combined .csv)
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files (1 .npy, 1 combined .csv)
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 20
    LEARNING_RATE = 0.0001
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetNPY(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    # For LSTM, the input feature dimension is the number of CQCC coefficients
    cqcc_dim = cqcc_sample.shape[0] 
    prosody_dim = prosody_sample.shape[0]

    model = LSTMMMLPFusionModel(
        cqcc_feature_dim=cqcc_dim,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 25379 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---
Epoch 1/20 | Train Loss: 0.2983, Train Acc: 0.8854 | Val Loss: 0.1771, Val EER: 0.1328, Val F1: 0.9459
-> New best model saved with EER: 0.1328
Epoch 2/20 | Train Loss: 0.1622, Train Acc: 0.9120 | Val Loss: 0.1385, Val EER: 0.0858, Val F1: 0.9641
-> New best model saved with EER: 0.0858
Epoch 3/20 | Train Loss: 0.0913, Train Acc: 0.9655 | Val Loss: 0.1122, Val EER: 0.0714, Val F1: 0.9799
-> New best model saved with EER: 0.0714
Epoch 4/20 | Train Loss: 0.0482, Train Acc: 0.9841 | Val Loss: 0.1796, Val EER: 0.0730, Val F1: 0.9821
Epoch 5/20 | Train Loss: 0.0342, Train Acc: 0.9901 | Val Loss: 0.1498, Val EER: 0.0521, Val F1: 0.9817
-> New best model saved with EER: 0.0521
Epoch 6/20 | Train Loss: 0.0293, Train Acc: 0.9912 | Val Loss: 0.1175, Val EER: 0.0600, Val F1: 0.9860
Epoch 7/20 | Train Loss: 0.0191, Train Acc: 0.9948 | Val Loss: 0.1849, Val EER: 0.0608, Val F1: 

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.7963
Test EER: 0.1177
Test F1-Score: 0.9049


In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os

# --- LSTM + Cross-Attention Fusion Model ---
class LSTMAttentionFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using an LSTM and cross-attention.
    
    1. An LSTM processes the CQCC spectrogram as a sequence of feature vectors.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the LSTM's output sequence.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, lstm_hidden_dim=128, lstm_layers=2, prosody_embed_dim=64):
        """
        Args:
            cqcc_feature_dim (int): The number of CQCC coefficients.
            prosody_feature_dim (int): The number of prosodic features.
            lstm_hidden_dim (int): The hidden dimension size of the LSTM.
            lstm_layers (int): The number of LSTM layers.
            prosody_embed_dim (int): The output dimension of the prosody MLP.
        """
        super(LSTMAttentionFusionModel, self).__init__()
        
        # The LSTM's output dimension will be hidden_dim * 2 because it's bidirectional
        self.lstm_output_dim = lstm_hidden_dim * 2
        
        # --- 1. LSTM for CQCC Feature Extraction ---
        self.lstm = nn.LSTM(
            input_size=cqcc_feature_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if lstm_layers > 1 else 0
        )
        
        # --- 2. MLP for Prosodic Feature Processing ---
        # This MLP will create the query for the attention mechanism
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, prosody_embed_dim),
            nn.ReLU()
        )
        
        # --- 3. Attention Mechanism Layers ---
        # We need linear layers to project the query and keys to a common dimension if they differ.
        # Here, we'll design it so the prosody_embed_dim and lstm_output_dim are compatible.
        # For simplicity, we'll use a linear layer to make the LSTM output match the prosody query dimension.
        self.key_projection = nn.Linear(self.lstm_output_dim, prosody_embed_dim)

        # --- 4. Classifier Head ---
        # The input is the attention context vector + the original prosody query vector
        classifier_input_dim = prosody_embed_dim + prosody_embed_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> Permute for LSTM: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC through LSTM to get the full sequence output
        # lstm_out shape: (batch, seq_len, num_directions * hidden_size)
        lstm_out, _ = self.lstm(cqcc_x)
        
        # 2. Process prosodic features to get the query
        # prosody_query shape: (batch, prosody_embed_dim)
        prosody_query = self.prosody_mlp(prosody_x)

        # --- 3. Cross-Attention ---
        # Project LSTM output to create keys and values for attention
        # keys shape: (batch, seq_len, prosody_embed_dim)
        keys = self.key_projection(lstm_out)
        # values are the original lstm output
        values = lstm_out
        
        # Reshape query for batch matrix multiplication
        # query shape: (batch, 1, prosody_embed_dim)
        query_unsqueezed = prosody_query.unsqueeze(1)
        
        # Calculate attention scores: Q * K^T
        # (batch, 1, embed_dim) * (batch, embed_dim, seq_len) -> (batch, 1, seq_len)
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        attention_scores = attention_scores / (keys.size(-1) ** 0.5) # Scale
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector: Weights * V
        # (batch, 1, seq_len) * (batch, seq_len, lstm_output_dim) -> (batch, 1, lstm_output_dim)
        context = torch.bmm(attention_weights, values)
        # Squeeze to remove the sequence dimension: (batch, lstm_output_dim)
        context = context.squeeze(1)

        # Project the context to match the prosody embedding dimension for concatenation
        # This step was missing in the previous attention model, let's make it more robust.
        # We'll use a linear layer for this. Let's define it in __init__
        # For simplicity in this version, we will assume the context vector can be fused directly
        # with a projected version of itself. Let's refine the classifier input.
        
        # Let's redefine the classifier input to be simpler and more direct:
        # We will concatenate the context vector with the original prosody query.
        # The context vector's size is self.lstm_output_dim.
        # The prosody query's size is prosody_embed_dim.
        
        # Re-calculating classifier input dimension based on what we have:
        # classifier_input_dim = self.lstm_output_dim + prosody_embed_dim
        # Let's adjust the classifier in __init__ for this.
        
        # For this implementation, let's stick to the original plan and adjust the classifier input dynamically
        # Let's create a new classifier head here for clarity, or better, fix the __init__ part.
        
        # Let's fix the architecture to be more robust from the start.
        # The context vector has size `lstm_output_dim`. The query has size `prosody_embed_dim`.
        # We will concatenate these two.
        
        # Let's adjust the classifier input dimension in __init__
        # classifier_input_dim = self.lstm_output_dim + prosody_embed_dim
        # And the classifier head needs to be updated accordingly.

        # Let's assume the model is re-initialized with the correct dimensions.
        # The logic below is correct if the classifier is initialized with:
        # classifier_input_dim = (lstm_hidden_dim * 2) + prosody_embed_dim
        
        fused_features = torch.cat([context, prosody_query], dim=1)
        
        # Pass through the final classifier
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files (all .npy)
    TRAIN_CQCC_FILE = 'cqcc_features_aligned.npy'
    TRAIN_PROSODY_FILE = 'prosody_features_aligned.npy'
    TRAIN_LABELS_FILE = 'labels_aligned.npy'
    
    # Validation data files (1 .npy, 1 combined .csv)
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files (1 .npy, 1 combined .csv)
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 20
    LEARNING_RATE = 0.0001
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetNPY(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    # For LSTM, the input feature dimension is the number of CQCC coefficients
    cqcc_dim = cqcc_sample.shape[0] 
    prosody_dim = prosody_sample.shape[0]

    # Use the new Attention model
    model = LSTMAttentionFusionModel(
        cqcc_feature_dim=cqcc_dim,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    # Let's fix the classifier input dimension inside the model's __init__
    # This was a point of confusion before. Let's make it explicit.
    lstm_hidden_dim = 128 # Must match the default in the model
    prosody_embed_dim = 64 # Must match the default in the model
    classifier_input_dim = (lstm_hidden_dim * 2) + prosody_embed_dim
    
    # Re-initializing the classifier head with the correct input dimension
    model.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        ).to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 25379 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---
Epoch 1/20 | Train Loss: 0.3634, Train Acc: 0.8882 | Val Loss: 0.2445, Val EER: 0.1932, Val F1: 0.9459
-> New best model saved with EER: 0.1932
Epoch 2/20 | Train Loss: 0.2145, Train Acc: 0.9044 | Val Loss: 0.1677, Val EER: 0.1122, Val F1: 0.9616
-> New best model saved with EER: 0.1122
Epoch 3/20 | Train Loss: 0.1276, Train Acc: 0.9487 | Val Loss: 0.1380, Val EER: 0.0746, Val F1: 0.9722
-> New best model saved with EER: 0.0746
Epoch 4/20 | Train Loss: 0.0890, Train Acc: 0.9692 | Val Loss: 0.1241, Val EER: 0.0675, Val F1: 0.9793
-> New best model saved with EER: 0.0675
Epoch 5/20 | Train Loss: 0.0615, Train Acc: 0.9808 | Val Loss: 0.0983, Val EER: 0.0631, Val F1: 0.9783
-> New best model saved with EER: 0.0631
Epoch 6/20 | Train Loss: 0.0468, Train Acc: 0.9851 | Val Loss: 0.1315, Val EER: 0.0557, Val F1: 0.9827
-> New best model saved with EER: 0.0557
Epoch 7/20 | T

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.5919
Test EER: 0.1050
Test F1-Score: 0.9240


In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- Positional Encoding for Transformer ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- Transformer + MLP Fusion Model (No Attention) ---
class TransformerFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using a Transformer and an MLP without attention.
    
    1. A Transformer Encoder processes the CQCC spectrogram as a sequence.
    2. The output of the Transformer is aggregated (here, by taking the first time step's output).
    3. An MLP processes the 1D prosodic feature vector.
    4. The two resulting vectors are concatenated and passed to a classifier.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, nhead=8, num_encoder_layers=3, dim_feedforward=512, dropout=0.1):
        super(TransformerFusionModel, self).__init__()
        self.d_model = d_model
        
        # --- 1. CQCC Feature Processing Path ---
        self.cqcc_projection = nn.Linear(cqcc_feature_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_model),
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the aggregated transformer output (size d_model) + the prosody MLP output (size d_model)
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> Permute for Transformer: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC features
        cqcc_embed = self.cqcc_projection(cqcc_x) * math.sqrt(self.d_model)
        cqcc_pos = self.pos_encoder(cqcc_embed.permute(1, 0, 2)).permute(1, 0, 2)
        transformer_out = self.transformer_encoder(cqcc_pos)
        
        # Aggregate the transformer output. We'll take the output of the first time step.
        transformer_aggregated = transformer_out[:, 0, :]
        
        # 2. Process prosodic features
        prosody_out = self.prosody_mlp(prosody_x)

        # 3. Fusion and Classification
        fused_features = torch.cat([transformer_aggregated, prosody_out], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files (all .npy)
    TRAIN_CQCC_FILE = 'cqcc_features_aligned.npy'
    TRAIN_PROSODY_FILE = 'prosody_features_aligned.npy'
    TRAIN_LABELS_FILE = 'labels_aligned.npy'
    
    # Validation data files (1 .npy, 1 combined .csv)
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files (1 .npy, 1 combined .csv)
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 20
    LEARNING_RATE = 0.0001
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetNPY(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_dim = cqcc_sample.shape[0] 
    prosody_dim = prosody_sample.shape[0]

    # Use the new Transformer Fusion model (no attention)
    model = TransformerFusionModel(
        cqcc_feature_dim=cqcc_dim,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 25379 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---
Epoch 1/20 | Train Loss: 0.3807, Train Acc: 0.8912 | Val Loss: 0.3380, Val EER: 0.4075, Val F1: 0.9459
-> New best model saved with EER: 0.4075
Epoch 2/20 | Train Loss: 0.3514, Train Acc: 0.8980 | Val Loss: 0.3463, Val EER: 0.3831, Val F1: 0.9459
-> New best model saved with EER: 0.3831
Epoch 3/20 | Train Loss: 0.3364, Train Acc: 0.8983 | Val Loss: 0.3437, Val EER: 0.3619, Val F1: 0.9459
-> New best model saved with EER: 0.3619
Epoch 4/20 | Train Loss: 0.3290, Train Acc: 0.8983 | Val Loss: 0.3175, Val EER: 0.3787, Val F1: 0.9459
Epoch 5/20 | Train Loss: 0.3220, Train Acc: 0.8983 | Val Loss: 0.3173, Val EER: 0.3367, Val F1: 0.9459
-> New best model saved with EER: 0.3367
Epoch 6/20 | Train Loss: 0.3053, Train Acc: 0.8985 | Val Loss: 0.2688, Val EER: 0.2539, Val F1: 0.9459
-> New best model saved with EER: 0.2539
Epoch 7/20 | Train Loss: 0.2664, Train Acc: 0.8982 | Va

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.2923
Test EER: 0.1903
Test F1-Score: 0.9011


In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- Positional Encoding for Transformer ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape is (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- Transformer + Cross-Attention Fusion Model ---
class TransformerAttentionFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using a Transformer and cross-attention.
    
    1. A Transformer Encoder processes the CQCC spectrogram as a sequence.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the Transformer's output sequence.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, nhead=8, num_encoder_layers=3, dim_feedforward=512, dropout=0.1):
        super(TransformerAttentionFusionModel, self).__init__()
        self.d_model = d_model
        
        # --- 1. CQCC Feature Processing Path ---
        self.cqcc_projection = nn.Linear(cqcc_feature_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_model), # Output dimension must match d_model for attention
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the context vector (size d_model) + the prosody query (size d_model)
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> Permute for Transformer: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC features
        cqcc_embed = self.cqcc_projection(cqcc_x) * math.sqrt(self.d_model)
        # Permute for positional encoding: (batch, seq, feature) -> (seq, batch, feature)
        cqcc_pos = self.pos_encoder(cqcc_embed.permute(1, 0, 2))
        # Permute back for transformer encoder: (seq, batch, feature) -> (batch, seq, feature)
        transformer_out = self.transformer_encoder(cqcc_pos.permute(1, 0, 2)) # This is our Key and Value
        
        # 2. Process prosodic features to get the query
        prosody_query = self.prosody_mlp(prosody_x)

        # --- 3. Cross-Attention ---
        keys = transformer_out
        values = transformer_out
        query_unsqueezed = prosody_query.unsqueeze(1)
        
        # Calculate attention scores: Q * K^T
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        attention_scores = attention_scores / (keys.size(-1) ** 0.5) # Scale
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector by applying weights to values
        context = torch.bmm(attention_weights, values)
        context = context.squeeze(1) # Remove the sequence dimension
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files
    TRAIN_CQCC_FILE = 'processed_data/cqcc_features.npy'
    TRAIN_COMBINED_FILE = 'processed_data/prosodic_features_and_labels.csv'
    
    # Validation data files
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 128
    NUM_EPOCHS = 40
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5 # Added weight decay
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetCombinedCSV(TRAIN_CQCC_FILE, TRAIN_COMBINED_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_dim = cqcc_sample.shape[0] 
    prosody_dim = prosody_sample.shape[0]

    # Use the Transformer Cross-Attention model
    model = TransformerAttentionFusionModel(
        cqcc_feature_dim=cqcc_dim,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    # Add weight_decay to the optimizer
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Add a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        # Step the scheduler based on the validation EER
        scheduler.step(val_eer)
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 46019 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---




Epoch 1/40 | Train Loss: 0.6667, Train Acc: 0.5971 | Val Loss: 0.4529, Val EER: 0.3771, Val F1: 0.9410
-> New best model saved with EER: 0.3771
Epoch 2/40 | Train Loss: 0.4516, Train Acc: 0.7836 | Val Loss: 0.3912, Val EER: 0.3795, Val F1: 0.9458
Epoch 3/40 | Train Loss: 0.3993, Train Acc: 0.8165 | Val Loss: 0.3424, Val EER: 0.3699, Val F1: 0.9460
-> New best model saved with EER: 0.3699
Epoch 4/40 | Train Loss: 0.3499, Train Acc: 0.8434 | Val Loss: 0.2746, Val EER: 0.2561, Val F1: 0.9463
-> New best model saved with EER: 0.2561
Epoch 5/40 | Train Loss: 0.3102, Train Acc: 0.8606 | Val Loss: 0.2366, Val EER: 0.1966, Val F1: 0.9497
-> New best model saved with EER: 0.1966
Epoch 6/40 | Train Loss: 0.2729, Train Acc: 0.8784 | Val Loss: 0.2289, Val EER: 0.2009, Val F1: 0.9531
Epoch 7/40 | Train Loss: 0.2462, Train Acc: 0.8944 | Val Loss: 0.2610, Val EER: 0.2414, Val F1: 0.9495
Epoch 8/40 | Train Loss: 0.2203, Train Acc: 0.9044 | Val Loss: 0.1926, Val EER: 0.1515, Val F1: 0.9556
-> New best 

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.3534
Test EER: 0.0934
Test F1-Score: 0.9274


In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- CNN + Cross-Attention Fusion Model ---
class CrossAttentionModel(nn.Module):
    """
    A model that uses cross-attention to fuse prosodic features and CQCC features.
    
    1. A CNN processes the CQCC spectrogram to extract spatial features.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the CNN's output feature map.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_shape, prosody_feature_dim, embed_dim=128):
        """
        Args:
            cqcc_shape (tuple): The shape of the input CQCC (channels, features, frames).
            prosody_feature_dim (int): The number of prosodic features.
            embed_dim (int): The dimensionality of the common embedding space for attention.
        """
        super(CrossAttentionModel, self).__init__()
        
        self.embed_dim = embed_dim
        
        # --- 1. CNN for CQCC Feature Extraction ---
        self.cnn_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=32, out_channels=embed_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(embed_dim),
        )
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, embed_dim), # Output dimension must match embed_dim for attention
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the context vector (size embed_dim) + the prosody query (size embed_dim)
        classifier_input_dim = embed_dim + embed_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # Add a channel dimension to CQCC input for the CNN
        # Shape: (batch, features, frames) -> (batch, 1, features, frames)
        cqcc_x = cqcc_x.unsqueeze(1)
        
        # 1. Process inputs through their respective networks
        # cqcc_map shape: (batch, embed_dim, H', W')
        cqcc_map = self.cnn_extractor(cqcc_x)
        
        # prosody_query shape: (batch, embed_dim)
        prosody_query = self.prosody_mlp(prosody_x)

        # --- 2. Cross-Attention Mechanism ---
        batch_size, _, H, W = cqcc_map.shape
        
        # Flatten the spatial dimensions of the CQCC map to create a sequence of features
        # This will be our 'value' and 'key' in the attention mechanism.
        # Shape: (batch, embed_dim, H', W') -> (batch, embed_dim, H'*W')
        cqcc_seq = cqcc_map.view(batch_size, self.embed_dim, -1)
        
        # Reshape for attention calculation: (batch, H'*W', embed_dim)
        keys = cqcc_seq.permute(0, 2, 1)
        values = keys # In this case, keys and values are the same
        
        # The prosody vector is our 'query'. We need to expand it for matrix multiplication.
        # Shape: (batch, embed_dim) -> (batch, 1, embed_dim)
        query_unsqueezed = prosody_query.unsqueeze(1)
        
        # Calculate attention scores (Query * Key^T)
        # Q: (batch, 1, embed_dim), K^T: (batch, embed_dim, H'*W') -> Scores: (batch, 1, H'*W')
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        
        # Scale the scores
        attention_scores = attention_scores / (self.embed_dim ** 0.5)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Apply weights to the CQCC sequence (Value) to get the context vector
        # Weights: (batch, 1, H'*W'), V: (batch, H'*W', embed_dim) -> Context: (batch, 1, embed_dim)
        context = torch.bmm(attention_weights, values)
        
        # Remove the middle dimension: (batch, 1, embed_dim) -> (batch, embed_dim)
        context = context.squeeze(1)
        
        # --- 3. Final Classification ---
        # Concatenate the attention output with the original processed prosody vector
        fused_features = torch.cat([context, prosody_query], dim=1)
        
        # Pass the fused features through the final classifier
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files
    TRAIN_CQCC_FILE = 'processed_data/cqcc_features.npy'
    TRAIN_COMBINED_FILE = 'processed_data/prosodic_features_and_labels.csv'
    
    # Validation data files
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 128
    NUM_EPOCHS = 40
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5 # Added weight decay
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetCombinedCSV(TRAIN_CQCC_FILE, TRAIN_COMBINED_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_shape = (1, cqcc_sample.shape[0], cqcc_sample.shape[1])
    prosody_dim = prosody_sample.shape[0]

    # Use the CNN Cross-Attention model
    model = CrossAttentionModel(
        cqcc_shape=cqcc_shape,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    # Add weight_decay to the optimizer
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Add a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        # Step the scheduler based on the validation EER
        scheduler.step(val_eer)
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 46019 samples.
Validation data loaded: 24844 samples.





--- Starting Training ---
Epoch 1/40 | Train Loss: 0.5205, Train Acc: 0.7115 | Val Loss: 0.2788, Val EER: 0.1479, Val F1: 0.9151
-> New best model saved with EER: 0.1479
Epoch 2/40 | Train Loss: 0.2232, Train Acc: 0.8957 | Val Loss: 0.2119, Val EER: 0.1246, Val F1: 0.9216
-> New best model saved with EER: 0.1246
Epoch 3/40 | Train Loss: 0.1855, Train Acc: 0.9092 | Val Loss: 0.1841, Val EER: 0.1229, Val F1: 0.9381
-> New best model saved with EER: 0.1229
Epoch 4/40 | Train Loss: 0.1726, Train Acc: 0.9147 | Val Loss: 0.1909, Val EER: 0.1126, Val F1: 0.9256
-> New best model saved with EER: 0.1126
Epoch 5/40 | Train Loss: 0.1657, Train Acc: 0.9190 | Val Loss: 0.2011, Val EER: 0.1103, Val F1: 0.9244
-> New best model saved with EER: 0.1103
Epoch 6/40 | Train Loss: 0.1550, Train Acc: 0.9258 | Val Loss: 0.1730, Val EER: 0.1095, Val F1: 0.9459
-> New best model saved with EER: 0.1095
Epoch 7/40 | Train Loss: 0.1508, Train Acc: 0.9276 | Val Loss: 0.1705, Val EER: 0.1054, Val F1: 0.9463
-> New

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.8347
Test EER: 0.1621
Test F1-Score: 0.8691


In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- LSTM + Cross-Attention Fusion Model ---
class LSTMAttentionFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using a BiLSTM and cross-attention.
    
    1. A BiLSTM processes the CQCC spectrogram as a sequence of feature vectors.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the BiLSTM's output sequence.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, lstm_hidden_dim=128, lstm_layers=2):
        """
        Args:
            cqcc_feature_dim (int): The number of CQCC coefficients.
            prosody_feature_dim (int): The number of prosodic features.
            lstm_hidden_dim (int): The hidden dimension size of the LSTM.
            lstm_layers (int): The number of LSTM layers.
        """
        super(LSTMAttentionFusionModel, self).__init__()
        
        self.lstm_output_dim = lstm_hidden_dim * 2 # Times 2 for bidirectional
        
        # --- 1. BiLSTM for CQCC Feature Extraction ---
        self.lstm = nn.LSTM(
            input_size=cqcc_feature_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if lstm_layers > 1 else 0
        )
        
        # --- 2. MLP for Prosodic Feature Processing ---
        # This MLP creates the query for the attention mechanism.
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, self.lstm_output_dim), # Output must match LSTM output dim for attention
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the context vector + the prosody query vector.
        # Both have size self.lstm_output_dim.
        classifier_input_dim = self.lstm_output_dim + self.lstm_output_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> Permute for LSTM: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC through LSTM to get the full sequence output (keys and values)
        # lstm_out shape: (batch, seq_len, num_directions * hidden_size)
        lstm_out, _ = self.lstm(cqcc_x)
        
        # 2. Process prosodic features to get the query
        # prosody_query shape: (batch, lstm_output_dim)
        prosody_query = self.prosody_mlp(prosody_x)

        # --- 3. Cross-Attention ---
        keys = lstm_out
        values = lstm_out
        
        # Reshape query for batch matrix multiplication
        # query shape: (batch, 1, lstm_output_dim)
        query_unsqueezed = prosody_query.unsqueeze(1)
        
        # Calculate attention scores: Q * K^T
        # (batch, 1, embed_dim) * (batch, embed_dim, seq_len) -> (batch, 1, seq_len)
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        attention_scores = attention_scores / (keys.size(-1) ** 0.5) # Scale
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector: Weights * V
        # (batch, 1, seq_len) * (batch, seq_len, lstm_output_dim) -> (batch, 1, lstm_output_dim)
        context = torch.bmm(attention_weights, values)
        context = context.squeeze(1) # Remove the sequence dimension
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Datasets ---
class AudioSpoofDatasetNPY(Dataset):
    """Custom Dataset for loading all data from .npy files (for training)."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more training feature files not found.")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Training data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

class AudioSpoofDatasetCombinedCSV(Dataset):
    """Custom Dataset for val/test: CQCC from .npy, prosody and labels from one combined .csv."""
    def __init__(self, cqcc_file, combined_csv_file):
        if not all(os.path.exists(f) for f in [cqcc_file, combined_csv_file]):
            raise FileNotFoundError(f"Validation or Test feature files not found: {cqcc_file}, {combined_csv_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        combined_data = pd.read_csv(combined_csv_file)
        
        if 'label' not in combined_data.columns:
            raise ValueError("The combined CSV file must contain a 'label' column.")
        self.labels = combined_data['label'].values
        
        metadata_cols = ['label']
        if 'filename' in combined_data.columns:
            metadata_cols.append('filename')
        if 'attack_id' in combined_data.columns:
            metadata_cols.append('attack_id')
        
        self.prosody_data = combined_data.drop(columns=metadata_cols).values

        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Training data files
    TRAIN_CQCC_FILE = 'processed_data/cqcc_features.npy'
    TRAIN_COMBINED_FILE = 'processed_data/prosodic_features_and_labels.csv'
    
    # Validation data files
    VAL_CQCC_FILE = 'processed_data/cqcc_features_val.npy'
    VAL_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_val.csv'

    # Test data files
    TEST_CQCC_FILE = 'processed_data/cqcc_features_test.npy'
    TEST_COMBINED_FILE = 'processed_data/prosodic_features_and_labels_test.csv'
    
    BATCH_SIZE = 128
    NUM_EPOCHS = 40
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5 # Added weight decay
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDatasetCombinedCSV(TRAIN_CQCC_FILE, TRAIN_COMBINED_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDatasetCombinedCSV(VAL_CQCC_FILE, VAL_COMBINED_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_dim = cqcc_sample.shape[0] 
    prosody_dim = prosody_sample.shape[0]

    # Use the LSTM Cross-Attention model
    model = LSTMAttentionFusionModel(
        cqcc_feature_dim=cqcc_dim,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    # Add weight_decay to the optimizer
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Add a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        # Step the scheduler based on the validation EER
        scheduler.step(val_eer)
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDatasetCombinedCSV(TEST_CQCC_FILE, TEST_COMBINED_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")


Using device: cuda
Training data loaded: 46019 samples.
Validation data loaded: 24844 samples.

--- Starting Training ---




Epoch 1/40 | Train Loss: 0.4166, Train Acc: 0.8100 | Val Loss: 0.2086, Val EER: 0.0789, Val F1: 0.9593
-> New best model saved with EER: 0.0789
Epoch 2/40 | Train Loss: 0.1128, Train Acc: 0.9708 | Val Loss: 0.1313, Val EER: 0.0639, Val F1: 0.9700
-> New best model saved with EER: 0.0639
Epoch 3/40 | Train Loss: 0.0696, Train Acc: 0.9834 | Val Loss: 0.1063, Val EER: 0.0722, Val F1: 0.9849
Epoch 4/40 | Train Loss: 0.0590, Train Acc: 0.9855 | Val Loss: 0.1127, Val EER: 0.0632, Val F1: 0.9821
-> New best model saved with EER: 0.0632
Epoch 5/40 | Train Loss: 0.0407, Train Acc: 0.9897 | Val Loss: 0.1549, Val EER: 0.0954, Val F1: 0.9832
Epoch 6/40 | Train Loss: 0.0323, Train Acc: 0.9914 | Val Loss: 0.1436, Val EER: 0.0836, Val F1: 0.9838
Epoch 7/40 | Train Loss: 0.0314, Train Acc: 0.9926 | Val Loss: 0.1613, Val EER: 0.0896, Val F1: 0.9850
Epoch 8/40 | Train Loss: 0.0222, Train Acc: 0.9944 | Val Loss: 0.1408, Val EER: 0.0808, Val F1: 0.9856
Epoch 9/40 | Train Loss: 0.0181, Train Acc: 0.9956 | 

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.6690
Test EER: 0.1052
Test F1-Score: 0.9066


In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- Positional Encoding for Transformer ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape is (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- Transformer + Cross-Attention Fusion Model ---
class TransformerAttentionFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using a Transformer and cross-attention.
    
    1. A Transformer Encoder processes the CQCC spectrogram as a sequence.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the Transformer's output sequence.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, nhead=8, num_encoder_layers=3, dim_feedforward=512, dropout=0.1):
        super(TransformerAttentionFusionModel, self).__init__()
        self.d_model = d_model
        
        # --- 1. CQCC Feature Processing Path ---
        self.cqcc_projection = nn.Linear(cqcc_feature_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        
        # --- 2. MLP for Prosodic Feature Processing ---
        # Updated to handle the actual prosody feature dimension
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256),  # Increased capacity for larger input
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_model), # Output dimension must match d_model for attention
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the context vector (size d_model) + the prosody query (size d_model)
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> Permute for Transformer: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC features
        cqcc_embed = self.cqcc_projection(cqcc_x) * math.sqrt(self.d_model)
        # Permute for positional encoding: (batch, seq, feature) -> (seq, batch, feature)
        cqcc_pos = self.pos_encoder(cqcc_embed.permute(1, 0, 2))
        # Permute back for transformer encoder: (seq, batch, feature) -> (batch, seq, feature)
        transformer_out = self.transformer_encoder(cqcc_pos.permute(1, 0, 2)) # This is our Key and Value
        
        # 2. Process prosodic features to get the query
        # Flatten prosody features if they have multiple dimensions
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        
        prosody_query = self.prosody_mlp(prosody_x)

        # --- 3. Cross-Attention ---
        keys = transformer_out
        values = transformer_out
        query_unsqueezed = prosody_query.unsqueeze(1)
        
        # Calculate attention scores: Q * K^T
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        attention_scores = attention_scores / (keys.size(-1) ** 0.5) # Scale
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector by applying weights to values
        context = torch.bmm(attention_weights, values)
        context = context.squeeze(1) # Remove the sequence dimension
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Dataset ---
class AudioSpoofDataset(Dataset):
    """Custom Dataset for loading CQCC, prosody, and labels from .npy files."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more feature files not found: {cqcc_file}, {prosody_file}, {labels_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"
        
        # Debug: Print shapes to understand the data structure
        print(f"CQCC data shape: {self.cqcc_data.shape}")
        print(f"Prosody data shape: {self.prosody_data.shape}")
        print(f"Labels shape: {self.labels.shape}")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Update these paths to your new .npy files
    
    # Training data files
    TRAIN_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_train.npy'
    TRAIN_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_train.npy'
    TRAIN_LABELS_FILE = 'processed_data_aligned_lld/labels_train.npy'
    
    # Validation data files
    VAL_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_dev.npy'
    VAL_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_dev.npy'
    VAL_LABELS_FILE = 'processed_data_aligned_lld/labels_dev.npy'

    # Test data files
    TEST_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_test.npy'
    TEST_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_test.npy'
    TEST_LABELS_FILE = 'processed_data_aligned_lld/labels_test.npy'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 40
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDataset(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDataset(VAL_CQCC_FILE, VAL_PROSODY_FILE, VAL_LABELS_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_dim = cqcc_sample.shape[0] 
    
    # Handle prosody dimensions properly
    if prosody_sample.dim() > 1:
        prosody_dim = prosody_sample.numel()  # Total number of elements if multi-dimensional
    else:
        prosody_dim = prosody_sample.shape[0]

    print(f"CQCC feature dimension: {cqcc_dim}")
    print(f"Prosody feature dimension: {prosody_dim}")
    print(f"Prosody sample shape: {prosody_sample.shape}")

    # Use the Transformer Cross-Attention model
    model = TransformerAttentionFusionModel(
        cqcc_feature_dim=cqcc_dim,
        prosody_feature_dim=prosody_dim
    ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        scheduler.step(val_eer)
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_model.pth'):
        try:
            test_dataset = AudioSpoofDataset(TEST_CQCC_FILE, TEST_PROSODY_FILE, TEST_LABELS_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_model.pth' found to test. Please run the training first.")

Using device: cuda
CQCC data shape: (46019, 128, 157)
Prosody data shape: (46019, 23, 157)
Labels shape: (46019,)
Training data loaded: 46019 samples.
CQCC data shape: (24844, 128, 157)
Prosody data shape: (24844, 23, 157)
Labels shape: (24844,)
Validation data loaded: 24844 samples.
CQCC feature dimension: 128
Prosody feature dimension: 3611
Prosody sample shape: torch.Size([23, 157])

--- Starting Training ---
Epoch 1/40 | Train Loss: 0.4842, Train Acc: 0.7736 | Val Loss: 0.5272, Val EER: 0.1462, Val F1: 0.4273
-> New best model saved with EER: 0.1462
Epoch 2/40 | Train Loss: 0.1555, Train Acc: 0.9392 | Val Loss: 0.2303, Val EER: 0.1083, Val F1: 0.6217
-> New best model saved with EER: 0.1083
Epoch 3/40 | Train Loss: 0.1001, Train Acc: 0.9625 | Val Loss: 0.5237, Val EER: 0.1122, Val F1: 0.4946
Epoch 4/40 | Train Loss: 0.0775, Train Acc: 0.9718 | Val Loss: 0.2234, Val EER: 0.1090, Val F1: 0.6642
Epoch 5/40 | Train Loss: 0.0613, Train Acc: 0.9781 | Val Loss: 0.1393, Val EER: 0.0738, Va

  model.load_state_dict(torch.load('best_model.pth'))



--- Test Results ---
Test Loss: 0.5361
Test EER: 0.0883
Test F1-Score: 0.7114


In [37]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- CNN + Cross-Attention Fusion Model ---
class CNNAttentionFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using a CNN and cross-attention.
    
    1. A CNN processes the CQCC spectrogram as a 2D feature map.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the CNN's output feature maps.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, dropout=0.1):
        super(CNNAttentionFusionModel, self).__init__()
        self.d_model = d_model
        
        # --- 1. CNN for CQCC Feature Processing ---
        self.cnn_layers = nn.Sequential(
            # First conv block
            nn.Conv1d(cqcc_feature_dim, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            # Second conv block
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            # Third conv block
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            # Fourth conv block
            nn.Conv1d(256, d_model, kernel_size=3, padding=1),
            nn.BatchNorm1d(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Global average pooling alternative (if needed)
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_model), # Output dimension must match d_model for attention
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the context vector (size d_model) + the prosody query (size d_model)
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) - perfect for Conv1d
        
        # 1. Process CQCC features with CNN
        cnn_out = self.cnn_layers(cqcc_x)  # Shape: (batch, d_model, frames)
        
        # Permute for attention: (batch, d_model, frames) -> (batch, frames, d_model)
        cnn_features = cnn_out.permute(0, 2, 1)  # This will be our Keys and Values
        
        # 2. Process prosodic features to get the query
        # Flatten prosody features if they have multiple dimensions
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        
        prosody_query = self.prosody_mlp(prosody_x)  # Shape: (batch, d_model)

        # --- 3. Cross-Attention ---
        keys = cnn_features  # Shape: (batch, frames, d_model)
        values = cnn_features  # Shape: (batch, frames, d_model)
        query_unsqueezed = prosody_query.unsqueeze(1)  # Shape: (batch, 1, d_model)
        
        # Calculate attention scores: Q * K^T
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))  # (batch, 1, frames)
        attention_scores = attention_scores / (keys.size(-1) ** 0.5)  # Scale
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector by applying weights to values
        context = torch.bmm(attention_weights, values)  # (batch, 1, d_model)
        context = context.squeeze(1)  # Remove the sequence dimension -> (batch, d_model)
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)  # (batch, 2*d_model)
        logits = self.classifier(fused_features)
        return logits

# --- Alternative CNN model with different architecture ---
class CNNAttentionFusionModelV2(nn.Module):
    """
    Alternative CNN architecture with residual connections and different structure.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, dropout=0.1):
        super(CNNAttentionFusionModelV2, self).__init__()
        self.d_model = d_model
        
        # --- 1. CNN with Residual Blocks for CQCC ---
        self.conv1 = nn.Conv1d(cqcc_feature_dim, 64, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        
        # Residual blocks
        self.res_block1 = self._make_res_block(64, 64, kernel_size=3)
        self.res_block2 = self._make_res_block(64, 128, kernel_size=3)
        self.res_block3 = self._make_res_block(128, 256, kernel_size=3)
        
        # Final projection to d_model
        self.final_conv = nn.Conv1d(256, d_model, kernel_size=1)
        self.final_bn = nn.BatchNorm1d(d_model)
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_model),
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )
        
        self.dropout = nn.Dropout(dropout)
    
    def _make_res_block(self, in_channels, out_channels, kernel_size=3):
        """Create a residual block."""
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(out_channels)
        )
    
    def forward(self, cqcc_x, prosody_x):
        # 1. Process CQCC features with CNN
        x = F.relu(self.bn1(self.conv1(cqcc_x)))
        x = self.dropout(x)
        
        # Apply residual blocks
        identity = x
        x = self.res_block1(x)
        if x.size(1) == identity.size(1):  # Same channels
            x = F.relu(x + identity)
        else:
            x = F.relu(x)
        
        identity = x
        x = self.res_block2(x)
        x = F.relu(x)  # Different channels, no residual connection
        
        identity = x
        x = self.res_block3(x)
        x = F.relu(x)  # Different channels, no residual connection
        
        # Final projection
        cnn_out = F.relu(self.final_bn(self.final_conv(x)))
        
        # Permute for attention: (batch, d_model, frames) -> (batch, frames, d_model)
        cnn_features = cnn_out.permute(0, 2, 1)
        
        # 2. Process prosodic features
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        
        prosody_query = self.prosody_mlp(prosody_x)
        
        # 3. Cross-Attention (same as in the first model)
        keys = cnn_features
        values = cnn_features
        query_unsqueezed = prosody_query.unsqueeze(1)
        
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        attention_scores = attention_scores / (keys.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.bmm(attention_weights, values).squeeze(1)
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Dataset ---
class AudioSpoofDataset(Dataset):
    """Custom Dataset for loading CQCC, prosody, and labels from .npy files."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more feature files not found: {cqcc_file}, {prosody_file}, {labels_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"
        
        # Debug: Print shapes to understand the data structure
        print(f"CQCC data shape: {self.cqcc_data.shape}")
        print(f"Prosody data shape: {self.prosody_data.shape}")
        print(f"Labels shape: {self.labels.shape}")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Update these paths to your new .npy files
    
    # Training data files
    TRAIN_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_train.npy'
    TRAIN_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_train.npy'
    TRAIN_LABELS_FILE = 'processed_data_aligned_lld/labels_train.npy'
    
    # Validation data files
    VAL_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_dev.npy'
    VAL_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_dev.npy'
    VAL_LABELS_FILE = 'processed_data_aligned_lld/labels_dev.npy'

    # Test data files
    TEST_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_test.npy'
    TEST_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_test.npy'
    TEST_LABELS_FILE = 'processed_data_aligned_lld/labels_test.npy'
    
    BATCH_SIZE = 128
    NUM_EPOCHS = 40
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5
    USE_RESNET_CNN = True  # Set to True to use the ResNet-style CNN (CNNAttentionFusionModelV2)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDataset(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDataset(VAL_CQCC_FILE, VAL_PROSODY_FILE, VAL_LABELS_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_dim = cqcc_sample.shape[0] 
    
    # Handle prosody dimensions properly
    if prosody_sample.dim() > 1:
        prosody_dim = prosody_sample.numel()  # Total number of elements if multi-dimensional
    else:
        prosody_dim = prosody_sample.shape[0]

    print(f"CQCC feature dimension: {cqcc_dim}")
    print(f"Prosody feature dimension: {prosody_dim}")
    print(f"Prosody sample shape: {prosody_sample.shape}")

    # Choose CNN model
    if USE_RESNET_CNN:
        print("Using ResNet-style CNN model")
        model = CNNAttentionFusionModelV2(
            cqcc_feature_dim=cqcc_dim,
            prosody_feature_dim=prosody_dim
        ).to(device)
    else:
        print("Using standard CNN model")
        model = CNNAttentionFusionModel(
            cqcc_feature_dim=cqcc_dim,
            prosody_feature_dim=prosody_dim
        ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        scheduler.step(val_eer)
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_cnn_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_cnn_model.pth'):
        try:
            test_dataset = AudioSpoofDataset(TEST_CQCC_FILE, TEST_PROSODY_FILE, TEST_LABELS_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_cnn_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_cnn_model.pth' found to test. Please run the training first.")

Using device: cuda
CQCC data shape: (46019, 128, 157)
Prosody data shape: (46019, 23, 157)
Labels shape: (46019,)
Training data loaded: 46019 samples.
CQCC data shape: (24844, 128, 157)
Prosody data shape: (24844, 23, 157)
Labels shape: (24844,)
Validation data loaded: 24844 samples.
CQCC feature dimension: 128
Prosody feature dimension: 3611
Prosody sample shape: torch.Size([23, 157])
Using ResNet-style CNN model

--- Starting Training ---




Epoch 1/40 | Train Loss: 0.5494, Train Acc: 0.7297 | Val Loss: 0.3105, Val EER: 0.1133, Val F1: 0.5731
-> New best model saved with EER: 0.1133
Epoch 2/40 | Train Loss: 0.2009, Train Acc: 0.9183 | Val Loss: 0.1604, Val EER: 0.1115, Val F1: 0.6115
-> New best model saved with EER: 0.1115
Epoch 3/40 | Train Loss: 0.1628, Train Acc: 0.9327 | Val Loss: 0.1556, Val EER: 0.1022, Val F1: 0.6588
-> New best model saved with EER: 0.1022
Epoch 4/40 | Train Loss: 0.1555, Train Acc: 0.9367 | Val Loss: 0.1521, Val EER: 0.1019, Val F1: 0.6597
-> New best model saved with EER: 0.1019
Epoch 5/40 | Train Loss: 0.1448, Train Acc: 0.9399 | Val Loss: 0.1564, Val EER: 0.1024, Val F1: 0.6594
Epoch 6/40 | Train Loss: 0.1115, Train Acc: 0.9591 | Val Loss: 0.1228, Val EER: 0.0644, Val F1: 0.8096
-> New best model saved with EER: 0.0644
Epoch 7/40 | Train Loss: 0.1374, Train Acc: 0.9465 | Val Loss: 0.1361, Val EER: 0.0852, Val F1: 0.7667
Epoch 8/40 | Train Loss: 0.0564, Train Acc: 0.9831 | Val Loss: 0.1248, Val

  model.load_state_dict(torch.load('best_cnn_model.pth'))



--- Test Results ---
Test Loss: 0.8916
Test EER: 0.1105
Test F1-Score: 0.6058


In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, f1_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn.functional as F
import os
import math

# --- BiLSTM + Cross-Attention Fusion Model ---
class BiLSTMAttentionFusionModel(nn.Module):
    """
    A model that fuses CQCC and prosodic features using BiLSTM and cross-attention.
    
    1. A BiLSTM processes the CQCC spectrogram as a temporal sequence.
    2. An MLP processes the 1D prosodic feature vector to create a query.
    3. The prosody query attends to the BiLSTM's output sequence.
    4. The resulting context vector is fused with the prosody query for classification.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, hidden_dim=128, num_layers=2, dropout=0.1):
        super(BiLSTMAttentionFusionModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # --- 1. BiLSTM for CQCC Feature Processing ---
        self.bilstm = nn.LSTM(
            input_size=cqcc_feature_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # BiLSTM outputs 2 * hidden_dim (forward + backward)
        self.bilstm_output_dim = 2 * hidden_dim
        
        # Optional: Project BiLSTM output to a specific dimension
        self.bilstm_projection = nn.Linear(self.bilstm_output_dim, hidden_dim)
        self.bilstm_dropout = nn.Dropout(dropout)
        
        # --- 2. MLP for Prosodic Feature Processing ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, hidden_dim), # Output dimension must match hidden_dim for attention
            nn.ReLU()
        )
        
        # --- 3. Classifier Head ---
        # The input is the context vector (size hidden_dim) + the prosody query (size hidden_dim)
        classifier_input_dim = hidden_dim + hidden_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> Permute for LSTM: (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC features with BiLSTM
        bilstm_out, (hidden, cell) = self.bilstm(cqcc_x)  # Shape: (batch, frames, 2*hidden_dim)
        
        # Project BiLSTM output to desired dimension
        bilstm_features = self.bilstm_projection(bilstm_out)  # Shape: (batch, frames, hidden_dim)
        bilstm_features = self.bilstm_dropout(bilstm_features)
        
        # 2. Process prosodic features to get the query
        # Flatten prosody features if they have multiple dimensions
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        
        prosody_query = self.prosody_mlp(prosody_x)  # Shape: (batch, hidden_dim)

        # --- 3. Cross-Attention ---
        keys = bilstm_features  # Shape: (batch, frames, hidden_dim)
        values = bilstm_features  # Shape: (batch, frames, hidden_dim)
        query_unsqueezed = prosody_query.unsqueeze(1)  # Shape: (batch, 1, hidden_dim)
        
        # Calculate attention scores: Q * K^T
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))  # (batch, 1, frames)
        attention_scores = attention_scores / (keys.size(-1) ** 0.5)  # Scale
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector by applying weights to values
        context = torch.bmm(attention_weights, values)  # (batch, 1, hidden_dim)
        context = context.squeeze(1)  # Remove the sequence dimension -> (batch, hidden_dim)
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)  # (batch, 2*hidden_dim)
        logits = self.classifier(fused_features)
        return logits

# --- Alternative BiLSTM model with multiple layers and attention ---
class BiLSTMAttentionFusionModelV2(nn.Module):
    """
    Advanced BiLSTM model with multiple processing layers and self-attention.
    """
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, hidden_dim=128, num_layers=3, dropout=0.1):
        super(BiLSTMAttentionFusionModelV2, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # --- 1. Multi-layer BiLSTM for CQCC ---
        self.bilstm1 = nn.LSTM(
            input_size=cqcc_feature_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Additional BiLSTM layer for more complex feature extraction
        self.bilstm2 = nn.LSTM(
            input_size=2 * hidden_dim,
            hidden_size=hidden_dim // 2,
            num_layers=2,
            batch_first=True,
            dropout=dropout,
            bidirectional=True
        )
        
        # Final output dimension after second BiLSTM
        self.bilstm_output_dim = hidden_dim  # (hidden_dim // 2) * 2 = hidden_dim
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(self.bilstm_output_dim)
        
        # --- 2. Enhanced MLP for Prosodic Features ---
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # --- 3. Multi-head attention (simplified) ---
        self.num_heads = 4
        self.head_dim = hidden_dim // self.num_heads
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        # --- 4. Enhanced Classifier ---
        classifier_input_dim = hidden_dim + hidden_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
        
        self.dropout = nn.Dropout(dropout)

    def multi_head_attention(self, query, keys, values):
        """Simplified multi-head attention mechanism."""
        batch_size, seq_len, hidden_dim = keys.size()
        
        # Project query, keys, values
        Q = self.query_proj(query).view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key_proj(keys).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value_proj(values).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, 1, hidden_dim)
        
        # Output projection
        output = self.out_proj(context)
        return output.squeeze(1)

    def forward(self, cqcc_x, prosody_x):
        # CQCC input shape: (batch, features, frames) -> (batch, frames, features)
        cqcc_x = cqcc_x.permute(0, 2, 1)
        
        # 1. Process CQCC features with stacked BiLSTMs
        bilstm_out1, _ = self.bilstm1(cqcc_x)
        bilstm_out1 = self.dropout(bilstm_out1)
        
        bilstm_out2, _ = self.bilstm2(bilstm_out1)
        bilstm_features = self.layer_norm(bilstm_out2)
        
        # 2. Process prosodic features
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        
        prosody_query = self.prosody_mlp(prosody_x)
        
        # 3. Multi-head Cross-Attention
        context = self.multi_head_attention(
            prosody_query.unsqueeze(1), 
            bilstm_features, 
            bilstm_features
        )
        
        # 4. Fusion and Classification
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits

# --- PyTorch Dataset ---
class AudioSpoofDataset(Dataset):
    """Custom Dataset for loading CQCC, prosody, and labels from .npy files."""
    def __init__(self, cqcc_file, prosody_file, labels_file):
        if not all(os.path.exists(f) for f in [cqcc_file, prosody_file, labels_file]):
            raise FileNotFoundError(f"One or more feature files not found: {cqcc_file}, {prosody_file}, {labels_file}")
        
        self.cqcc_data = np.load(cqcc_file)
        self.prosody_data = np.load(prosody_file)
        self.labels = np.load(labels_file)
        
        assert len(self.cqcc_data) == len(self.prosody_data) == len(self.labels), "Data length mismatch!"
        
        # Debug: Print shapes to understand the data structure
        print(f"CQCC data shape: {self.cqcc_data.shape}")
        print(f"Prosody data shape: {self.prosody_data.shape}")
        print(f"Labels shape: {self.labels.shape}")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        cqcc = torch.tensor(self.cqcc_data[idx], dtype=torch.float32)
        prosody = torch.tensor(self.prosody_data[idx], dtype=torch.float32)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return cqcc, prosody, label

# --- Evaluation Metric ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for cqcc, prosody, labels in dataloader:
        cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(cqcc, prosody)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping for LSTM stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for cqcc, prosody, labels in dataloader:
            cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
            outputs = model(cqcc, prosody)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            scores = torch.sigmoid(outputs).cpu().numpy()
            all_scores.extend(scores.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    y_true = np.array(all_labels)
    y_scores = np.array(all_scores)
    
    eer = calculate_eer(y_true, y_scores)
    y_pred = (y_scores > 0.5).astype(int)
    f1 = f1_score(y_true, y_pred)
    
    return avg_loss, eer, f1

# --- Main Pipeline ---
if __name__ == '__main__':
    # --- Configuration ---
    # Update these paths to your new .npy files
    
    # Training data files
    TRAIN_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_train.npy'
    TRAIN_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_train.npy'
    TRAIN_LABELS_FILE = 'processed_data_aligned_lld/labels_train.npy'
    
    # Validation data files
    VAL_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_dev.npy'
    VAL_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_dev.npy'
    VAL_LABELS_FILE = 'processed_data_aligned_lld/labels_dev.npy'

    # Test data files
    TEST_CQCC_FILE = 'processed_data_aligned_lld/cqcc_features_test.npy'
    TEST_PROSODY_FILE = 'processed_data_aligned_lld/egmaps_lld_features_test.npy'
    TEST_LABELS_FILE = 'processed_data_aligned_lld/labels_test.npy'
    
    BATCH_SIZE = 64
    NUM_EPOCHS = 30
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5
    USE_ADVANCED_BILSTM = True  # Set to True to use the advanced BiLSTM model (BiLSTMAttentionFusionModelV2)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 1. Load Data and Create DataLoaders ---
    try:
        train_dataset = AudioSpoofDataset(TRAIN_CQCC_FILE, TRAIN_PROSODY_FILE, TRAIN_LABELS_FILE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        print(f"Training data loaded: {len(train_dataset)} samples.")
        
        val_dataset = AudioSpoofDataset(VAL_CQCC_FILE, VAL_PROSODY_FILE, VAL_LABELS_FILE)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        print(f"Validation data loaded: {len(val_dataset)} samples.")

    except (FileNotFoundError, ValueError, AssertionError) as e:
        print(f"Error loading data: {e}")
        print("Please ensure you have run the feature extraction script correctly and that all file paths are correct.")
        exit()

    # --- 2. Initialize Model, Loss, and Optimizer ---
    cqcc_sample, prosody_sample, _ = train_dataset[0]
    cqcc_dim = cqcc_sample.shape[0] 
    
    # Handle prosody dimensions properly
    if prosody_sample.dim() > 1:
        prosody_dim = prosody_sample.numel()  # Total number of elements if multi-dimensional
    else:
        prosody_dim = prosody_sample.shape[0]

    print(f"CQCC feature dimension: {cqcc_dim}")
    print(f"Prosody feature dimension: {prosody_dim}")
    print(f"Prosody sample shape: {prosody_sample.shape}")

    # Choose BiLSTM model
    if USE_ADVANCED_BILSTM:
        print("Using Advanced BiLSTM model with multi-head attention")
        model = BiLSTMAttentionFusionModelV2(
            cqcc_feature_dim=cqcc_dim,
            prosody_feature_dim=prosody_dim,
            hidden_dim=128,
            num_layers=3
        ).to(device)
    else:
        print("Using Standard BiLSTM model")
        model = BiLSTMAttentionFusionModel(
            cqcc_feature_dim=cqcc_dim,
            prosody_feature_dim=prosody_dim,
            hidden_dim=128,
            num_layers=2
        ).to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # --- 3. Training Loop ---
    print("\n--- Starting Training ---")
    best_val_eer = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_eer, val_f1 = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val EER: {val_eer:.4f}, Val F1: {val_f1:.4f}")
        
        scheduler.step(val_eer)
        
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            torch.save(model.state_dict(), 'best_bilstm_model.pth')
            print(f"-> New best model saved with EER: {best_val_eer:.4f}")
            
    print("\n--- Training Complete ---")
    print(f"Best validation EER achieved: {best_val_eer:.4f}")

    # --- 4. Testing Loop ---
    print("\n--- Starting Testing ---")
    if os.path.exists('best_bilstm_model.pth'):
        try:
            test_dataset = AudioSpoofDataset(TEST_CQCC_FILE, TEST_PROSODY_FILE, TEST_LABELS_FILE)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            print(f"Test data loaded: {len(test_dataset)} samples.")

            print("Loading best model for testing...")
            model.load_state_dict(torch.load('best_bilstm_model.pth'))
            
            test_loss, test_eer, test_f1 = evaluate(model, test_loader, criterion, device)
            
            print("\n--- Test Results ---")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test EER: {test_eer:.4f}")
            print(f"Test F1-Score: {test_f1:.4f}")
        
        except (FileNotFoundError, ValueError, AssertionError) as e:
            print(f"Error loading test files: {e}")
            print("Please ensure your test feature files are generated and paths are correct.")

    else:
        print("No 'best_bilstm_model.pth' found to test. Please run the training first.")

Using device: cuda
CQCC data shape: (46019, 128, 157)
Prosody data shape: (46019, 23, 157)
Labels shape: (46019,)
Training data loaded: 46019 samples.
CQCC data shape: (24844, 128, 157)
Prosody data shape: (24844, 23, 157)
Labels shape: (24844,)
Validation data loaded: 24844 samples.
CQCC feature dimension: 128
Prosody feature dimension: 3611
Prosody sample shape: torch.Size([23, 157])
Using Advanced BiLSTM model with multi-head attention

--- Starting Training ---




Epoch 1/30 | Train Loss: 0.2955, Train Acc: 0.8596 | Val Loss: 0.0937, Val EER: 0.0582, Val F1: 0.8537
-> New best model saved with EER: 0.0582
Epoch 2/30 | Train Loss: 0.0610, Train Acc: 0.9870 | Val Loss: 0.1579, Val EER: 0.0494, Val F1: 0.8725
-> New best model saved with EER: 0.0494
Epoch 3/30 | Train Loss: 0.0400, Train Acc: 0.9934 | Val Loss: 0.2637, Val EER: 0.0648, Val F1: 0.8472
Epoch 4/30 | Train Loss: 0.0288, Train Acc: 0.9953 | Val Loss: 0.2921, Val EER: 0.0805, Val F1: 0.8638
Epoch 5/30 | Train Loss: 0.0245, Train Acc: 0.9961 | Val Loss: 0.2006, Val EER: 0.0718, Val F1: 0.8887
Epoch 6/30 | Train Loss: 0.0175, Train Acc: 0.9972 | Val Loss: 0.2955, Val EER: 0.0570, Val F1: 0.8563
Epoch 7/30 | Train Loss: 0.0125, Train Acc: 0.9982 | Val Loss: 0.2658, Val EER: 0.0586, Val F1: 0.8869
Epoch 8/30 | Train Loss: 0.0112, Train Acc: 0.9982 | Val Loss: 0.1765, Val EER: 0.0440, Val F1: 0.9128
-> New best model saved with EER: 0.0440
Epoch 9/30 | Train Loss: 0.0112, Train Acc: 0.9984 | 

  model.load_state_dict(torch.load('best_bilstm_model.pth'))



--- Test Results ---
Test Loss: 1.4918
Test EER: 0.1048
Test F1-Score: 0.6306


In [23]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import roc_curve, classification_report
from tqdm import tqdm
import numpy as np

# --- 1. Setup and Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Helper Function for EER ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    fnr = 1 - tpr
    # Find the absolute difference between fnr and fpr
    eer_index = np.nanargmin(np.abs(fnr - fpr))
    eer_threshold = thresholds[eer_index]
    eer = fpr[eer_index]
    return eer * 100 # Return as a percentage

# --- 2. Load and Prepare Data ---
print("Loading training data...")
X_train_np = np.load('processed_data_aligned_lld/cqcc_features_train.npy')
y_train_np = np.load('processed_data_aligned_lld/labels_train.npy')

print("Loading validation data...")
X_val_np = np.load('processed_data_aligned_lld/cqcc_features_dev.npy')
y_val_np = np.load('processed_data_aligned_lld/labels_dev.npy')

print("Loading test data...")
X_test_np = np.load('processed_data_aligned_lld/cqcc_features_test.npy')
y_test_np = np.load('processed_data_aligned_lld/labels_test.npy')


# Add channel dimension for Conv2D
X_train_np = np.expand_dims(X_train_np, axis=1)
X_val_np = np.expand_dims(X_val_np, axis=1)
X_test_np = np.expand_dims(X_test_np, axis=1)

# Convert to PyTorch Tensors
X_train = torch.from_numpy(X_train_np).float()
y_train = torch.from_numpy(y_train_np).float().view(-1, 1)
X_val = torch.from_numpy(X_val_np).float()
y_val = torch.from_numpy(y_val_np).float().view(-1, 1)
X_test = torch.from_numpy(X_test_np).float()
y_test = torch.from_numpy(y_test_np).float().view(-1, 1)

print(f"Training data tensor shape: {X_train.shape}")
print(f"Validation data tensor shape: {X_val.shape}")
print(f"Test data tensor shape: {X_test.shape}")


# Create DataLoaders
BATCH_SIZE = 32
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- 3. Define the Dynamic 2D-CNN Model ---
class CNN_2D(nn.Module):
    def __init__(self, input_height, input_width):
        super(CNN_2D, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2)
        )
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, input_height, input_width)
            dummy_output = self.feature_extractor(dummy_input)
            flattened_size = dummy_output.flatten(1).shape[1]
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

# --- Instantiate the Model ---
input_height = X_train.shape[2]
input_width = X_train.shape[3]
model = CNN_2D(input_height=input_height, input_width=input_width).to(device)
print("\nModel Architecture:")
print(model)

# --- 4. Define Loss Function and Optimizer ---
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# --- 5. Training and Validation Loop ---
EPOCHS = 20

for epoch in range(EPOCHS):
    # --- Training Phase ---
    model.train()
    total_train_loss = 0
    # Wrap train_loader with tqdm
    for data, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Training]"):
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        loss = criterion(outputs, targets)
        total_train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    avg_train_loss = total_train_loss / len(train_loader)

    # --- Validation Phase ---
    model.eval()
    total_val_loss = 0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        # Wrap val_loader with tqdm
        for data, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Validation]"):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            loss = criterion(outputs, targets)
            total_val_loss += loss.item()
            predicted = torch.sigmoid(outputs) > 0.5
            total_val += targets.size(0)
            correct_val += (predicted == targets).sum().item()
    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = (correct_val / total_val) * 100

    print(f"Epoch [{epoch+1}/{EPOCHS}] | Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | Val Accuracy: {val_accuracy:.2f}%")

print("\nTraining Finished!")

# --- 6. Final Evaluation on the Test Set ---
print("\nEvaluating the model on the final test set...")
model.eval()
all_labels = []
all_scores = []
all_preds = []
with torch.no_grad():
    # Wrap test_loader with tqdm
    for data, targets in tqdm(test_loader, desc="Final Evaluation"):
        data, targets = data.to(device), targets.to(device)
        
        # Get raw scores for EER
        outputs = model(data)
        scores = torch.sigmoid(outputs)
        
        # Get predictions for classification report
        predicted = scores > 0.5
        
        all_labels.extend(targets.cpu().numpy())
        all_scores.extend(scores.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Convert lists to numpy arrays
all_labels = np.array(all_labels)
all_scores = np.array(all_scores)
all_preds = np.array(all_preds)

# Calculate metrics
eer_score = calculate_eer(all_labels, all_scores)
report = classification_report(all_labels, all_preds, target_names=['Bona Fide (0)', 'Spoof (1)'])
accuracy = (all_preds == all_labels).mean() * 100

print("\n" + "="*30)
print("      Final Test Set Results")
print("="*30)
print(f"\nAccuracy: {accuracy:.2f}%")
print(f"EER (Equal Error Rate): {eer_score:.2f}%")
print("\nClassification Report:")
print(report)
print("="*30)


Using device: cuda
Loading training data...
Loading validation data...
Loading test data...
Training data tensor shape: torch.Size([46019, 1, 128, 157])
Validation data tensor shape: torch.Size([24844, 1, 128, 157])
Test data tensor shape: torch.Size([71237, 1, 128, 157])

Model Architecture:
CNN_2D(
  (feature_extractor): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): ReLU()
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (9): ReLU()
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1

Epoch 1/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.25it/s]
Epoch 1/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 523.95it/s]


Epoch [1/20] | Train Loss: 0.0908 | Val Loss: 0.0930 | Val Accuracy: 96.94%


Epoch 2/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 197.14it/s]
Epoch 2/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 522.09it/s]


Epoch [2/20] | Train Loss: 0.0167 | Val Loss: 0.0544 | Val Accuracy: 98.28%


Epoch 3/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 200.01it/s]
Epoch 3/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 521.65it/s]


Epoch [3/20] | Train Loss: 0.0090 | Val Loss: 0.0742 | Val Accuracy: 97.87%


Epoch 4/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.84it/s]
Epoch 4/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 522.85it/s]


Epoch [4/20] | Train Loss: 0.0050 | Val Loss: 0.0826 | Val Accuracy: 98.01%


Epoch 5/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.91it/s]
Epoch 5/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 522.19it/s]


Epoch [5/20] | Train Loss: 0.0062 | Val Loss: 0.0587 | Val Accuracy: 98.33%


Epoch 6/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 191.15it/s]
Epoch 6/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 524.88it/s]


Epoch [6/20] | Train Loss: 0.0046 | Val Loss: 0.1053 | Val Accuracy: 97.82%


Epoch 7/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.77it/s]
Epoch 7/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 528.68it/s]


Epoch [7/20] | Train Loss: 0.0035 | Val Loss: 0.0622 | Val Accuracy: 98.38%


Epoch 8/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 200.00it/s]
Epoch 8/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 527.98it/s]


Epoch [8/20] | Train Loss: 0.0058 | Val Loss: 0.2615 | Val Accuracy: 96.32%


Epoch 9/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.84it/s]
Epoch 9/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 529.15it/s]


Epoch [9/20] | Train Loss: 0.0019 | Val Loss: 0.1237 | Val Accuracy: 97.87%


Epoch 10/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.72it/s]
Epoch 10/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 528.45it/s]


Epoch [10/20] | Train Loss: 0.0046 | Val Loss: 0.0894 | Val Accuracy: 98.12%


Epoch 11/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.73it/s]
Epoch 11/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 527.80it/s]


Epoch [11/20] | Train Loss: 0.0019 | Val Loss: 0.1997 | Val Accuracy: 97.46%


Epoch 12/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.65it/s]
Epoch 12/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 527.10it/s]


Epoch [12/20] | Train Loss: 0.0038 | Val Loss: 0.1657 | Val Accuracy: 97.44%


Epoch 13/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.56it/s]
Epoch 13/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 525.19it/s]


Epoch [13/20] | Train Loss: 0.0027 | Val Loss: 0.2266 | Val Accuracy: 97.20%


Epoch 14/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.68it/s]
Epoch 14/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 528.36it/s]


Epoch [14/20] | Train Loss: 0.0020 | Val Loss: 0.1167 | Val Accuracy: 98.30%


Epoch 15/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.69it/s]
Epoch 15/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 524.16it/s]


Epoch [15/20] | Train Loss: 0.0028 | Val Loss: 0.1741 | Val Accuracy: 97.27%


Epoch 16/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.52it/s]
Epoch 16/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 529.48it/s]


Epoch [16/20] | Train Loss: 0.0033 | Val Loss: 0.1217 | Val Accuracy: 98.18%


Epoch 17/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.65it/s]
Epoch 17/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 521.83it/s]


Epoch [17/20] | Train Loss: 0.0006 | Val Loss: 0.1881 | Val Accuracy: 97.67%


Epoch 18/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.39it/s]
Epoch 18/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 522.24it/s]


Epoch [18/20] | Train Loss: 0.0021 | Val Loss: 0.1278 | Val Accuracy: 98.16%


Epoch 19/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.11it/s]
Epoch 19/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 520.43it/s]


Epoch [19/20] | Train Loss: 0.0014 | Val Loss: 0.3071 | Val Accuracy: 96.95%


Epoch 20/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 184.98it/s]
Epoch 20/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 523.48it/s]


Epoch [20/20] | Train Loss: 0.0018 | Val Loss: 0.4044 | Val Accuracy: 96.45%

Training Finished!

Evaluating the model on the final test set...


Final Evaluation: 100%|██████████| 2227/2227 [00:04<00:00, 528.79it/s]



      Final Test Set Results

Accuracy: 94.40%
EER (Equal Error Rate): 9.35%

Classification Report:
               precision    recall  f1-score   support

Bona Fide (0)       0.98      0.96      0.97     63882
    Spoof (1)       0.70      0.81      0.75      7355

     accuracy                           0.94     71237
    macro avg       0.84      0.88      0.86     71237
 weighted avg       0.95      0.94      0.95     71237



In [24]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import roc_curve, classification_report, f1_score
from tqdm import tqdm
import numpy as np

# --- 1. Setup and Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Helper Function for EER ---
def calculate_eer(y_true, y_scores):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1)
    fnr = 1 - tpr
    # Find the absolute difference between fnr and fpr
    eer_index = np.nanargmin(np.abs(fnr - fpr))
    eer_threshold = thresholds[eer_index]
    eer = fpr[eer_index]
    return eer * 100 # Return as a percentage

# --- 2. Load and Prepare Data ---
print("Loading training data...")
X_train_np = np.load('processed_data_aligned_lld/cqcc_features_train.npy')
y_train_np = np.load('processed_data_aligned_lld/labels_train.npy')

print("Loading validation data...")
X_val_np = np.load('processed_data_aligned_lld/cqcc_features_dev.npy')
y_val_np = np.load('processed_data_aligned_lld/labels_dev.npy')

print("Loading test data...")
X_test_np = np.load('processed_data_aligned_lld/cqcc_features_test.npy')
y_test_np = np.load('processed_data_aligned_lld/labels_test.npy')


# Add channel dimension for Conv2D
X_train_np = np.expand_dims(X_train_np, axis=1)
X_val_np = np.expand_dims(X_val_np, axis=1)
X_test_np = np.expand_dims(X_test_np, axis=1)

# Convert to PyTorch Tensors
X_train = torch.from_numpy(X_train_np).float()
y_train = torch.from_numpy(y_train_np).float().view(-1, 1)
X_val = torch.from_numpy(X_val_np).float()
y_val = torch.from_numpy(y_val_np).float().view(-1, 1)
X_test = torch.from_numpy(X_test_np).float()
y_test = torch.from_numpy(y_test_np).float().view(-1, 1)

print(f"Training data tensor shape: {X_train.shape}")
print(f"Validation data tensor shape: {X_val.shape}")
print(f"Test data tensor shape: {X_test.shape}")


# Create DataLoaders
BATCH_SIZE = 32
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- 3. Define the Dynamic 2D-CNN Model ---
class CNN_2D(nn.Module):
    def __init__(self, input_height, input_width):
        super(CNN_2D, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2)
        )
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, input_height, input_width)
            dummy_output = self.feature_extractor(dummy_input)
            flattened_size = dummy_output.flatten(1).shape[1]
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

# --- Instantiate the Model ---
input_height = X_train.shape[2]
input_width = X_train.shape[3]
model = CNN_2D(input_height=input_height, input_width=input_width).to(device)
print("\nModel Architecture:")
print(model)

# --- 4. Define Loss Function and Optimizer ---
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# --- 5. Training and Validation Loop ---
EPOCHS = 20

for epoch in range(EPOCHS):
    # --- Training Phase ---
    model.train()
    total_train_loss = 0
    # Wrap train_loader with tqdm
    for data, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Training]"):
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        loss = criterion(outputs, targets)
        total_train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    avg_train_loss = total_train_loss / len(train_loader)

    # --- Validation Phase ---
    model.eval()
    total_val_loss = 0
    all_val_labels = []
    all_val_preds = []
    with torch.no_grad():
        # Wrap val_loader with tqdm
        for data, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Validation]"):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            loss = criterion(outputs, targets)
            total_val_loss += loss.item()
            
            predicted = torch.sigmoid(outputs) > 0.5
            all_val_labels.extend(targets.cpu().numpy())
            all_val_preds.extend(predicted.cpu().numpy())

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = (np.array(all_val_preds) == np.array(all_val_labels)).mean() * 100
    val_f1 = f1_score(all_val_labels, all_val_preds)


    print(f"Epoch [{epoch+1}/{EPOCHS}] | Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | Val F1: {val_f1:.4f}")

print("\nTraining Finished!")

# --- 6. Final Evaluation on the Test Set ---
print("\nEvaluating the model on the final test set...")
model.eval()
all_labels = []
all_scores = []
all_preds = []
with torch.no_grad():
    # Wrap test_loader with tqdm
    for data, targets in tqdm(test_loader, desc="Final Evaluation"):
        data, targets = data.to(device), targets.to(device)
        
        # Get raw scores for EER
        outputs = model(data)
        scores = torch.sigmoid(outputs)
        
        # Get predictions for classification report
        predicted = scores > 0.5
        
        all_labels.extend(targets.cpu().numpy())
        all_scores.extend(scores.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Convert lists to numpy arrays
all_labels = np.array(all_labels)
all_scores = np.array(all_scores)
all_preds = np.array(all_preds)

# Calculate metrics
eer_score = calculate_eer(all_labels, all_scores)
report = classification_report(all_labels, all_preds, target_names=['Bona Fide (0)', 'Spoof (1)'])
accuracy = (all_preds == all_labels).mean() * 100

print("\n" + "="*30)
print("      Final Test Set Results")
print("="*30)
print(f"\nAccuracy: {accuracy:.2f}%")
print(f"EER (Equal Error Rate): {eer_score:.2f}%")
print("\nClassification Report:")
print(report)
print("="*30)


Using device: cuda
Loading training data...
Loading validation data...
Loading test data...
Training data tensor shape: torch.Size([46019, 1, 128, 157])
Validation data tensor shape: torch.Size([24844, 1, 128, 157])
Test data tensor shape: torch.Size([71237, 1, 128, 157])

Model Architecture:
CNN_2D(
  (feature_extractor): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): ReLU()
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (9): ReLU()
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1

Epoch 1/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.35it/s]
Epoch 1/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 521.24it/s]


Epoch [1/20] | Train Loss: 0.0796 | Val Loss: 0.0569 | Val Acc: 97.94% | Val F1: 0.8922


Epoch 2/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 186.45it/s]
Epoch 2/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 521.14it/s]


Epoch [2/20] | Train Loss: 0.0143 | Val Loss: 0.0453 | Val Acc: 98.32% | Val F1: 0.9166


Epoch 3/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.58it/s]
Epoch 3/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 522.95it/s]


Epoch [3/20] | Train Loss: 0.0086 | Val Loss: 0.0598 | Val Acc: 98.17% | Val F1: 0.9044


Epoch 4/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.48it/s]
Epoch 4/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 519.84it/s]


Epoch [4/20] | Train Loss: 0.0055 | Val Loss: 0.0872 | Val Acc: 98.06% | Val F1: 0.8972


Epoch 5/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.18it/s]
Epoch 5/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 520.25it/s]


Epoch [5/20] | Train Loss: 0.0046 | Val Loss: 0.1309 | Val Acc: 97.73% | Val F1: 0.8769


Epoch 6/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.72it/s]
Epoch 6/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 523.80it/s]


Epoch [6/20] | Train Loss: 0.0064 | Val Loss: 0.1187 | Val Acc: 98.01% | Val F1: 0.8935


Epoch 7/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.48it/s]
Epoch 7/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 524.07it/s]


Epoch [7/20] | Train Loss: 0.0035 | Val Loss: 0.0478 | Val Acc: 98.62% | Val F1: 0.9324


Epoch 8/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.59it/s]
Epoch 8/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 517.37it/s]


Epoch [8/20] | Train Loss: 0.0030 | Val Loss: 0.1144 | Val Acc: 98.16% | Val F1: 0.9020


Epoch 9/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.36it/s]
Epoch 9/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 519.13it/s]


Epoch [9/20] | Train Loss: 0.0041 | Val Loss: 0.1896 | Val Acc: 97.33% | Val F1: 0.8507


Epoch 10/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.07it/s]
Epoch 10/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 518.23it/s]


Epoch [10/20] | Train Loss: 0.0026 | Val Loss: 0.1158 | Val Acc: 98.27% | Val F1: 0.9090


Epoch 11/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.01it/s]
Epoch 11/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 519.08it/s]


Epoch [11/20] | Train Loss: 0.0022 | Val Loss: 0.3549 | Val Acc: 96.68% | Val F1: 0.8067


Epoch 12/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.32it/s]
Epoch 12/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 517.12it/s]


Epoch [12/20] | Train Loss: 0.0036 | Val Loss: 0.1159 | Val Acc: 98.10% | Val F1: 0.8991


Epoch 13/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.24it/s]
Epoch 13/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 527.32it/s]


Epoch [13/20] | Train Loss: 0.0008 | Val Loss: 0.2221 | Val Acc: 97.27% | Val F1: 0.8467


Epoch 14/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 198.88it/s]
Epoch 14/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 517.58it/s]


Epoch [14/20] | Train Loss: 0.0035 | Val Loss: 0.1377 | Val Acc: 97.71% | Val F1: 0.8745


Epoch 15/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.03it/s]
Epoch 15/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 519.90it/s]


Epoch [15/20] | Train Loss: 0.0023 | Val Loss: 0.4571 | Val Acc: 95.11% | Val F1: 0.6871


Epoch 16/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.03it/s]
Epoch 16/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 518.07it/s]


Epoch [16/20] | Train Loss: 0.0025 | Val Loss: 0.2509 | Val Acc: 97.06% | Val F1: 0.8330


Epoch 17/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.30it/s]
Epoch 17/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 523.76it/s]


Epoch [17/20] | Train Loss: 0.0029 | Val Loss: 0.0907 | Val Acc: 98.60% | Val F1: 0.9284


Epoch 18/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 199.13it/s]
Epoch 18/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 519.95it/s]


Epoch [18/20] | Train Loss: 0.0019 | Val Loss: 0.1879 | Val Acc: 97.74% | Val F1: 0.8763


Epoch 19/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 198.87it/s]
Epoch 19/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 520.67it/s]


Epoch [19/20] | Train Loss: 0.0019 | Val Loss: 0.3318 | Val Acc: 96.57% | Val F1: 0.8001


Epoch 20/20 [Training]: 100%|██████████| 1439/1439 [00:07<00:00, 185.09it/s]
Epoch 20/20 [Validation]: 100%|██████████| 777/777 [00:01<00:00, 521.03it/s]


Epoch [20/20] | Train Loss: 0.0022 | Val Loss: 0.1723 | Val Acc: 97.85% | Val F1: 0.8832

Training Finished!

Evaluating the model on the final test set...


Final Evaluation: 100%|██████████| 2227/2227 [00:04<00:00, 526.41it/s]



      Final Test Set Results

Accuracy: 93.17%
EER (Equal Error Rate): 8.44%

Classification Report:
               precision    recall  f1-score   support

Bona Fide (0)       0.99      0.94      0.96     63882
    Spoof (1)       0.62      0.89      0.73      7355

     accuracy                           0.93     71237
    macro avg       0.80      0.91      0.84     71237
 weighted avg       0.95      0.93      0.94     71237



In [30]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import shap
import math
import torch.nn.functional as F

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/TransformerFusion_PyTorch_Best_23feat.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics_transformer_23feat.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance_transformer_23feat.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance_transformer_23feat.png"
SHAP_PLOT_PATH = "saved_models/shap_importance_transformer_23feat.png"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))
    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')
    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')
    fig.tight_layout()
    plt.title('Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

# --- Positional Encoding for Transformer ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x is expected to be of shape (seq_len, batch_size, d_model)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- CORRECTED: Transformer + Cross-Attention Fusion Model ---
class TransformerAttentionFusionModel(nn.Module):
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, nhead=4, num_encoder_layers=3, dim_feedforward=256, dropout=0.1):
        super(TransformerAttentionFusionModel, self).__init__()
        self.d_model = d_model

        # This layer projects the input CQCC features into the model's dimension (d_model)
        self.cqcc_projection = nn.Linear(cqcc_feature_dim, d_model)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # TransformerEncoderLayer expects input as (N, S, E) with batch_first=True
        # N=batch_size, S=sequence_length, E=embedding_dim (d_model)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, d_model), nn.ReLU()
        )
        
        # The classifier takes the concatenated features from attention context and prosody query
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # The input cqcc_x has shape (batch_size, seq_len, feature_dim)
        # The projection layer expects the feature_dim to be the last dimension, which is correct.
        # --- FIX: The incorrect permute operation has been removed. ---
        # INCORRECT LINE: cqcc_x = cqcc_x.permute(0, 2, 1) 
        
        # 1. Project CQCC features to the model dimension
        # Input: (batch_size, seq_len, cqcc_feature_dim) -> Output: (batch_size, seq_len, d_model)
        cqcc_embed = self.cqcc_projection(cqcc_x) * math.sqrt(self.d_model)

        # 2. Add positional encoding.
        # The PositionalEncoding module expects (seq_len, batch_size, d_model).
        # So we permute the dimensions from (N, S, E) to (S, N, E).
        cqcc_embed_permuted = cqcc_embed.permute(1, 0, 2)
        cqcc_pos = self.pos_encoder(cqcc_embed_permuted)

        # 3. Pass through the Transformer Encoder.
        # The TransformerEncoder (with batch_first=True) expects (batch_size, seq_len, d_model).
        # So we permute back from (S, N, E) to (N, S, E).
        transformer_input = cqcc_pos.permute(1, 0, 2)
        transformer_out = self.transformer_encoder(transformer_input)
        
        # 4. Process prosody features to create a query vector
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        prosody_query = self.prosody_mlp(prosody_x)

        # 5. Perform cross-attention
        # Query: prosody_query (batch_size, d_model)
        # Keys/Values: transformer_out (batch_size, seq_len, d_model)
        keys = values = transformer_out
        query_unsqueezed = prosody_query.unsqueeze(1) # Shape: (batch_size, 1, d_model)
        
        # Calculate attention scores: Q * K^T
        # (B, 1, D) @ (B, D, S) -> (B, 1, S)
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2)) / (keys.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Calculate context vector: AttentionWeights * V
        # (B, 1, S) @ (B, S, D) -> (B, 1, D)
        context = torch.bmm(attention_weights, values).squeeze(1) # Shape: (batch_size, d_model)
        
        # 6. Fuse features and classify
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        
        return logits, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, device, save_path):
    print("\n--- Running Cross-Attention Weight Analysis ---")
    model.eval()
    all_weights = []
    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            all_weights.append(weights.squeeze(1).cpu().numpy())
    
    avg_weights = np.mean(np.concatenate(all_weights, axis=0), axis=0)
    
    plt.figure(figsize=(15, 6))
    plt.plot(avg_weights, color='purple')
    plt.xlabel('CQCC Time Frame')
    plt.ylabel('Average Attention Weight')
    plt.title('Cross-Attention: Importance of Acoustic Time Frames Guided by Prosody')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAttention plot saved to {save_path}")
    plt.close()


def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    print("\n--- Running Feature Ablation Analysis ---")
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody[:, feature_to_ablate] = 0.0
                logits, _ = model(cqcc, prosody)
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")
    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    print("\n--- Running SHAP Analysis ---")
    model.eval()
    # Use a larger background set for more stable SHAP values if possible
    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        
        # Repeat the first CQCC sample from the background set for each prosody sample
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1) # Adjusted repeat for 3D tensor
        
        with torch.no_grad():
            logits, _ = model(cqcc_tensor, prosody_tensor)
            output = torch.sigmoid(logits)
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    
    # Use a dedicated figure for shap plot to avoid conflicts
    plt.figure() 
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()

# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        X_cqcc_train = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)
        X_cqcc_val = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)
        
        print("Converting 3D LLD prosodic features to 2D summary statistics (mean)...")
        X_prosody_train = np.mean(X_prosody_train_3d, axis=2)
        X_prosody_val = np.mean(X_prosody_val_3d, axis=2)
        
        feature_columns = [
            'Loudness_sma3','alphaRatio_sma3','hammarbergIndex_sma3','slope0-500_sma3',
            'slope500-1500_sma3','spectralFlux_sma3','mfcc1_sma3','mfcc2_sma3',
            'mfcc3_sma3','mfcc4_sma3','F0semitoneFrom27.5Hz_sma3nz','jitterLocal_sma3nz',
            'shimmerLocaldB_sma3nz','HNRdBACF_sma3nz','logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz','F1frequency_sma3nz','F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz','F2frequency_sma3nz','F2amplitudeLogRelF0_sma3nz',
            'F3frequency_sma3nz','F3amplitudeLogRelF0_sma3nz'
        ]
        num_prosodic_features = X_prosody_train.shape[1]
        if len(feature_columns) != num_prosodic_features:
            print(f"Warning: Number of feature names ({len(feature_columns)}) does not match number of features ({num_prosodic_features}). Using generic names.")
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("--- Scaling Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val)
    scaler_cqcc = StandardScaler()
    ns, nx, ny = X_cqcc_train.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train.reshape(ns, -1)).reshape(ns, nx, ny)
    nsv, nxv, nyv = X_cqcc_val.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val.reshape(nsv, -1)).reshape(nsv, nxv, nyv)

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = TransformerAttentionFusionModel(
        cqcc_feature_dim=X_cqcc_train.shape[2], # feature_dim is the last dimension
        prosody_feature_dim=X_prosody_train.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            logits, _ = model(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)
        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        cm = confusion_matrix(all_labels, all_preds)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        history.update({'train_loss': history['train_loss']+[avg_train_loss], 'val_loss': history['val_loss']+[avg_val_loss], 'val_acc': history['val_acc']+[val_accuracy], 'f1': history['f1']+[f1], 'eer': history['eer']+[eer]})
        scheduler.step(avg_val_loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    plot_training_history(history, PLOT_SAVE_PATH)

    # --- FINAL TESTING AND ANALYSIS ---
    print("\n--- Starting Final Testing and Analysis ---")
    try:
        X_cqcc_test = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        X_prosody_test = np.mean(X_prosody_test_3d, axis=2)
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test)
        ns_test, nx_test, ny_test = X_cqcc_test.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test.reshape(ns_test, -1)).reshape(ns_test, nx_test, ny_test)
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        analysis_model = TransformerAttentionFusionModel(
            cqcc_feature_dim=X_cqcc_train.shape[2],
            prosody_feature_dim=X_prosody_train.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()

        all_test_labels, all_test_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in tqdm(test_loader, desc="Final Testing"):
                cqcc, prosody = cqcc.to(DEVICE), prosody.to(DEVICE)
                logits, _ = analysis_model(cqcc, prosody)
                all_test_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_test_labels.extend(labels.cpu().numpy())
        
        all_test_labels, all_test_scores = np.array(all_test_labels), np.array(all_test_scores).flatten()
        all_test_preds = (all_test_scores > 0.5).astype(int)
        test_accuracy = 100 * np.sum(all_test_preds == all_test_labels) / len(all_test_labels)
        test_f1 = f1_score(all_test_labels, all_test_preds)
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        test_cm = confusion_matrix(all_test_labels, all_test_preds)
        print("\n--- Final Test Results ---")
        print(f"Accuracy: {test_accuracy:.2f}% | F1-Score: {test_f1:.4f} | EER: {test_eer:.2f}%")
        print("Confusion Matrix:\n", test_cm)

        analyze_attention_weights(analysis_model, test_loader, DEVICE, ATTENTION_PLOT_PATH)
        perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except Exception as e:
        import traceback
        print(f"Error during testing/analysis: {e}")
        traceback.print_exc()


Using device: cuda
--- Loading and Preparing Data ---
Converting 3D LLD prosodic features to 2D summary statistics (mean)...
--- Scaling Data ---


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


TransformerAttentionFusionModel(
  (cqcc_projection): Linear(in_features=157, out_features=128, bias=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (prosody_mlp): Sequential(
    (0): Linear(in_features=23, out_features=256, bias=True)


Epoch 1/40: 100%|██████████| 720/720 [00:09<00:00, 78.13it/s]



Epoch 1 | Train Loss: 0.3237 | Val Loss: 0.2048 | Val Acc: 88.92% | F1: 0.5934 | EER: 12.72%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 2/40: 100%|██████████| 720/720 [00:10<00:00, 68.30it/s]



Epoch 2 | Train Loss: 0.1707 | Val Loss: 0.1676 | Val Acc: 91.87% | F1: 0.6395 | EER: 12.02%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 3/40: 100%|██████████| 720/720 [00:11<00:00, 61.76it/s]



Epoch 3 | Train Loss: 0.1379 | Val Loss: 0.1472 | Val Acc: 93.25% | F1: 0.6425 | EER: 10.79%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 4/40: 100%|██████████| 720/720 [00:10<00:00, 67.24it/s]



Epoch 4 | Train Loss: 0.1151 | Val Loss: 0.1628 | Val Acc: 93.27% | F1: 0.6953 | EER: 10.91%


Epoch 5/40: 100%|██████████| 720/720 [00:11<00:00, 64.49it/s]



Epoch 5 | Train Loss: 0.0971 | Val Loss: 0.1350 | Val Acc: 94.47% | F1: 0.7225 | EER: 10.06%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 6/40: 100%|██████████| 720/720 [00:11<00:00, 64.80it/s]



Epoch 6 | Train Loss: 0.0824 | Val Loss: 0.1474 | Val Acc: 94.67% | F1: 0.7077 | EER: 9.93%


Epoch 7/40: 100%|██████████| 720/720 [00:10<00:00, 70.98it/s]



Epoch 7 | Train Loss: 0.0736 | Val Loss: 0.1640 | Val Acc: 94.36% | F1: 0.6724 | EER: 9.73%


Epoch 8/40: 100%|██████████| 720/720 [00:08<00:00, 80.55it/s] 



Epoch 8 | Train Loss: 0.0658 | Val Loss: 0.1484 | Val Acc: 94.76% | F1: 0.7003 | EER: 9.63%


Epoch 9/40: 100%|██████████| 720/720 [00:08<00:00, 81.04it/s] 



Epoch 9 | Train Loss: 0.0611 | Val Loss: 0.1574 | Val Acc: 95.22% | F1: 0.7331 | EER: 9.18%


Epoch 10/40: 100%|██████████| 720/720 [00:08<00:00, 81.45it/s] 



Epoch 10 | Train Loss: 0.0540 | Val Loss: 0.1772 | Val Acc: 94.73% | F1: 0.6871 | EER: 9.58%


Epoch 11/40: 100%|██████████| 720/720 [00:08<00:00, 87.33it/s] 



Epoch 11 | Train Loss: 0.0464 | Val Loss: 0.2107 | Val Acc: 94.59% | F1: 0.6733 | EER: 9.24%


Epoch 12/40: 100%|██████████| 720/720 [00:08<00:00, 87.04it/s] 



Epoch 12 | Train Loss: 0.0346 | Val Loss: 0.1801 | Val Acc: 95.26% | F1: 0.7350 | EER: 9.18%


Epoch 13/40: 100%|██████████| 720/720 [00:05<00:00, 138.38it/s]



Epoch 13 | Train Loss: 0.0309 | Val Loss: 0.2129 | Val Acc: 95.00% | F1: 0.7037 | EER: 9.14%


Epoch 14/40: 100%|██████████| 720/720 [00:07<00:00, 102.70it/s]



Epoch 14 | Train Loss: 0.0315 | Val Loss: 0.2098 | Val Acc: 95.01% | F1: 0.7048 | EER: 8.95%


Epoch 15/40: 100%|██████████| 720/720 [00:09<00:00, 79.12it/s] 



Epoch 15 | Train Loss: 0.0291 | Val Loss: 0.2396 | Val Acc: 94.76% | F1: 0.6769 | EER: 9.00%


Epoch 16/40: 100%|██████████| 720/720 [00:08<00:00, 80.54it/s] 



Epoch 16 | Train Loss: 0.0283 | Val Loss: 0.2357 | Val Acc: 94.85% | F1: 0.6835 | EER: 8.95%


Epoch 17/40: 100%|██████████| 720/720 [00:08<00:00, 83.23it/s] 



Epoch 17 | Train Loss: 0.0266 | Val Loss: 0.2133 | Val Acc: 95.14% | F1: 0.7145 | EER: 8.92%


Epoch 18/40: 100%|██████████| 720/720 [00:07<00:00, 102.42it/s]



Epoch 18 | Train Loss: 0.0247 | Val Loss: 0.2171 | Val Acc: 95.11% | F1: 0.7099 | EER: 8.81%


Epoch 19/40: 100%|██████████| 720/720 [00:07<00:00, 90.95it/s] 



Epoch 19 | Train Loss: 0.0235 | Val Loss: 0.2084 | Val Acc: 95.19% | F1: 0.7189 | EER: 8.82%


Epoch 20/40: 100%|██████████| 720/720 [00:05<00:00, 131.26it/s]



Epoch 20 | Train Loss: 0.0236 | Val Loss: 0.2227 | Val Acc: 95.13% | F1: 0.7109 | EER: 8.67%


Epoch 21/40: 100%|██████████| 720/720 [00:09<00:00, 73.62it/s] 



Epoch 21 | Train Loss: 0.0223 | Val Loss: 0.2408 | Val Acc: 94.94% | F1: 0.6928 | EER: 8.77%


Epoch 22/40: 100%|██████████| 720/720 [00:08<00:00, 88.99it/s] 



Epoch 22 | Train Loss: 0.0236 | Val Loss: 0.2189 | Val Acc: 95.13% | F1: 0.7137 | EER: 8.99%


Epoch 23/40: 100%|██████████| 720/720 [00:08<00:00, 86.79it/s] 



Epoch 23 | Train Loss: 0.0218 | Val Loss: 0.2309 | Val Acc: 95.07% | F1: 0.7050 | EER: 8.83%


Epoch 24/40: 100%|██████████| 720/720 [00:08<00:00, 87.21it/s] 



Epoch 24 | Train Loss: 0.0229 | Val Loss: 0.2288 | Val Acc: 95.05% | F1: 0.7043 | EER: 8.86%


Epoch 25/40: 100%|██████████| 720/720 [00:05<00:00, 139.60it/s]



Epoch 25 | Train Loss: 0.0228 | Val Loss: 0.2239 | Val Acc: 95.14% | F1: 0.7124 | EER: 8.87%


Epoch 26/40: 100%|██████████| 720/720 [00:05<00:00, 142.02it/s]



Epoch 26 | Train Loss: 0.0219 | Val Loss: 0.2284 | Val Acc: 95.08% | F1: 0.7064 | EER: 8.81%


Epoch 27/40: 100%|██████████| 720/720 [00:05<00:00, 138.84it/s]



Epoch 27 | Train Loss: 0.0216 | Val Loss: 0.2293 | Val Acc: 95.09% | F1: 0.7066 | EER: 8.86%


Epoch 28/40: 100%|██████████| 720/720 [00:05<00:00, 138.73it/s]



Epoch 28 | Train Loss: 0.0226 | Val Loss: 0.2232 | Val Acc: 95.13% | F1: 0.7122 | EER: 8.85%


Epoch 29/40: 100%|██████████| 720/720 [00:07<00:00, 100.68it/s]



Epoch 29 | Train Loss: 0.0216 | Val Loss: 0.2251 | Val Acc: 95.12% | F1: 0.7107 | EER: 8.83%


Epoch 30/40: 100%|██████████| 720/720 [00:10<00:00, 70.48it/s] 



Epoch 30 | Train Loss: 0.0211 | Val Loss: 0.2273 | Val Acc: 95.11% | F1: 0.7091 | EER: 8.83%


Epoch 31/40: 100%|██████████| 720/720 [00:09<00:00, 75.42it/s] 



Epoch 31 | Train Loss: 0.0230 | Val Loss: 0.2279 | Val Acc: 95.09% | F1: 0.7077 | EER: 8.83%


Epoch 32/40: 100%|██████████| 720/720 [00:08<00:00, 81.27it/s] 



Epoch 32 | Train Loss: 0.0223 | Val Loss: 0.2289 | Val Acc: 95.09% | F1: 0.7072 | EER: 8.83%


Epoch 33/40: 100%|██████████| 720/720 [00:08<00:00, 80.93it/s] 



Epoch 33 | Train Loss: 0.0200 | Val Loss: 0.2274 | Val Acc: 95.10% | F1: 0.7082 | EER: 8.83%


Epoch 34/40: 100%|██████████| 720/720 [00:08<00:00, 83.43it/s] 



Epoch 34 | Train Loss: 0.0217 | Val Loss: 0.2276 | Val Acc: 95.10% | F1: 0.7084 | EER: 8.84%


Epoch 35/40: 100%|██████████| 720/720 [00:08<00:00, 86.37it/s] 



Epoch 35 | Train Loss: 0.0207 | Val Loss: 0.2285 | Val Acc: 95.09% | F1: 0.7070 | EER: 8.84%


Epoch 36/40: 100%|██████████| 720/720 [00:08<00:00, 81.56it/s] 



Epoch 36 | Train Loss: 0.0217 | Val Loss: 0.2283 | Val Acc: 95.08% | F1: 0.7066 | EER: 8.83%


Epoch 37/40: 100%|██████████| 720/720 [00:07<00:00, 93.81it/s] 



Epoch 37 | Train Loss: 0.0213 | Val Loss: 0.2286 | Val Acc: 95.09% | F1: 0.7070 | EER: 8.83%


Epoch 38/40: 100%|██████████| 720/720 [00:08<00:00, 87.32it/s] 



Epoch 38 | Train Loss: 0.0216 | Val Loss: 0.2284 | Val Acc: 95.08% | F1: 0.7068 | EER: 8.83%


Epoch 39/40: 100%|██████████| 720/720 [00:05<00:00, 139.71it/s]



Epoch 39 | Train Loss: 0.0219 | Val Loss: 0.2284 | Val Acc: 95.08% | F1: 0.7068 | EER: 8.83%


Epoch 40/40: 100%|██████████| 720/720 [00:07<00:00, 90.35it/s] 



Epoch 40 | Train Loss: 0.0212 | Val Loss: 0.2285 | Val Acc: 95.08% | F1: 0.7068 | EER: 8.83%

Training plot saved to saved_models/training_metrics_transformer_23feat.png

--- Starting Final Testing and Analysis ---


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Final Testing: 100%|██████████| 1114/1114 [00:02<00:00, 433.58it/s]



--- Final Test Results ---
Accuracy: 86.87% | F1-Score: 0.5660 | EER: 14.14%
Confusion Matrix:
 [[55783  8099]
 [ 1255  6100]]

--- Running Cross-Attention Weight Analysis ---


Analyzing Attention: 100%|██████████| 1114/1114 [00:03<00:00, 292.22it/s]



Attention plot saved to saved_models/attention_importance_transformer_23feat.png

--- Running Feature Ablation Analysis ---
Baseline EER with all features: 14.14%


Performing Ablation: 100%|██████████| 23/23 [01:00<00:00,  2.65s/it]



Feature Importance based on EER Increase:
- slope500-1500_sma3: EER increases by 0.88%
- spectralFlux_sma3: EER increases by 0.58%
- mfcc1_sma3: EER increases by 0.49%
- hammarbergIndex_sma3: EER increases by 0.45%
- slope0-500_sma3: EER increases by 0.38%
- F3frequency_sma3nz: EER increases by 0.37%
- F1amplitudeLogRelF0_sma3nz: EER increases by 0.33%
- HNRdBACF_sma3nz: EER increases by 0.27%
- F1frequency_sma3nz: EER increases by 0.24%
- F0semitoneFrom27.5Hz_sma3nz: EER increases by 0.23%
- F2amplitudeLogRelF0_sma3nz: EER increases by 0.19%
- logRelF0-H1-A3_sma3nz: EER increases by 0.18%
- F3amplitudeLogRelF0_sma3nz: EER increases by 0.15%
- shimmerLocaldB_sma3nz: EER increases by 0.14%
- Loudness_sma3: EER increases by 0.13%
- mfcc3_sma3: EER increases by 0.12%
- mfcc4_sma3: EER increases by 0.06%
- jitterLocal_sma3nz: EER increases by 0.02%
- F2frequency_sma3nz: EER increases by 0.00%
- logRelF0-H1-H2_sma3nz: EER increases by -0.03%
- mfcc2_sma3: EER increases by -0.11%
- F1bandwi

  0%|          | 0/64 [00:00<?, ?it/s]

Plotting SHAP summary...


No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored
The figure layout has changed to tight


SHAP plot saved to saved_models/shap_importance_transformer_23feat.png


In [31]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import shap
import math
import torch.nn.functional as F

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/TransformerFusion_PyTorch_Best_23feat.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics_transformer_23feat.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance_transformer_23feat.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance_transformer_23feat.png"
SHAP_PLOT_PATH = "saved_models/shap_importance_transformer_23feat.png"
# --- MODIFIED: Added paths for the new model trained on top 6 features ---
MODEL_SAVE_PATH_TOP6 = "saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth"
PLOT_SAVE_PATH_TOP6 = "saved_models/training_metrics_transformer_Top6feat.png"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path, title_prefix=""):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))
    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')
    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')
    fig.tight_layout()
    plt.title(f'{title_prefix} Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]

# --- Positional Encoding for Transformer ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- Transformer + Cross-Attention Fusion Model ---
class TransformerAttentionFusionModel(nn.Module):
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, d_model=128, nhead=4, num_encoder_layers=3, dim_feedforward=256, dropout=0.1):
        super(TransformerAttentionFusionModel, self).__init__()
        self.d_model = d_model
        self.cqcc_projection = nn.Linear(cqcc_feature_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, d_model), nn.ReLU()
        )
        classifier_input_dim = d_model + d_model
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        cqcc_embed = self.cqcc_projection(cqcc_x) * math.sqrt(self.d_model)
        cqcc_embed_permuted = cqcc_embed.permute(1, 0, 2)
        cqcc_pos = self.pos_encoder(cqcc_embed_permuted)
        transformer_input = cqcc_pos.permute(1, 0, 2)
        transformer_out = self.transformer_encoder(transformer_input)
        if prosody_x.dim() > 2:
            prosody_x = prosody_x.view(prosody_x.size(0), -1)
        prosody_query = self.prosody_mlp(prosody_x)
        keys = values = transformer_out
        query_unsqueezed = prosody_query.unsqueeze(1)
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2)) / (keys.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.bmm(attention_weights, values).squeeze(1)
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        return logits, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, device, save_path):
    print("\n--- Running Cross-Attention Weight Analysis ---")
    model.eval()
    all_weights = []
    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            all_weights.append(weights.squeeze(1).cpu().numpy())
    avg_weights = np.mean(np.concatenate(all_weights, axis=0), axis=0)
    plt.figure(figsize=(15, 6))
    plt.plot(avg_weights, color='purple')
    plt.xlabel('CQCC Time Frame')
    plt.ylabel('Average Attention Weight')
    plt.title('Cross-Attention: Importance of Acoustic Time Frames Guided by Prosody')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAttention plot saved to {save_path}")
    plt.close()

# --- MODIFIED: perform_feature_ablation now returns the sorted features ---
def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    print("\n--- Running Feature Ablation Analysis ---")
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody_clone = prosody.clone()
                    prosody_clone[:, feature_to_ablate] = 0.0
                    logits, _ = model(cqcc, prosody_clone)
                else:
                    logits, _ = model(cqcc, prosody)
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
    
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")
    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()
    return sorted_features


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    print("\n--- Running SHAP Analysis ---")
    model.eval()
    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        with torch.no_grad():
            logits, _ = model(cqcc_tensor, prosody_tensor)
            output = torch.sigmoid(logits)
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    plt.figure() 
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()

# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        X_cqcc_train_full = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)
        X_cqcc_val_full = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)
        
        print("Converting 3D LLD prosodic features to 2D summary statistics (mean)...")
        X_prosody_train_full = np.mean(X_prosody_train_3d, axis=2)
        X_prosody_val_full = np.mean(X_prosody_val_3d, axis=2)
        
        feature_columns = [
            'Loudness_sma3','alphaRatio_sma3','hammarbergIndex_sma3','slope0-500_sma3',
            'slope500-1500_sma3','spectralFlux_sma3','mfcc1_sma3','mfcc2_sma3',
            'mfcc3_sma3','mfcc4_sma3','F0semitoneFrom27.5Hz_sma3nz','jitterLocal_sma3nz',
            'shimmerLocaldB_sma3nz','HNRdBACF_sma3nz','logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz','F1frequency_sma3nz','F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz','F2frequency_sma3nz','F2amplitudeLogRelF0_sma3nz',
            'F3frequency_sma3nz','F3amplitudeLogRelF0_sma3nz'
        ]
        num_prosodic_features = X_prosody_train_full.shape[1]
        if len(feature_columns) != num_prosodic_features:
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("\n--- Scaling Full Feature Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train_full)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val_full)
    scaler_cqcc = StandardScaler()
    ns, nx, ny = X_cqcc_train_full.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train_full.reshape(ns, -1)).reshape(ns, nx, ny)
    nsv, nxv, nyv = X_cqcc_val_full.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val_full.reshape(nsv, -1)).reshape(nsv, nxv, nyv)

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = TransformerAttentionFusionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=X_prosody_train_full.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training (All Features) ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (All Feats)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            logits, _ = model(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        all_preds = (all_scores > 0.5).astype(int)
        val_accuracy = 100 * np.sum(all_preds == all_labels) / len(all_labels)
        f1 = f1_score(all_labels, all_preds)
        eer = calculate_eer(all_labels, all_scores)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f} | EER: {eer:.2f}%")
        history.update({'train_loss': history['train_loss']+[avg_train_loss], 'val_loss': history['val_loss']+[avg_val_loss], 'val_acc': history['val_acc']+[val_accuracy], 'f1': history['f1']+[f1], 'eer': history['eer']+[eer]})
        scheduler.step(avg_val_loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    plot_training_history(history, PLOT_SAVE_PATH, title_prefix="All Features")

    # --- FINAL TESTING AND ANALYSIS (ALL FEATURES) ---
    print("\n--- Starting Final Testing and Analysis (All Features) ---")
    try:
        X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        X_prosody_test_full = np.mean(X_prosody_test_3d, axis=2)
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_full)
        ns_test, nx_test, ny_test = X_cqcc_test_full.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_full.reshape(ns_test, -1)).reshape(ns_test, nx_test, ny_test)
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        analysis_model = TransformerAttentionFusionModel(
            cqcc_feature_dim=X_cqcc_train_full.shape[2],
            prosody_feature_dim=X_prosody_train_full.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()

        all_test_labels, all_test_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in tqdm(test_loader, desc="Final Testing (All Feats)"):
                cqcc, prosody = cqcc.to(DEVICE), prosody.to(DEVICE)
                logits, _ = analysis_model(cqcc, prosody)
                all_test_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_test_labels.extend(labels.cpu().numpy())
        
        all_test_labels, all_test_scores = np.array(all_test_labels), np.array(all_test_scores).flatten()
        test_eer = calculate_eer(all_test_labels, all_test_scores)
        print(f"\n--- Final Test Results (All Features) --- | EER: {test_eer:.2f}%")

        analyze_attention_weights(analysis_model, test_loader, DEVICE, ATTENTION_PLOT_PATH)
        # --- MODIFIED: Capture the sorted features from the ablation analysis ---
        sorted_features = perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except Exception as e:
        print(f"Error during testing/analysis: {e}")

    # ==============================================================================
    # --- NEW: RETRAINING WITH TOP 6 FEATURES ---
    # ==============================================================================
    print("\n\n--- Starting Retraining with Top 6 Features ---")
    
    # 1. Get top 6 feature names and their original indices
    top_6_feature_names = [item[0] for item in sorted_features[:6]]
    top_6_indices = [feature_columns.index(name) for name in top_6_feature_names]
    print("Top 6 features selected for retraining:", top_6_feature_names)

    # 2. Filter the original prosody datasets
    X_prosody_train_top6 = X_prosody_train_full[:, top_6_indices]
    X_prosody_val_top6 = X_prosody_val_full[:, top_6_indices]
    X_prosody_test_top6 = X_prosody_test_full[:, top_6_indices]

    # 3. Create and fit a new scaler for the top 6 features
    scaler_prosody_top6 = StandardScaler()
    X_prosody_train_scaled_top6 = scaler_prosody_top6.fit_transform(X_prosody_train_top6)
    X_prosody_val_scaled_top6 = scaler_prosody_top6.transform(X_prosody_val_top6)
    
    # 4. Create new datasets and dataloaders
    train_dataset_top6 = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled_top6, y_train)
    val_dataset_top6 = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled_top6, y_val)
    train_loader_top6 = DataLoader(train_dataset_top6, batch_size=BATCH_SIZE, shuffle=True)
    val_loader_top6 = DataLoader(val_dataset_top6, batch_size=BATCH_SIZE, shuffle=False)

    # 5. Create a new model instance with prosody_feature_dim=6
    model_top6 = TransformerAttentionFusionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=6  # <-- Key change: only 6 features
    ).to(DEVICE)

    optimizer_top6 = optim.Adam(model_top6.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler_top6 = optim.lr_scheduler.ReduceLROnPlateau(optimizer_top6, 'min', factor=0.2, patience=5, verbose=True)

    print(model_top6)
    best_val_loss_top6 = float('inf')
    history_top6 = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training (Top 6 Features) ---")

    for epoch in range(EPOCHS):
        model_top6.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader_top6, desc=f"Epoch {epoch+1}/{EPOCHS} (Top 6)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer_top6.zero_grad()
            logits, _ = model_top6(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer_top6.step()
            running_loss += loss.item()

        model_top6.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader_top6:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model_top6(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader_top6)
        avg_val_loss = val_loss / len(val_loader_top6)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        eer = calculate_eer(all_labels, all_scores)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | EER: {eer:.2f}%")
        history_top6.update({'train_loss': history_top6['train_loss']+[avg_train_loss], 'val_loss': history_top6['val_loss']+[avg_val_loss], 'eer': history_top6['eer']+[eer]})
        scheduler_top6.step(avg_val_loss)

        if avg_val_loss < best_val_loss_top6:
            best_val_loss_top6 = avg_val_loss
            torch.save(model_top6.state_dict(), MODEL_SAVE_PATH_TOP6)
            print(f"   -> Val loss decreased. New best model (Top 6) saved to {MODEL_SAVE_PATH_TOP6}")

    plot_training_history(history_top6, PLOT_SAVE_PATH_TOP6, title_prefix="Top 6 Features")

    # --- FINAL TESTING (TOP 6 FEATURES) ---
    print("\n--- Starting Final Testing (Top 6 Features) ---")
    X_prosody_test_scaled_top6 = scaler_prosody_top6.transform(X_prosody_test_top6)
    test_dataset_top6 = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled_top6, y_test)
    test_loader_top6 = DataLoader(test_dataset_top6, batch_size=BATCH_SIZE, shuffle=False)
    
    model_top6.load_state_dict(torch.load(MODEL_SAVE_PATH_TOP6))
    model_top6.eval()

    all_test_labels, all_test_scores = [], []
    with torch.no_grad():
        for cqcc, prosody, labels in tqdm(test_loader_top6, desc="Final Testing (Top 6)"):
            cqcc, prosody = cqcc.to(DEVICE), prosody.to(DEVICE)
            logits, _ = model_top6(cqcc, prosody)
            all_test_scores.extend(torch.sigmoid(logits).cpu().numpy())
            all_test_labels.extend(labels.cpu().numpy())
    
    all_test_labels, all_test_scores = np.array(all_test_labels), np.array(all_test_scores).flatten()
    test_eer_top6 = calculate_eer(all_test_labels, all_test_scores)
    print(f"\n--- Final Test Results (Top 6 Features) --- | EER: {test_eer_top6:.2f}%")
    print("\n--- Experiment Complete ---")



Using device: cuda
--- Loading and Preparing Data ---
Converting 3D LLD prosodic features to 2D summary statistics (mean)...

--- Scaling Full Feature Data ---


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


TransformerAttentionFusionModel(
  (cqcc_projection): Linear(in_features=157, out_features=128, bias=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (prosody_mlp): Sequential(
    (0): Linear(in_features=23, out_features=256, bias=True)


Epoch 1/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 100.34it/s]



Epoch 1 | Train Loss: 0.3354 | Val Loss: 0.2303 | Val Acc: 87.44% | F1: 0.5833 | EER: 12.84%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 2/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 80.40it/s]



Epoch 2 | Train Loss: 0.1853 | Val Loss: 0.2121 | Val Acc: 89.45% | F1: 0.6206 | EER: 11.85%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 3/40 (All Feats): 100%|██████████| 720/720 [00:06<00:00, 110.05it/s]



Epoch 3 | Train Loss: 0.1477 | Val Loss: 0.1645 | Val Acc: 92.51% | F1: 0.6684 | EER: 11.19%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 4/40 (All Feats): 100%|██████████| 720/720 [00:09<00:00, 79.19it/s] 



Epoch 4 | Train Loss: 0.1221 | Val Loss: 0.1444 | Val Acc: 93.48% | F1: 0.7022 | EER: 9.62%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 5/40 (All Feats): 100%|██████████| 720/720 [00:06<00:00, 116.47it/s]



Epoch 5 | Train Loss: 0.1069 | Val Loss: 0.1337 | Val Acc: 94.54% | F1: 0.7138 | EER: 9.85%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 6/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 94.22it/s] 



Epoch 6 | Train Loss: 0.0917 | Val Loss: 0.1331 | Val Acc: 94.80% | F1: 0.7228 | EER: 9.35%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 7/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 86.46it/s] 



Epoch 7 | Train Loss: 0.0779 | Val Loss: 0.1312 | Val Acc: 94.52% | F1: 0.7434 | EER: 8.83%
   -> Val loss decreased. New best model saved to saved_models/TransformerFusion_PyTorch_Best_23feat.pth


Epoch 8/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 102.24it/s]



Epoch 8 | Train Loss: 0.0675 | Val Loss: 0.1419 | Val Acc: 95.25% | F1: 0.7390 | EER: 8.99%


Epoch 9/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 88.43it/s] 



Epoch 9 | Train Loss: 0.0608 | Val Loss: 0.1544 | Val Acc: 95.20% | F1: 0.7323 | EER: 9.50%


Epoch 10/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 89.49it/s] 



Epoch 10 | Train Loss: 0.0544 | Val Loss: 0.1459 | Val Acc: 95.27% | F1: 0.7373 | EER: 8.63%


Epoch 11/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 95.38it/s] 



Epoch 11 | Train Loss: 0.0515 | Val Loss: 0.2168 | Val Acc: 94.65% | F1: 0.6612 | EER: 8.63%


Epoch 12/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 86.43it/s] 



Epoch 12 | Train Loss: 0.0482 | Val Loss: 0.1741 | Val Acc: 95.18% | F1: 0.7235 | EER: 8.56%


Epoch 13/40 (All Feats): 100%|██████████| 720/720 [00:05<00:00, 139.74it/s]



Epoch 13 | Train Loss: 0.0394 | Val Loss: 0.1497 | Val Acc: 95.66% | F1: 0.7639 | EER: 8.44%


Epoch 14/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 82.01it/s] 



Epoch 14 | Train Loss: 0.0280 | Val Loss: 0.1837 | Val Acc: 95.43% | F1: 0.7385 | EER: 8.44%


Epoch 15/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 90.03it/s] 



Epoch 15 | Train Loss: 0.0263 | Val Loss: 0.1898 | Val Acc: 95.45% | F1: 0.7410 | EER: 8.28%


Epoch 16/40 (All Feats): 100%|██████████| 720/720 [00:06<00:00, 111.59it/s]



Epoch 16 | Train Loss: 0.0256 | Val Loss: 0.2245 | Val Acc: 95.19% | F1: 0.7099 | EER: 8.36%


Epoch 17/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 85.55it/s] 



Epoch 17 | Train Loss: 0.0244 | Val Loss: 0.1995 | Val Acc: 95.48% | F1: 0.7427 | EER: 8.56%


Epoch 18/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 90.00it/s] 



Epoch 18 | Train Loss: 0.0222 | Val Loss: 0.2146 | Val Acc: 95.30% | F1: 0.7270 | EER: 8.32%


Epoch 19/40 (All Feats): 100%|██████████| 720/720 [00:06<00:00, 106.29it/s]



Epoch 19 | Train Loss: 0.0218 | Val Loss: 0.2189 | Val Acc: 95.32% | F1: 0.7244 | EER: 8.46%


Epoch 20/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 81.43it/s] 



Epoch 20 | Train Loss: 0.0194 | Val Loss: 0.2075 | Val Acc: 95.54% | F1: 0.7424 | EER: 8.44%


Epoch 21/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 70.76it/s] 



Epoch 21 | Train Loss: 0.0194 | Val Loss: 0.2128 | Val Acc: 95.49% | F1: 0.7387 | EER: 8.36%


Epoch 22/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 99.54it/s] 



Epoch 22 | Train Loss: 0.0196 | Val Loss: 0.2250 | Val Acc: 95.38% | F1: 0.7284 | EER: 8.34%


Epoch 23/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 98.37it/s] 



Epoch 23 | Train Loss: 0.0193 | Val Loss: 0.2252 | Val Acc: 95.36% | F1: 0.7281 | EER: 8.38%


Epoch 24/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 82.37it/s] 



Epoch 24 | Train Loss: 0.0197 | Val Loss: 0.2152 | Val Acc: 95.54% | F1: 0.7426 | EER: 8.24%


Epoch 25/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 91.09it/s] 



Epoch 25 | Train Loss: 0.0176 | Val Loss: 0.2183 | Val Acc: 95.47% | F1: 0.7384 | EER: 8.25%


Epoch 26/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 83.12it/s] 



Epoch 26 | Train Loss: 0.0171 | Val Loss: 0.2171 | Val Acc: 95.49% | F1: 0.7400 | EER: 8.25%


Epoch 27/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.61it/s]



Epoch 27 | Train Loss: 0.0178 | Val Loss: 0.2221 | Val Acc: 95.46% | F1: 0.7369 | EER: 8.25%


Epoch 28/40 (All Feats): 100%|██████████| 720/720 [00:09<00:00, 72.35it/s]



Epoch 28 | Train Loss: 0.0179 | Val Loss: 0.2202 | Val Acc: 95.47% | F1: 0.7378 | EER: 8.24%


Epoch 29/40 (All Feats): 100%|██████████| 720/720 [00:09<00:00, 76.21it/s] 



Epoch 29 | Train Loss: 0.0161 | Val Loss: 0.2186 | Val Acc: 95.48% | F1: 0.7394 | EER: 8.30%


Epoch 30/40 (All Feats): 100%|██████████| 720/720 [00:09<00:00, 75.92it/s] 



Epoch 30 | Train Loss: 0.0167 | Val Loss: 0.2172 | Val Acc: 95.51% | F1: 0.7416 | EER: 8.28%


Epoch 31/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 71.56it/s] 



Epoch 31 | Train Loss: 0.0176 | Val Loss: 0.2195 | Val Acc: 95.48% | F1: 0.7399 | EER: 8.28%


Epoch 32/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.91it/s]



Epoch 32 | Train Loss: 0.0173 | Val Loss: 0.2217 | Val Acc: 95.46% | F1: 0.7371 | EER: 8.28%


Epoch 33/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 80.77it/s]



Epoch 33 | Train Loss: 0.0177 | Val Loss: 0.2215 | Val Acc: 95.46% | F1: 0.7374 | EER: 8.29%


Epoch 34/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 87.44it/s] 



Epoch 34 | Train Loss: 0.0178 | Val Loss: 0.2232 | Val Acc: 95.46% | F1: 0.7371 | EER: 8.28%


Epoch 35/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 86.83it/s] 



Epoch 35 | Train Loss: 0.0171 | Val Loss: 0.2225 | Val Acc: 95.46% | F1: 0.7371 | EER: 8.28%


Epoch 36/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 85.93it/s] 



Epoch 36 | Train Loss: 0.0172 | Val Loss: 0.2217 | Val Acc: 95.48% | F1: 0.7385 | EER: 8.27%


Epoch 37/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 99.44it/s] 



Epoch 37 | Train Loss: 0.0174 | Val Loss: 0.2236 | Val Acc: 95.46% | F1: 0.7370 | EER: 8.28%


Epoch 38/40 (All Feats): 100%|██████████| 720/720 [00:07<00:00, 91.22it/s] 



Epoch 38 | Train Loss: 0.0184 | Val Loss: 0.2238 | Val Acc: 95.45% | F1: 0.7361 | EER: 8.27%


Epoch 39/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 85.26it/s] 



Epoch 39 | Train Loss: 0.0165 | Val Loss: 0.2236 | Val Acc: 95.46% | F1: 0.7368 | EER: 8.27%


Epoch 40/40 (All Feats): 100%|██████████| 720/720 [00:08<00:00, 84.19it/s] 



Epoch 40 | Train Loss: 0.0169 | Val Loss: 0.2237 | Val Acc: 95.46% | F1: 0.7370 | EER: 8.28%

Training plot saved to saved_models/training_metrics_transformer_23feat.png

--- Starting Final Testing and Analysis (All Features) ---


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Final Testing (All Feats): 100%|██████████| 1114/1114 [00:02<00:00, 438.31it/s]



--- Final Test Results (All Features) --- | EER: 14.10%

--- Running Cross-Attention Weight Analysis ---


Analyzing Attention: 100%|██████████| 1114/1114 [00:02<00:00, 444.43it/s]



Attention plot saved to saved_models/attention_importance_transformer_23feat.png

--- Running Feature Ablation Analysis ---
Baseline EER with all features: 14.10%


Performing Ablation: 100%|██████████| 23/23 [01:03<00:00,  2.77s/it]



Feature Importance based on EER Increase:
- slope500-1500_sma3: EER increases by 0.79%
- F3frequency_sma3nz: EER increases by 0.53%
- F1frequency_sma3nz: EER increases by 0.53%
- mfcc1_sma3: EER increases by 0.42%
- slope0-500_sma3: EER increases by 0.42%
- HNRdBACF_sma3nz: EER increases by 0.40%
- F0semitoneFrom27.5Hz_sma3nz: EER increases by 0.40%
- spectralFlux_sma3: EER increases by 0.33%
- Loudness_sma3: EER increases by 0.24%
- logRelF0-H1-H2_sma3nz: EER increases by 0.22%
- F2amplitudeLogRelF0_sma3nz: EER increases by 0.22%
- F3amplitudeLogRelF0_sma3nz: EER increases by 0.20%
- F2frequency_sma3nz: EER increases by 0.16%
- mfcc4_sma3: EER increases by 0.16%
- F1amplitudeLogRelF0_sma3nz: EER increases by 0.16%
- mfcc3_sma3: EER increases by 0.14%
- mfcc2_sma3: EER increases by 0.11%
- jitterLocal_sma3nz: EER increases by 0.09%
- shimmerLocaldB_sma3nz: EER increases by 0.08%
- hammarbergIndex_sma3: EER increases by 0.03%
- alphaRatio_sma3: EER increases by -0.00%
- logRelF0-H1-A3_

  0%|          | 0/64 [00:00<?, ?it/s]

Plotting SHAP summary...


No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored
The figure layout has changed to tight


SHAP plot saved to saved_models/shap_importance_transformer_23feat.png


--- Starting Retraining with Top 6 Features ---
Top 6 features selected for retraining: ['slope500-1500_sma3', 'F3frequency_sma3nz', 'F1frequency_sma3nz', 'mfcc1_sma3', 'slope0-500_sma3', 'HNRdBACF_sma3nz']


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


TransformerAttentionFusionModel(
  (cqcc_projection): Linear(in_features=157, out_features=128, bias=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (prosody_mlp): Sequential(
    (0): Linear(in_features=6, out_features=256, bias=True)
 

Epoch 1/40 (Top 6): 100%|██████████| 720/720 [00:08<00:00, 82.63it/s] 



Epoch 1 | Train Loss: 0.3394 | Val Loss: 0.2290 | EER: 12.37%
   -> Val loss decreased. New best model (Top 6) saved to saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth


Epoch 2/40 (Top 6): 100%|██████████| 720/720 [00:09<00:00, 76.52it/s] 



Epoch 2 | Train Loss: 0.1836 | Val Loss: 0.2199 | EER: 11.07%
   -> Val loss decreased. New best model (Top 6) saved to saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth


Epoch 3/40 (Top 6): 100%|██████████| 720/720 [00:07<00:00, 96.09it/s]



Epoch 3 | Train Loss: 0.1510 | Val Loss: 0.1721 | EER: 10.60%
   -> Val loss decreased. New best model (Top 6) saved to saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth


Epoch 4/40 (Top 6): 100%|██████████| 720/720 [00:08<00:00, 80.83it/s]



Epoch 4 | Train Loss: 0.1280 | Val Loss: 0.1525 | EER: 10.28%
   -> Val loss decreased. New best model (Top 6) saved to saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth


Epoch 5/40 (Top 6): 100%|██████████| 720/720 [00:08<00:00, 83.15it/s]



Epoch 5 | Train Loss: 0.1143 | Val Loss: 0.1463 | EER: 10.68%
   -> Val loss decreased. New best model (Top 6) saved to saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth


Epoch 6/40 (Top 6): 100%|██████████| 720/720 [00:07<00:00, 95.59it/s] 



Epoch 6 | Train Loss: 0.1001 | Val Loss: 0.1407 | EER: 9.80%
   -> Val loss decreased. New best model (Top 6) saved to saved_models/TransformerFusion_PyTorch_Best_Top6feat.pth


Epoch 7/40 (Top 6): 100%|██████████| 720/720 [00:09<00:00, 72.13it/s]



Epoch 7 | Train Loss: 0.0912 | Val Loss: 0.1450 | EER: 9.85%


Epoch 8/40 (Top 6): 100%|██████████| 720/720 [00:09<00:00, 74.11it/s]



Epoch 8 | Train Loss: 0.0818 | Val Loss: 0.1747 | EER: 11.57%


Epoch 9/40 (Top 6): 100%|██████████| 720/720 [00:10<00:00, 71.13it/s]



Epoch 9 | Train Loss: 0.0740 | Val Loss: 0.1643 | EER: 9.65%


Epoch 10/40 (Top 6): 100%|██████████| 720/720 [00:09<00:00, 74.47it/s]



Epoch 10 | Train Loss: 0.0672 | Val Loss: 0.1775 | EER: 9.50%


Epoch 11/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.01it/s]



Epoch 11 | Train Loss: 0.0587 | Val Loss: 0.1609 | EER: 9.32%


Epoch 12/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.50it/s]



Epoch 12 | Train Loss: 0.0547 | Val Loss: 0.1783 | EER: 9.22%


Epoch 13/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 110.98it/s]



Epoch 13 | Train Loss: 0.0414 | Val Loss: 0.1908 | EER: 9.38%


Epoch 14/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.56it/s]



Epoch 14 | Train Loss: 0.0362 | Val Loss: 0.2059 | EER: 9.12%


Epoch 15/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.38it/s]



Epoch 15 | Train Loss: 0.0342 | Val Loss: 0.2038 | EER: 9.46%


Epoch 16/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.14it/s]



Epoch 16 | Train Loss: 0.0324 | Val Loss: 0.2145 | EER: 9.46%


Epoch 17/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.57it/s]



Epoch 17 | Train Loss: 0.0333 | Val Loss: 0.2076 | EER: 9.28%


Epoch 18/40 (Top 6): 100%|██████████| 720/720 [00:08<00:00, 87.94it/s] 



Epoch 18 | Train Loss: 0.0299 | Val Loss: 0.2083 | EER: 9.26%


Epoch 19/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.98it/s]



Epoch 19 | Train Loss: 0.0277 | Val Loss: 0.2114 | EER: 9.30%


Epoch 20/40 (Top 6): 100%|██████████| 720/720 [00:05<00:00, 124.15it/s]



Epoch 20 | Train Loss: 0.0281 | Val Loss: 0.2162 | EER: 9.22%


Epoch 21/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.42it/s]



Epoch 21 | Train Loss: 0.0262 | Val Loss: 0.2285 | EER: 9.34%


Epoch 22/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.07it/s]



Epoch 22 | Train Loss: 0.0280 | Val Loss: 0.2251 | EER: 9.27%


Epoch 23/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 113.61it/s]



Epoch 23 | Train Loss: 0.0272 | Val Loss: 0.2222 | EER: 9.50%


Epoch 24/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.23it/s]



Epoch 24 | Train Loss: 0.0262 | Val Loss: 0.2221 | EER: 9.32%


Epoch 25/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.93it/s]



Epoch 25 | Train Loss: 0.0262 | Val Loss: 0.2243 | EER: 9.26%


Epoch 26/40 (Top 6): 100%|██████████| 720/720 [00:05<00:00, 130.90it/s]



Epoch 26 | Train Loss: 0.0241 | Val Loss: 0.2272 | EER: 9.22%


Epoch 27/40 (Top 6): 100%|██████████| 720/720 [00:05<00:00, 131.26it/s]



Epoch 27 | Train Loss: 0.0244 | Val Loss: 0.2305 | EER: 9.16%


Epoch 28/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 115.90it/s]



Epoch 28 | Train Loss: 0.0252 | Val Loss: 0.2282 | EER: 9.26%


Epoch 29/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.08it/s]



Epoch 29 | Train Loss: 0.0247 | Val Loss: 0.2277 | EER: 9.26%


Epoch 30/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.11it/s]



Epoch 30 | Train Loss: 0.0242 | Val Loss: 0.2267 | EER: 9.31%


Epoch 31/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.30it/s]



Epoch 31 | Train Loss: 0.0247 | Val Loss: 0.2289 | EER: 9.29%


Epoch 32/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.67it/s]



Epoch 32 | Train Loss: 0.0240 | Val Loss: 0.2282 | EER: 9.31%


Epoch 33/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 113.05it/s]



Epoch 33 | Train Loss: 0.0244 | Val Loss: 0.2283 | EER: 9.29%


Epoch 34/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.24it/s]



Epoch 34 | Train Loss: 0.0251 | Val Loss: 0.2289 | EER: 9.26%


Epoch 35/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.56it/s]



Epoch 35 | Train Loss: 0.0244 | Val Loss: 0.2312 | EER: 9.26%


Epoch 36/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 109.46it/s]



Epoch 36 | Train Loss: 0.0249 | Val Loss: 0.2302 | EER: 9.30%


Epoch 37/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 106.97it/s]



Epoch 37 | Train Loss: 0.0249 | Val Loss: 0.2300 | EER: 9.30%


Epoch 38/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.90it/s]



Epoch 38 | Train Loss: 0.0250 | Val Loss: 0.2294 | EER: 9.28%


Epoch 39/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 107.16it/s]



Epoch 39 | Train Loss: 0.0250 | Val Loss: 0.2292 | EER: 9.30%


Epoch 40/40 (Top 6): 100%|██████████| 720/720 [00:06<00:00, 108.19it/s]



Epoch 40 | Train Loss: 0.0230 | Val Loss: 0.2290 | EER: 9.30%

Training plot saved to saved_models/training_metrics_transformer_Top6feat.png

--- Starting Final Testing (Top 6 Features) ---


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Final Testing (Top 6): 100%|██████████| 1114/1114 [00:02<00:00, 396.40it/s]


--- Final Test Results (Top 6 Features) --- | EER: 15.05%

--- Experiment Complete ---





In [34]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import shap
import math
import torch.nn.functional as F

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/CNNAttentionFusion_PyTorch_Best_23feat.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics_cnn_attention_23feat.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance_cnn_attention_23feat.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance_cnn_attention_23feat.png"
SHAP_PLOT_PATH = "saved_models/shap_importance_cnn_attention_23feat.png"
MODEL_SAVE_PATH_TOP10 = "saved_models/CNNAttentionFusion_PyTorch_Best_Top10feat.pth"
PLOT_SAVE_PATH_TOP10= "saved_models/training_metrics_cnn_attention_Top10feat.png"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path, title_prefix=""):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))
    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')
    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')
    fig.tight_layout()
    plt.title(f'{title_prefix} Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]


# --- NEW: CNN + Attention Fusion Model ---
class CNNAttentionFusionModel(nn.Module):
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, cnn_out_channels=128, mlp_out_dim=128):
        super(CNNAttentionFusionModel, self).__init__()
        
        # CNN branch for processing CQCC features. This will produce the Key and Value for attention.
        self.cnn_branch = nn.Sequential(
            nn.Conv1d(in_channels=cqcc_feature_dim, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=cnn_out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(cnn_out_channels),
        )
        
        # MLP branch for processing prosodic features. This will produce the Query.
        self.prosody_mlp = nn.Sequential(
            nn.Linear(prosody_feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            # The output dimension must match the CNN output channels for attention
            nn.Linear(256, cnn_out_channels), 
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Classifier to combine the attention context and the prosody query
        classifier_input_dim = cnn_out_channels + cnn_out_channels
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # 1. Process CQCC features through the CNN branch
        # Permute from (batch, seq_len, features) to (batch, features, seq_len)
        cqcc_x_permuted = cqcc_x.permute(0, 2, 1)
        cnn_out = self.cnn_branch(cqcc_x_permuted) # Shape: (batch, cnn_out_channels, seq_len)
        
        # 2. Define Key and Value from CNN output for attention
        # Permute to (batch, seq_len, cnn_out_channels) for attention calculation
        keys = values = cnn_out.permute(0, 2, 1)
        
        # 3. Process prosodic features through the MLP to get the Query
        prosody_query = self.prosody_mlp(prosody_x) # Shape: (batch, cnn_out_channels)

        # 4. Perform Cross-Attention
        query_unsqueezed = prosody_query.unsqueeze(1) # Shape: (batch, 1, cnn_out_channels)
        
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2)) / (keys.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1) # Shape: (batch, 1, seq_len)
        
        context = torch.bmm(attention_weights, values).squeeze(1) # Shape: (batch, cnn_out_channels)
        
        # 5. Fuse the attention context with the prosody query and classify
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        
        # Return logits for loss and attention_weights for analysis
        return logits, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, device, save_path):
    print("\n--- Running Cross-Attention Weight Analysis ---")
    model.eval()
    all_weights = []
    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            all_weights.append(weights.squeeze(1).cpu().numpy())
    
    avg_weights = np.mean(np.concatenate(all_weights, axis=0), axis=0)
    
    plt.figure(figsize=(15, 6))
    plt.plot(avg_weights, color='purple')
    plt.xlabel('CQCC Time Frame (after CNN processing)')
    plt.ylabel('Average Attention Weight')
    plt.title('Cross-Attention: Importance of Acoustic Time Frames Guided by Prosody')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAttention plot saved to {save_path}")
    plt.close()

def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    print("\n--- Running Feature Ablation Analysis ---")
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody_clone = prosody.clone()
                    prosody_clone[:, feature_to_ablate] = 0.0
                    logits, _ = model(cqcc, prosody_clone)
                else:
                    logits, _ = model(cqcc, prosody)
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
    
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")
    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()
    return sorted_features


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    print("\n--- Running SHAP Analysis ---")
    model.eval()
    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        with torch.no_grad():
            logits, _ = model(cqcc_tensor, prosody_tensor)
            output = torch.sigmoid(logits)
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    plt.figure() 
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()

# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        X_cqcc_train_full = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)
        X_cqcc_val_full = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)
        
        print("Converting 3D LLD prosodic features to 2D summary statistics (mean)...")
        X_prosody_train_full = np.mean(X_prosody_train_3d, axis=2)
        X_prosody_val_full = np.mean(X_prosody_val_3d, axis=2)
        
        feature_columns = [
            'Loudness_sma3','alphaRatio_sma3','hammarbergIndex_sma3','slope0-500_sma3',
            'slope500-1500_sma3','spectralFlux_sma3','mfcc1_sma3','mfcc2_sma3',
            'mfcc3_sma3','mfcc4_sma3','F0semitoneFrom27.5Hz_sma3nz','jitterLocal_sma3nz',
            'shimmerLocaldB_sma3nz','HNRdBACF_sma3nz','logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz','F1frequency_sma3nz','F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz','F2frequency_sma3nz','F2amplitudeLogRelF0_sma3nz',
            'F3frequency_sma3nz','F3amplitudeLogRelF0_sma3nz'
        ]
        num_prosodic_features = X_prosody_train_full.shape[1]
        if len(feature_columns) != num_prosodic_features:
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("\n--- Scaling Full Feature Data ---")
    scaler_prosody = StandardScaler()
    X_prosody_train_scaled = scaler_prosody.fit_transform(X_prosody_train_full)
    X_prosody_val_scaled = scaler_prosody.transform(X_prosody_val_full)
    scaler_cqcc = StandardScaler()
    ns, nx, ny = X_cqcc_train_full.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train_full.reshape(ns, -1)).reshape(ns, nx, ny)
    nsv, nxv, nyv = X_cqcc_val_full.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val_full.reshape(nsv, -1)).reshape(nsv, nxv, nyv)

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = CNNAttentionFusionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=X_prosody_train_full.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training (All Features) ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (All Feats)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            logits, _ = model(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        eer = calculate_eer(all_labels, all_scores)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | EER: {eer:.2f}%")
        history.update({'train_loss': history['train_loss']+[avg_train_loss], 'val_loss': history['val_loss']+[avg_val_loss], 'eer': history['eer']+[eer]})
        scheduler.step(avg_val_loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    plot_training_history(history, PLOT_SAVE_PATH, title_prefix="All Features (CNN-Attention)")

    # --- FINAL TESTING AND ANALYSIS (ALL FEATURES) ---
    print("\n--- Starting Final Testing and Analysis (All Features) ---")
    try:
        X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        X_prosody_test_full = np.mean(X_prosody_test_3d, axis=2)
        X_prosody_test_scaled = scaler_prosody.transform(X_prosody_test_full)
        ns_test, nx_test, ny_test = X_cqcc_test_full.shape
        X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_full.reshape(ns_test, -1)).reshape(ns_test, nx_test, ny_test)
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        analysis_model = CNNAttentionFusionModel(
            cqcc_feature_dim=X_cqcc_train_full.shape[2],
            prosody_feature_dim=X_prosody_train_full.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()
        
        analyze_attention_weights(analysis_model, test_loader, DEVICE, ATTENTION_PLOT_PATH)
        sorted_features = perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except Exception as e:
        print(f"Error during testing/analysis: {e}")

    # ==============================================================================
    # --- NEW: RETRAINING WITH TOP 10 FEATURES ---
    # ==============================================================================
    print("\n\n--- Starting Retraining with Top 6 Features ---")
    
    top_10_feature_names = [item[0] for item in sorted_features[:10]]
    top_10_indices = [feature_columns.index(name) for name in top_10_feature_names]
    print("Top 10 features selected for retraining:", top_10_feature_names)

    X_prosody_train_top10 = X_prosody_train_full[:, top_10_indices]
    X_prosody_val_top10 = X_prosody_val_full[:, top_10_indices]
    X_prosody_test_top10 = X_prosody_test_full[:, top_10_indices]

    scaler_prosody_top10 = StandardScaler()
    X_prosody_train_scaled_top10 = scaler_prosody_top10.fit_transform(X_prosody_train_top10)
    X_prosody_val_scaled_top10 = scaler_prosody_top10.transform(X_prosody_val_top10)
    
    train_dataset_top10 = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_scaled_top10, y_train)
    val_dataset_top10 = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_scaled_top10, y_val)
    train_loader_top10 = DataLoader(train_dataset_top10, batch_size=BATCH_SIZE, shuffle=True)
    val_loader_top10 = DataLoader(val_dataset_top10, batch_size=BATCH_SIZE, shuffle=False)

    model_top10 = CNNAttentionFusionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=10
    ).to(DEVICE)

    optimizer_top10 = optim.Adam(model_top10.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler_top10 = optim.lr_scheduler.ReduceLROnPlateau(optimizer_top10, 'min', factor=0.2, patience=5, verbose=True)

    print(model_top10)
    best_val_loss_top10 = float('inf')
    history_top10 = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training (Top 10 Features) ---")

    for epoch in range(EPOCHS):
        model_top10.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader_top10, desc=f"Epoch {epoch+1}/{EPOCHS} (Top 10)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer_top10.zero_grad()
            logits, _ = model_top10(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer_top10.step()
            running_loss += loss.item()

        model_top10.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader_top10:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model_top10(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader_top10)
        avg_val_loss = val_loss / len(val_loader_top10)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        eer = calculate_eer(all_labels, all_scores)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | EER: {eer:.2f}%")
        history_top10.update({'train_loss': history_top10['train_loss']+[avg_train_loss], 'val_loss': history_top10['val_loss']+[avg_val_loss], 'eer': history_top10['eer']+[eer]})
        scheduler_top10.step(avg_val_loss)

        if avg_val_loss < best_val_loss_top10:
            best_val_loss_top10 = avg_val_loss
            torch.save(model_top10.state_dict(), MODEL_SAVE_PATH_TOP10)
            print(f"   -> Val loss decreased. New best model (Top 10) saved to {MODEL_SAVE_PATH_TOP10}")

    plot_training_history(history_top10, PLOT_SAVE_PATH_TOP10, title_prefix="Top 10 Features (CNN-Attention)")

    # --- FINAL TESTING (TOP 10 FEATURES) ---
    print("\n--- Starting Final Testing (Top 10 Features) ---")
    X_prosody_test_scaled_top10 = scaler_prosody_top10.transform(X_prosody_test_top10)
    test_dataset_top10 = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_scaled_top10, y_test)
    test_loader_top10 = DataLoader(test_dataset_top10, batch_size=BATCH_SIZE, shuffle=False)
    
    model_top10.load_state_dict(torch.load(MODEL_SAVE_PATH_TOP10))
    model_top10.eval()

    all_test_labels, all_test_scores = [], []
    with torch.no_grad():
        for cqcc, prosody, labels in tqdm(test_loader_top10, desc="Final Testing (Top 10)"):
            cqcc, prosody = cqcc.to(DEVICE), prosody.to(DEVICE)
            logits, _ = model_top10(cqcc, prosody)
            all_test_scores.extend(torch.sigmoid(logits).cpu().numpy())
            all_test_labels.extend(labels.cpu().numpy())
    
    all_test_labels, all_test_scores = np.array(all_test_labels), np.array(all_test_scores).flatten()
    test_eer_top10 = calculate_eer(all_test_labels, all_test_scores)
    print(f"\n--- Final Test Results (Top 10 Features) --- | EER: {test_eer_top10:.2f}%")
    print("\n--- Experiment Complete ---")


Using device: cuda
--- Loading and Preparing Data ---
Converting 3D LLD prosodic features to 2D summary statistics (mean)...

--- Scaling Full Feature Data ---


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


CNNAttentionFusionModel(
  (cnn_branch): Sequential(
    (0): Conv1d(157, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (4): ReLU()
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (7): ReLU()
    (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (prosody_mlp): Sequential(
    (0): Linear(in_features=23, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    

Epoch 1/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 172.92it/s]



Epoch 1 | Train Loss: 0.3000 | Val Loss: 0.2450 | EER: 13.05%
   -> Val loss decreased. New best model saved to saved_models/CNNAttentionFusion_PyTorch_Best_23feat.pth


Epoch 2/40 (All Feats): 100%|██████████| 720/720 [00:03<00:00, 204.08it/s]



Epoch 2 | Train Loss: 0.1319 | Val Loss: 0.1583 | EER: 10.31%
   -> Val loss decreased. New best model saved to saved_models/CNNAttentionFusion_PyTorch_Best_23feat.pth


Epoch 3/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 155.70it/s]



Epoch 3 | Train Loss: 0.0833 | Val Loss: 0.1658 | EER: 10.95%


Epoch 4/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 251.55it/s]



Epoch 4 | Train Loss: 0.0526 | Val Loss: 0.1664 | EER: 9.85%


Epoch 5/40 (All Feats): 100%|██████████| 720/720 [00:03<00:00, 224.78it/s]



Epoch 5 | Train Loss: 0.0342 | Val Loss: 0.1857 | EER: 9.87%


Epoch 6/40 (All Feats): 100%|██████████| 720/720 [00:03<00:00, 214.71it/s]



Epoch 6 | Train Loss: 0.0263 | Val Loss: 0.2113 | EER: 9.69%


Epoch 7/40 (All Feats): 100%|██████████| 720/720 [00:03<00:00, 207.56it/s]



Epoch 7 | Train Loss: 0.0189 | Val Loss: 0.3263 | EER: 10.52%


Epoch 8/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 243.09it/s]



Epoch 8 | Train Loss: 0.0191 | Val Loss: 0.2012 | EER: 9.27%


Epoch 9/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 258.21it/s]



Epoch 9 | Train Loss: 0.0079 | Val Loss: 0.3043 | EER: 10.05%


Epoch 10/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.93it/s]



Epoch 10 | Train Loss: 0.0044 | Val Loss: 0.3103 | EER: 9.84%


Epoch 11/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 255.71it/s]



Epoch 11 | Train Loss: 0.0036 | Val Loss: 0.3603 | EER: 10.52%


Epoch 12/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 261.29it/s]



Epoch 12 | Train Loss: 0.0032 | Val Loss: 0.4316 | EER: 10.45%


Epoch 13/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 258.11it/s]



Epoch 13 | Train Loss: 0.0036 | Val Loss: 0.3933 | EER: 9.83%


Epoch 14/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 256.72it/s]



Epoch 14 | Train Loss: 0.0031 | Val Loss: 0.4045 | EER: 10.52%


Epoch 15/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 273.01it/s]



Epoch 15 | Train Loss: 0.0027 | Val Loss: 0.4046 | EER: 10.22%


Epoch 16/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 277.30it/s]



Epoch 16 | Train Loss: 0.0023 | Val Loss: 0.4143 | EER: 10.05%


Epoch 17/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 269.86it/s]



Epoch 17 | Train Loss: 0.0017 | Val Loss: 0.4081 | EER: 9.97%


Epoch 18/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 264.89it/s]



Epoch 18 | Train Loss: 0.0014 | Val Loss: 0.4544 | EER: 10.20%


Epoch 19/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 261.34it/s]



Epoch 19 | Train Loss: 0.0016 | Val Loss: 0.4257 | EER: 10.11%


Epoch 20/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.90it/s]



Epoch 20 | Train Loss: 0.0010 | Val Loss: 0.4910 | EER: 10.80%


Epoch 21/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 261.43it/s]



Epoch 21 | Train Loss: 0.0013 | Val Loss: 0.4992 | EER: 10.83%


Epoch 22/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 258.50it/s]



Epoch 22 | Train Loss: 0.0012 | Val Loss: 0.4504 | EER: 10.52%


Epoch 23/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.12it/s]



Epoch 23 | Train Loss: 0.0009 | Val Loss: 0.4679 | EER: 10.68%


Epoch 24/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 257.70it/s]



Epoch 24 | Train Loss: 0.0013 | Val Loss: 0.4529 | EER: 10.64%


Epoch 25/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.02it/s]



Epoch 25 | Train Loss: 0.0009 | Val Loss: 0.5117 | EER: 10.52%


Epoch 26/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 257.81it/s]



Epoch 26 | Train Loss: 0.0014 | Val Loss: 0.4995 | EER: 10.60%


Epoch 27/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.62it/s]



Epoch 27 | Train Loss: 0.0008 | Val Loss: 0.4825 | EER: 10.75%


Epoch 28/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.10it/s]



Epoch 28 | Train Loss: 0.0012 | Val Loss: 0.4647 | EER: 10.48%


Epoch 29/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 257.68it/s]



Epoch 29 | Train Loss: 0.0012 | Val Loss: 0.4828 | EER: 10.33%


Epoch 30/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 257.44it/s]



Epoch 30 | Train Loss: 0.0009 | Val Loss: 0.4743 | EER: 10.40%


Epoch 31/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.85it/s]



Epoch 31 | Train Loss: 0.0010 | Val Loss: 0.4736 | EER: 10.48%


Epoch 32/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 170.77it/s]



Epoch 32 | Train Loss: 0.0008 | Val Loss: 0.4856 | EER: 10.48%


Epoch 33/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 170.07it/s]



Epoch 33 | Train Loss: 0.0010 | Val Loss: 0.4775 | EER: 10.44%


Epoch 34/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 156.30it/s]



Epoch 34 | Train Loss: 0.0011 | Val Loss: 0.4888 | EER: 10.60%


Epoch 35/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 164.97it/s]



Epoch 35 | Train Loss: 0.0008 | Val Loss: 0.4836 | EER: 10.48%


Epoch 36/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 163.21it/s]



Epoch 36 | Train Loss: 0.0011 | Val Loss: 0.4822 | EER: 10.40%


Epoch 37/40 (All Feats): 100%|██████████| 720/720 [00:04<00:00, 175.65it/s]



Epoch 37 | Train Loss: 0.0009 | Val Loss: 0.4704 | EER: 10.40%


Epoch 38/40 (All Feats): 100%|██████████| 720/720 [00:03<00:00, 206.07it/s]



Epoch 38 | Train Loss: 0.0014 | Val Loss: 0.4696 | EER: 10.38%


Epoch 39/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 260.13it/s]



Epoch 39 | Train Loss: 0.0012 | Val Loss: 0.4801 | EER: 10.54%


Epoch 40/40 (All Feats): 100%|██████████| 720/720 [00:02<00:00, 259.03it/s]



Epoch 40 | Train Loss: 0.0009 | Val Loss: 0.4810 | EER: 10.47%

Training plot saved to saved_models/training_metrics_cnn_attention_23feat.png

--- Starting Final Testing and Analysis (All Features) ---


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



--- Running Cross-Attention Weight Analysis ---


Analyzing Attention: 100%|██████████| 1114/1114 [00:02<00:00, 509.51it/s]



Attention plot saved to saved_models/attention_importance_cnn_attention_23feat.png

--- Running Feature Ablation Analysis ---
Baseline EER with all features: 13.85%


Performing Ablation: 100%|██████████| 23/23 [00:54<00:00,  2.35s/it]



Feature Importance based on EER Increase:
- slope500-1500_sma3: EER increases by 0.64%
- spectralFlux_sma3: EER increases by 0.63%
- hammarbergIndex_sma3: EER increases by 0.61%
- slope0-500_sma3: EER increases by 0.40%
- F1frequency_sma3nz: EER increases by 0.30%
- mfcc1_sma3: EER increases by 0.23%
- F2amplitudeLogRelF0_sma3nz: EER increases by 0.23%
- HNRdBACF_sma3nz: EER increases by 0.20%
- F3amplitudeLogRelF0_sma3nz: EER increases by 0.18%
- mfcc3_sma3: EER increases by 0.12%
- F3frequency_sma3nz: EER increases by 0.12%
- F2frequency_sma3nz: EER increases by 0.12%
- shimmerLocaldB_sma3nz: EER increases by 0.11%
- logRelF0-H1-H2_sma3nz: EER increases by 0.07%
- F0semitoneFrom27.5Hz_sma3nz: EER increases by 0.00%
- Loudness_sma3: EER increases by -0.01%
- logRelF0-H1-A3_sma3nz: EER increases by -0.03%
- jitterLocal_sma3nz: EER increases by -0.05%
- F1amplitudeLogRelF0_sma3nz: EER increases by -0.05%
- mfcc4_sma3: EER increases by -0.08%
- alphaRatio_sma3: EER increases by -0.12%
-

  0%|          | 0/64 [00:00<?, ?it/s]

Plotting SHAP summary...


No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored
The figure layout has changed to tight


SHAP plot saved to saved_models/shap_importance_cnn_attention_23feat.png


--- Starting Retraining with Top 6 Features ---
Top 10 features selected for retraining: ['slope500-1500_sma3', 'spectralFlux_sma3', 'hammarbergIndex_sma3', 'slope0-500_sma3', 'F1frequency_sma3nz', 'mfcc1_sma3', 'F2amplitudeLogRelF0_sma3nz', 'HNRdBACF_sma3nz', 'F3amplitudeLogRelF0_sma3nz', 'mfcc3_sma3']


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


CNNAttentionFusionModel(
  (cnn_branch): Sequential(
    (0): Conv1d(157, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (4): ReLU()
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (7): ReLU()
    (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (prosody_mlp): Sequential(
    (0): Linear(in_features=10, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    

Epoch 1/40 (Top 10): 100%|██████████| 720/720 [00:05<00:00, 136.31it/s]



Epoch 1 | Train Loss: 0.3061 | Val Loss: 0.1963 | EER: 12.91%
   -> Val loss decreased. New best model (Top 10) saved to saved_models/CNNAttentionFusion_PyTorch_Best_Top10feat.pth


Epoch 2/40 (Top 10): 100%|██████████| 720/720 [00:05<00:00, 129.00it/s]



Epoch 2 | Train Loss: 0.1451 | Val Loss: 0.1637 | EER: 12.05%
   -> Val loss decreased. New best model (Top 10) saved to saved_models/CNNAttentionFusion_PyTorch_Best_Top10feat.pth


Epoch 3/40 (Top 10): 100%|██████████| 720/720 [00:05<00:00, 132.99it/s]



Epoch 3 | Train Loss: 0.0875 | Val Loss: 0.1669 | EER: 11.23%


Epoch 4/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 151.31it/s]



Epoch 4 | Train Loss: 0.0569 | Val Loss: 0.2137 | EER: 12.95%


Epoch 5/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 227.42it/s]



Epoch 5 | Train Loss: 0.0396 | Val Loss: 0.2045 | EER: 9.87%


Epoch 6/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 227.60it/s]



Epoch 6 | Train Loss: 0.0263 | Val Loss: 0.2393 | EER: 11.11%


Epoch 7/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 224.09it/s]



Epoch 7 | Train Loss: 0.0220 | Val Loss: 0.1824 | EER: 9.89%


Epoch 8/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 228.99it/s]



Epoch 8 | Train Loss: 0.0190 | Val Loss: 0.2406 | EER: 10.69%


Epoch 9/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 229.17it/s]



Epoch 9 | Train Loss: 0.0087 | Val Loss: 0.2251 | EER: 9.31%


Epoch 10/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 237.84it/s]



Epoch 10 | Train Loss: 0.0057 | Val Loss: 0.2923 | EER: 10.69%


Epoch 11/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 239.58it/s]



Epoch 11 | Train Loss: 0.0046 | Val Loss: 0.3047 | EER: 10.33%


Epoch 12/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 227.59it/s]



Epoch 12 | Train Loss: 0.0037 | Val Loss: 0.3246 | EER: 10.56%


Epoch 13/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 228.55it/s]



Epoch 13 | Train Loss: 0.0031 | Val Loss: 0.2883 | EER: 10.01%


Epoch 14/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 234.71it/s]



Epoch 14 | Train Loss: 0.0032 | Val Loss: 0.3545 | EER: 10.28%


Epoch 15/40 (Top 10): 100%|██████████| 720/720 [00:02<00:00, 240.46it/s]



Epoch 15 | Train Loss: 0.0020 | Val Loss: 0.3459 | EER: 10.44%


Epoch 16/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 229.59it/s]



Epoch 16 | Train Loss: 0.0018 | Val Loss: 0.3898 | EER: 10.38%


Epoch 17/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 234.42it/s]



Epoch 17 | Train Loss: 0.0020 | Val Loss: 0.3756 | EER: 10.19%


Epoch 18/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 232.27it/s]



Epoch 18 | Train Loss: 0.0017 | Val Loss: 0.3731 | EER: 10.37%


Epoch 19/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 232.42it/s]



Epoch 19 | Train Loss: 0.0014 | Val Loss: 0.3814 | EER: 10.46%


Epoch 20/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 233.74it/s]



Epoch 20 | Train Loss: 0.0015 | Val Loss: 0.4044 | EER: 10.43%


Epoch 21/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 230.69it/s]



Epoch 21 | Train Loss: 0.0018 | Val Loss: 0.3961 | EER: 10.32%


Epoch 22/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 232.76it/s]



Epoch 22 | Train Loss: 0.0009 | Val Loss: 0.3938 | EER: 10.41%


Epoch 23/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 168.70it/s]



Epoch 23 | Train Loss: 0.0012 | Val Loss: 0.4236 | EER: 10.17%


Epoch 24/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 170.81it/s]



Epoch 24 | Train Loss: 0.0013 | Val Loss: 0.4041 | EER: 10.52%


Epoch 25/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 177.78it/s]



Epoch 25 | Train Loss: 0.0015 | Val Loss: 0.3989 | EER: 10.56%


Epoch 26/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 165.23it/s]



Epoch 26 | Train Loss: 0.0013 | Val Loss: 0.3929 | EER: 10.52%


Epoch 27/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 179.38it/s]



Epoch 27 | Train Loss: 0.0013 | Val Loss: 0.3982 | EER: 10.68%


Epoch 28/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 165.42it/s]



Epoch 28 | Train Loss: 0.0011 | Val Loss: 0.4012 | EER: 10.44%


Epoch 29/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 170.60it/s]



Epoch 29 | Train Loss: 0.0012 | Val Loss: 0.4055 | EER: 10.34%


Epoch 30/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 171.06it/s]



Epoch 30 | Train Loss: 0.0015 | Val Loss: 0.3876 | EER: 10.56%


Epoch 31/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 168.95it/s]



Epoch 31 | Train Loss: 0.0011 | Val Loss: 0.4126 | EER: 10.45%


Epoch 32/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 168.99it/s]



Epoch 32 | Train Loss: 0.0011 | Val Loss: 0.3751 | EER: 10.44%


Epoch 33/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 167.40it/s]



Epoch 33 | Train Loss: 0.0011 | Val Loss: 0.3856 | EER: 10.52%


Epoch 34/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 181.69it/s]



Epoch 34 | Train Loss: 0.0010 | Val Loss: 0.4126 | EER: 10.34%


Epoch 35/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 233.33it/s]



Epoch 35 | Train Loss: 0.0013 | Val Loss: 0.4070 | EER: 10.36%


Epoch 36/40 (Top 10): 100%|██████████| 720/720 [00:04<00:00, 164.76it/s]



Epoch 36 | Train Loss: 0.0011 | Val Loss: 0.4196 | EER: 10.52%


Epoch 37/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 231.63it/s]



Epoch 37 | Train Loss: 0.0010 | Val Loss: 0.4127 | EER: 10.55%


Epoch 38/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 232.45it/s]



Epoch 38 | Train Loss: 0.0009 | Val Loss: 0.4115 | EER: 10.48%


Epoch 39/40 (Top 10): 100%|██████████| 720/720 [00:02<00:00, 240.60it/s]



Epoch 39 | Train Loss: 0.0011 | Val Loss: 0.4206 | EER: 10.71%


Epoch 40/40 (Top 10): 100%|██████████| 720/720 [00:03<00:00, 234.36it/s]



Epoch 40 | Train Loss: 0.0011 | Val Loss: 0.3762 | EER: 10.52%

Training plot saved to saved_models/training_metrics_cnn_attention_Top10feat.png

--- Starting Final Testing (Top 10 Features) ---


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Final Testing (Top 10): 100%|██████████| 1114/1114 [00:02<00:00, 504.87it/s]



--- Final Test Results (Top 10 Features) --- | EER: 15.55%

--- Experiment Complete ---


In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import shap
import math
import torch.nn.functional as F

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/CNNBiLSTMAttention_PyTorch_Best_23feat.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics_cnn_bilstm_attention_23feat.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance_cnn_bilstm_attention_23feat.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance_cnn_bilstm_attention_23feat.png"
SHAP_PLOT_PATH = "saved_models/shap_importance_cnn_bilstm_attention_23feat.png"
MODEL_SAVE_PATH_TOP6 = "saved_models/CNNBiLSTMAttention_PyTorch_Best_Top6feat.pth"
PLOT_SAVE_PATH_TOP6 = "saved_models/training_metrics_cnn_bilstm_attention_Top6feat.png"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")


def calculate_eer(y_true, y_score):
    """Calculates the Equal Error Rate (EER)."""
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer * 100

def plot_training_history(history, save_path, title_prefix=""):
    """Plots and saves the training history graph."""
    fig, ax1 = plt.subplots(figsize=(12, 8))
    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')
    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel('EER (%)', color=color)
    ax2.plot(history['eer'], color=color, linestyle='-', label='EER (%)')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.legend(loc='upper right')
    fig.tight_layout()
    plt.title(f'{title_prefix} Training and Validation Metrics')
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]


# --- NEW: CNN + BiLSTM + Attention Fusion Model ---
class CNNBiLSTMAttentionModel(nn.Module):
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, cnn_out_channels=128, lstm_hidden_dim=128):
        super(CNNBiLSTMAttentionModel, self).__init__()
        
        # CNN branch for processing CQCC features (Key, Value)
        self.cnn_branch = nn.Sequential(
            nn.Conv1d(in_channels=cqcc_feature_dim, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=cnn_out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(cnn_out_channels),
        )
        
        # BiLSTM branch for processing prosodic features (Query)
        self.bilstm_branch = nn.LSTM(
            input_size=prosody_feature_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=0.3
        )

        # Linear layer to transform BiLSTM output to match CNN channel dimension for attention
        # BiLSTM output is 2 * lstm_hidden_dim because it's bidirectional
        self.query_transform = nn.Linear(lstm_hidden_dim * 2, cnn_out_channels)
        
        # Classifier to combine the attention context and the prosody query
        classifier_input_dim = cnn_out_channels + cnn_out_channels
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # 1. Process CQCC features through the CNN branch
        cqcc_x_permuted = cqcc_x.permute(0, 2, 1)
        cnn_out = self.cnn_branch(cqcc_x_permuted)
        keys = values = cnn_out.permute(0, 2, 1)
        
        # 2. Process prosodic features through the BiLSTM branch
        # Add a sequence dimension: (batch, features) -> (batch, 1, features)
        prosody_x_unsqueezed = prosody_x.unsqueeze(1)
        lstm_out, _ = self.bilstm_branch(prosody_x_unsqueezed)
        # Take the output of the last time step
        lstm_out_last = lstm_out[:, -1, :]
        
        # 3. Transform BiLSTM output to create the Query
        prosody_query = self.query_transform(lstm_out_last)

        # 4. Perform Cross-Attention
        query_unsqueezed = prosody_query.unsqueeze(1)
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2)) / (keys.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.bmm(attention_weights, values).squeeze(1)
        
        # 5. Fuse and classify
        fused_features = torch.cat([context, prosody_query], dim=1)
        logits = self.classifier(fused_features)
        
        return logits, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, device, save_path):
    print("\n--- Running Cross-Attention Weight Analysis ---")
    model.eval()
    all_weights = []
    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            all_weights.append(weights.squeeze(1).cpu().numpy())
    
    avg_weights = np.mean(np.concatenate(all_weights, axis=0), axis=0)
    
    plt.figure(figsize=(15, 6))
    plt.plot(avg_weights, color='purple')
    plt.xlabel('CQCC Time Frame (after CNN processing)')
    plt.ylabel('Average Attention Weight')
    plt.title('Cross-Attention: Importance of Acoustic Time Frames Guided by Prosody')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAttention plot saved to {save_path}")
    plt.close()

def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    print("\n--- Running Feature Ablation Analysis ---")
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody_clone = prosody.clone()
                    prosody_clone[:, feature_to_ablate] = 0.0
                    logits, _ = model(cqcc, prosody_clone)
                else:
                    logits, _ = model(cqcc, prosody)
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        return calculate_eer(np.array(all_labels), np.array(all_scores).flatten())

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
    
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")
    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()
    return sorted_features


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    print("\n--- Running SHAP Analysis ---")
    model.eval()
    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        with torch.no_grad():
            logits, _ = model(cqcc_tensor, prosody_tensor)
            output = torch.sigmoid(logits)
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    plt.figure() 
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()

# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        X_cqcc_train_full = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)
        
        X_cqcc_val_full = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)

        X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        
        feature_columns = [
            'Loudness_sma3','alphaRatio_sma3','hammarbergIndex_sma3','slope0-500_sma3',
            'slope500-1500_sma3','spectralFlux_sma3','mfcc1_sma3','mfcc2_sma3',
            'mfcc3_sma3','mfcc4_sma3','F0semitoneFrom27.5Hz_sma3nz','jitterLocal_sma3nz',
            'shimmerLocaldB_sma3nz','HNRdBACF_sma3nz','logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz','F1frequency_sma3nz','F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz','F2frequency_sma3nz','F2amplitudeLogRelF0_sma3nz',
            'F3frequency_sma3nz','F3amplitudeLogRelF0_sma3nz'
        ]
        # Assuming shape is (samples, features, timesteps)
        num_prosodic_features = X_prosody_train_3d.shape[1]
        if len(feature_columns) != num_prosodic_features:
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    # --- NEW: Scaling LLD Prosodic Features before creating summary stats ---
    print("\n--- Scaling LLD Prosodic Features ---")
    # Assuming shape is (samples, features, timesteps). We need to scale across the features.
    # To do this, we transpose to (samples, timesteps, features) and then reshape.
    ns_p, nf_p, nt_p = X_prosody_train_3d.shape
    X_prosody_train_reshaped = X_prosody_train_3d.transpose(0, 2, 1).reshape(-1, nf_p)
    
    # Fit scaler ONLY on training data
    scaler_prosody_lld = StandardScaler().fit(X_prosody_train_reshaped)

    # Transform train data and reshape back
    X_prosody_train_3d_scaled = scaler_prosody_lld.transform(X_prosody_train_reshaped).reshape(ns_p, nt_p, nf_p).transpose(0, 2, 1)

    # Transform validation data
    nsv_p, nfv_p, ntv_p = X_prosody_val_3d.shape
    X_prosody_val_reshaped = X_prosody_val_3d.transpose(0, 2, 1).reshape(-1, nfv_p)
    X_prosody_val_3d_scaled = scaler_prosody_lld.transform(X_prosody_val_reshaped).reshape(nsv_p, ntv_p, nfv_p).transpose(0, 2, 1)
    
    # Transform test data
    nst_p, nft_p, ntt_p = X_prosody_test_3d.shape
    X_prosody_test_reshaped = X_prosody_test_3d.transpose(0, 2, 1).reshape(-1, nft_p)
    X_prosody_test_3d_scaled = scaler_prosody_lld.transform(X_prosody_test_reshaped).reshape(nst_p, ntt_p, nft_p).transpose(0, 2, 1)

    print("Converting SCALED 3D LLD prosodic features to 2D summary statistics (mean)...")
    X_prosody_train_full = np.mean(X_prosody_train_3d_scaled, axis=2)
    X_prosody_val_full = np.mean(X_prosody_val_3d_scaled, axis=2)
    X_prosody_test_full = np.mean(X_prosody_test_3d_scaled, axis=2)

    # --- Scaling CQCC Data ---
    print("\n--- Scaling CQCC Data ---")
    scaler_cqcc = StandardScaler()
    ns, nx, ny = X_cqcc_train_full.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train_full.reshape(ns, -1)).reshape(ns, nx, ny)
    nsv, nxv, nyv = X_cqcc_val_full.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val_full.reshape(nsv, -1)).reshape(nsv, nxv, nyv)
    ns_test, nx_test, ny_test = X_cqcc_test_full.shape
    X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_full.reshape(ns_test, -1)).reshape(ns_test, nx_test, ny_test)

    # Use the scaled & summarized prosody data
    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_full, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_full, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = CNNBiLSTMAttentionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=X_prosody_train_full.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True)
    
    print(model)
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training (All Features) ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (All Feats)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            logits, _ = model(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        eer = calculate_eer(all_labels, all_scores)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | EER: {eer:.2f}%")
        history.update({'train_loss': history['train_loss']+[avg_train_loss], 'val_loss': history['val_loss']+[avg_val_loss], 'eer': history['eer']+[eer]})
        scheduler.step(avg_val_loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val loss decreased. New best model saved to {MODEL_SAVE_PATH}")

    plot_training_history(history, PLOT_SAVE_PATH, title_prefix="All Features (CNN-BiLSTM-Attention)")

    # --- FINAL TESTING AND ANALYSIS (ALL FEATURES) ---
    print("\n--- Starting Final Testing and Analysis (All Features) ---")
    try:
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_full, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        analysis_model = CNNBiLSTMAttentionModel(
            cqcc_feature_dim=X_cqcc_train_full.shape[2],
            prosody_feature_dim=X_prosody_train_full.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()
        
        analyze_attention_weights(analysis_model, test_loader, DEVICE, ATTENTION_PLOT_PATH)
        sorted_features = perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except Exception as e:
        print(f"Error during testing/analysis: {e}")

    # ==============================================================================
    # --- NEW: RETRAINING WITH TOP 6 FEATURES ---
    # ==============================================================================
    print("\n\n--- Starting Retraining with Top 6 Features ---")
    
    top_6_feature_names = [item[0] for item in sorted_features[:6]]
    top_6_indices = [feature_columns.index(name) for name in top_6_feature_names]
    print("Top 6 features selected for retraining:", top_6_feature_names)

    X_prosody_train_top6 = X_prosody_train_full[:, top_6_indices]
    X_prosody_val_top6 = X_prosody_val_full[:, top_6_indices]
    X_prosody_test_top6 = X_prosody_test_full[:, top_6_indices]
    
    train_dataset_top6 = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_top6, y_train)
    val_dataset_top6 = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_top6, y_val)
    train_loader_top6 = DataLoader(train_dataset_top6, batch_size=BATCH_SIZE, shuffle=True)
    val_loader_top6 = DataLoader(val_dataset_top6, batch_size=BATCH_SIZE, shuffle=False)

    model_top6 = CNNBiLSTMAttentionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=6
    ).to(DEVICE)

    optimizer_top6 = optim.Adam(model_top6.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler_top6 = optim.lr_scheduler.ReduceLROnPlateau(optimizer_top6, 'min', factor=0.2, patience=5, verbose=True)

    print(model_top6)
    best_val_loss_top6 = float('inf')
    history_top6 = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'f1': [], 'eer': []}
    print("\n--- Starting Model Training (Top 6 Features) ---")

    for epoch in range(EPOCHS):
        model_top6.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader_top6, desc=f"Epoch {epoch+1}/{EPOCHS} (Top 6)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer_top6.zero_grad()
            logits, _ = model_top6(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer_top6.step()
            running_loss += loss.item()

        model_top6.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader_top6:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model_top6(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader_top6)
        avg_val_loss = val_loss / len(val_loader_top6)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        eer = calculate_eer(all_labels, all_scores)
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | EER: {eer:.2f}%")
        history_top6.update({'train_loss': history_top6['train_loss']+[avg_train_loss], 'val_loss': history_top6['val_loss']+[avg_val_loss], 'eer': history_top6['eer']+[eer]})
        scheduler_top6.step(avg_val_loss)

        if avg_val_loss < best_val_loss_top6:
            best_val_loss_top6 = avg_val_loss
            torch.save(model_top6.state_dict(), MODEL_SAVE_PATH_TOP6)
            print(f"   -> Val loss decreased. New best model (Top 6) saved to {MODEL_SAVE_PATH_TOP6}")

    plot_training_history(history_top6, PLOT_SAVE_PATH_TOP6, title_prefix="Top 6 Features (CNN-BiLSTM-Attention)")

    # --- FINAL TESTING (TOP 6 FEATURES) ---
    print("\n--- Starting Final Testing (Top 6 Features) ---")
    test_dataset_top6 = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_top6, y_test)
    test_loader_top6 = DataLoader(test_dataset_top6, batch_size=BATCH_SIZE, shuffle=False)
    
    model_top6.load_state_dict(torch.load(MODEL_SAVE_PATH_TOP6))
    model_top6.eval()

    all_test_labels, all_test_scores = [], []
    with torch.no_grad():
        for cqcc, prosody, labels in tqdm(test_loader_top6, desc="Final Testing (Top 6)"):
            cqcc, prosody = cqcc.to(DEVICE), prosody.to(DEVICE)
            logits, _ = model_top6(cqcc, prosody)
            all_test_scores.extend(torch.sigmoid(logits).cpu().numpy())
            all_test_labels.extend(labels.cpu().numpy())
    
    all_test_labels, all_test_scores = np.array(all_test_labels), np.array(all_test_scores).flatten()
    test_eer_top6 = calculate_eer(all_test_labels, all_test_scores)
    print(f"\n--- Final Test Results (Top 6 Features) --- | EER: {test_eer_top6:.2f}%")
    print("\n--- Experiment Complete ---")

Using device: cuda
--- Loading and Preparing Data ---

--- Scaling LLD Prosodic Features ---
Converting SCALED 3D LLD prosodic features to 2D summary statistics (mean)...

--- Scaling CQCC Data ---


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


CNNBiLSTMAttentionModel(
  (cnn_branch): Sequential(
    (0): Conv1d(157, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (4): ReLU()
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (7): ReLU()
    (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (bilstm_branch): LSTM(23, 128, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (query_transform): Linear(in_features=256, out_features=128, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.

Epoch 1/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.38it/s]



Epoch 1 | Train Loss: 0.2988 | Val Loss: 0.1552 | EER: 10.60%
   -> Val loss decreased. New best model saved to saved_models/CNNBiLSTMAttention_PyTorch_Best_23feat.pth


Epoch 2/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.40it/s]



Epoch 2 | Train Loss: 0.0847 | Val Loss: 0.1354 | EER: 10.09%
   -> Val loss decreased. New best model saved to saved_models/CNNBiLSTMAttention_PyTorch_Best_23feat.pth


Epoch 3/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.60it/s]



Epoch 3 | Train Loss: 0.0412 | Val Loss: 0.1605 | EER: 9.34%


Epoch 4/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.06it/s]



Epoch 4 | Train Loss: 0.0220 | Val Loss: 0.2041 | EER: 9.69%


Epoch 5/40 (All Feats): 100%|██████████| 720/720 [00:13<00:00, 52.74it/s]



Epoch 5 | Train Loss: 0.0153 | Val Loss: 0.2036 | EER: 9.03%


Epoch 6/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.22it/s]



Epoch 6 | Train Loss: 0.0118 | Val Loss: 0.2359 | EER: 9.26%


Epoch 7/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.37it/s]



Epoch 7 | Train Loss: 0.0092 | Val Loss: 0.1957 | EER: 8.52%


Epoch 8/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.72it/s]



Epoch 8 | Train Loss: 0.0062 | Val Loss: 0.2056 | EER: 8.24%


Epoch 9/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.44it/s]



Epoch 9 | Train Loss: 0.0021 | Val Loss: 0.3009 | EER: 8.80%


Epoch 10/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 60.80it/s]



Epoch 10 | Train Loss: 0.0005 | Val Loss: 0.3192 | EER: 8.79%


Epoch 11/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.81it/s]



Epoch 11 | Train Loss: 0.0045 | Val Loss: 0.3149 | EER: 9.06%


Epoch 12/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 58.45it/s]



Epoch 12 | Train Loss: 0.0014 | Val Loss: 0.3372 | EER: 9.34%


Epoch 13/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.28it/s]



Epoch 13 | Train Loss: 0.0004 | Val Loss: 0.3756 | EER: 9.35%


Epoch 14/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.70it/s]



Epoch 14 | Train Loss: 0.0003 | Val Loss: 0.3781 | EER: 9.22%


Epoch 15/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.84it/s]



Epoch 15 | Train Loss: 0.0002 | Val Loss: 0.3553 | EER: 8.99%


Epoch 16/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 65.02it/s]



Epoch 16 | Train Loss: 0.0001 | Val Loss: 0.3686 | EER: 9.15%


Epoch 17/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.64it/s]



Epoch 17 | Train Loss: 0.0005 | Val Loss: 0.3523 | EER: 8.93%


Epoch 18/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 62.59it/s]



Epoch 18 | Train Loss: 0.0003 | Val Loss: 0.3815 | EER: 8.96%


Epoch 19/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.10it/s]



Epoch 19 | Train Loss: 0.0001 | Val Loss: 0.3872 | EER: 9.03%


Epoch 20/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.42it/s]



Epoch 20 | Train Loss: 0.0002 | Val Loss: 0.3527 | EER: 8.99%


Epoch 21/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.38it/s]



Epoch 21 | Train Loss: 0.0001 | Val Loss: 0.3818 | EER: 9.06%


Epoch 22/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.08it/s]



Epoch 22 | Train Loss: 0.0001 | Val Loss: 0.3900 | EER: 9.02%


Epoch 23/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 65.40it/s]



Epoch 23 | Train Loss: 0.0001 | Val Loss: 0.4011 | EER: 8.94%


Epoch 24/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.87it/s]



Epoch 24 | Train Loss: 0.0002 | Val Loss: 0.3749 | EER: 8.95%


Epoch 25/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 65.84it/s]



Epoch 25 | Train Loss: 0.0002 | Val Loss: 0.3819 | EER: 9.10%


Epoch 26/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.49it/s]



Epoch 26 | Train Loss: 0.0001 | Val Loss: 0.4538 | EER: 9.17%


Epoch 27/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.34it/s]



Epoch 27 | Train Loss: 0.0001 | Val Loss: 0.3978 | EER: 9.07%


Epoch 28/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.76it/s]



Epoch 28 | Train Loss: 0.0001 | Val Loss: 0.4069 | EER: 8.96%


Epoch 29/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.31it/s]



Epoch 29 | Train Loss: 0.0001 | Val Loss: 0.4082 | EER: 9.11%


Epoch 30/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 65.49it/s]



Epoch 30 | Train Loss: 0.0001 | Val Loss: 0.4276 | EER: 9.11%


Epoch 31/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.23it/s]



Epoch 31 | Train Loss: 0.0001 | Val Loss: 0.4252 | EER: 9.14%


Epoch 32/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 65.55it/s]



Epoch 32 | Train Loss: 0.0001 | Val Loss: 0.3839 | EER: 9.00%


Epoch 33/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.92it/s]



Epoch 33 | Train Loss: 0.0002 | Val Loss: 0.4246 | EER: 9.11%


Epoch 34/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.91it/s]



Epoch 34 | Train Loss: 0.0001 | Val Loss: 0.4261 | EER: 9.18%


Epoch 35/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 69.72it/s]



Epoch 35 | Train Loss: 0.0001 | Val Loss: 0.4039 | EER: 9.10%


Epoch 36/40 (All Feats):  85%|████████▌ | 614/720 [00:09<00:01, 67.24it/s]


KeyboardInterrupt: 

In [4]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, roc_curve, accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import matplotlib.pyplot as plt
import shap
import math
import torch.nn.functional as F

# --- Configuration ---
# Paths for TRAINING data
CQCC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/cqcc_features_train.npy"
PROSODIC_FEATURES_TRAIN_PATH = "processed_data_aligned_lld/egmaps_lld_features_train.npy"
LABELS_TRAIN_PATH = "processed_data_aligned_lld/labels_train.npy"

# Paths for VALIDATION data
CQCC_FEATURES_VAL_PATH = "processed_data_aligned_lld/cqcc_features_dev.npy"
PROSODIC_FEATURES_VAL_PATH = "processed_data_aligned_lld/egmaps_lld_features_dev.npy"
LABELS_VAL_PATH = "processed_data_aligned_lld/labels_dev.npy"

# Paths for TEST data
CQCC_FEATURES_TEST_PATH = "processed_data_aligned_lld/cqcc_features_test.npy"
PROSODIC_FEATURES_TEST_PATH = "processed_data_aligned_lld/egmaps_lld_features_test.npy"
LABELS_TEST_PATH = "processed_data_aligned_lld/labels_test.npy"

# --- Model and Analysis Configuration ---
MODEL_SAVE_PATH = "saved_models/AcousticProsodicAttention_Best_23feat.pth"
PLOT_SAVE_PATH = "saved_models/training_metrics_AcousticProsodicAttention_23feat.png"
ATTENTION_PLOT_PATH = "saved_models/attention_importance_AcousticProsodicAttention_23feat.png"
ABLATION_PLOT_PATH = "saved_models/ablation_importance_AcousticProsodicAttention_23feat.png"
SHAP_PLOT_PATH = "saved_models/shap_importance_AcousticProsodicAttention_23feat.png"
MODEL_SAVE_PATH_TOP6 = "saved_models/AcousticProsodicAttention_Best_Top6feat.pth"
PLOT_SAVE_PATH_TOP6 = "saved_models/training_metrics_AcousticProsodicAttention_Top6feat.png"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 40
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5

os.makedirs("saved_models", exist_ok=True)
print(f"Using device: {DEVICE}")

# --- MODIFIED: calculate_eer now returns the threshold as well ---
def calculate_eer(y_true, y_score):
    """
    Calculates the Equal Error Rate (EER) and the threshold at which it occurs.
    """
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    
    # Calculate the EER
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    
    # Calculate the threshold at which EER occurs
    thresh = interp1d(fpr, thresholds)(eer)
    
    return eer * 100, thresh

# --- MODIFIED: plot_training_history now includes validation accuracy ---
def plot_training_history(history, save_path, title_prefix=""):
    """Plots and saves the training history graph, including validation accuracy."""
    fig, ax1 = plt.subplots(figsize=(12, 8))
    
    # Plot Loss on the primary y-axis
    color = 'tab:red'
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['train_loss'], color=color, linestyle='--', label='Train Loss')
    ax1.plot(history['val_loss'], color=color, linestyle='-', label='Val Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')
    
    # Plot EER and Accuracy on the secondary y-axis
    ax2 = ax1.twinx()  
    color_eer = 'tab:blue'
    ax2.set_ylabel('EER (%) / Accuracy (%)', color=color_eer)
    ax2.plot(history['eer'], color=color_eer, linestyle='-', label='EER (%)')
    ax2.plot(history['val_acc'], color='tab:green', linestyle=':', label='Val Accuracy (%)')
    ax2.tick_params(axis='y', labelcolor=color_eer)
    ax2.legend(loc='upper right')
    
    fig.tight_layout()
    plt.title(f'{title_prefix} Training and Validation Metrics')
    plt.grid(True, alpha=0.3)
    plt.savefig(save_path)
    print(f"\nTraining plot saved to {save_path}")
    plt.close()


class AudioFeatureDataset(Dataset):
    """Custom PyTorch Dataset."""
    def __init__(self, cqcc_data, prosody_data, labels):
        self.cqcc_data = torch.tensor(cqcc_data, dtype=torch.float32)
        self.prosody_data = torch.tensor(prosody_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.cqcc_data[idx], self.prosody_data[idx], self.labels[idx]


# --- COMPLETELY NEW ARCHITECTURE ---
class AcousticProsodicAttentionModel(nn.Module):
    def __init__(self, cqcc_feature_dim, prosody_feature_dim, cnn_channels=128, lstm_hidden=128, attention_dim=128):
        super(AcousticProsodicAttentionModel, self).__init__()
        
        # Acoustic Branch (CNN for CQCC)
        self.acoustic_branch = nn.Sequential(
            nn.Conv1d(in_channels=cqcc_feature_dim, out_channels=64, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.MaxPool1d(2),
            nn.Conv1d(in_channels=64, out_channels=cnn_channels, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(cnn_channels)
        )
        
        # Prosodic Branch (BiLSTM for Prosody)
        self.prosodic_branch = nn.LSTM(
            input_size=prosody_feature_dim,
            hidden_size=lstm_hidden,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=0.4
        )
        
        # Attention Mechanism Layers
        # Project acoustic features (Key/Value)
        self.key_value_projection = nn.Linear(cnn_channels, attention_dim)
        # Project prosodic features (Query)
        self.query_projection = nn.Linear(lstm_hidden * 2, attention_dim)
        
        # Final Classifier
        # Input will be the concatenated attention context and prosodic query
        classifier_input_dim = attention_dim + attention_dim
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1)
        )

    def forward(self, cqcc_x, prosody_x):
        # 1. Process Acoustic Features (CQCC)
        # Input: (B, Seq, Feat) -> Permute: (B, Feat, Seq)
        acoustic_out = self.acoustic_branch(cqcc_x.permute(0, 2, 1))
        # Permute back: (B, Seq, Channels)
        acoustic_out = acoustic_out.permute(0, 2, 1)

        # 2. Process Prosodic Features
        # Input: (B, Feat) -> Unsqueeze: (B, 1, Feat) for LSTM
        prosody_x_unsqueezed = prosody_x.unsqueeze(1)
        prosodic_out, _ = self.prosodic_branch(prosody_x_unsqueezed)
        # Take the last hidden state: (B, LSTM_Hidden * 2)
        prosodic_out = prosodic_out[:, -1, :]
        
        # 3. Project features for Attention
        # Keys and Values from acoustic branch
        keys = values = self.key_value_projection(acoustic_out) # (B, Seq, Attention_Dim)
        # Query from prosodic branch
        query = self.query_projection(prosodic_out) # (B, Attention_Dim)
        
        # 4. Perform Cross-Attention
        # Unsqueeze query for batch matrix multiplication: (B, 1, Attention_Dim)
        query_unsqueezed = query.unsqueeze(1)
        
        # Attention Score = (Query * Key^T) / sqrt(d_k)
        attention_scores = torch.bmm(query_unsqueezed, keys.transpose(1, 2))
        attention_scores = attention_scores / math.sqrt(keys.size(-1))
        
        # Softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1) # (B, 1, Seq)
        
        # Context = Attention Weights * Value
        context = torch.bmm(attention_weights, values).squeeze(1) # (B, Attention_Dim)
        
        # 5. Fuse and Classify
        # Concatenate the focused acoustic info (context) and the prosodic info (query)
        fused_features = torch.cat([context, query], dim=1)
        logits = self.classifier(fused_features)
        
        return logits, attention_weights

# ==============================================================================
# ANALYSIS FUNCTIONS
# ==============================================================================

def analyze_attention_weights(model, dataloader, device, save_path):
    print("\n--- Running Cross-Attention Weight Analysis ---")
    model.eval()
    all_weights = []
    with torch.no_grad():
        for cqcc, prosody, _ in tqdm(dataloader, desc="Analyzing Attention"):
            cqcc, prosody = cqcc.to(device), prosody.to(device)
            _, weights = model(cqcc, prosody)
            all_weights.append(weights.squeeze(1).cpu().numpy())
    
    avg_weights = np.mean(np.concatenate(all_weights, axis=0), axis=0)
    
    plt.figure(figsize=(15, 6))
    plt.plot(avg_weights, color='purple')
    plt.xlabel('CQCC Time Frame (after CNN processing)')
    plt.ylabel('Average Attention Weight')
    plt.title('Cross-Attention: Importance of Acoustic Time Frames Guided by Prosody')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAttention plot saved to {save_path}")
    plt.close()

def perform_feature_ablation(model, dataloader, feature_names, device, save_path):
    print("\n--- Running Feature Ablation Analysis ---")
    def evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=None):
        model.eval()
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in dataloader:
                cqcc, prosody, labels = cqcc.to(device), prosody.to(device), labels.to(device)
                if feature_to_ablate is not None:
                    prosody_clone = prosody.clone()
                    prosody_clone[:, feature_to_ablate] = 0.0
                    logits, _ = model(cqcc, prosody_clone)
                else:
                    logits, _ = model(cqcc, prosody)
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        eer, _ = calculate_eer(np.array(all_labels), np.array(all_scores).flatten())
        return eer

    baseline_eer = evaluate_eer_for_ablation(model, dataloader, device)
    print(f"Baseline EER with all features: {baseline_eer:.2f}%")
    eer_increases = {}
    for i, name in enumerate(tqdm(feature_names, desc="Performing Ablation")):
        ablated_eer = evaluate_eer_for_ablation(model, dataloader, device, feature_to_ablate=i)
        eer_increases[name] = ablated_eer - baseline_eer
    
    sorted_features = sorted(eer_increases.items(), key=lambda item: item[1], reverse=True)
    
    print("\nFeature Importance based on EER Increase:")
    for feature, increase in sorted_features:
        print(f"- {feature}: EER increases by {increase:.2f}%")
    names = [item[0] for item in sorted_features]
    increases = [item[1] for item in sorted_features]
    plt.figure(figsize=(12, 10))
    plt.barh(names, increases, color='salmon')
    plt.xlabel('EER Increase (%)')
    plt.title('Prosodic Feature Importance based on Feature Ablation')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nAblation plot saved to {save_path}")
    plt.close()
    return sorted_features


def analyze_with_shap(model, dataloader, feature_names, device, save_path):
    print("\n--- Running SHAP Analysis ---")
    model.eval()
    background_cqcc, background_prosody, _ = next(iter(dataloader))
    test_cqcc, test_prosody, _ = next(iter(dataloader))
    
    def model_wrapper(prosodic_features_numpy):
        num_samples = prosodic_features_numpy.shape[0]
        prosody_tensor = torch.from_numpy(prosodic_features_numpy).float().to(device)
        cqcc_background_sample = background_cqcc[0:1].to(device)
        cqcc_tensor = cqcc_background_sample.repeat(num_samples, 1, 1)
        with torch.no_grad():
            logits, _ = model(cqcc_tensor, prosody_tensor)
            output = torch.sigmoid(logits)
        return output.cpu().numpy()

    explainer = shap.KernelExplainer(model_wrapper, background_prosody.numpy())
    print("Calculating SHAP values (this may take a while)...")
    shap_values = explainer.shap_values(test_prosody.numpy(), nsamples=100)
    print("Plotting SHAP summary...")
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    plt.figure() 
    shap.summary_plot(shap_values, test_prosody.numpy(), feature_names=feature_names, show=False)
    plt.title('SHAP Summary for Prosodic Features')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"SHAP plot saved to {save_path}")
    plt.close()

# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    try:
        print("--- Loading and Preparing Data ---")
        X_cqcc_train_full = np.load(CQCC_FEATURES_TRAIN_PATH)
        X_prosody_train_3d = np.load(PROSODIC_FEATURES_TRAIN_PATH)
        y_train = np.load(LABELS_TRAIN_PATH)
        
        X_cqcc_val_full = np.load(CQCC_FEATURES_VAL_PATH)
        X_prosody_val_3d = np.load(PROSODIC_FEATURES_VAL_PATH)
        y_val = np.load(LABELS_VAL_PATH)

        X_cqcc_test_full = np.load(CQCC_FEATURES_TEST_PATH)
        X_prosody_test_3d = np.load(PROSODIC_FEATURES_TEST_PATH)
        y_test = np.load(LABELS_TEST_PATH)
        
        feature_columns = [
            'Loudness_sma3','alphaRatio_sma3','hammarbergIndex_sma3','slope0-500_sma3',
            'slope500-1500_sma3','spectralFlux_sma3','mfcc1_sma3','mfcc2_sma3',
            'mfcc3_sma3','mfcc4_sma3','F0semitoneFrom27.5Hz_sma3nz','jitterLocal_sma3nz',
            'shimmerLocaldB_sma3nz','HNRdBACF_sma3nz','logRelF0-H1-H2_sma3nz',
            'logRelF0-H1-A3_sma3nz','F1frequency_sma3nz','F1bandwidth_sma3nz',
            'F1amplitudeLogRelF0_sma3nz','F2frequency_sma3nz','F2amplitudeLogRelF0_sma3nz',
            'F3frequency_sma3nz','F3amplitudeLogRelF0_sma3nz'
        ]
        num_prosodic_features = X_prosody_train_3d.shape[1]
        if len(feature_columns) != num_prosodic_features:
            feature_columns = [f'ProsodicFeat_{i+1}' for i in range(num_prosodic_features)]

    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading data: {e}")
        exit()

    print("\n--- Scaling LLD Prosodic Features ---")
    ns_p, nf_p, nt_p = X_prosody_train_3d.shape
    X_prosody_train_reshaped = X_prosody_train_3d.transpose(0, 2, 1).reshape(-1, nf_p)
    scaler_prosody_lld = StandardScaler().fit(X_prosody_train_reshaped)
    X_prosody_train_3d_scaled = scaler_prosody_lld.transform(X_prosody_train_reshaped).reshape(ns_p, nt_p, nf_p).transpose(0, 2, 1)
    nsv_p, nfv_p, ntv_p = X_prosody_val_3d.shape
    X_prosody_val_reshaped = X_prosody_val_3d.transpose(0, 2, 1).reshape(-1, nfv_p)
    X_prosody_val_3d_scaled = scaler_prosody_lld.transform(X_prosody_val_reshaped).reshape(nsv_p, ntv_p, nfv_p).transpose(0, 2, 1)
    nst_p, nft_p, ntt_p = X_prosody_test_3d.shape
    X_prosody_test_reshaped = X_prosody_test_3d.transpose(0, 2, 1).reshape(-1, nft_p)
    X_prosody_test_3d_scaled = scaler_prosody_lld.transform(X_prosody_test_reshaped).reshape(nst_p, ntt_p, nft_p).transpose(0, 2, 1)

    print("Converting SCALED 3D LLD prosodic features to 2D summary statistics (mean)...")
    X_prosody_train_full = np.mean(X_prosody_train_3d_scaled, axis=2)
    X_prosody_val_full = np.mean(X_prosody_val_3d_scaled, axis=2)
    X_prosody_test_full = np.mean(X_prosody_test_3d_scaled, axis=2)

    print("\n--- Scaling CQCC Data ---")
    scaler_cqcc = StandardScaler()
    ns, nx, ny = X_cqcc_train_full.shape
    X_cqcc_train_scaled = scaler_cqcc.fit_transform(X_cqcc_train_full.reshape(ns, -1)).reshape(ns, nx, ny)
    nsv, nxv, nyv = X_cqcc_val_full.shape
    X_cqcc_val_scaled = scaler_cqcc.transform(X_cqcc_val_full.reshape(nsv, -1)).reshape(nsv, nxv, nyv)
    ns_test, nx_test, ny_test = X_cqcc_test_full.shape
    X_cqcc_test_scaled = scaler_cqcc.transform(X_cqcc_test_full.reshape(ns_test, -1)).reshape(ns_test, nx_test, ny_test)

    train_dataset = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_full, y_train)
    val_dataset = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_full, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = AcousticProsodicAttentionModel(
        cqcc_feature_dim=X_cqcc_train_full.shape[2],
        prosody_feature_dim=X_prosody_train_full.shape[1]
    ).to(DEVICE)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
    
    print(model)
    # --- MODIFIED: Checkpointing based on best EER ---
    best_eer = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'eer': []}
    print("\n--- Starting Model Training (All Features) ---")

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for cqcc, prosody, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (All Feats)"):
            cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            logits, _ = model(cqcc, prosody)
            loss = criterion(logits, labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        val_loss = 0.0
        all_labels, all_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in val_loader:
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                logits, _ = model(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                val_loss += loss.item()
                all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
        
        # --- MODIFIED: Calculate EER, threshold, and accuracy ---
        eer, eer_thresh = calculate_eer(all_labels, all_scores)
        val_preds = (all_scores > eer_thresh).astype(int)
        val_accuracy = accuracy_score(all_labels, val_preds) * 100
        
        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | EER: {eer:.2f}%")
        history.update({'train_loss': history['train_loss']+[avg_train_loss], 'val_loss': history['val_loss']+[avg_val_loss], 'val_acc': history['val_acc']+[val_accuracy], 'eer': history['eer']+[eer]})
        scheduler.step(eer) # Schedule based on EER
        
        # --- MODIFIED: Save model if EER has improved ---
        if eer < best_eer:
            best_eer = eer
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   -> Val EER decreased to {eer:.2f}%. New best model saved to {MODEL_SAVE_PATH}")

    plot_training_history(history, PLOT_SAVE_PATH, title_prefix="All Features (New Architecture)")

    # --- FINAL TESTING AND ANALYSIS (ALL FEATURES) ---
    print("\n--- Starting Final Testing and Analysis (All Features) ---")
    try:
        test_dataset = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_full, y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        analysis_model = AcousticProsodicAttentionModel(
            cqcc_feature_dim=X_cqcc_train_full.shape[2],
            prosody_feature_dim=X_prosody_train_full.shape[1]
        ).to(DEVICE)
        analysis_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        analysis_model.eval()
        
        analyze_attention_weights(analysis_model, test_loader, DEVICE, ATTENTION_PLOT_PATH)
        sorted_features = perform_feature_ablation(analysis_model, test_loader, feature_columns, DEVICE, ABLATION_PLOT_PATH)
        analyze_with_shap(analysis_model, test_loader, feature_columns, DEVICE, SHAP_PLOT_PATH)

    except Exception as e:
        print(f"Error during testing/analysis: {e}")
        sorted_features = [] # Ensure sorted_features exists

    # ==============================================================================
    # --- RETRAINING WITH TOP 6 FEATURES ---
    # ==============================================================================
    if not sorted_features:
        print("\n--- Skipping Retraining with Top 6 Features due to previous error ---")
    else:
        print("\n\n--- Starting Retraining with Top 6 Features ---")
        
        top_6_feature_names = [item[0] for item in sorted_features[:6]]
        top_6_indices = [feature_columns.index(name) for name in top_6_feature_names]
        print("Top 6 features selected for retraining:", top_6_feature_names)

        X_prosody_train_top6 = X_prosody_train_full[:, top_6_indices]
        X_prosody_val_top6 = X_prosody_val_full[:, top_6_indices]
        X_prosody_test_top6 = X_prosody_test_full[:, top_6_indices]
        
        train_dataset_top6 = AudioFeatureDataset(X_cqcc_train_scaled, X_prosody_train_top6, y_train)
        val_dataset_top6 = AudioFeatureDataset(X_cqcc_val_scaled, X_prosody_val_top6, y_val)
        train_loader_top6 = DataLoader(train_dataset_top6, batch_size=BATCH_SIZE, shuffle=True)
        val_loader_top6 = DataLoader(val_dataset_top6, batch_size=BATCH_SIZE, shuffle=False)

        model_top6 = AcousticProsodicAttentionModel(
            cqcc_feature_dim=X_cqcc_train_full.shape[2],
            prosody_feature_dim=6
        ).to(DEVICE)

        optimizer_top6 = optim.Adam(model_top6.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler_top6 = optim.lr_scheduler.ReduceLROnPlateau(optimizer_top6, 'min', patience=5, verbose=True)

        print(model_top6)
        best_eer_top6 = float('inf')
        history_top6 = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'eer': []}
        print("\n--- Starting Model Training (Top 6 Features) ---")

        for epoch in range(EPOCHS):
            model_top6.train()
            running_loss = 0.0
            for cqcc, prosody, labels in tqdm(train_loader_top6, desc=f"Epoch {epoch+1}/{EPOCHS} (Top 6)"):
                cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                optimizer_top6.zero_grad()
                logits, _ = model_top6(cqcc, prosody)
                loss = criterion(logits, labels.unsqueeze(1))
                loss.backward()
                optimizer_top6.step()
                running_loss += loss.item()

            model_top6.eval()
            val_loss = 0.0
            all_labels, all_scores = [], []
            with torch.no_grad():
                for cqcc, prosody, labels in val_loader_top6:
                    cqcc, prosody, labels = cqcc.to(DEVICE), prosody.to(DEVICE), labels.to(DEVICE)
                    logits, _ = model_top6(cqcc, prosody)
                    loss = criterion(logits, labels.unsqueeze(1))
                    val_loss += loss.item()
                    all_scores.extend(torch.sigmoid(logits).cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

            avg_train_loss = running_loss / len(train_loader_top6)
            avg_val_loss = val_loss / len(val_loader_top6)
            all_labels, all_scores = np.array(all_labels), np.array(all_scores).flatten()
            eer, eer_thresh = calculate_eer(all_labels, all_scores)
            val_preds = (all_scores > eer_thresh).astype(int)
            val_accuracy = accuracy_score(all_labels, val_preds) * 100
            
            print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | EER: {eer:.2f}%")
            history_top6.update({'train_loss': history_top6['train_loss']+[avg_train_loss], 'val_loss': history_top6['val_loss']+[avg_val_loss], 'val_acc': history_top6['val_acc']+[val_accuracy], 'eer': history_top6['eer']+[eer]})
            scheduler_top6.step(eer)

            if eer < best_eer_top6:
                best_eer_top6 = eer
                torch.save(model_top6.state_dict(), MODEL_SAVE_PATH_TOP6)
                print(f"   -> Val EER decreased to {eer:.2f}%. New best model (Top 6) saved to {MODEL_SAVE_PATH_TOP6}")

        plot_training_history(history_top6, PLOT_SAVE_PATH_TOP6, title_prefix="Top 6 Features (New Architecture)")

        # --- FINAL TESTING (TOP 6 FEATURES) ---
        print("\n--- Starting Final Testing (Top 6 Features) ---")
        test_dataset_top6 = AudioFeatureDataset(X_cqcc_test_scaled, X_prosody_test_top6, y_test)
        test_loader_top6 = DataLoader(test_dataset_top6, batch_size=BATCH_SIZE, shuffle=False)
        
        model_top6.load_state_dict(torch.load(MODEL_SAVE_PATH_TOP6))
        model_top6.eval()

        all_test_labels, all_test_scores = [], []
        with torch.no_grad():
            for cqcc, prosody, labels in tqdm(test_loader_top6, desc="Final Testing (Top 6)"):
                cqcc, prosody = cqcc.to(DEVICE), prosody.to(DEVICE)
                logits, _ = model_top6(cqcc, prosody)
                all_test_scores.extend(torch.sigmoid(logits).cpu().numpy())
                all_test_labels.extend(labels.cpu().numpy())
        
        all_test_labels, all_test_scores = np.array(all_test_labels), np.array(all_test_scores).flatten()
        test_eer_top6, _ = calculate_eer(all_test_labels, all_test_scores)
        print(f"\n--- Final Test Results (Top 6 Features) --- | EER: {test_eer_top6:.2f}%")
        print("\n--- Experiment Complete ---")


Using device: cuda
--- Loading and Preparing Data ---

--- Scaling LLD Prosodic Features ---
Converting SCALED 3D LLD prosodic features to 2D summary statistics (mean)...

--- Scaling CQCC Data ---


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


AcousticProsodicAttentionModel(
  (acoustic_branch): Sequential(
    (0): Conv1d(157, 64, kernel_size=(7,), stride=(1,), padding=(3,))
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv1d(64, 128, kernel_size=(5,), stride=(1,), padding=(2,))
    (5): ReLU()
    (6): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (prosodic_branch): LSTM(23, 128, num_layers=2, batch_first=True, dropout=0.4, bidirectional=True)
  (key_value_projection): Linear(in_features=128, out_features=128, bias=True)
  (query_projection): Linear(in_features=256, out_features=128, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=

Epoch 1/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.83it/s]



Epoch 1 | Train Loss: 0.2606 | Val Loss: 0.1447 | Val Acc: 89.74% | EER: 10.27%
   -> Val EER decreased to 10.27%. New best model saved to saved_models/AcousticProsodicAttention_Best_23feat.pth


Epoch 2/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.59it/s]



Epoch 2 | Train Loss: 0.0344 | Val Loss: 0.1452 | Val Acc: 91.53% | EER: 8.48%
   -> Val EER decreased to 8.48%. New best model saved to saved_models/AcousticProsodicAttention_Best_23feat.pth


Epoch 3/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.39it/s]



Epoch 3 | Train Loss: 0.0129 | Val Loss: 0.1936 | Val Acc: 91.36% | EER: 8.64%


Epoch 4/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 65.33it/s]



Epoch 4 | Train Loss: 0.0061 | Val Loss: 0.2020 | Val Acc: 91.34% | EER: 8.66%


Epoch 5/40 (All Feats): 100%|██████████| 720/720 [00:13<00:00, 52.83it/s]



Epoch 5 | Train Loss: 0.0050 | Val Loss: 0.2703 | Val Acc: 89.98% | EER: 10.02%


Epoch 6/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.40it/s]



Epoch 6 | Train Loss: 0.0046 | Val Loss: 0.2676 | Val Acc: 91.24% | EER: 8.75%


Epoch 7/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 60.00it/s]



Epoch 7 | Train Loss: 0.0043 | Val Loss: 0.3063 | Val Acc: 91.33% | EER: 8.67%


Epoch 8/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 61.28it/s]



Epoch 8 | Train Loss: 0.0045 | Val Loss: 0.3253 | Val Acc: 91.15% | EER: 8.85%


Epoch 9/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.13it/s]



Epoch 9 | Train Loss: 0.0011 | Val Loss: 0.2395 | Val Acc: 91.66% | EER: 8.34%
   -> Val EER decreased to 8.34%. New best model saved to saved_models/AcousticProsodicAttention_Best_23feat.pth


Epoch 10/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 63.27it/s]



Epoch 10 | Train Loss: 0.0006 | Val Loss: 0.2623 | Val Acc: 91.79% | EER: 8.20%
   -> Val EER decreased to 8.20%. New best model saved to saved_models/AcousticProsodicAttention_Best_23feat.pth


Epoch 11/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.61it/s]



Epoch 11 | Train Loss: 0.0005 | Val Loss: 0.3019 | Val Acc: 91.49% | EER: 8.52%


Epoch 12/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.76it/s]



Epoch 12 | Train Loss: 0.0003 | Val Loss: 0.2861 | Val Acc: 91.48% | EER: 8.52%


Epoch 13/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 61.87it/s]



Epoch 13 | Train Loss: 0.0002 | Val Loss: 0.3354 | Val Acc: 91.44% | EER: 8.56%


Epoch 14/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 69.04it/s]



Epoch 14 | Train Loss: 0.0001 | Val Loss: 0.3353 | Val Acc: 91.59% | EER: 8.41%


Epoch 15/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.49it/s]



Epoch 15 | Train Loss: 0.0001 | Val Loss: 0.3153 | Val Acc: 91.76% | EER: 8.24%


Epoch 16/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.38it/s]



Epoch 16 | Train Loss: 0.0001 | Val Loss: 0.4011 | Val Acc: 91.48% | EER: 8.52%


Epoch 17/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.80it/s]



Epoch 17 | Train Loss: 0.0001 | Val Loss: 0.3622 | Val Acc: 91.45% | EER: 8.56%


Epoch 18/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.82it/s]



Epoch 18 | Train Loss: 0.0001 | Val Loss: 0.4123 | Val Acc: 91.45% | EER: 8.56%


Epoch 19/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.54it/s]



Epoch 19 | Train Loss: 0.0001 | Val Loss: 0.3507 | Val Acc: 91.40% | EER: 8.59%


Epoch 20/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.33it/s]



Epoch 20 | Train Loss: 0.0001 | Val Loss: 0.4132 | Val Acc: 91.49% | EER: 8.52%


Epoch 21/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.89it/s]



Epoch 21 | Train Loss: 0.0001 | Val Loss: 0.3374 | Val Acc: 91.57% | EER: 8.44%


Epoch 22/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 65.98it/s]



Epoch 22 | Train Loss: 0.0000 | Val Loss: 0.3814 | Val Acc: 91.54% | EER: 8.46%


Epoch 23/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.56it/s]



Epoch 23 | Train Loss: 0.0001 | Val Loss: 0.4143 | Val Acc: 91.47% | EER: 8.53%


Epoch 24/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.91it/s]



Epoch 24 | Train Loss: 0.0000 | Val Loss: 0.3685 | Val Acc: 91.58% | EER: 8.42%


Epoch 25/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.73it/s]



Epoch 25 | Train Loss: 0.0000 | Val Loss: 0.3671 | Val Acc: 91.49% | EER: 8.51%


Epoch 26/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.92it/s]



Epoch 26 | Train Loss: 0.0000 | Val Loss: 0.3893 | Val Acc: 91.55% | EER: 8.44%


Epoch 27/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 65.86it/s]



Epoch 27 | Train Loss: 0.0000 | Val Loss: 0.4206 | Val Acc: 91.54% | EER: 8.45%


Epoch 28/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.70it/s]



Epoch 28 | Train Loss: 0.0025 | Val Loss: 0.4235 | Val Acc: 91.44% | EER: 8.56%


Epoch 29/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.37it/s]



Epoch 29 | Train Loss: 0.0000 | Val Loss: 0.4109 | Val Acc: 91.37% | EER: 8.63%


Epoch 30/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.81it/s]



Epoch 30 | Train Loss: 0.0000 | Val Loss: 0.4245 | Val Acc: 91.57% | EER: 8.44%


Epoch 31/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 67.54it/s]



Epoch 31 | Train Loss: 0.0000 | Val Loss: 0.3286 | Val Acc: 91.56% | EER: 8.44%


Epoch 32/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 68.06it/s]



Epoch 32 | Train Loss: 0.0000 | Val Loss: 0.4343 | Val Acc: 91.30% | EER: 8.70%


Epoch 33/40 (All Feats): 100%|██████████| 720/720 [00:10<00:00, 66.43it/s]



Epoch 33 | Train Loss: 0.0000 | Val Loss: 0.3679 | Val Acc: 91.49% | EER: 8.51%


Epoch 34/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.75it/s]



Epoch 34 | Train Loss: 0.0000 | Val Loss: 0.3467 | Val Acc: 91.58% | EER: 8.42%


Epoch 35/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.75it/s]



Epoch 35 | Train Loss: 0.0003 | Val Loss: 0.4036 | Val Acc: 91.44% | EER: 8.56%


Epoch 36/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.68it/s]



Epoch 36 | Train Loss: 0.0000 | Val Loss: 0.3608 | Val Acc: 91.52% | EER: 8.49%


Epoch 37/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.79it/s]



Epoch 37 | Train Loss: 0.0000 | Val Loss: 0.3850 | Val Acc: 91.33% | EER: 8.67%


Epoch 38/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.45it/s]



Epoch 38 | Train Loss: 0.0000 | Val Loss: 0.3768 | Val Acc: 91.54% | EER: 8.47%


Epoch 39/40 (All Feats): 100%|██████████| 720/720 [00:11<00:00, 60.04it/s]



Epoch 39 | Train Loss: 0.0000 | Val Loss: 0.3689 | Val Acc: 91.56% | EER: 8.44%


Epoch 40/40 (All Feats): 100%|██████████| 720/720 [00:12<00:00, 59.59it/s]



Epoch 40 | Train Loss: 0.0000 | Val Loss: 0.3745 | Val Acc: 91.48% | EER: 8.52%

Training plot saved to saved_models/training_metrics_AcousticProsodicAttention_23feat.png

--- Starting Final Testing and Analysis (All Features) ---


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



--- Running Cross-Attention Weight Analysis ---


Analyzing Attention: 100%|██████████| 1114/1114 [00:06<00:00, 168.14it/s]



Attention plot saved to saved_models/attention_importance_AcousticProsodicAttention_23feat.png

--- Running Feature Ablation Analysis ---
Baseline EER with all features: 12.11%


Performing Ablation: 100%|██████████| 23/23 [04:07<00:00, 10.76s/it]



Feature Importance based on EER Increase:
- slope0-500_sma3: EER increases by 0.72%
- F1amplitudeLogRelF0_sma3nz: EER increases by 0.21%
- spectralFlux_sma3: EER increases by 0.21%
- mfcc3_sma3: EER increases by 0.18%
- F3amplitudeLogRelF0_sma3nz: EER increases by 0.17%
- F2amplitudeLogRelF0_sma3nz: EER increases by 0.15%
- mfcc1_sma3: EER increases by 0.11%
- shimmerLocaldB_sma3nz: EER increases by 0.10%
- logRelF0-H1-A3_sma3nz: EER increases by 0.08%
- slope500-1500_sma3: EER increases by 0.08%
- F0semitoneFrom27.5Hz_sma3nz: EER increases by 0.05%
- HNRdBACF_sma3nz: EER increases by 0.02%
- mfcc4_sma3: EER increases by 0.02%
- F3frequency_sma3nz: EER increases by 0.00%
- F2frequency_sma3nz: EER increases by -0.01%
- jitterLocal_sma3nz: EER increases by -0.01%
- logRelF0-H1-H2_sma3nz: EER increases by -0.01%
- F1frequency_sma3nz: EER increases by -0.01%
- hammarbergIndex_sma3: EER increases by -0.05%
- Loudness_sma3: EER increases by -0.07%
- mfcc2_sma3: EER increases by -0.11%
- alp

  0%|          | 0/64 [00:00<?, ?it/s]

Error during testing/analysis: CUDA out of memory. Tried to allocate 492.00 MiB. GPU 0 has a total capacity of 10.57 GiB of which 378.06 MiB is free. Process 3599198 has 9.90 GiB memory in use. Including non-PyTorch memory, this process has 306.00 MiB memory in use. Of the allocated memory 58.53 MiB is allocated by PyTorch, and 47.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

--- Skipping Retraining with Top 6 Features due to previous error ---
