In [None]:
import os 
import sys
from pathlib import Path
from typing import Dict, Tuple, Literal
from IPython.display import display, HTML

import pandas as pd
import numpy as np
import seaborn as sns
import torch
from speechbrain.utils.metric_stats import EER
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm

sys.path.append('/home/' + os.getenv('USER') + '/adversarial-robustness-for-sr')

from src.modules.metrics.metrics import VerificationMetrics
from src.modules.metrics.metrics import AS_norm 

## Scores Manipulation

In [None]:
# Configuration
MODELS_PATH = Path("/dataHDD/ahmad/hpc_results/second_run")
EVAL_MODE = "SINGLE"  # Options: "SINGLE", "EVAL_ALL"
SINGLE_EXPERIMENT = "vpc_amm_cyclic-available-max_dur10-32"
EXPERIMENT_PATTERN = "vpc_amm_cyclic-*avail*-max_dur10-*32"  # Default pattern
EVAL_TEST = False  # Whether to evaluate test or validation data

# Function to get experiments to evaluate
def get_experiments_to_evaluate(mode=EVAL_MODE, 
                               single_exp=SINGLE_EXPERIMENT,
                               pattern=EXPERIMENT_PATTERN):
    """Get list of experiments to evaluate based on mode"""
    if mode == "SINGLE":
        return [single_exp]
    elif mode == "EVAL_ALL":
        # Find all directories matching the pattern
        matched_dirs = glob.glob(str(MODELS_PATH / pattern))
        # Extract just the experiment names
        experiments = [Path(d).name for d in matched_dirs]
        print(f"Found {len(experiments)} matching experiments")
        return experiments
    else:
        raise ValueError(f"Unknown EVAL_MODE: {mode}")


def get_df_path(experiment_name, eval_test, model_path_dir=MODELS_PATH):
    model_exp = model_path_dir / experiment_name
    
    # Define file paths based on whether we're evaluating test or validation data
    scores_csv_file = 'test_scores.csv' if eval_test else 'valid_best_scores.csv'
    dirname = 'test' if eval_test else 'valid'
    
    # Find scores file
    tmp = list(model_exp.rglob(f'{dirname}*/{scores_csv_file}'))
    if not tmp:
        raise ValueError(f"⚠️ No {scores_csv_file} found for {experiment_name}")
  
    assert len(tmp) == 1, f'Expected one file called {scores_csv_file}, found: {len(tmp)}'
    return tmp[0]


def process_experiment(experiment_name, eval_test=EVAL_TEST, model_path_dir=MODELS_PATH):
    """Process a single experiment and return the processed dataframe"""
    print(f"\nProcessing experiment: {experiment_name}")

    # Load scores df
    scores_path = get_df_path(experiment_name, eval_test, model_path_dir=MODELS_PATH)
    df = pd.read_csv(f'{scores_path}')

    # Load data df
    data_df_file = 'test.csv' if eval_test else 'dev.csv'
    df_test = pd.read_csv(f'{str(model_path_dir / experiment_name / f"vpc2025_artifacts/{data_df_file}")}', sep="|")
    
    # Process dataframes
    df['rel_filepath'] = df['audio_path'].apply(lambda x: x.split('vpc2025_official/')[-1])
    df = df.merge(df_test[['speaker_id', 'rel_filepath', 'gender', 'recording_duration', 'text']], 
                    on='rel_filepath', how='left')
    
    print(f"✓ Successfully processed data for {experiment_name}")
    return df


# Sort experiments by the anonymization model name
def extract_anon_model(experiment_name):
    """Extract the anonymization model name from the experiment pattern"""
    parts = experiment_name.split('-')
    if len(parts) < 3:
        return experiment_name  # Return original if pattern doesn't match
    return parts[1]  # The anonymization model is the second part


# Get experiments based on current mode
experiments = get_experiments_to_evaluate()
experiments = sorted(experiments, key=extract_anon_model)
print(f"Will evaluate: {experiments}")

