In [None]:
import os 
import sys
from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd
import numpy as np
import torch
from speechbrain.utils.metric_stats import EER
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy

sys.path.append('../.')
from src.modules.metrics.metrics import VerificationMetrics, AS_norm
from src.datamodules.components.vpc25.vpc_dataset import VPCTestCoallate, VPC25TestDataset

In [None]:
model_exp = Path('/home.local/aloradi/logs/train/runs/2025-02-02_22-12-58')
anon_model = 'B3'
data_df_file = 'test.csv'
df_test = pd.read_csv(f'{str(model_exp / f"vpc2025_artifacts/{data_df_file}")}', sep="|")


# Load embeddings
trials_embeds = 'test_embeds.pt'
enrol_embeds = 'enrol_embeds.pt'
cohort_embeds = 'valid_embeds.pt'

test_path_tmp = list(model_exp.rglob(f'*/{trials_embeds}'))
assert len(test_path_tmp) == 1, f'Expected one file called test_embeds.pt, found: {len(test_path_tmp)}'
trials_path = test_path_tmp[0]

enrol_tmp_tmp = list(model_exp.rglob(f'*/{enrol_embeds}'))
assert len(enrol_tmp_tmp) == 1, f'Expected one file called test_embeds.pt, found: {len(enrol_tmp_tmp)}'
enrol_path = enrol_tmp_tmp[0]

cohort_path = list(model_exp.rglob(f'*/{cohort_embeds}'))[0]


test_embeds = torch.load(str(trials_path))
enrol_embeds = torch.load(str(enrol_path))

cohort_embeddings = torch.load(str(cohort_path))
cohort_embeddings = torch.stack(list(cohort_embeddings.values()))

In [None]:
# load test data
test_trials_path = 'test_trials.csv'
test_trials_tmp = list(model_exp.rglob(f'*/{test_trials_path}'))
assert len(test_trials_tmp) == 1, f'Expected one file called test_embeds.pt, found: {len(test_path_tmp)}'
test_trials_csv = test_trials_tmp[0]

test_data = VPC25TestDataset(data_dir='/home/aloradi/adversarial-robustness-for-sr/data/vpc2025_official',
                             test_trials_path=test_trials_csv,
                             sample_rate=16000,
                             max_duration=None)

test_loader = DataLoader(test_data,
                         shuffle=False,
                         batch_size=16,
                         collate_fn=VPCTestCoallate())

In [None]:
metric = VerificationMetrics()
trials_results = []

total_batches = len(test_loader)
with tqdm(total=total_batches, desc="Test trials") as pbar:
    for batch in test_loader:
        trial_embeddings = torch.stack([test_embeds[path] for path in batch.audio_path])
        enroll_embeddings = torch.stack([enrol_embeds[model][enroll_id]
                                    for model, enroll_id in zip(batch.model, batch.enroll_id)])
        
        raw_scores = torch.nn.functional.cosine_similarity(enroll_embeddings, trial_embeddings)

        if cohort_embeddings is not None:
            normalized_scores = []
            for i, (enroll_emb, test_emb, model) in enumerate(zip(enroll_embeddings, trial_embeddings, batch.model)):
                raw_score = raw_scores[i]
                                
                # Apply AS-Norm
                norm_score = AS_norm(score=raw_score, 
                                     enroll_embedding=enroll_emb, 
                                     test_embedding=test_emb, 
                                     cohort_embeddings=cohort_embeddings, 
                                     topk=300)
                normalized_scores.append(norm_score)
            
            # Convert back to tensor
            normalized_scores = torch.tensor(normalized_scores, device=raw_scores.device)
        
        else:
            normalized_scores = raw_scores.clone()        

        # Update metric with normalized scores
        metric.update(scores=normalized_scores, labels=torch.tensor(batch.trial_label))
        
        batch_dict = {
            "enrollment_id": batch.enroll_id,
            "audio_path": batch.audio_path,
            "label": batch.trial_label,
            "score": normalized_scores.tolist(),
            "model": batch.model,
        }
        trials_results.append(batch_dict)
        
        # Update progress bar
        pbar.update(1)

In [None]:
scores = pd.DataFrame([
    {
        "enrollment_id": enroll_id,
        "audio_path": audio_path,
        "label": label,
        "score": score,
        "model": model,                
    }
    for batch in trials_results
    for enroll_id, audio_path, label, score, model in zip(
        batch["enrollment_id"],
        batch["audio_path"],
        batch["label"],
        batch["score"],
        batch["model"],
    )
])

In [None]:
df = deepcopy(scores)
df['rel_filepath'] = df['audio_path'].str.extract(fr'({anon_model}.*)')
df = df.merge(df_test[['speaker_id', 'rel_filepath', 'gender', 'recording_duration', 'text']], on='rel_filepath', how='left')

# Compute EER for each enrollment speaker
speakers = df['enrollment_id'].unique()
genders = df['gender'].unique()
eer_results_gender = {}

metric_out_gender = {'male': [], 'female': []}
gender_figures =  {'male': [], 'female': []}

    
######## Gender-level EER (Official evaluation) #########
for gender in genders:
    metric.reset()
    
    speaker_data = df[df['gender'] == gender]
    scores = speaker_data['score'].values
    labels = speaker_data['label'].values
    
    metric.update(torch.from_numpy(scores), torch.from_numpy(labels))
    metrics = metric.compute()
    gender_figures[gender] = metric.plot_curves()
    
    metric_out_gender[gender] = metrics
    eer_results_gender[gender] = metrics['eer']