In [133]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
import wandb
from tqdm import tqdm
import os
import eval_metrics as em
from sklearn.metrics import classification_report
from models_v2 import LSTM_branch, FFN_branch, CNN_branch, SpoofEnsemble, LSTM_FFN_classifer, LSTM_classifier, FFN_classifier, CNN_classifer

### Configurations

#### get configs from the training run

In [134]:
api = wandb.Api()
#NOTE: remember to change
# LSTM: 4ljqj7sn
# FFN: hdtay323
# pitch: sv31dkzd
# hnr: wuahdjeu
# jitter: 24gbjh7s
# shimmer: 3y1j9rcw
run_id = "3y1j9rcw"
run_path = f"qianyue-university-of-stuttgart/teamlab_deepfake/runs/{run_id}"
train_run = api.run(run_path)
train_config = train_run.config

print(train_config)
print(type(train_config))

{'model': 'FFN_classifier', 'seeds': [0], 'epochs': 50, 'dataset': 'ASVSpoof19_LA', 'feature': 'shimmer', 'ffn_dims': [6, 32], 'scheduler': True, 'batch_size': 32, 'attack_type': 'all', 'cnn_padding': 1, 'conv_kernel': [3, 3], 'pool_kernel': [2, 2], 'cnn_channels': [1, 32, 64, 128], 'dropout_rate': 0.3, 'oversampling': True, 'bidirectional': True, 'learning_rate': 0.0005, 'loss_function': 'weighted_CE', 'lstm_n_layers': 1, 'lstm_input_dim': 2, 'lstm_hidden_dim': 64, 'scheduler_factor': 0.5, 'scheduler_patience': 4}
<class 'dict'>


#### config setting for the current run

In [135]:
run = wandb.init(
    project = "teamlab_deepfake",
    job_type = "evaluation",
    name = "EvaluationShimmer_12",     #NOTE: EvaluationLSTM_12/...
    notes = None,
    config = {
            # for testing
            "test_data": "test", #NOTE: dev/test
            # general
            "model": train_config.get('model'),
            "dataset": train_config.get('dataset'),
            "feature": train_config.get('feature'),
            "attack_type": train_config.get('attack_type'),
            "loss_function": train_config.get('loss_function'),
            "scheduler": train_config.get('scheduler', False),
            "scheduler_factor": train_config.get('scheduler_factor', 0.5),
            "scheduler_patience": train_config.get('scheduler_patience', 4),
            "epochs": train_config.get('epochs'),
            "batch_size": train_config.get('batch_size'),
            "oversampling": train_config.get('oversampling'),
            "learning_rate": train_config.get('learning_rate'),
            "dropout_rate": train_config.get('dropout_rate'),
            # lstm layer
            "lstm_input_dim": train_config.get('lstm_input_dim'),
            "lstm_hidden_dim": train_config.get('lstm_hidden_dim'),
            "bidirectional": train_config.get('bidirectional'),
            "lstm_n_layers": train_config.get('lstm_n_layers'),
            # ffn layer
            "ffn_dims": train_config.get('ffn_dims'),
            # cnn_layer
            "cnn_channels": train_config.get('cnn_channels'),
            "conv_kernel": train_config.get('conv_kernel'),
            "pool_kernel": train_config.get('pool_kernel'),
            "cnn_padding": train_config.get('cnn_padding')
    },
)

config=run.config
            

print(config)

{'test_data': 'test', 'model': 'FFN_classifier', 'dataset': 'ASVSpoof19_LA', 'feature': 'shimmer', 'attack_type': 'all', 'loss_function': 'weighted_CE', 'scheduler': True, 'scheduler_factor': 0.5, 'scheduler_patience': 4, 'epochs': 50, 'batch_size': 32, 'oversampling': True, 'learning_rate': 0.0005, 'dropout_rate': 0.3, 'lstm_input_dim': 2, 'lstm_hidden_dim': 64, 'bidirectional': True, 'lstm_n_layers': 1, 'ffn_dims': [6, 32], 'cnn_channels': [1, 32, 64, 128], 'conv_kernel': [3, 3], 'pool_kernel': [2, 2], 'cnn_padding': 1}


