### Integrated Pipeline for Speaker Separation and Identification
The final part of the assignment involves designing and evaluating a novel pipeline that combines speaker separation and identification into a unified system, trained and tested on the custom multi-speaker dataset.

In [None]:
pip install speechbrain

In [2]:
pip install pesq

Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install mir-eval

Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import glob
import random
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from speechbrain.pretrained import SepformerSeparation
from mir_eval.separation import bss_eval_sources
from pesq import pesq
from tqdm import tqdm
from itertools import permutations
import torch.nn.functional as F
from transformers import Wav2Vec2FeatureExtractor, WavLMModel
from scipy.spatial.distance import cosine
import torchaudio.transforms as T
from torch.cuda.amp import autocast, GradScaler
import librosa

  from .autonotebook import tqdm as notebook_tqdm
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
  from speechbrain.pretrained import SepformerSeparation


In [5]:
torchaudio.set_audio_backend("ffmpeg") 

  torchaudio.set_audio_backend("ffmpeg")


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [7]:
VOXCELEB2_TXT_BASE_DIR = "Dataset/vox2_test_txt/txt"
VOXCELEB2_AUDIO_BASE_DIR = "Converted_WAVs"

In [None]:
def load_voxceleb_metadata(txt_base_dir, audio_base_dir):
    
    speaker_dict = {}
    pattern = os.path.join(txt_base_dir, "**", "*.txt")
    txt_files = glob.glob(pattern, recursive=True)
    
    for txt_file in txt_files:
        rel_path = os.path.relpath(txt_file, txt_base_dir)
        parts = rel_path.split(os.sep)
        if len(parts) < 3:
            continue  
        
        speaker_id = parts[0]
        recording_id = parts[1]
        file_name = parts[2]
        audio_file = os.path.join(audio_base_dir, speaker_id, recording_id, file_name.replace('.txt', '.wav'))
        if not os.path.exists(audio_file):
            continue
        
        if speaker_id not in speaker_dict:
            speaker_dict[speaker_id] = []
        speaker_dict[speaker_id].append(audio_file)
    
    return speaker_dict

In [None]:
def load_audio_file(file_path, target_sr):
    audio, sr = torchaudio.load(file_path)
    if sr != target_sr:
        transform = torchaudio.transforms.Resample(sr, target_sr)
        audio = transform(audio)
    if audio.shape[0] > 1:
        audio = audio.mean(dim=0, keepdim=True)
    return audio

In [10]:
def pad_or_truncate(aud1, aud2):
    len1 = aud1.shape[1]
    len2 = aud2.shape[1]
    if len1 < len2:
        pad = torch.zeros(1, len2 - len1)
        aud1 = torch.cat([aud1, pad], dim=1)
    elif len2 < len1:
        pad = torch.zeros(1, len1 - len2)
        aud2 = torch.cat([aud2, pad], dim=1)
    return aud1, aud2

In [11]:
def mix_utterances(utt1, utt2, snr_dB=0):
    utt1, utt2 = pad_or_truncate(utt1, utt2)
    power1 = utt1.pow(2).mean()
    power2 = utt2.pow(2).mean()
    scale = torch.sqrt(power1 / (10**(snr_dB/10) * power2 + 1e-8))
    utt2_scaled = utt2 * scale
    mixture = utt1 + utt2_scaled
    return mixture, utt1, utt2_scaled

In [None]:

FIXED_DURATION = 3  

import torch

def fix_audio_length(audio, sr, target_duration=3):
    target_length = int(target_duration * sr)  
    audio_length = len(audio)

    if audio_length > target_length:
        audio = audio[:target_length]
    elif audio_length < target_length:
        pad_length = target_length - audio_length
        audio = np.pad(audio, (0, pad_length), mode='constant')

    return torch.tensor(audio, dtype=torch.float32) 