In [None]:
def load_valid_embeddings(model_exp: str, eval_test: bool, device='cuda') -> Tuple[torch.Tensor, torch.Tensor]:
    assert eval_test is False, "This function is not implemented for test set"

    VALID_BEST_ENROL_EMBEDS = "valid_artifacts/valid_best_enrol_embeds.pt"
    VALID_BEST_EMBEDS = "valid_artifacts/valid_best_embeds.pt"
    
    # Path to test cohort embeddings 
    test_artifacts_dir = os.path.join(model_exp, "test_artifacts*")
    
    # Find the most recent test artifacts directory
    test_dirs = glob.glob(test_artifacts_dir)
    if not test_dirs:
        print("WARNING: No test cohort embeddings found")
        return None, None, None
        
    latest_test_dir = max(test_dirs, key=os.path.getctime)
    
    # Load validation embeddings
    valid_enrol_embeds = torch.load(os.path.join(model_exp, f'{VALID_BEST_ENROL_EMBEDS}'), map_location=torch.device(device))
    valid_embeds = torch.load(os.path.join(model_exp, f'{VALID_BEST_EMBEDS}'), map_location=torch.device(device))
    cohort_embeds = torch.load(os.path.join(latest_test_dir, "test_cohort_embeds.pt"),  map_location=torch.device(device))

    print(f"✓ Loaded embeddings for {model_exp}")
    return valid_enrol_embeds, valid_embeds, cohort_embeds

## Plotting Utils

In [None]:
def create_radar_plot(eer_results_speaker,
                      output_path,
                      title="Speaker-specific EER",
                      description="Equal Error Rate (EER) across different speakers"):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_axes([0.1, 0.1, 0.6, 0.8], polar=True)
    
    labels = list(eer_results_speaker.keys())
    values = list(eer_results_speaker.values())
    labels = list(eer_results_speaker.keys())
    values = list(eer_results_speaker.values())
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    
    values += values[:1]
    angles += angles[:1]
    
    ax.plot(angles, values, color='#FF6B6B', linewidth=2, marker='o', 
           markersize=8, label='EER Values')
    
    ax.grid(color='gray', alpha=0.2, linestyle='--', linewidth=1)
    ax.set_ylim(0, max(values) * 1.2)
    
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels, fontsize=10, fontweight='bold')
    
    for i in range(num_vars):
        angle_rad = angles[i]
        angle_deg = angle_rad * 180 / np.pi
        
        if angle_deg > 90 and angle_deg < 270:
            rotation = angle_deg + 180
        else:
            rotation = angle_deg
            
        ax.text(angle_rad, values[i] + max(values) * 0.1,
                f'{values[i]:.2f}',
                ha='center', va='center',
                rotation=rotation,
                fontsize=9,
                bbox=dict(facecolor='white', 
                         edgecolor='none',
                         alpha=0.8,
                         pad=2))
    
    plt.title(title, pad=20, fontsize=14, fontweight='bold')
    ax.legend(loc='center left', bbox_to_anchor=(1.1, 0.5))

    plt.savefig(output_path, 
                dpi=300, 
                bbox_inches='tight',
                pad_inches=0.5,
                format='png',
                transparent=True)
    plt.close()
    
    return fig, ax

## Compute gender/speaker metrics

In [None]:
def compute_gender_metrics(df: pd.DataFrame, scores_col: Literal['score', 'norm_score'] = 'score') -> Tuple[Dict, Dict, Dict]:
    """
    Compute verification metrics for each gender.
    
    Args:
        df: DataFrame with columns ['gender', score', 'label']
        scores_col: analyze raw scores or normalzied scores (score or norm_score) 
    
    Returns:
        gender_metrics: Dict of metrics per gender
        gender_curves: Dict of curve data per gender
        gender_eer: Dict of EER values per gender
    """
    gender_metrics = {}
    gender_curves = {}
    gender_eer = {}
    metric = VerificationMetrics()
    
    for gender in df['gender'].unique():
        metric.reset()
        gender_data = df[df['gender'] == gender]
        
        # Compute metrics
        metric.update(
            torch.from_numpy(gender_data[scores_col].values),
            torch.from_numpy(gender_data['label'].values)
        )
        metrics = metric.compute()
        
        # Store results
        gender_metrics[gender] = metrics
        gender_curves[gender] = {
            k: v.detach().cpu().numpy() 
            for k, v in metric._curve_data.items()
        }
        gender_eer[gender] = metrics['eer']
        
    return gender_metrics, gender_curves, gender_eer


