In [None]:
import os 
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Literal
from IPython.display import display, HTML

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib as mpl
import torch
from speechbrain.utils.metric_stats import EER
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy

sys.path.append('/home/ahmad/adversarial-robustness-for-sr')
from src.modules.metrics.metrics import VerificationMetrics

## Scores Manipulation

In [None]:
# model_filename = "vpc_amm_29March" # vpc_amm_29March last_hidden_exp_comprehensive 2025-02-02_22-12-58, 2025-02-09_01-18-15 2025-02-12_14-13-10 2025-02-15_22-03-15 2025-02-21_17-50-04
# MODELS_PATH = '/home.local/aloradi/logs/train/runs'
# MODELS_PATH = "/home/ahmad/adversarial-robustness-for-sr/logs/train/runs"

model_filename = "sv_baseline-voxceleb-max_dur3.0-bs32" #"vpc_amm_cyclic-T10-2-max_dur10-bs32" "vpc_norm_enh-available_models-max_dur10-64", 'vpc_ce-available_models-max_dur10-64', 'vpc_norm_enh-available_models-max_dur10-64', 'vpc_amm-available_models-max_dur10-64', 'vpc_amm-available_models-max_dur10-with_contrastive_loss-64']
MODELS_PATH = "/dataHDD/ahmad/hpc_results/second_run"

model_exp = Path(f'{MODELS_PATH}/{model_filename}')
eval_test = True

scores_csv_file = 'test_scores.csv' if eval_test else 'valid_best_scores.csv' # 'valid_best_scores.csv' 'best_valid_scores.csv'

artifacts_dir = 'voxceleb_artifacts' # 'vpc2025_artifacts' 'voxceleb_artifacts'

data_df_file = 'veri_test.csv' #'test.csv' if eval_test else'dev.csv'
dirname = 'test' if eval_test else 'valid'

tmp = list(model_exp.rglob(f'{dirname}*/{scores_csv_file}'))
assert len(tmp) == 1, f'Expected one file called test_scores.csv, found: {len(tmp)}'
scores_path = tmp[0]

df = pd.read_csv(f'{scores_path}')
df_test = pd.read_csv(f'{str(model_exp / f"{artifacts_dir}/{data_df_file}")}', sep="|")

df['enroll_path'] = df['enroll_path'].apply(lambda x: x.split('voxceleb/voxceleb1_2/')[-1])
df['test_path'] = df['test_path'].apply(lambda x: x.split('voxceleb/voxceleb1_2/')[-1])

## Plot Trial Scores  

In [None]:
def get_bin_count(data):
    """Calculate optimal number of bins (Freedman-Diaconis rule)"""
    iqr = np.percentile(data, 75) - np.percentile(data, 25)
    bin_width = 2 * iqr * len(data)**(-1/3)  # Freedman-Diaconis rule
    return int(np.ceil((data.max() - data.min()) / bin_width)) if bin_width > 0 else 30


