In [None]:
import pandas as pd
from pathlib import Path

import numpy as np
from speechbrain.utils.metric_stats import EER
import torch
import matplotlib.pyplot as plt

import sys
sys.path.append('../.')
from src.modules.metrics.metrics import SpeakerVerificationMetrics, VerificationMetrics

In [None]:
model_exp = Path('/home.local/aloradi/logs/train/runs/2025-02-02_22-12-58')
anon_model = 'B3'
eval_test = True

scores_csv_file = 'test_scores.csv' if eval_test else'valid_scores.csv'
data_df_file = 'test.csv' if eval_test else'dev.csv'

tmp = list(model_exp.rglob(f'*/{scores_csv_file}'))
assert len(tmp) == 1, f'Expected one file called test_scores.csv, found: {len(tmp)}'
scores_path = tmp[0]

df = pd.read_csv(f'{scores_path}')
df_test = pd.read_csv(f'{str(model_exp / f"vpc2025_artifacts/{data_df_file}")}', sep="|")

df['rel_filepath'] = df['audio_path'].str.extract(fr'({anon_model}.*)')
df = df.merge(df_test[['speaker_id', 'rel_filepath', 'gender', 'recording_duration', 'text']], on='rel_filepath', how='left')

metric = VerificationMetrics() #VerificationMetrics() SpeakerVerificationMetrics

In [None]:
# Compute EER for each enrollment speaker
speakers = df['enrollment_id'].unique()
genders = df['gender'].unique()
eer_results_speaker = {}
eer_results_gender = {}

metric_out_speaker = {}
speakers_figures = {}
metric_out_gender = {'male': [], 'female': []}
gender_figures =  {'male': [], 'female': []}


######## Speaker-level EER (take with a pinch of salt) #########
for speaker in speakers:
    metric.reset()
    speaker_data = df[df['speaker_id'] == speaker]
    scores = speaker_data['score'].values
    labels = speaker_data['label'].values
    
    metric.update(torch.from_numpy(scores), torch.from_numpy(labels))
    metrics = metric.compute()
    metric_out_speaker[speaker] = metrics
#     speakers_figures[speaker] = metric.plot_curves()
    
    eer_results_speaker[speaker.split('_')[-1]] = metrics['eer']

    
######## Gender-level EER (Official evaluation) #########
for gender in genders:
    metric.reset()
    
    speaker_data = df[df['gender'] == gender]
    scores = speaker_data['score'].values
    labels = speaker_data['label'].values
    
    metric.update(torch.from_numpy(scores), torch.from_numpy(labels))
    metrics = metric.compute()
#     gender_figures[gender] = metric.plot_curves()
    
    metric_out_gender[gender] = metrics
    eer_results_gender[gender] = metrics['eer']

In [None]:
metric_out_speaker = pd.DataFrame.from_records(metric_out_speaker).astype(float)
metric_out_gender = pd.DataFrame.from_records(metric_out_gender).astype(float)

In [None]:
metric_out_speaker.to_csv('per_speaker_results.csv')
metric_out_gender.to_csv('per_gender_results.csv')

In [None]:
def create_radar_plot(eer_results_speaker,
                      output_path,
                      title="Speaker-specific EER",
                      description="Equal Error Rate (EER) across different speakers"):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_axes([0.1, 0.1, 0.6, 0.8], polar=True)
    
    labels = list(eer_results_speaker.keys())
    values = list(eer_results_speaker.values())
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    
    values += values[:1]
    angles += angles[:1]
    
    ax.plot(angles, values, color='#FF6B6B', linewidth=2, marker='o', 
           markersize=8, label='EER Values')
    
    ax.grid(color='gray', alpha=0.2, linestyle='--', linewidth=1)
    ax.set_ylim(0, max(values) * 1.2)
    
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels, fontsize=10, fontweight='bold')
    
    for i in range(num_vars):
        angle_rad = angles[i]
        angle_deg = angle_rad * 180 / np.pi
        
        if angle_deg > 90 and angle_deg < 270:
            rotation = angle_deg + 180
        else:
            rotation = angle_deg
            
        ax.text(angle_rad, values[i] + max(values) * 0.1,
                f'{values[i]:.2f}',
                ha='center', va='center',
                rotation=rotation,
                fontsize=9,
                bbox=dict(facecolor='white', 
                         edgecolor='none',
                         alpha=0.8,
                         pad=2))
    
    plt.title(title, pad=20, fontsize=14, fontweight='bold')
    ax.legend(loc='center left', bbox_to_anchor=(1.1, 0.5))

    plt.savefig(output_path, 
                dpi=300, 
                bbox_inches='tight',
                pad_inches=0.5,
                format='png',
                transparent=True)
    plt.close()
    
    return fig, ax

In [None]:
create_radar_plot(eer_results_speaker, output_path='speaker_eer.png')