def plot_gender_det_curves(gender_curves: Dict, gender_metrics: Dict, eps: float = 1e-8) -> plt.Figure:
    """
    Plot DET curves for multiple speakers.
    
    Args:
        gender_curves: Dict of curve data per speaker
        gender_metrics: Dict of metrics per speaker
        eps: Small value for numerical stability
    
    Returns:
        Matplotlib figure
    """
    plt.style.use('seaborn-v0_8-paper')
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Get sorted genders for consistent ordering
    genders = sorted(gender_curves.keys())
    
    # Define fixed colors for consistency
    GENDER_COLORS = {'female': '#FF69B4','male': '#4169E1'}
    
    # Plot each speaker's curve
    mean_eer = []
    for gender in genders:
        curves = gender_curves[gender]
        metrics = gender_metrics[gender]
        
        far = np.maximum(curves['far'], eps)
        frr = np.maximum(curves['frr'], eps)
        eer = metrics['eer']
        mean_eer.append(eer)
        
        ax.plot(far, frr, '-', color=GENDER_COLORS.get(gender, 'gray'), alpha=0.5, linewidth=1, label=f'{gender} (EER: {eer:.5f})')
    
    # Plot mean EER point for visual clarity
    mean_eer_value = np.mean(mean_eer)
    ax.plot(mean_eer_value, mean_eer_value, 'ko', markersize=8, label=f'Mean EER: {mean_eer_value:.4f}')
    
    # Plot diagonal
    ax.plot([eps, 1], [eps, 1], 'k--', alpha=0.3, linewidth=1)
    
    # Customize plot
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('False Acceptance Rate (FAR)')
    ax.set_ylabel('False Rejection Rate (FRR)')
    ax.set_title('Detection Error Tradeoff (DET) Curves by Gender')
    ax.grid(True, which='both', linestyle='--', alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fancybox=False, edgecolor='black')
    
    fig.tight_layout()
    return fig


def compute_speaker_metrics(df: pd.DataFrame) -> Tuple[Dict, Dict, Dict]:
    """
    Compute verification metrics for each speaker.
    
    Args:
        df: DataFrame with columns ['enrollment_id', 'speaker_id', 'score', 'label']
    
    Returns:
        speaker_metrics: Dict of metrics per speaker
        speaker_curves: Dict of curve data per speaker
        speaker_eer: Dict of EER values per speaker ID
    """
    speaker_metrics = {}
    speaker_curves = {}
    speaker_eer = {}
    metric = VerificationMetrics()
    
    for speaker in df['enrollment_id'].unique():
        metric.reset()
        speaker_data = df[df['speaker_id'] == speaker]
        
        # Compute metrics
        metric.update(
            torch.from_numpy(speaker_data['score'].values),
            torch.from_numpy(speaker_data['label'].values)
        )
        metrics = metric.compute()
        
        # Store results
        speaker_metrics[speaker] = metrics
        speaker_curves[speaker] = {
            k: v.detach().cpu().numpy() 
            for k, v in metric._curve_data.items()
        }
        speaker_eer[speaker.split('_')[-1]] = metrics['eer']
        
    return speaker_metrics, speaker_curves, speaker_eer


def plot_speaker_det_curves(speaker_curves: Dict, 
                            speaker_metrics: Dict,
                            eps: float = 1e-8,
                            max_speakers: int = None) -> plt.Figure:
    """
    Plot DET curves for multiple speakers.
    
    Args:
        speaker_curves: Dict of curve data per speaker
        speaker_metrics: Dict of metrics per speaker
        eps: Small value for numerical stability
        max_speakers: Maximum number of speakers to plot (None for all)
    
    Returns:
        Matplotlib figure
    """
    plt.style.use('seaborn-v0_8-paper')
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Select speakers to plot
    speakers = sorted(speaker_curves.keys())
    if max_speakers:
        speakers = speakers[:max_speakers]
    
    # Get color map for speakers
    colors = plt.cm.tab20(np.linspace(0, 1, len(speakers)))
    
    # Plot each speaker's curve
    mean_eer = []
    for speaker, color in zip(speakers, colors):
        curves = speaker_curves[speaker]
        metrics = speaker_metrics[speaker]
        
        far = np.maximum(curves['far'], eps)
        frr = np.maximum(curves['frr'], eps)
        eer = metrics['eer']
        mean_eer.append(eer)
        
        ax.plot(far, frr, '-', color=color, alpha=0.5, linewidth=1,
               label=f'Speaker {speaker.split("_")[-1]} (EER: {eer:.5f})')
    
    # Plot mean EER point for visual clarity
    mean_eer_value = np.mean(mean_eer)
    ax.plot(mean_eer_value, mean_eer_value, 'ko', markersize=8,
           label=f'Mean EER: {mean_eer_value:.5f}')
    
    # Plot diagonal
    ax.plot([eps, 1], [eps, 1], 'k--', alpha=0.3, linewidth=1)
    
    # Customize plot
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('False Acceptance Rate (FAR)')
    ax.set_ylabel('False Rejection Rate (FRR)')
    ax.set_title('Detection Error Tradeoff (DET) Curves by Speaker')
    ax.grid(True, which='both', linestyle='--', alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', 
             frameon=True, fancybox=False, edgecolor='black')
    
    fig.tight_layout()
    return fig