In [None]:
def create_scenario(speaker_dict, speaker_ids, num_mixtures=1000, target_sr=16000):
    
    mixtures = []
    available_speakers = [s for s in speaker_ids if s in speaker_dict and len(speaker_dict[s]) > 0]
    if len(available_speakers) < 2:
        raise ValueError("Need at least two speakers to create mixtures.")
    
    for _ in range(num_mixtures):
        spk1, spk2 = random.sample(available_speakers, 2)
        utt1_path = random.choice(speaker_dict[spk1])
        utt2_path = random.choice(speaker_dict[spk2])
        utt1 = load_audio_file(utt1_path, target_sr)
        utt2 = load_audio_file(utt2_path, target_sr)
        utt1 = fix_audio_length(utt1, target_sr, FIXED_DURATION)
        utt2 = fix_audio_length(utt2, target_sr, FIXED_DURATION)
        mixture, ref1, ref2 = mix_utterances(utt1, utt2, snr_dB=0)
        mixture = fix_audio_length(mixture, target_sr, FIXED_DURATION)
        mixtures.append((mixture, [ref1, ref2], [spk1, spk2]))
    return mixtures

In [None]:
class JointSeparationIdentification(nn.Module):
    def __init__(self, sepformer_model, id_model, num_speakers, embedding_dim, id_loss_weight=0.5):
        super(JointSeparationIdentification, self).__init__()
        self.sepformer = sepformer_model  
        for param in self.sepformer.parameters():
            param.requires_grad = False
        self.id_model = id_model  
        for layer in self.id_model.encoder.layers[:10]:
            for param in layer.parameters():
                param.requires_grad = False
        self.classifier = nn.Linear(embedding_dim, num_speakers)
        self.id_loss_weight = id_loss_weight
        self.resampler = T.Resample(orig_freq=8000, new_freq=16000)
    
    def forward(self, mixture, labels=None): 
        mixture_input = mixture.squeeze(1)  
        separated = self.sepformer.separate_batch(mixture_input)
        separated = separated.squeeze(1).transpose(1, 2)
    
        B, n, T = separated.shape
        separated_16k = []
        for i in range(n):
            src = separated[:, i, :].unsqueeze(1)  
            upsampled = self.resampler(src)         
            separated_16k.append(upsampled.squeeze(1))
        separated_16k = torch.stack(separated_16k, dim=1)  
        embeddings = []
        for i in range(n):
            source_i = separated_16k[:, i, :]  
            id_outputs = self.id_model(source_i)
            emb = id_outputs.last_hidden_state.mean(dim=1)  
            embeddings.append(emb)
        embeddings = torch.stack(embeddings, dim=1)  

        if labels is not None:
            B, n, emb_dim = embeddings.shape
            flat_embeddings = embeddings.reshape(-1, emb_dim)  
            logits = self.classifier(flat_embeddings)  
            logits = logits.view(B, n, -1)  
            return separated_16k, embeddings, logits
        else:
            return separated_16k, embeddings



In [None]:
def joint_loss(separated, references, logits, labels, id_loss_weight=0.5):
    min_length = min(separated.size(2), references.size(2))
    separated = separated[:, :, :min_length]
    references = references[:, :, :min_length]
    sep_loss = F.mse_loss(separated, references)
    B, n, _ = logits.shape
    id_loss = F.cross_entropy(logits.view(B * n, -1), labels.view(B * n))
    
    return sep_loss + id_loss_weight * id_loss, sep_loss, id_loss


In [None]:
def evaluate_metrics(ref_sources, est_sources, sample_rate):
    SDR, SIR, SAR, _ = bss_eval_sources(ref_sources, est_sources)
    pesq_mode = "nb" if sample_rate == 8000 else "wb"

    pesq_scores = []
    for i in range(ref_sources.shape[0]):
        score = pesq(sample_rate, ref_sources[i], est_sources[i], mode=pesq_mode)
        pesq_scores.append(score)

    return SDR, SIR, SAR, pesq_scores

