### Multi-Speaker Scenario Dataset Creation and Evaluation
The second part of the assignment shifts focus to multi-speaker scenarios, involving the creation of a custom dataset and the evaluation of both speaker separation and identification tasks using advanced models.

In [None]:
pip install speechbrain


In [None]:
pip install pesq


In [None]:
pip install mir-eval

In [73]:
import os
import glob
import random
import torch
import torchaudio
import numpy as np
from speechbrain.pretrained import SepformerSeparation
from mir_eval.separation import bss_eval_sources
from pesq import pesq
from tqdm import tqdm

In [74]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [75]:
VOXCELEB2_TXT_BASE_DIR = "/kaggle/input/vox2celebtext/txt"
VOXCELEB2_AUDIO_BASE_DIR = "/kaggle/input/voxcelebdataset-su/vox2_test_aac/aac"

In [9]:
TARGET_SAMPLE_RATE = 8000

In [76]:
TARGET_SAMPLE_RATE_16 = 16000

In [None]:
def load_voxceleb_metadata(txt_base_dir, audio_base_dir):
    
    speaker_dict = {}
    pattern = os.path.join(txt_base_dir, "**", "*.txt")
    txt_files = glob.glob(pattern, recursive=True)
    
    for txt_file in txt_files:
        rel_path = os.path.relpath(txt_file, txt_base_dir)
        parts = rel_path.split(os.sep)
        if len(parts) < 3:
            continue  
        
        speaker_id = parts[0]
        recording_id = parts[1]
        file_name = parts[2]
        audio_file = os.path.join(audio_base_dir, speaker_id, recording_id, file_name.replace('.txt', '.m4a'))
        if not os.path.exists(audio_file):
            continue
        
        if speaker_id not in speaker_dict:
            speaker_dict[speaker_id] = []
        speaker_dict[speaker_id].append(audio_file)
    
    return speaker_dict

In [78]:
def resample_audio(audio, orig_sr, target_sr):
    if orig_sr != target_sr:
        transform = torchaudio.transforms.Resample(orig_sr, target_sr)
        audio = transform(audio)
    return audio

In [None]:
def load_audio_file(file_path, target_sr=TARGET_SAMPLE_RATE):
    
    audio, sr = torchaudio.load(file_path)
    audio = resample_audio(audio, sr, target_sr)
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    return audio


In [80]:
def pad_or_truncate(aud1, aud2):
    len1 = aud1.shape[1]
    len2 = aud2.shape[1]
    if len1 < len2:
        pad = torch.zeros(1, len2 - len1)
        aud1 = torch.cat([aud1, pad], dim=1)
    elif len2 < len1:
        pad = torch.zeros(1, len1 - len2)
        aud2 = torch.cat([aud2, pad], dim=1)
    return aud1, aud2

In [None]:
def mix_utterances(utt1, utt2, snr_dB=0):
    
    utt1, utt2 = pad_or_truncate(utt1, utt2)
    power1 = utt1.pow(2).mean()
    power2 = utt2.pow(2).mean()
    scale = torch.sqrt(power1 / (10**(snr_dB/10) * power2 + 1e-8))
    utt2_scaled = utt2 * scale
    mixture = utt1 + utt2_scaled
    return mixture, utt1, utt2_scaled

In [None]:
def evaluate_metrics(ref_sources, est_sources, sample_rate):
    SDR, SIR, SAR, _ = bss_eval_sources(ref_sources, est_sources)
    pesq_mode = "nb" if sample_rate == 8000 else "wb"

    pesq_scores = []
    for i in range(ref_sources.shape[0]):
        score = pesq(sample_rate, ref_sources[i], est_sources[i], mode=pesq_mode)
        pesq_scores.append(score)

    return SDR, SIR, SAR, pesq_scores

In [None]:
def create_scenario(speaker_dict, speaker_ids, num_mixtures=500):
    
    mixtures = []
    available_speakers = [s for s in speaker_ids if s in speaker_dict and len(speaker_dict[s]) > 0]
    if len(available_speakers) < 2:
        raise ValueError("Need at least two speakers to create mixtures.")

    for _ in range(num_mixtures):
        spk1, spk2 = random.sample(available_speakers, 2)
        utt1_path = random.choice(speaker_dict[spk1])
        utt2_path = random.choice(speaker_dict[spk2])
        utt1 = load_audio_file(utt1_path)
        utt2 = load_audio_file(utt2_path)
        mixture, ref1, ref2 = mix_utterances(utt1, utt2, snr_dB=0)

        mixtures.append((mixture, [ref1, ref2], [spk1, spk2]))

    return mixtures


In [None]:
from itertools import permutations