def analyze_speaker_results(df: pd.DataFrame, output_path: str, max_speakers: int = None):
    """
    Analyze and plot verification results from scores file.
    
    Args:
        df: CSV dataframe
        output_path: path to save the figure
        max_speakers: Maximum number of speakers to plot
    """
    # Compute metrics
    speaker_metrics, speaker_curves, speaker_eer = compute_speaker_metrics(df)
    
    # Plot DET curves
    det_fig = plot_speaker_det_curves(
        speaker_curves, 
        speaker_metrics,
        max_speakers=max_speakers
    )
    plt.savefig(output_path,
                dpi=300,
                bbox_inches='tight',
                pad_inches=0.5,
                format='png',
                transparent=True)
    plt.show()
    
    return det_fig, speaker_metrics, speaker_eer


def analyze_gender_results(df: pd.DataFrame, output_path: str, scores_col: Literal['score', 'norm_score']):
    """
    Analyze and plot verification results from scores file.
    
    Args:
        df: CSV dataframe
        output_path: path to save the figure
        max_speakers: Maximum number of speakers to plot
        scores_col: analyze raw scores or normalzied scores (score or norm_score)
    """
    # Compute metrics
    gender_metrics, gender_curves, gender_eer = compute_gender_metrics(df, scores_col=scores_col)
    
    # Plot DET curves
    det_fig = plot_gender_det_curves(gender_curves, gender_metrics)
    
    plt.savefig(output_path,
                dpi=300,
                bbox_inches='tight',
                pad_inches=0.5,
                format='png',
                transparent=True)
    plt.show()
    
    return det_fig, gender_metrics, gender_eer

## Plot Trial Scores  

In [None]:
def get_bin_count(data):
    """Calculate optimal number of bins (Freedman-Diaconis rule)"""
    iqr = np.percentile(data, 75) - np.percentile(data, 25)
    bin_width = 2 * iqr * len(data)**(-1/3)  # Freedman-Diaconis rule
    return int(np.ceil((data.max() - data.min()) / bin_width)) if bin_width > 0 else 30


def plot_score_distribution(df_test, score_schemes=None, figsize=(10, 6), save_dir=None, show_plot=True):
    """
    Plot the distribution of similarity scores for genuine and impostor trials.
    
    Parameters:
    -----------
    df_test : pandas.DataFrame
        DataFrame containing similarity scores and labels
    score_schemes : list, optional
        List of score column names to plot (default: ['score', 'norm_score'] or available columns)
    figsize : tuple, optional
        Figure size as (width, height) in inches (default: (10, 6))
    save_dir : str, optional
        Directory path to save plots. If None, plots won't be saved.
    show_plot : bool, optional
        Whether to display the plot interactively (default: True)
    """
    # Set publication-quality font settings
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'DejaVu Serif', 'Palatino', 'Computer Modern Roman'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'figure.titlesize': 18,
        'text.usetex': False,  # Set to True if you have LaTeX installed
        'axes.linewidth': 1.2,
        'xtick.major.width': 1.2,
        'ytick.major.width': 1.2,
        'xtick.major.size': 5,
        'ytick.major.size': 5
    })
    
    # Set default score schemes if not provided
    if score_schemes is None:
        default_schemes = ['score', 'norm_score']
        # Filter to only keep schemes that exist in the dataframe
        score_schemes = [scheme for scheme in default_schemes if scheme in df_test.columns]
        
        # If none of the default schemes exist, try to find any columns that might be scores
        if not score_schemes:
            # Look for columns that might contain scores (excluding 'label')
            potential_scores = [col for col in df_test.columns if col != 'label']
            if potential_scores:
                score_schemes = potential_scores[:2]  # Take at most 2 columns
                print(f"Using detected score columns: {score_schemes}")
            else:
                raise ValueError("No score columns found in the dataframe.")
    else:
        # Filter user-provided schemes to only keep those that exist in the dataframe
        valid_schemes = [scheme for scheme in score_schemes if scheme in df_test.columns]
        if not valid_schemes:
            raise ValueError(f"None of the provided score schemes {score_schemes} exist in the dataframe.")
        
        if len(valid_schemes) < len(score_schemes):
            print(f"Warning: Only using valid columns: {valid_schemes}. Skipping non-existent columns.")
        
        score_schemes = valid_schemes
    
    # Create save directory if it doesn't exist
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    for score_scheme in score_schemes:
        # Set a professional style with whitegrid for clarity
        sns.set_style('whitegrid', {
            'grid.linestyle': '--',
            'grid.alpha': 0.7,
            'axes.edgecolor': '0.2',
            'axes.grid': True
        })
        
        # Create figure with higher resolution for publication
        fig = plt.figure(figsize=figsize, dpi=300)
        ax = fig.add_subplot(111)
        
        # Extract scores by label
        neg_scores = df_test[df_test.label == 0][score_scheme]
        pos_scores = df_test[df_test.label == 1][score_scheme]
        
        # Get count of trials
        neg_count = len(neg_scores)
        pos_count = len(pos_scores)
        total_count = neg_count + pos_count
        
        # Calculate appropriate bin count with upper limit
        bins = max(get_bin_count(neg_scores), get_bin_count(pos_scores))