In [None]:
def train_joint_model(joint_model, train_loader, optimizer, device, num_epochs):
    joint_model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for mixtures, references, speaker_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            mixtures = mixtures.to(device)
            references = references.to(device)
            speaker_labels = speaker_labels.to(device)
            
            optimizer.zero_grad()
            with autocast():
                separated, embeddings, logits = joint_model(mixtures, speaker_labels)
                loss, sep_loss, id_loss = joint_loss(separated, references, logits, speaker_labels, joint_model.id_loss_weight)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
            torch.cuda.empty_cache()
            
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

In [None]:
def evaluate_joint_model(joint_model, test_loader, device):
    joint_model.eval()
    sep_metrics = []

    def best_permutation_match(ref_sources, est_sources, sample_rate):
        best_sdr = -float('inf')
        best_perm = None
        for perm in permutations([0, 1]):
            perm_est_sources = est_sources[list(perm), :]
            SDR, _, _, _ = evaluate_metrics(ref_sources, perm_est_sources, sample_rate)
            total_sdr = sum(SDR)
            if total_sdr > best_sdr:
                best_sdr = total_sdr
                best_perm = perm
        return best_perm

    
    for mixtures, references, _ in tqdm(test_loader, desc="Evaluating"):
        mixtures = mixtures.to(device)
        separated, _ = joint_model(mixtures) 
        B, n, T = separated.shape

        for i in range(B):
            ref_sources = np.array([r.squeeze().cpu().numpy() for r in references[i]])  
            est_sources = np.array([
                librosa.resample(src, orig_sr=16000, target_sr=int(16000 * (ref_sources.shape[1] / src.shape[0])))
                for src in separated[i].detach().cpu().numpy()
                ])
            best_perm = best_permutation_match(ref_sources, est_sources, 16000)
            est_sources = est_sources[list(best_perm), :]
            SDR, SIR, SAR, pesq_scores = evaluate_metrics(ref_sources, est_sources, 16000)
            sep_metrics.append({
                "SDR": SDR,
                "SIR": SIR,
                "SAR": SAR,
                "PESQ": pesq_scores
            })

    avg_SDR = np.mean([m['SDR'][0] for m in sep_metrics])
    avg_SIR = np.mean([m['SIR'][0] for m in sep_metrics])
    avg_SAR = np.mean([m['SAR'][0] for m in sep_metrics])
    avg_PESQ = np.mean([m['PESQ'][0] for m in sep_metrics])

    print("\n--- Test Separation Metrics ---")
    print(f"Average SDR: {avg_SDR:.2f} dB")
    print(f"Average SIR: {avg_SIR:.2f} dB")
    print(f"Average SAR: {avg_SAR:.2f} dB")
    print(f"Average PESQ: {avg_PESQ:.2f}")

    return sep_metrics


In [27]:
TARGET_SR_16 = 16000
TARGET_SR_8 = 8000

In [28]:
speaker_dict = load_voxceleb_metadata(VOXCELEB2_TXT_BASE_DIR, VOXCELEB2_AUDIO_BASE_DIR)
all_speakers = sorted(list(speaker_dict.keys()))
print(f"Total speakers: {len(all_speakers)}")

Total speakers: 68


In [None]:
train_ids = all_speakers[:30]
test_ids = all_speakers[30:60]

In [None]:
print("Creating training mixtures...")
train_mixtures = create_scenario(speaker_dict, train_ids, num_mixtures=100, target_sr=TARGET_SR_16)
print("Creating testing mixtures...")
test_mixtures = create_scenario(speaker_dict, test_ids, num_mixtures=100, target_sr=TARGET_SR_16)

Creating training mixtures...
Creating testing mixtures...


In [31]:
print(f"Number of training mixtures: {len(train_mixtures)}")
print(f"Number of testing mixtures: {len(test_mixtures)}")

Number of training mixtures: 100
Number of testing mixtures: 100


In [32]:
train_ids[:10]

['id00812',
 'id00817',
 'id01066',
 'id01106',
 'id01298',
 'id01460',
 'id01509',
 'id01593',
 'id01618',
 'id01822']

In [33]:
sepformer_model = SepformerSeparation.from_hparams("speechbrain/sepformer-wsj02mix",run_opts={"device":"cuda", "precision": "16-mixed"})

INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Fetching from HuggingFace Hub 'speechbrain/sepformer-wsj02mix' if not cached
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/sepformer-wsj02mix' if not cached
INFO:speechbrain.utils.fetching:Fetch masknet.ckpt: Fetching from HuggingFace Hub 'speechbrain/sepformer-wsj02mix' if not cached
INFO:speechbrain.utils.fetching:Fetch encoder.ckpt: Fetching from HuggingFace Hub 'speechbrain/sepformer-wsj02mix' if not cached
INFO:speechbrain.utils.fetching:Fetch decoder.ckpt: Fetching from HuggingFace Hub 'speechbrain/sepformer-wsj02mix' if not cached
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: masknet, encoder, decoder
  state_dict = torch.load(path, map_location=device)


In [34]:
pretrained_model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(device)

In [35]:
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus")

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sepformer_model.to(device)
pretrained_model.to(device)

WavLMModel(
  (feature_extractor): WavLMFeatureEncoder(
    (conv_layers): ModuleList(
      (0): WavLMGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): WavLMFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): WavLMEncoder(
    (pos_conv_embed): WavLMPositionalConvEmbedding(
      (conv): Parametrized

In [37]:
num_speakers = len(train_ids)
num_speakers

30

In [38]:
embedding_dim = pretrained_model.config.hidden_size

In [39]:
joint_model = JointSeparationIdentification(sepformer_model, pretrained_model, num_speakers, embedding_dim, id_loss_weight=0.5)
joint_model.to(device)

JointSeparationIdentification(
  (sepformer): SepformerSeparation(
    (mods): ModuleDict(
      (encoder): Encoder(
        (conv1d): Conv1d(1, 256, kernel_size=(16,), stride=(8,), bias=False)
      )
      (decoder): Decoder(256, 1, kernel_size=(16,), stride=(8,), bias=False)
      (masknet): Dual_Path_Model(
        (norm): GroupNorm(1, 256, eps=1e-08, affine=True)
        (conv1d): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
        (dual_mdl): ModuleList(
          (0-1): 2 x Dual_Computation_Block(
            (intra_mdl): SBTransformerBlock(
              (mdl): TransformerEncoder(
                (layers): ModuleList(
                  (0-7): 8 x TransformerEncoderLayer(
                    (self_att): MultiheadAttention(
                      (att): MultiheadAttention(
                        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
                      )
                    )
                    (pos_ffn): Posi

In [40]:
from torch.utils.data import DataLoader

In [None]:
def custom_collate_fn(batch):
    mixtures, references, labels = zip(*batch)
    
    max_len_mix = max(m.shape[1] for m in mixtures)
    max_len_refs = max(r.shape[1] for r in references)
    max_len = max(max_len_mix, max_len_refs)
    
    padded_mixtures = []
    for mix in mixtures:
        pad_size = max_len - mix.shape[1]
        if pad_size > 0:
            pad_tensor = torch.zeros(mix.shape[0], pad_size)
            mix_padded = torch.cat([mix, pad_tensor], dim=1)
        else:
            mix_padded = mix
        padded_mixtures.append(mix_padded)
    
    padded_references = []
    for refs in references:
        pad_size = max_len - refs.shape[1]
        if pad_size > 0:
            pad_tensor = torch.zeros(refs.shape[0], pad_size)
            refs_padded = torch.cat([refs, pad_tensor], dim=1)
        else:
            refs_padded = refs
        padded_references.append(refs_padded)
    
    mixtures_batch = torch.stack(padded_mixtures)     
    references_batch = torch.stack(padded_references)    
    labels_batch = torch.stack(labels)                   
    
    return mixtures_batch, references_batch, labels_batch


In [None]:
class MixtureDataset(torch.utils.data.Dataset):
    def __init__(self, mixtures, target_sr):
        self.mixtures = mixtures
        self.target_sr = target_sr
    def __len__(self):
        return len(self.mixtures)
    def __getitem__(self, idx):
        mixture, refs, spk_ids = self.mixtures[idx]
        if self.target_sr != 8000:
            resampler = T.Resample(self.target_sr, 8000)
            mixture = resampler(mixture)
        refs_8k = []
        for r in refs:
            if self.target_sr != 8000:
                r_ds = T.Resample(self.target_sr, 8000)(r)
            else:
                r_ds = r
            upsampler = T.Resample(8000, 16000)
            r_up = upsampler(r_ds)  
            refs_8k.append(r_up.squeeze(0))
        refs_tensor = torch.stack(refs_8k, dim=0)
        spk_to_idx = {spk: i for i, spk in enumerate(sorted(set(spk_ids)))}
        labels = torch.tensor([spk_to_idx[spk] for spk in spk_ids], dtype=torch.long)
        return mixture, refs_tensor, labels 


In [44]:
train_dataset = MixtureDataset(train_mixtures, target_sr=TARGET_SR_16)
test_dataset = MixtureDataset(test_mixtures, target_sr=TARGET_SR_16)


In [45]:
print(f"Train Dataset Size: {len(train_dataset)}")
print(f"Test Dataset Size: {len(test_dataset)}")

Train Dataset Size: 100
Test Dataset Size: 100


In [46]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

In [None]:
for batch in test_loader:
    print(len(batch))  
    break

3


In [None]:
batch = next(iter(train_loader))
mixtures_batch, references_batch, labels_batch = batch
print("Mixtures shape:", mixtures_batch.shape)        
print("References shape:", references_batch.shape)      
print("Labels shape:", labels_batch.shape)              


Mixtures shape: torch.Size([1, 1, 155648])
References shape: torch.Size([1, 2, 155648])
Labels shape: torch.Size([1, 2])


In [49]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of testing batches: {len(test_loader)}")

Number of training batches: 100
Number of testing batches: 100


In [None]:
optimizer = torch.optim.Adam(joint_model.parameters(), lr=1e-4)

In [51]:

scaler = GradScaler()


  scaler = GradScaler()


In [None]:
print("Starting joint training...")
train_joint_model(joint_model, train_loader, optimizer, device, num_epochs=10)

Starting joint training...


  with autocast():
Epoch 1: 100%|██████████| 100/100 [00:55<00:00,  1.81it/s]


Epoch 1/10, Loss: 325.7318


Epoch 2: 100%|██████████| 100/100 [00:54<00:00,  1.84it/s]


Epoch 2/10, Loss: 287.3443


Epoch 3: 100%|██████████| 100/100 [00:54<00:00,  1.83it/s]


Epoch 3/10, Loss: 281.3568


Epoch 4: 100%|██████████| 100/100 [00:54<00:00,  1.85it/s]


Epoch 4/10, Loss: 280.1805


Epoch 5: 100%|██████████| 100/100 [00:54<00:00,  1.83it/s]


Epoch 5/10, Loss: 279.8194


Epoch 6: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


Epoch 6/10, Loss: 279.5151


Epoch 7: 100%|██████████| 100/100 [00:53<00:00,  1.87it/s]


Epoch 7/10, Loss: 279.2141


Epoch 8: 100%|██████████| 100/100 [00:53<00:00,  1.87it/s]


Epoch 8/10, Loss: 278.4450


Epoch 9: 100%|██████████| 100/100 [00:53<00:00,  1.87it/s]


Epoch 9/10, Loss: 280.7335


Epoch 10: 100%|██████████| 100/100 [00:53<00:00,  1.86it/s]

Epoch 10/10, Loss: 278.8762





In [None]:
print("Evaluating joint model on test set...")
sep_metrics = evaluate_joint_model(joint_model, test_loader, device)

Evaluating joint model on test set...


	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  SDR, SIR, SAR, _ = bss_eval_sources(ref_sources, est_sources)
Evaluating: 100%|██████████| 100/100 [09:27<00:00,  5.68s/it]


--- Test Separation Metrics ---
Average SDR: -22.26 dB
Average SIR: 1.45 dB
Average SAR: -19.61 dB
Average PESQ: 1.35