def best_permutation_match(ref_sources, est_sources, sample_rate):
    best_sdr = -float('inf')
    best_perm = None
    
    for perm in permutations([0, 1]):  
        perm_est_sources = est_sources[list(perm), :]
        SDR, _, _, _ = evaluate_metrics(ref_sources, perm_est_sources, sample_rate)
        total_sdr = sum(SDR)  
        
        if total_sdr > best_sdr:
            best_sdr = total_sdr
            best_perm = perm
    
    return best_perm 

In [None]:
def run_separation_and_evaluation(mixtures, model, sample_rate, save_dir="separated_audio"):
    os.makedirs(save_dir, exist_ok=True)
    metrics_results = []
    
    for idx, (mixture, refs, spk_ids) in enumerate(tqdm(mixtures, desc="Processing mixtures")):
        temp_filename = f"temp_mixture_{idx}.wav"
        torchaudio.save(temp_filename, mixture, sample_rate)

        mix_tensor, sr = torchaudio.load(temp_filename)
        separated = model.separate_batch(mix_tensor)

        ref_sources = np.array([s.squeeze().numpy() for s in refs])  # (2, samples)
        est_sources = separated.cpu().numpy().squeeze(0).T  # (2, samples)

        best_perm = best_permutation_match(ref_sources, est_sources, sample_rate)
        est_sources = est_sources[list(best_perm), :]  # Reorder estimated sources

        for spk_idx, (est_source, spk_id) in enumerate(zip(est_sources, spk_ids)):
            spk_save_path = os.path.join(save_dir, f"spk_{spk_id}_sep_{idx}.wav")
            torchaudio.save(spk_save_path, torch.tensor(est_source).unsqueeze(0), sample_rate)

        SDR, SIR, SAR, pesq_scores = evaluate_metrics(ref_sources, est_sources, sample_rate)
        metrics_results.append({
            "SDR": SDR,
            "SIR": SIR,
            "SAR": SAR,
            "PESQ": pesq_scores
        })

        os.remove(temp_filename)

    return metrics_results

In [85]:
print("Loading metadata...")
speaker_dict = load_voxceleb_metadata(VOXCELEB2_TXT_BASE_DIR, VOXCELEB2_AUDIO_BASE_DIR)
all_speakers = sorted(list(speaker_dict.keys()))
print(f"Total speakers found: {len(all_speakers)}")

Loading metadata...
Total speakers found: 118


In [32]:
train_ids = all_speakers[:50]
test_ids = all_speakers[:100]
print(f"Training speakers: {len(train_ids)}, Testing speakers: {len(test_ids)}")

Training speakers: 50, Testing speakers: 100


In [33]:
print("Creating testing scenario mixtures...")
test_mixtures = create_scenario(speaker_dict, test_ids)
print("Created")

Creating testing scenario mixtures...
Created


In [87]:
print("Loading pre-trained SepFormer model...")
sepformer_model = SepformerSeparation.from_hparams("speechbrain/sepformer-wsj02mix",run_opts={"device":"cuda"})


Loading pre-trained SepFormer model...


In [88]:
print("Performing separation and evaluation...")
metrics_results = run_separation_and_evaluation(test_mixtures, sepformer_model, TARGET_SAMPLE_RATE)

Performing separation and evaluation...


	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  SDR, SIR, SAR, _ = bss_eval_sources(ref_sources, est_sources)
Processing mixtures: 100%|██████████| 500/500 [19:04<00:00,  2.29s/it]


In [89]:
all_SDR = np.concatenate([res["SDR"] for res in metrics_results])
all_SIR = np.concatenate([res["SIR"] for res in metrics_results])
all_SAR = np.concatenate([res["SAR"] for res in metrics_results])
all_PESQ = np.concatenate([res["PESQ"] for res in metrics_results])
    
print("\n--- Evaluation Results (Average over mixtures) ---")
print(f"Average SDR: {np.mean(all_SDR):.2f} dB")
print(f"Average SIR: {np.mean(all_SIR):.2f} dB")
print(f"Average SAR: {np.mean(all_SAR):.2f} dB")
print(f"Average PESQ: {np.mean(all_PESQ):.2f}")


--- Evaluation Results (Average over mixtures) ---
Average SDR: 8.79 dB
Average SIR: 14.83 dB
Average SAR: 11.61 dB
Average PESQ: 2.01


### TASK B

In [90]:
def resample_audio_16(audio, orig_sr, target_sr=16000):
    if orig_sr != target_sr:
        transform = torchaudio.transforms.Resample(orig_sr, target_sr)
        audio = transform(audio)
    return audio

In [None]:
def load_audio_file_16(file_path):
    audio, sr = torchaudio.load(file_path)
    target_sr = 16000
    audio = resample_audio_16(audio, sr, target_sr=16000)
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    return audio