#         bins = min(bins, 100)  # Cap at 100 bins for visualization clarity
        
        # Calculate statistics once
        neg_mean = neg_scores.mean()
        pos_mean = pos_scores.mean()
        separation = pos_mean - neg_mean
        
        # Create normalized histograms with more professional colors
        ax.hist(neg_scores, bins=bins, alpha=0.7, density=True, 
                color='#D45E5E', edgecolor='#8B0000', linewidth=1.2, 
                label=f'Impostor Trials (n={neg_count:,})')
        ax.hist(pos_scores, bins=bins, alpha=0.7, density=True,
                color='#5E81D4', edgecolor='#00008B', linewidth=1.2,
                label=f'Genuine Trials (n={pos_count:,})')
        
        # Add vertical lines for means with publication-quality styling
        ax.axvline(neg_mean, color='#8B0000', linestyle='dashed', linewidth=2,
                  label=f'Impostor Mean: {neg_mean:.5f}')
        ax.axvline(pos_mean, color='#00008B', linestyle='dashed', linewidth=2,
                  label=f'Genuine Mean: {pos_mean:.5f}')
        
        # Add labels and title with improved formatting
        score_name = score_scheme.replace('_', ' ').title()
        ax.set_xlabel('Similarity Score', fontweight='bold')
        ax.set_ylabel('Normalized Frequency', fontweight='bold')
        ax.set_title(f'Distribution of Similarity Scores by Trial Type ({score_name})', 
                    fontweight='bold', pad=15)
        
        # Customize legend with cleaner appearance
        legend = ax.legend(frameon=True, fancybox=False, framealpha=0.95, 
                          edgecolor='0.2', loc='best')
        legend.get_frame().set_linewidth(1.0)
        
        # Add annotation about the separation and trial counts with cleaner styling
        info_text = f"Separation: {separation:.5f}\nTotal Trials: {total_count:,}\nImpostor: {neg_count:,} ({neg_count/total_count:.1%})\nGenuine: {pos_count:,} ({pos_count/total_count:.1%})"

        ax.annotate(info_text,
                   xy=(0.97, 0.97), xycoords='axes fraction', 
                   bbox=dict(boxstyle="round,pad=0.4", fc="#F8F8F8", ec="gray", 
                            alpha=0.95, linewidth=1.2),
                   ha='right', va='top', fontweight='normal',  # Changed from 'bold' to 'normal'
                   fontsize=plt.rcParams['legend.fontsize'])   # Match legend font size
        
        # Customize spines and ticks for publication quality
        for spine in ['top', 'right', 'bottom', 'left']:
            ax.spines[spine].set_linewidth(1.2)
        ax.tick_params(width=1.2, length=5, direction='out')
        
        plt.tight_layout()
        
        # Save the plot if directory is provided
        if save_dir:
            # Create filename from score scheme with high resolution
            filename = f"{score_scheme.replace(' ', '_')}_distribution.pdf"  # PDF for vector quality
            filepath = os.path.join(save_dir, filename)
            plt.savefig(filepath, dpi=600, bbox_inches='tight', format='pdf')
            
            # Also save a PNG version for easy viewing
            png_filepath = os.path.join(save_dir, f"{score_scheme.replace(' ', '_')}_distribution.png")
            plt.savefig(png_filepath, dpi=300, bbox_inches='tight')
            
            print(f"Saved plot to {filepath} and {png_filepath}")
        
        # Show the plot if requested
        if show_plot:
            plt.show()
        else:
            plt.close()

