In [1]:
import os
import torch
import numpy as np
import torchaudio
from tqdm import tqdm
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import pandas as pd


In [2]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "facebook/wav2vec2-large-xlsr-53"
trial_file = "/kaggle/input/speech-assn-2-task-1-data/veri_test2.txt"
wav_root = "/kaggle/input/speech-assn-2-task-1-data/vox1-20250326T135514Z-001/vox1/vox1_test_wav/wav"
sample_rate = 16000

In [3]:
# Load model
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2Model.from_pretrained(model_name).to(device)
model.eval()

preprocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2LayerNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (activation): GELUActivation()
      )
      (1-4): 4 x Wav2Vec2LayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2LayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=1024, bias=True)
    (dropout)

In [4]:
def load_audio(path):
    """Load and preprocess audio file without truncation/padding"""
    try:
        waveform, orig_sr = torchaudio.load(path)
        
        # Resample if needed
        if orig_sr != sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sr, sample_rate)
            waveform = resampler(waveform)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            
        return waveform.squeeze()  # Remove truncation/padding logic
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return None

In [5]:
def get_embedding(audio_path):
    """Extract speaker embedding from variable-length audio"""
    waveform = load_audio(audio_path)
    if waveform is None:
        return None
        
    waveform = waveform.to(device)
    
    # Let the feature extractor handle padding dynamically
    inputs = feature_extractor(
        waveform,
        sampling_rate=sample_rate,
        return_tensors="pt",
        padding="longest"  # Key change: pads to longest in batch
    ).input_values.to(device)
    
    with torch.no_grad():
        outputs = model(inputs)
    
    embeddings = torch.mean(outputs.last_hidden_state, dim=1)
    return embeddings.squeeze().cpu().numpy()

In [6]:
def read_trials(trial_path):
    """Read trial pairs from VoxCeleb1 trial file"""
    trials = []
    with open(trial_path) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 3:
                continue
                
            # Your trial format: label path1 path2
            label = int(parts[0])
            path1 = os.path.join(wav_root, parts[1])
            path2 = os.path.join(wav_root, parts[2])
            
            trials.append((path1, path2, label))
    return trials


In [7]:
def compute_metrics(labels, scores):
    """Calculate all required metrics"""
    # Calculate EER
    fpr, tpr, thresholds = roc_curve(labels, scores)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    
    # Calculate TAR@1%FAR
    far_1_idx = np.argmin(np.abs(fpr - 0.01))
    tar_at_1far = tpr[far_1_idx]
    
    # Calculate Speaker Identification Accuracy
    # For this we'll consider each test sample and find most similar enrollment
    # (This requires tracking which embeddings belong to which speakers)
    return {
        'eer': eer,
        'tar_at_1far': tar_at_1far,
        'identification_accuracy': None  # Will be calculated separately
    }

def calculate_identification_accuracy(embeddings_dict):
    """Calculate speaker identification accuracy"""
    correct = 0
    total = 0
    
    # Group embeddings by speaker
    speaker_embeddings = {}
    for path, emb in embeddings_dict.items():
        speaker = path.split('/')[-3]  # Extract speaker ID from path
        if speaker not in speaker_embeddings:
            speaker_embeddings[speaker] = []
        speaker_embeddings[speaker].append(emb)
    
    # For each speaker's embeddings, compare against all others
    for speaker, embs in speaker_embeddings.items():
        for i, emb in enumerate(embs):
            # Compare against all other embeddings
            best_score = -1
            best_speaker = None
            for other_speaker, other_embs in speaker_embeddings.items():
                for other_emb in other_embs:
                    if np.array_equal(emb, other_emb):
                        continue  # Skip same embedding
                    score = np.dot(emb, other_emb) / (np.linalg.norm(emb) * np.linalg.norm(other_emb))
                    if score > best_score:
                        best_score = score
                        best_speaker = other_speaker
            
            if best_speaker == speaker:
                correct += 1
            total += 1
    
    return correct / total if total > 0 else 0

In [8]:
print("Loading trial pairs...")
trials = read_trials(trial_file)
print(f"Loaded {len(trials)} trial pairs")

Loading trial pairs...
Loaded 37611 trial pairs


In [9]:
scores = []
labels = []
embeddings_dict = {}  # To store all embeddings for identification accuracy

for path1, path2, label in tqdm(trials, desc="Processing trials"):
    # Skip if files don't exist
    if not os.path.exists(path1) or not os.path.exists(path2):
        continue
        
    # Get embeddings (with caching)
    if path1 not in embeddings_dict:
        emb1 = get_embedding(path1)
        if emb1 is None:
            continue
        embeddings_dict[path1] = emb1
    else:
        emb1 = embeddings_dict[path1]
    
    if path2 not in embeddings_dict:
        emb2 = get_embedding(path2)
        if emb2 is None:
            continue
        embeddings_dict[path2] = emb2
    else:
        emb2 = embeddings_dict[path2]
    
    # Calculate cosine similarity
    similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
    scores.append(similarity)
    labels.append(label)

# Calculate verification metrics
if len(scores) > 0:
    metrics = compute_metrics(labels, scores)
    
    # Calculate identification accuracy
    metrics['identification_accuracy'] = calculate_identification_accuracy(embeddings_dict)
    
    print(f"\nEvaluation Results:")
    print(f"Processed {len(scores)} valid trial pairs")
    print(f"Equal Error Rate (EER): {metrics['eer']*100:.2f}%")
    print(f"True Acceptance Rate @ 1% FAR: {metrics['tar_at_1far']*100:.2f}%")
    print(f"Speaker Identification Accuracy: {metrics['identification_accuracy']*100:.2f}%")
    
    # Save results with all metrics
    results = pd.DataFrame({
        'enroll_path': [t[0] for t in trials[:len(scores)]],
        'test_path': [t[1] for t in trials[:len(scores)]],
        'label': labels,
        'score': scores,
        'eer': metrics['eer'],
        'tar_at_1far': metrics['tar_at_1far'],
        'identification_accuracy': metrics['identification_accuracy']
    })
    results.to_csv("verification_results_with_metrics.csv", index=False)
    print("Results with all metrics saved to verification_results_with_metrics.csv")
else:
    print("No valid trial pairs processed")

Processing trials: 100%|██████████| 37611/37611 [07:16<00:00, 86.18it/s]  



Evaluation Results:
Processed 37611 valid trial pairs
Equal Error Rate (EER): 47.93%
True Acceptance Rate @ 1% FAR: 2.46%
Speaker Identification Accuracy: 34.56%
Results with all metrics saved to verification_results_with_metrics.csv