In [None]:
def create_scenario_with_labels(speaker_dict, speaker_ids, num_mixtures=100):
    mixtures = []
    available_speakers = [s for s in speaker_ids if s in speaker_dict and len(speaker_dict[s]) > 0]
    if len(available_speakers) < 2:
        raise ValueError("Need at least two speakers to create mixtures.")
    
    for _ in range(num_mixtures):
        spk1, spk2 = random.sample(available_speakers, 2)
        utt1_path = random.choice(speaker_dict[spk1])
        utt2_path = random.choice(speaker_dict[spk2])
        utt1 = load_audio_file_16(utt1_path)
        utt2 = load_audio_file_16(utt2_path)
        mixture, ref1, ref2 = mix_utterances(utt1, utt2, snr_dB=0)
        mixtures.append((mixture, [ref1, ref2], [spk1, spk2]))
    return mixtures

In [92]:
folder_path = "/kaggle/working/separated_audio"
print(os.listdir(folder_path))

['spk_id02542_sep_261.wav', 'spk_id07312_sep_442.wav', 'spk_id01822_sep_3.wav', 'spk_id01224_sep_318.wav', 'spk_id01224_sep_1.wav', 'spk_id04295_sep_159.wav', 'spk_id06913_sep_103.wav', 'spk_id00419_sep_112.wav', 'spk_id00419_sep_113.wav', 'spk_id01541_sep_128.wav', 'spk_id00866_sep_63.wav', 'spk_id01228_sep_306.wav', 'spk_id01106_sep_466.wav', 'spk_id05459_sep_133.wav', 'spk_id00812_sep_106.wav', 'spk_id01460_sep_379.wav', 'spk_id00017_sep_449.wav', 'spk_id02181_sep_274.wav', 'spk_id01509_sep_132.wav', 'spk_id01066_sep_499.wav', 'spk_id03030_sep_27.wav', 'spk_id05594_sep_254.wav', 'spk_id06811_sep_385.wav', 'spk_id02725_sep_308.wav', 'spk_id06310_sep_181.wav', 'spk_id00562_sep_62.wav', 'spk_id02019_sep_77.wav', 'spk_id03127_sep_184.wav', 'spk_id01333_sep_48.wav', 'spk_id01000_sep_3.wav', 'spk_id03980_sep_118.wav', 'spk_id03382_sep_26.wav', 'spk_id04627_sep_64.wav', 'spk_id00017_sep_149.wav', 'spk_id01593_sep_356.wav', 'spk_id03524_sep_53.wav', 'spk_id01333_sep_44.wav', 'spk_id05816_se

In [53]:
import shutil

shutil.make_archive('/kaggle/working/my_output', 'zip', '/kaggle/working/separated_audio')

'/kaggle/working/my_output.zip'

In [None]:
folder_path = "/kaggle/working/separated_audio" 
files_and_dirs = os.listdir(folder_path)
print(len(files_and_dirs))

1000


In [None]:
def extract_embedding(model, processor, waveform, sample_rate=TARGET_SAMPLE_RATE_16):
    input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
    input_values = input_values.to(next(model.parameters()).device)
    with torch.no_grad():
        if hasattr(model, "extract_embeddings"):
            embedding = model.extract_embeddings(input_values)
        else:
            outputs = model(input_values)
            embedding = outputs.last_hidden_state.mean(dim=1)
    return embedding.squeeze()  

In [None]:
def identify_speaker(embedding, enrollment_embeddings):
    best_id = None
    best_sim = -float('inf')
    for spk, enroll_emb in enrollment_embeddings.items():
        sim = torch.cosine_similarity(embedding, enroll_emb, dim=0).item()
        if sim > best_sim:
            best_sim = sim
            best_id = spk
    return best_id

In [97]:
print(f"Total speakers: {len(all_speakers)}")

Total speakers: 118


In [None]:
new_test_ids = all_speakers[:100]

In [99]:
from transformers import Wav2Vec2FeatureExtractor, WavLMModel
from scipy.spatial.distance import cosine

In [100]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [101]:
pretrained_model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(device)

In [102]:
import torch.nn as nn

In [103]:
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(torch.clamp(1.0 - torch.pow(cosine, 2), min=1e-6))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=input.device)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [104]:
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r=4, alpha=16.0, dropout=0.1):
        super(LoRALinear, self).__init__()
        self.r = r
        if r > 0:
            self.lora_A = nn.Linear(in_features, r, bias=False)
            self.lora_B = nn.Linear(r, out_features, bias=False)
            self.scaling = alpha / r
            self.dropout = nn.Dropout(dropout)
        else:
            self.lora_A = None
            self.lora_B = None

    def forward(self, x):
        if self.r > 0:
            return self.lora_B(self.dropout(self.lora_A(x))) * self.scaling
        else:
            return 0