In [None]:
def eval_dev(df, cohort_embeddings, trial_embeds, enrol_embeds, device, batch_size=8):
    """Validation step for trials."""
    # iterate over batches 
    results = []
    metric = VerificationMetrics()
    # df = df.rename({'label': 'trial_label', 'enrollment_id': 'enroll_id'}, axis=1)
    df = df.rename({'enrollment_id': 'enroll_id'}, axis=1)
    
    for i in tqdm(range(0, len(df), batch_size)):
        batch = df.iloc[i: i + batch_size]
        batch_dict, normalized_scores = _trials_eval_step(batch, cohort_embeddings, trial_embeds, enrol_embeds, device='cuda')
        metric.update(scores=normalized_scores, labels=torch.tensor(batch.label.tolist()))
        results.append(batch_dict)
    
    metrics = metric.compute()

    scores = pd.DataFrame([
                {"enrollment_id": enroll_id,
                "audio_path": audio_path,
                "label": label,
                "score": score,
                "norm_score": norm_score,
                "model": model,                
                }
                for batch in results
                for enroll_id, audio_path, label, score, norm_score, model in zip(
                    batch["enrollment_id"],
                    batch["audio_path"],
                    batch["label"],
                    batch["score"],
                    batch["norm_score"],
                    batch["model"],
                )
            ])
    
    # Update scores DataFrame with computed metrics
    scores.loc[:, metrics.keys()] = [v.item() if torch.is_tensor(v) else v for v in metrics.values()]
    return scores
    

def _trials_eval_step(batch, cohort_embeddings, trial_embeds, enrol_embeds, device):
    trial_embeddings = torch.stack([trial_embeds[path] for path in batch.audio_path]).to(device)
    enroll_embeddings = torch.stack([enrol_embeds[model][enroll_id]
                                for model, enroll_id in zip(batch.model, batch.enroll_id)]).to(device)

    # Compute raw cosine similarity scores
    raw_scores = torch.nn.functional.cosine_similarity(enroll_embeddings, trial_embeddings)
    
    if cohort_embeddings is not None:
        normalized_scores = []
        for i, (enroll_emb, test_emb, model) in enumerate(zip(enroll_embeddings, trial_embeddings, batch.model)):
            
            # Get model-specific cohort embeddings
            model_cohort = cohort_embeddings.get(model)#.to(device)
            assert model_cohort is not None, f"No cohort embeddings found for model {model}"
            if isinstance(model_cohort, dict):
                model_cohort = torch.stack(list(model_cohort.values()))
            if model_cohort.ndim != 2:
                raise ValueError(f"Invalid cohort embeddings shape for model {model}: {model_cohort.shape}")
            
            # Apply AS-Norm
            norm_score = AS_norm(score=raw_scores[i],
                                 enroll_embedding=enroll_emb,
                                 test_embedding=test_emb, 
                                 cohort_embeddings=model_cohort,
                                 **{'topk': 10000, 'min_cohort_size': 3000})
            normalized_scores.append(norm_score)
        
        # Convert back to tensor
        normalized_scores = torch.tensor(normalized_scores, device=raw_scores.device)
    
    else:
        normalized_scores = raw_scores.clone()
    
    batch_dict = {
        "enrollment_id": batch.enroll_id,
        "audio_path": batch.audio_path,
        "label": batch.label,
        "score": raw_scores.detach().cpu().tolist(),
        "norm_score": normalized_scores.detach().cpu().tolist(),
        "model": batch.model,
    }

    return batch_dict, normalized_scores

In [None]:
def print_analysis_header(anon_model: str):
    header = f"""
    <div style="background-color: #f0f2f6; padding: 10px; border-radius: 5px; margin: 10px 0;">
        <h3 style="color: #2c3e50; margin: 0; text-align: center;">
            Analyzing Anonymization System: {anon_model}
        </h3>
    </div>
    """
    display(HTML(header))