def plot_score_distribution(df_test, score_schemes=None, figsize=(10, 6), save_dir=None, show_plot=True):
    """
    Plot the distribution of similarity scores for genuine and impostor trials.
    
    Parameters:
    -----------
    df_test : pandas.DataFrame
        DataFrame containing similarity scores and labels
    score_schemes : list, optional
        List of score column names to plot (default: ['score', 'norm_score'] or available columns)
    figsize : tuple, optional
        Figure size as (width, height) in inches (default: (10, 6))
    save_dir : str, optional
        Directory path to save plots. If None, plots won't be saved.
    show_plot : bool, optional
        Whether to display the plot interactively (default: True)
    """
    # Set publication-quality font settings
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'DejaVu Serif', 'Palatino', 'Computer Modern Roman'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'figure.titlesize': 18,
        'text.usetex': False,  # Set to True if you have LaTeX installed
        'axes.linewidth': 1.2,
        'xtick.major.width': 1.2,
        'ytick.major.width': 1.2,
        'xtick.major.size': 5,
        'ytick.major.size': 5
    })
    
    # Set default score schemes if not provided
    if score_schemes is None:
        default_schemes = ['score', 'norm_score']
        # Filter to only keep schemes that exist in the dataframe
        score_schemes = [scheme for scheme in default_schemes if scheme in df_test.columns]
        
        # If none of the default schemes exist, try to find any columns that might be scores
        if not score_schemes:
            # Look for columns that might contain scores (excluding 'label')
            potential_scores = [col for col in df_test.columns if col != 'label']
            if potential_scores:
                score_schemes = potential_scores[:2]  # Take at most 2 columns
                print(f"Using detected score columns: {score_schemes}")
            else:
                raise ValueError("No score columns found in the dataframe.")
    else:
        # Filter user-provided schemes to only keep those that exist in the dataframe
        valid_schemes = [scheme for scheme in score_schemes if scheme in df_test.columns]
        if not valid_schemes:
            raise ValueError(f"None of the provided score schemes {score_schemes} exist in the dataframe.")
        
        if len(valid_schemes) < len(score_schemes):
            print(f"Warning: Only using valid columns: {valid_schemes}. Skipping non-existent columns.")
        
        score_schemes = valid_schemes
    
    # Create save directory if it doesn't exist
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    for score_scheme in score_schemes:
        # Set a professional style with whitegrid for clarity
        sns.set_style('whitegrid', {
            'grid.linestyle': '--',
            'grid.alpha': 0.7,
            'axes.edgecolor': '0.2',
            'axes.grid': True
        })
        
        # Create figure with higher resolution for publication
        fig = plt.figure(figsize=figsize, dpi=300)
        ax = fig.add_subplot(111)
        
        # Extract scores by label
        neg_scores = df_test[df_test.trial_label == 0][score_scheme]
        pos_scores = df_test[df_test.trial_label == 1][score_scheme]
        
        # Get count of trials
        neg_count = len(neg_scores)
        pos_count = len(pos_scores)
        total_count = neg_count + pos_count
        
        # Calculate appropriate bin count with upper limit
        bins = max(get_bin_count(neg_scores), get_bin_count(pos_scores))
#         bins = min(bins, 100)  # Cap at 100 bins for visualization clarity
        
        # Calculate statistics once
        neg_mean = neg_scores.mean()
        pos_mean = pos_scores.mean()
        separation = pos_mean - neg_mean
        
        # Create normalized histograms with more professional colors
        ax.hist(neg_scores, bins=bins, alpha=0.7, density=True, 
                color='#D45E5E', edgecolor='#8B0000', linewidth=1.2, 
                label=f'Impostor Trials (n={neg_count:,})')
        ax.hist(pos_scores, bins=bins, alpha=0.7, density=True,
                color='#5E81D4', edgecolor='#00008B', linewidth=1.2,
                label=f'Genuine Trials (n={pos_count:,})')
        
        # Add vertical lines for means with publication-quality styling
        ax.axvline(neg_mean, color='#8B0000', linestyle='dashed', linewidth=2,
                  label=f'Impostor Mean: {neg_mean:.3f}')
        ax.axvline(pos_mean, color='#00008B', linestyle='dashed', linewidth=2,
                  label=f'Genuine Mean: {pos_mean:.3f}')
        
        # Add labels and title with improved formatting
        score_name = score_scheme.replace('_', ' ').title()
        ax.set_xlabel('Similarity Score', fontweight='bold')
        ax.set_ylabel('Normalized Frequency', fontweight='bold')
        ax.set_title(f'Distribution of Similarity Scores by Trial Type ({score_name})', 
                    fontweight='bold', pad=15)
        
        # Customize legend with cleaner appearance
        legend = ax.legend(frameon=True, fancybox=False, framealpha=0.95, 
                          edgecolor='0.2', loc='best')
        legend.get_frame().set_linewidth(1.0)
        
        # Add annotation about the separation and trial counts with cleaner styling
        info_text = f"Separation: {separation:.3f}\nTotal Trials: {total_count:,}\nImpostor: {neg_count:,} ({neg_count/total_count:.1%})\nGenuine: {pos_count:,} ({pos_count/total_count:.1%})"

        ax.annotate(info_text,
                   xy=(0.97, 0.97), xycoords='axes fraction', 
                   bbox=dict(boxstyle="round,pad=0.4", fc="#F8F8F8", ec="gray", 
                            alpha=0.95, linewidth=1.2),
                   ha='right', va='top', fontweight='normal',  # Changed from 'bold' to 'normal'
                   fontsize=plt.rcParams['legend.fontsize'])   # Match legend font size
        
        # Customize spines and ticks for publication quality
        for spine in ['top', 'right', 'bottom', 'left']:
            ax.spines[spine].set_linewidth(1.2)
        ax.tick_params(width=1.2, length=5, direction='out')
        
        plt.tight_layout()
        
        # Save the plot if directory is provided
        if save_dir:
            # Create filename from score scheme with high resolution
            filename = f"{score_scheme.replace(' ', '_')}_distribution.pdf"  # PDF for vector quality
            filepath = os.path.join(save_dir, filename)
            plt.savefig(filepath, dpi=600, bbox_inches='tight', format='pdf')
            
            # Also save a PNG version for easy viewing
            png_filepath = os.path.join(save_dir, f"{score_scheme.replace(' ', '_')}_distribution.png")
            plt.savefig(png_filepath, dpi=300, bbox_inches='tight')
            
            print(f"Saved plot to {filepath} and {png_filepath}")
        
        # Show the plot if requested
        if show_plot:
            plt.show()
        else:
            plt.close()