### Load data

In [136]:
PITCH_COLUMN = 'PITCH'
HNR_COLUMN = 'HNR'
JITTER_COLUMN = 'JITTER'
SHIMMER_COLUMN = 'SHIMMER'
MFCC_COLUMN = 'MFCC'
LABEL_COLUMN = 'LABEL'
ATTACK_TYPE_COLUMN = 'ATTACK_TYPE'
AUDIO_ID_COLUMN = 'AUDIO_ID'
                           
NAN_REPLACEMENT_VALUE = 0.0  
PADDING_VALUE = 0.0         
LABEL_BONAFIDE = 1
LABEL_SPOOF = 0

if config.test_data == "test":
    test_features_path = '/home/users1/liqe/TeamLab_phonetics/merged_eval_com.pkl'   #NOTE: tbc
elif config.test_data == "dev":
    test_features_path = '/home/users1/liqe/TeamLab_phonetics/merged_dev_com.pkl'
else:
    print("WARNING: invalid test data.")
df_test = pd.read_pickle(test_features_path)

#### Dataprocessing

In [137]:
# Process dataset
# If stored in a seperate .py -> need to inherit from the training code + attack type&id
class ASVDataset(Dataset):
    def __init__(self, dataframe, pitch_col, hnr_col, jitter_col, shimmer_col, mfcc_col, label_col, 
                 attack_type_col, audio_id_col, nan_replacement=NAN_REPLACEMENT_VALUE):
        
        self.labels = []
        self.attack_type = []
        self.audio_id = []
        self.processed_pitchhnr = []
        self.global_features = []
        self.processed_mfcc = []
        
        print(f"Attempting to process {len(dataframe)} entries from DataFrame")
        found_count = 0
        # Iterate through the DataFrame, process and pad the features
        for index, row in dataframe.iterrows():  
            if not np.isnan(row[label_col]):
                self.labels.append(row[label_col])
                self.attack_type.append(row[attack_type_col])
                self.audio_id.append(row[audio_id_col])

                pitch_sequence_raw = row[pitch_col]
                processed_pitch = np.nan_to_num(pitch_sequence_raw, nan=nan_replacement)
                
                hnr_sequence_raw = row[hnr_col]
                processed_hnr = np.nan_to_num(hnr_sequence_raw, nan=nan_replacement)

                ### NOTE:need to pad the two sequences to the same length
                max_length = max(len(processed_pitch), len(processed_hnr))
                if len(processed_pitch) > len(processed_hnr):
                    padding = np.zeros(max_length - len(processed_hnr), dtype=processed_hnr.dtype)
                    processed_hnr = np.concatenate((processed_hnr, padding))
                else:
                    padding = np.zeros(max_length - len(processed_pitch), dtype=processed_pitch.dtype)
                    processed_pitch = np.concatenate((processed_pitch, padding))

                combined_features = np.stack((processed_pitch, processed_hnr), axis=-1)
                
                if config.feature == "pitch":
                    self.processed_pitchhnr.append(torch.tensor(processed_pitch, dtype=torch.float32))
                elif config.feature == "hnr":
                    self.processed_pitchhnr.append(torch.tensor(processed_hnr, dtype=torch.float32))
                else:
                    self.processed_pitchhnr.append(torch.tensor(combined_features, dtype=torch.float32))


                # process and combine jitter and shimmer to one sequence
                processed_jitter = np.nan_to_num(row[jitter_col], nan=nan_replacement)
                processed_shimmer = np.nan_to_num(row[shimmer_col], nan=nan_replacement)
                jitter_shimmer = np.concatenate((processed_jitter, processed_shimmer))
                if config.feature == "jitter":
                    self.global_features.append(torch.tensor(processed_jitter, dtype=torch.float32))
                elif config.feature == "shimmer":
                    self.global_features.append(torch.tensor(processed_shimmer, dtype=torch.float32))
                else:
                    jitter_shimmer = np.concatenate((processed_jitter, processed_shimmer))
                    self.global_features.append(torch.tensor(jitter_shimmer, dtype=torch.float32))
                
                # process mfcc
                mfcc = row[mfcc_col]
                # NOTE: need transpose for padding (time, feature_dim)
                self.processed_mfcc.append(torch.tensor(mfcc, dtype=torch.float32).T)

                found_count += 1
        
        self.labels = torch.tensor(self.labels, dtype=torch.long)
        print(f"Successfully processed {found_count} samples out of {len(dataframe)} DataFrame entries.")


    def __len__(self):
        """Returns the total number of matched samples in the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Returns one sample from the dataset: a preprocessed pitch sequence and its label.
        """
        label = self.labels[idx]
        audio_id = self.audio_id[idx]
        attack_type = self.attack_type[idx]
        pitch_hnr = self.processed_pitchhnr[idx]
        global_feature = self.global_features[idx]
        mfcc = self.processed_mfcc[idx]
        return label, pitch_hnr, global_feature, mfcc, audio_id, attack_type