In [None]:
class FineTuneModel(nn.Module):
    def __init__(self, pretrained_model, embedding_dim, num_classes, lora_r=4, lora_alpha=16):
        super(FineTuneModel, self).__init__()
        self.pretrained = pretrained_model
        for param in self.pretrained.parameters():
            param.requires_grad = False
        self.lora = LoRALinear(embedding_dim, embedding_dim, r=lora_r, alpha=lora_alpha)
        self.arcface = ArcMarginProduct(embedding_dim, num_classes)
    
    def forward(self, input_values, labels):
        outputs = self.pretrained(input_values)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        adapted_embeddings = embeddings + self.lora(embeddings)
        logits = self.arcface(adapted_embeddings, labels)
        return logits, adapted_embeddings
    
  
    def extract_embeddings(self, input_values):
        outputs = self.pretrained(input_values)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        adapted_embeddings = embeddings + self.lora(embeddings)
        return adapted_embeddings

In [106]:
import math

In [None]:
embedding_dim = pretrained_model.config.hidden_size  
num_classes = len(test_ids)         
finetune_model = FineTuneModel(pretrained_model, embedding_dim, num_classes, lora_r=4, lora_alpha=16)
finetune_model.to(device)

FineTuneModel(
  (pretrained): WavLMModel(
    (feature_extractor): WavLMFeatureEncoder(
      (conv_layers): ModuleList(
        (0): WavLMGroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x WavLMNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x WavLMNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): WavLMFeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): WavLMEncoder(
 

In [108]:
finetune_model.load_state_dict(torch.load("/kaggle/input/finetuned-model/best_finetune_model_epoch10.pt"))

  finetune_model.load_state_dict(torch.load("/kaggle/input/finetuned-model/best_finetune_model_epoch10.pt"))


<All keys matched successfully>

In [109]:
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus")

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model.to(device).eval()


WavLMModel(
  (feature_extractor): WavLMFeatureEncoder(
    (conv_layers): ModuleList(
      (0): WavLMGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): WavLMFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): WavLMEncoder(
    (pos_conv_embed): WavLMPositionalConvEmbedding(
      (conv): Parametrized

In [110]:
finetune_model.to(device).eval()

FineTuneModel(
  (pretrained): WavLMModel(
    (feature_extractor): WavLMFeatureEncoder(
      (conv_layers): ModuleList(
        (0): WavLMGroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x WavLMNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x WavLMNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): WavLMFeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): WavLMEncoder(
 

In [111]:
enrollment_embeddings_pretrained = {}
enrollment_embeddings_finetuned = {}

In [None]:
import torchaudio.transforms as T

resample_to_16k = T.Resample(orig_freq=8000, new_freq=16000)

In [None]:
for spk in test_ids:
    if spk in speaker_dict:
        enrollment_file = speaker_dict[spk][0]  
        waveform = load_audio_file_16(enrollment_file)  
        emb_pre = extract_embedding(pretrained_model, processor, waveform, TARGET_SAMPLE_RATE_16)
        emb_ft = extract_embedding(finetune_model, processor, waveform, TARGET_SAMPLE_RATE_16)
        enrollment_embeddings_pretrained[spk] = emb_pre
        enrollment_embeddings_finetuned[spk] = emb_ft

In [115]:
len(enrollment_embeddings_finetuned)

100

In [117]:
correct_pretrained = 0
correct_finetuned = 0
total = 0

In [None]:
for idx, (mixture, _, true_speakers) in enumerate(tqdm(test_mixtures, desc="Identifying speakers")):
    for spk_idx, spk_id in enumerate(true_speakers):
        sep_audio_path = os.path.join("separated_audio", f"spk_{spk_id}_sep_{idx}.wav")
        separated_waveform, sr = torchaudio.load(sep_audio_path)
        separated_waveform = resample_to_16k(separated_waveform)
        emb_pre = extract_embedding(pretrained_model, processor, separated_waveform, TARGET_SAMPLE_RATE_16)
        emb_ft = extract_embedding(finetune_model, processor, separated_waveform, TARGET_SAMPLE_RATE_16)
        pred_pre = identify_speaker(emb_pre, enrollment_embeddings_pretrained)
        pred_ft = identify_speaker(emb_ft, enrollment_embeddings_finetuned)

        if pred_pre == spk_id:
            correct_pretrained += 1
        if pred_ft == spk_id:
            correct_finetuned += 1
        total += 1

Identifying speakers: 100%|██████████| 500/500 [01:44<00:00,  4.79it/s]


In [123]:
accuracy_pretrained = correct_pretrained / total if total > 0 else 0.0
accuracy_finetuned = correct_finetuned / total if total > 0 else 0.0

print("\n--- Rank-1 Identification Accuracy ---")
print(f"Pre-trained Model: {accuracy_pretrained * 100:.2f}%")
print(f"Fine-tuned Model: {accuracy_finetuned * 100:.2f}%")


--- Rank-1 Identification Accuracy ---
Pre-trained Model: 23.42%
Fine-tuned Model: 39.9%