In [None]:
def print_analysis_header(model_filename: str):
    header = f"""
    <div style="background-color: #f0f2f6; padding: 10px; border-radius: 5px; margin: 10px 0;">
        <h3 style="color: #2c3e50; margin: 0; text-align: center;">
            Experiment: {model_filename}
        </h3>
    </div>
    """
    display(HTML(header))

def print_analysis_results(metrics, caption=None):
    caption_html = f'<p style="text-align: center; font-style: italic; margin: 5px 0;">{caption}</p>' if caption else ''
    
    results = f"""
    <div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin: 10px 0;">
        <h4 style="color: #2c3e50; margin: 0 0 10px 0;">Results Summary:</h4>
        {caption_html}
        <div style="display: flex; justify-content: center; width: 100%;">
            <table style="width: 50%; border-collapse: collapse; margin-top: 8px;">
                <tr style="background-color: #34495e;">
                    <th style="padding: 6px; text-align: center; border: 1px solid #dee2e6; width: 50%; color: white;"><b>Metric</b></th>
                    <th style="padding: 6px; text-align: center; border: 1px solid #dee2e6; width: 50%; color: white;"><b>Value</b></th>
                </tr>
                <tr>
                    <td style="padding: 6px; text-align: center; border: 1px solid #dee2e6;"><b>EER</b></td>
                    <td style="padding: 6px; text-align: center; border: 1px solid #dee2e6;">{metrics['eer']:.3f}</td>
                </tr>
                <tr>
                    <td style="padding: 6px; text-align: center; border: 1px solid #dee2e6;"><b>MinDCF</b></td>
                    <td style="padding: 6px; text-align: center; border: 1px solid #dee2e6;">{metrics['minDCF']:.3f}</td>
                </tr>
                <tr>
                    <td style="padding: 6px; text-align: center; border: 1px solid #dee2e6;"><b>F_score</b></td>
                    <td style="padding: 6px; text-align: center; border: 1px solid #dee2e6;">{metrics['f_score']:.3f}</td>
                </tr>
            </table>
        </div>
    </div>
    """
    display(HTML(results))

def compute_metrics(df, scores_col='norm_score'):
    metric = VerificationMetrics()
    metric.reset()
    # Compute metrics
    metric.update(
        torch.from_numpy(df[scores_col].values),
        torch.from_numpy(df['trial_label'].values)
    )
    metrics = metric.compute()
    metric.plot_curves()
    
    return metrics, curves

### Analysis

In [None]:
tot_speaker_metrics = {}
tot_gender_metrics = {}
results_dir = f'results/{model_filename}'   # model_exp (maybe)
os.makedirs(results_dir, exist_ok=True)

print_analysis_header(model_filename)

# Plot scores histogram
plot_score_distribution(df, figsize=(12, 8), save_dir=f'{results_dir}/{header}', show_plot=True)

scores_cols = ['norm_score', 'score'] if not all(df.score == df.norm_score) else ['score']
for scores_col in scores_cols:
    metrics, curves = compute_metrics(df, scores_col=scores_col)
    caption = 'RAW Scores' if scores_col == 'score' else 'Normalized Scores'
    print_analysis_results(metrics, caption='Table: ' + caption)