In [138]:
def collate_fn(batch, padding_value=PADDING_VALUE):
    """
    Pads sequences within a batch to the same length.
    """
    labels = [item[0] for item in batch]
    pitch_hnrs = [item[1] for item in batch]
    global_features = [item[2] for item in batch]
    mfccs = [item[3] for item in batch]
    audio_ids = [item[4] for item in batch]
    attack_types = [item[5] for item in batch]

    labels = torch.stack(labels)

    pitchhnr_lengths = torch.tensor([len(seq) for seq in pitch_hnrs], dtype=torch.long)
    padded_pitchhnrs = pad_sequence(pitch_hnrs, batch_first=True, padding_value=padding_value)
    if padded_pitchhnrs.ndim == 2:     # lstm expects: [batch_size, sequence_length, feature_size]
        padded_pitchhnrs = padded_pitchhnrs.unsqueeze(2)

    global_features = torch.stack(global_features)

    padded_mfccs = pad_sequence(mfccs, batch_first=True, padding_value=padding_value)

    return labels, audio_ids, attack_types, pitchhnr_lengths, padded_pitchhnrs, global_features, padded_mfccs

#### Dataloader

In [139]:
pitch_dataset_test = ASVDataset(dataframe=df_test,  
                                    pitch_col=PITCH_COLUMN,
                                    hnr_col=HNR_COLUMN,
                                    jitter_col=JITTER_COLUMN,
                                    shimmer_col=SHIMMER_COLUMN,
                                    mfcc_col=MFCC_COLUMN,
                                    label_col=LABEL_COLUMN,
                                    attack_type_col=ATTACK_TYPE_COLUMN,
                                    audio_id_col=AUDIO_ID_COLUMN,
                                    nan_replacement=NAN_REPLACEMENT_VALUE)

test_dataloader = DataLoader(
    pitch_dataset_test, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=8
)

## For inspection
for i, batch_data in enumerate(test_dataloader):
    # batch_data is a tuple
    batch_labels, batch_ids, batch_types, batch_lengths, batch_pitchhnr, batch_global, batch_mfcc = batch_data
    print(f"\n--- Batch {i+1} ---")
    print(f"  Labels (first 5): {batch_labels[:5]}")
    print(f"  IDs (first 5): {batch_ids[:5]}")
    print(f"  Types (first 5): {batch_types[:5]}")
    print(f"  Padded Sequences Shape: {batch_pitchhnr.shape}")
    print(f"  Original Lengths (first 5): {batch_lengths[:5]}")
    print(f"  Global Shape: {batch_global.shape}")
    print(f"  MFCC Shape: {batch_mfcc.shape}")
    

    if i == 0: # Break after the first batch for inspection
        break