def print_analysis_results(speaker_eer, gender_eer):
    results = f"""
    <div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin: 10px 0;">
        <h4 style="color: #2c3e50; margin: 0 0 10px 0;">Results Summary:</h4>
        <p style="margin: 5px 0; color: #34495e;">
            <b>Speaker Results:</b><br>
            Mean EER: {np.mean(list(speaker_eer.values())):.4f} ± {np.std(list(speaker_eer.values())):.4f}
        </p>
        <p style="margin: 5px 0; color: #34495e;">
            <b>Gender Results:</b><br>
            Mean EER: {np.mean(list(gender_eer.values())):.4f} ± {np.std(list(gender_eer.values())):.4f}
        </p>
    </div>
    """
    display(HTML(results))

### Analysis

In [None]:
device = 'cuda'
results = {}

for experiment in experiments:

    df = process_experiment(experiment, eval_test=EVAL_TEST, model_path_dir=MODELS_PATH)

    #################### Normalize scores for validation set ########################
    if not EVAL_TEST and all(df.score == df.norm_score):
        valid_enrol_embeds, valid_embeds, cohort_embeds = load_valid_embeddings(
            model_exp=MODELS_PATH / experiment, eval_test=EVAL_TEST, device=device)
        
        # Perform normalization evaluation on the validation set
        df = eval_dev(df, cohort_embeds, valid_embeds, valid_enrol_embeds, device=device)
        scores_path = get_df_path(experiment, eval_test=EVAL_TEST, model_path_dir=MODELS_PATH)
        df.to_csv(scores_path, index=False)

        # merge with dev.csv for metdata
        df_test = pd.read_csv(f'{str(MODELS_PATH / experiment / f"vpc2025_artifacts/dev.csv")}', sep="|")
        df['rel_filepath'] = df['audio_path'].apply(lambda x: x.split('vpc2025_official/')[-1])
        df = df.merge(df_test[['speaker_id', 'rel_filepath', 'gender', 'recording_duration', 'text']], on='rel_filepath', how='left')
        ##############################################################################
        
    results[experiment] = df
    tot_speaker_metrics = {}
    tot_gender_metrics = {}
    data_split = 'test' if EVAL_TEST else 'dev'
    results_dir = f'results/{experiment}/{data_split}'
    os.makedirs(results_dir, exist_ok=True)
     
    for anon_model in df['model'].unique():
        df_test = df[df['model'] == anon_model]
        
        print_analysis_header(anon_model)

        # Plot scores histogram
        plot_score_distribution(df_test, figsize=(12, 8), save_dir=f'{results_dir}/{anon_model}', show_plot=True)
        
        fig_speaker, speaker_metrics, speaker_eer = analyze_speaker_results(
            df_test, output_path=f'{results_dir}/{anon_model}/speaker_DET_{anon_model}.png', max_speakers=None)

        scores_cols = ['score'] if all(df.score == df.norm_score) else ['score', 'norm_score']
        for scores_col in scores_cols:
            norm = 'Normalized' if scores_col == 'norm_score' else 'Raw'
            fig_gender, gender_metrics, gender_eer = analyze_gender_results(
                df_test, output_path=f'{results_dir}/{anon_model}/gender_DET_{anon_model}_{norm}Scores.png', scores_col=scores_col)
        
        print_analysis_results(speaker_eer, gender_eer)
        
        # Save results as csvs
        tot_speaker_metrics[anon_model] = pd.DataFrame.from_records(speaker_metrics).astype(float)
        tot_gender_metrics[anon_model] = pd.DataFrame.from_records(gender_metrics).astype(float)
        tot_speaker_metrics[anon_model].to_csv(f'{results_dir}/{anon_model}/per_speaker_results_{anon_model}.csv')
        tot_gender_metrics[anon_model].to_csv(f'{results_dir}/{anon_model}/per_gender_results_{anon_model}.csv')

        # Create radar plot
        create_radar_plot(speaker_eer, output_path=f'{results_dir}/{anon_model}/speaker_eer_{anon_model}.png')

In [None]:
# Compare results if in EVAL_ALL mode
if EVAL_MODE == "EVAL_ALL" and len(results) > 1:
    print("\n===== Comparison of Results =====")
    for exp, df in results.items():
        # Example metrics for comparison
        eer = df.get('eer', [0])[0] if 'eer' in df else "N/A"  # Adjust based on your actual metrics
        print(f"{exp}: EER={eer}")