Attempting to process 71237 entries from DataFrame
Successfully processed 71237 samples out of 71237 DataFrame entries.

--- Batch 1 ---
  Labels (first 5): tensor([0, 0, 0, 0, 0])
  IDs (first 5): ['LA_E_2169831', 'LA_E_9534923', 'LA_E_8832198', 'LA_E_6890294', 'LA_E_5780214']
  Types (first 5): ['A19', 'A14', 'A12', 'A17', 'A10']
  Padded Sequences Shape: torch.Size([32, 753, 2])
  Original Lengths (first 5): tensor([581, 147, 486, 398, 201])
  Global Shape: torch.Size([32, 6])
  MFCC Shape: torch.Size([32, 236, 60])


### Evaluation

In [140]:
def evaluate_and_explain(model, test_loader, device, num_examples=1):
    model.eval()
    results_data = []
    explain_data = []
    
    # --- 1. GATHER MODEL PREDICTIONS ---
    print("Gathering model predictions from the test set...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            batch_labels, batch_ids, batch_types, batch_lengths, batch_pitchhnrs, batch_globals, batch_mfccs = batch
            
            batch_labels = batch_labels.to(device)
            batch_mfccs = batch_mfccs.to(device)
            batch_pitchhnrs = batch_pitchhnrs.to(device)
            batch_globals = batch_globals.to(device)
            
            logits = model(batch_pitchhnrs, batch_lengths, batch_globals, batch_mfccs)

            scores = torch.softmax(logits, dim=1).cpu().numpy()


            for i in range(len(scores)):
                results_data.append({
                    'audio_id': batch_ids[i],
                    'attack_type': batch_types[i],
                    'label_true': batch_labels[i].item(),    # from tensor to scaler
                    'score': scores[i][LABEL_BONAFIDE],
                })
    

    results_df = pd.DataFrame(results_data)
    

    # --- 2. CALCULATE OVERALL EER (THE OPERATIONAL METRIC) ---
    print("\n--- Overall Performance ---")
    bonafide_rows = results_df[results_df['label_true'] == LABEL_BONAFIDE]
    spoof_rows = results_df[results_df['label_true'] == LABEL_SPOOF]
    
    scores_bonafide_np = bonafide_rows['score'].to_numpy()
    scores_spoof_overall_np = spoof_rows['score'].to_numpy()
    
    # This is your main, global EER and threshold
    overall_eer, overall_threshold = em.compute_eer(scores_bonafide_np, scores_spoof_overall_np)

    run.log({"eval_eer_overall": overall_eer, "eval_threshold_overall": overall_threshold})
    print(f"Overall EER: {overall_eer*100:.2f}% at threshold {overall_threshold:.4f}")

    # --- 3. CALCULATE PER-ATTACK PERFORMANCE (THE DETAILED DIAGNOSIS) ---
    print("\n--- Per-Attack Performance Analysis ---")
    attack_analysis_results = []
    unique_attacks = sorted(spoof_rows['attack_type'].unique())

    # Calculate the fixed False Rejection Rate at the global threshold
    frr_at_global_threshold = np.sum(scores_bonafide_np < overall_threshold) / len(scores_bonafide_np)
    run.log({"frr_at_global_threshold": frr_at_global_threshold})
    print(f"FRR at Global Threshold ({overall_threshold:.4f}): {frr_at_global_threshold*100:.2f}%")

    for attack_type in unique_attacks:
        current_attack_rows = spoof_rows[spoof_rows['attack_type'] == attack_type]
        scores_current_attack = current_attack_rows['score'].to_numpy()
        
        if len(scores_current_attack) == 0:
            continue

        # Analysis A: What is the BEST POSSIBLE EER for this attack?
        optimal_eer, optimal_threshold = em.compute_eer(scores_bonafide_np, scores_current_attack)
        
        # Analysis B: What is the ACTUAL error rate for this attack using the GLOBAL threshold?
        false_acceptances = np.sum(scores_current_attack >= overall_threshold)
        far_at_global_threshold = false_acceptances / len(scores_current_attack)
        
        # Store raw numeric values for logging and correct sorting in W&B UI
        attack_analysis_results.append({
            "attack_type": attack_type,
            "optimal_eer": optimal_eer,
            "optimal_threshold": optimal_threshold,
            "far_at_global_threshold": far_at_global_threshold,
            "num_examples": len(scores_current_attack)
        })
        # Use formatted strings only for the console printout
        print(f"  - {attack_type}: Optimal EER={optimal_eer*100:.2f}% | FAR @ Global Threshold={far_at_global_threshold*100:.2f}%")
        
        # Log both metrics to W&B for easier plotting
        run.log({
            f"eer_by_attack/{attack_type}": optimal_eer,
            f"far_at_global_threshold/{attack_type}": far_at_global_threshold
        })

    # Log the summary table of per-attack analysis
    per_attack_df = pd.DataFrame(attack_analysis_results)
    
    # Explicitly convert columns to a numeric type before logging to ensure correct sorting
    for col in ["optimal_eer", "optimal_threshold", "far_at_global_threshold"]:
        if col in per_attack_df.columns:
            per_attack_df[col] = pd.to_numeric(per_attack_df[col])

    run.log({"per_attack_analysis_table": wandb.Table(dataframe=per_attack_df)})
    
    # --- 2. LOG QUANTITATIVE RESULTS ---
    print("\nLogging overall quantitative metrics to W&B...")
    
    # Calculate predictions directly on the DataFrame to ensure correct alignment
    results_df['prediction'] = (results_df['score'] >= overall_threshold).astype(int)

    # Sort the DataFrame by score in descending order before logging
    results_df_sorted = results_df.sort_values(by='score', ascending=False)
    
    # Log the sorted detailed results table
    run.log({"test_results_table": wandb.Table(dataframe=results_df_sorted)})

    # Create the true and predicted labels for the confusion matrix
    labels_true = results_df['label_true'].to_numpy()
    labels_pred = results_df['prediction'].to_numpy()
    
    class_names = ['SPOOF', 'BONAFIDE']
    report_columns = ["Class", "Precision", "Recall", "F1-score", "Support"]
    class_report = classification_report(labels_true, labels_pred, labels=[0, 1], target_names=class_names).splitlines()
    report_table = []
    for line in class_report[2:(len(class_names)+2)]:
        report_table.append(line.split())
    
    run.log({
        "Confusion Matrix": wandb.plot.confusion_matrix(y_true=labels_true, preds=labels_pred, class_names=class_names),
        "Classification Report": wandb.Table(data=report_table, columns=report_columns)
    })



### Initiate the model

In [141]:
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    device_index = 4
    torch.cuda.set_device(device_index)
    DEVICE = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(DEVICE)}")
else:
    print("CUDA is not available. Using CPU.")
    DEVICE = torch.device('cpu')

CUDA available: True
CUDA device count: 9
Using CUDA device: NVIDIA GeForce RTX 2080 Ti


In [142]:
criterion = torch.nn.CrossEntropyLoss(reduction='mean')

In [143]:
def initiate_model():
    lstm_out= LSTM_branch(lstm_input_dim=config.lstm_input_dim, lstm_hidden_dim=config.lstm_hidden_dim, lstm_n_layers=config.lstm_n_layers, bidirectional=config.bidirectional).to(DEVICE)
    ffn_out= FFN_branch(ffn_dims=config.ffn_dims).to(DEVICE)
    cnn_out = CNN_branch(cnn_channels=config.cnn_channels, conv_kernel=config.conv_kernel, pool_kernel=config.pool_kernel, cnn_padding=config.cnn_padding).to(DEVICE)

#NOTE: remember to change the output dim
    if config.model=="SpoofEnsemble":
        model = SpoofEnsemble(lstm_branch=lstm_out, ffn_branch=ffn_out, cnn_branch=cnn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="LSTM_FFN_classifier":
        model = LSTM_FFN_classifer(lstm_out=lstm_out, ffn_out=ffn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="CNN_classifier":
        model = CNN_classifer(cnn_out=cnn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="LSTM_classifier":
        model = LSTM_classifier(lstm_hidden=lstm_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="FFN_classifier":
        model = FFN_classifier(ffn_hidden=ffn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    else:
        print("WARNING: invalid model name.")
    return model

### Setup and Run

In [144]:
model = initiate_model() 

# NOTE: remember to change
artifact = run.use_artifact(f'qianyue-university-of-stuttgart/teamlab_deepfake/{run_id}-best-model:v0', type='model')
artifact_dir = artifact.download()
model_path = os.path.join(artifact_dir, 'best_model')
# model_path = '/home/users1/liqe/TeamLab_phonetics/TeamLab/artifacts/h456ty5q-best-model:v0/best_model'

try:
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    print(f"Model loaded from {model_path}.")
except FileNotFoundError:
    print(f"WARNING: Model path not found at '{model_path}'. Using randomly initialized model.")

model.to(DEVICE)

# --- 3. Log Model as a W&B Artifact ---
if os.path.exists(model_path):
    model_artifact = wandb.Artifact(
        name=f"model-{run.id}", 
        type="model",
        description="Trained model checkpoint for spoof detection."
    )
    model_artifact.add_file(model_path)
    run.log_artifact(model_artifact)

evaluate_and_explain(model, test_dataloader, DEVICE)

run.finish()
print("\nEvaluation complete. Results logged to W&B.")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Model loaded from /home/users1/liqe/TeamLab_phonetics/TeamLab/artifacts/3y1j9rcw-best-model:v0/best_model.
Gathering model predictions from the test set...


Evaluating: 100%|██████████| 2227/2227 [00:11<00:00, 189.88it/s]



--- Overall Performance ---
Overall EER: 39.78% at threshold 0.6458

--- Per-Attack Performance Analysis ---
FRR at Global Threshold (0.6458): 39.78%
  - A07: Optimal EER=44.88% | FAR @ Global Threshold=47.80%
  - A08: Optimal EER=43.18% | FAR @ Global Threshold=46.28%
  - A09: Optimal EER=40.94% | FAR @ Global Threshold=43.06%
  - A10: Optimal EER=48.53% | FAR @ Global Threshold=58.51%
  - A11: Optimal EER=45.28% | FAR @ Global Threshold=51.36%
  - A12: Optimal EER=43.39% | FAR @ Global Threshold=46.97%
  - A13: Optimal EER=21.22% | FAR @ Global Threshold=6.25%
  - A14: Optimal EER=55.25% | FAR @ Global Threshold=69.37%
  - A15: Optimal EER=45.20% | FAR @ Global Threshold=51.67%
  - A16: Optimal EER=43.12% | FAR @ Global Threshold=46.78%
  - A17: Optimal EER=23.42% | FAR @ Global Threshold=11.09%
  - A18: Optimal EER=38.50% | FAR @ Global Threshold=36.94%
  - A19: Optimal EER=7.73% | FAR @ Global Threshold=1.12%

Logging overall quantitative metrics to W&B...


0,1
eer_by_attack/A07,▁
eer_by_attack/A08,▁
eer_by_attack/A09,▁
eer_by_attack/A10,▁
eer_by_attack/A11,▁
eer_by_attack/A12,▁
eer_by_attack/A13,▁
eer_by_attack/A14,▁
eer_by_attack/A15,▁
eer_by_attack/A16,▁

0,1
eer_by_attack/A07,0.44876
eer_by_attack/A08,0.43182
eer_by_attack/A09,0.40941
eer_by_attack/A10,0.4853
eer_by_attack/A11,0.45284
eer_by_attack/A12,0.43386
eer_by_attack/A13,0.21224
eer_by_attack/A14,0.55253
eer_by_attack/A15,0.45202
eer_by_attack/A16,0.43124



Evaluation complete. Results logged to W&B.
