In [None]:
import os
import glob
import random
import tempfile
import numpy as np
import librosa
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader

# For evaluation metrics (EER, TAR) we use scikit‑learn
from sklearn.metrics import roc_curve

# For SDR, SIR, SAR evaluation (using mir_eval; see https://github.com/craffel/mir_eval)
import mir_eval

# For PESQ (see https://pypi.org/project/pesq/)
from pesq import pesq
import gc
import copy
import math

In [2]:
# =============================================================================
# 1. ArcFace Loss Implementation 
#    (Reference: Deng et al., “ArcFace: Additive Angular Margin Loss for Deep Face Recognition”, https://arxiv.org/abs/1801.07698)
# =============================================================================
class ArcFaceLoss(nn.Module):
    def __init__(self, embedding_size, num_classes, margin=0.5, scale=64):
        """
        embedding_size: dimension of speaker embeddings
        num_classes: number of speaker classes
        margin: angular margin penalty
        scale: scaling factor for logits
        """
        super(ArcFaceLoss, self).__init__()
        self.margin = margin
        self.scale = scale
        self.weight = nn.Parameter(torch.Tensor(num_classes, embedding_size))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, embeddings, labels):
        # Normalize embeddings and weights
        normed_embeddings = F.normalize(embeddings, p=2, dim=1)
        normed_weight = F.normalize(self.weight, p=2, dim=1)
        cosine = F.linear(normed_embeddings, normed_weight)  # [batch, num_classes]
        # Add angular margin
        theta = torch.acos(torch.clamp(cosine, -1+1e-7, 1-1e-7))
        target_logits = torch.cos(theta + self.margin)
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)
        output = cosine * (1 - one_hot) + target_logits * one_hot
        output = output * self.scale
        loss = F.cross_entropy(output, labels)
        return loss

In [3]:
# =============================================================================
# 2. LoRA (Low-Rank Adaptation) Implementation
#    (Reference: Hu et al., “LoRA: Low-Rank Adaptation of Large Language Models”, https://arxiv.org/abs/2106.09685)
# =============================================================================
class LoRALinear(nn.Module):
    def __init__(self, original_linear, r=4, alpha=1.0):
        """
        Wraps an existing nn.Linear module with a low‑rank update.
        r: rank of the low‑rank adaptation
        alpha: scaling factor
        """
        super(LoRALinear, self).__init__()
        self.original_linear = original_linear
        self.r = r
        self.alpha = alpha
        self.lora_A = nn.Parameter(torch.randn(original_linear.out_features, r) * 0.01)
        self.lora_B = nn.Parameter(torch.randn(r, original_linear.in_features) * 0.01)
        # Freeze original weights
        for param in self.original_linear.parameters():
            param.requires_grad = False

    def forward(self, x):
        # Standard linear output plus low‑rank update
        lora_A = self.lora_A.to(x.device)
        lora_B = self.lora_B.to(x.device)
        return self.original_linear(x) + self.alpha * (x @ lora_B.t() @ lora_A.t())
        # return self.original_linear(x) + self.alpha * (x @ self.lora_B.t() @ self.lora_A.t())

def get_parent_module(model, module_name):
    names = module_name.split('.')
    parent = model
    for n in names[:-1]:
        parent = getattr(parent, n)
    return parent

def apply_lora(model, r=4, alpha=1.0):
    """
    Iterates over all submodules and replaces nn.Linear modules with LoRALinear.
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            parent = get_parent_module(model, name)
            child_name = name.split('.')[-1]
            setattr(parent, child_name, LoRALinear(module, r, alpha))

In [4]:
def trim_audio(waveform, sr, max_duration):
    """
    Trims the waveform to max_duration seconds.
    """
    max_samples = int(max_duration * sr)
    if waveform.size(1) > max_samples:
        return waveform[:, :max_samples]
    return waveform

In [None]:
# =============================================================================
# 3. Dataset Definitions
# =============================================================================
class VoxCelebDataset(Dataset):
    """
    Dataset for VoxCeleb2 data for fine‑tuning.
    Assumes directory structure like: vox2/aac/<identity>/.../<file>.m4a
    """
    def __init__(self, root_dir, identities, transform=None, file_ext='.m4a'):
        self.samples = []
        self.transform = transform
        for identity in identities:
            id_path = os.path.join(root_dir, identity)
            files = glob.glob(os.path.join(id_path, '**', f'*{file_ext}'), recursive=True)
            for file in files:
                self.samples.append((identity, file))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        identity, file_path = self.samples[idx]
        waveform, sr = torchaudio.load(file_path)
        if self.transform:
            waveform = self.transform(waveform)
        return identity, waveform, sr

class VoxCelebTrialDataset(Dataset):
    """
    Dataset for VoxCeleb1 trial pairs.
    Expects a text file with lines formatted as:
      <label> <enrollment_path> <test_path>
    where paths are relative to a given root (e.g. vox1/wav).
    """
    def __init__(self, trial_file, root_dir):
        self.trials = []
        self.root_dir = root_dir
        with open(trial_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                label = int(parts[0])
                enrollment = os.path.join(root_dir, parts[1])
                test = os.path.join(root_dir, parts[2])
                self.trials.append((label, enrollment, test))
    
    def __len__(self):
        return len(self.trials)
    
    def __getitem__(self, idx):
        label, enroll_path, test_path = self.trials[idx]
        enroll_waveform, sr1 = torchaudio.load(enroll_path)
        test_waveform, sr2 = torchaudio.load(test_path)
        return label, enroll_waveform, test_waveform, sr1

In [6]:
# =============================================================================
# 4. Pre-trained Speaker Verification Model Wrapper
#    (Using Hugging Face’s transformers – https://github.com/huggingface/transformers)
# =============================================================================
from transformers import WavLMModel, AutoFeatureExtractor

class SpeakerVerificationModel(nn.Module):
    def __init__(self, model_name='microsoft/wavlm-base-plus'):
        """
        Loads a pre‑trained model and adds a projection head.
        """
        super(SpeakerVerificationModel, self).__init__()
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
        self.model = WavLMModel.from_pretrained(model_name)
        self.embedding_dim = self.model.config.hidden_size
        self.fc = nn.Linear(self.embedding_dim, self.embedding_dim)  # Projection head
       
    def forward(self, waveforms):
        """
        waveforms: can be a list of raw audio signals or a tensor of shape [batch, channels, time].
        If waveforms is a tensor with a singleton channel dimension, it is squeezed to shape [batch, time].
        """
        # If input is a tensor with shape [batch, 1, time], squeeze to [batch, time]
        if isinstance(waveforms, torch.Tensor):
            if waveforms.dim() == 3 and waveforms.size(1) == 1:
                waveforms = waveforms.squeeze(1)  # shape now: [batch, time]
            # Optionally, if waveforms are batched as [batch, time], convert to list of arrays if needed.
            # waveforms = waveforms.tolist() if waveforms.ndim == 2 else waveforms
            # Convert each sample to a numpy array
            waveforms = [w.cpu().numpy() for w in waveforms]
        
        inputs = self.feature_extractor(waveforms, sampling_rate=16000, return_tensors="pt", padding=True)
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        # Mean pooling over time dimension
        hidden_states = outputs.last_hidden_state.mean(dim=1)
        embeddings = self.fc(hidden_states)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings

In [24]:
# =============================================================================
# 5. Evaluation Metrics for Speaker Verification
# =============================================================================
def compute_eer(scores, labels):
    """
    Computes Equal Error Rate (EER) given similarity scores and binary labels.
    """
    fpr, tpr, thresholds = roc_curve(labels, scores)
    fnr = 1 - tpr
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    return eer * 100

def compute_tar_at_far(scores, labels, target_far=0.01):
    """
    Computes True Accept Rate (TAR) at a given False Accept Rate (FAR).
    """
    fpr, tpr, thresholds = roc_curve(labels, scores)
    idx = np.argmin(np.abs(fpr - target_far))
    tar = tpr[idx]
    return tar * 100

def compute_identification_accuracy(labels, scores, threshold=0.5):
    predictions = [1 if score >= threshold else 0 for score in scores]
    correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
    accuracy = correct / len(labels) * 100  # Convert to percentage
    return accuracy

In [8]:
# =============================================================================
# 6. Training Function for Speaker Verification Model (Fine-tuning with LoRA + ArcFace)
# =============================================================================
def train_speaker_verification(model, train_loader, optimizer, arcface_loss, id2label, device):
    model.train()
    total_loss = 0
    for batch in train_loader:
        identities, waveforms, sr = batch
        # Assume waveforms are padded tensors; if not, additional collate_fn is needed.
        waveforms = waveforms.to(device)
        # Map identities to integer labels using id2label dict
        labels = torch.tensor([id2label[i] for i in identities]).to(device)
        optimizer.zero_grad()
        embeddings = model(waveforms)
        loss = arcface_loss(embeddings, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        del waveforms, embeddings, labels, loss
        # Call garbage collection and empty the CUDA cache
        gc.collect()
        torch.cuda.empty_cache()
    return total_loss / len(train_loader)

In [9]:
# =============================================================================
# 7. Multi-Speaker Dataset Creation (Mixing Utterances)
#    (Inspired by LibriMix: https://github.com/JorisCos/LibriMix/blob/master/generate_librimix.sh)
# =============================================================================
def mix_utterances(file1, file2, target_sr=16000, snr_db=0):
    """
    Loads two audio files and creates a mixture.
    Here we simply sum the waveforms (with equal energy).
    """
    y1, sr1 = librosa.load(file1, sr=target_sr)
    y2, sr2 = librosa.load(file2, sr=target_sr)
    # Truncate to same length
    min_len = min(len(y1), len(y2))
    y1 = y1[:min_len]
    y2 = y2[:min_len]
    # (Optional: adjust relative energy based on desired SNR.)
    mixture = y1 + y2
    mixture = mixture / np.max(np.abs(mixture) + 1e-7)
    return mixture, y1, y2

In [11]:
class MultiSpeakerDataset(Dataset):
    """
    Creates a multi-speaker dataset by pairing utterances from two speakers.
    For VoxCeleb2, we assume that for each identity there is a metadata txt file
    (in vox2/txt) and corresponding audio in vox2/aac.
    """
    def __init__(self, identity_list, metadata_dir, audio_dir, transform=None, max_duration=3.0, target_sr=16000):
        self.samples_by_id = {}
        self.transform = transform
        self.max_duration = max_duration  # in seconds, e.g., 3.0
        self.target_sr = target_sr
        # Collect all available utterances for each identity
        for identity in identity_list:
            meta_files = glob.glob(os.path.join(metadata_dir, identity, '*.txt'))
            files = []
            for meta_file in meta_files:
                with open(meta_file, 'r') as f:
                    for line in f:
                        audio_file = line.strip()  # assumed to be relative path
                        full_audio_path = os.path.join(audio_dir, identity, audio_file)
                        files.append(full_audio_path)
            if files:
                self.samples_by_id[identity] = files
        # Randomly pair utterances for mixing
        # random.shuffle(self.samples)
        # self.pairs = []
        # for i in range(0, len(self.samples) - 1, 2):
        #     self.pairs.append((self.samples[i], self.samples[i+1]))
        self.identities = list(self.samples_by_id.keys())
    
    def __len__(self):
        # return len(self.pairs)
        return math.comb(len(self.identities), 2)  # arbitraryly chosen for pairing
    
    def __getitem__(self, idx):
        # (id1, file1), (id2, file2) = self.pairs[idx]
        # mixture, source1, source2 = mix_utterances(file1, file2)
        # return (id1, id2), mixture, source1, source2
        #####################################################################
        # id1, id2 = random.sample(self.identities, 2)
        # file1 = random.choice(self.samples_by_id[id1])
        # file2 = random.choice(self.samples_by_id[id2])
        # mixture, source1, source2 = mix_utterances(file1, file2)
        # return (id1, id2), mixture, source1, source2
        ######################################################################
        # Randomly choose two different speakers
        id1, id2 = random.sample(self.identities, 2)
        file1 = random.choice(self.samples_by_id[id1])
        file2 = random.choice(self.samples_by_id[id2])
        # Create a mixture from the two audio files.
        mixture, source1, source2 = mix_utterances(file1, file2, target_sr=self.target_sr)
        # If a maximum duration is specified, trim the signals.
        if self.max_duration is not None:
            max_samples = int(self.max_duration * self.target_sr)
            mixture = mixture[:max_samples]
            source1 = source1[:max_samples]
            source2 = source2[:max_samples]
        return (id1, id2), mixture, source1, source2

In [12]:
# =============================================================================
# 8. Pre-trained SepFormer Wrapper for Speech Separation & Enhancement
#    (Using SpeechBrain’s pre-trained SepFormer model: https://huggingface.co/speechbrain/sepformer-whamr)
# =============================================================================
class SepFormerWrapper:
    def __init__(self, model_name="speechbrain/sepformer-whamr"):
        from speechbrain.pretrained import SepformerSeparation as SepFormer
        self.model = SepFormer.from_hparams(source=model_name, savedir="pretrained_sepformer")
    
    def separate(self, mixture, sample_rate=8000):
        """
        Writes the mixture to a temporary file, calls the SepFormer separation,
        and returns the estimated sources.
        """
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
            sf.write(temp_file.name, mixture, sample_rate)
            temp_filename = temp_file.name
        est_sources = self.model.separate_file(temp_filename)
        os.remove(temp_filename)
        return est_sources  # list of separated sources (as numpy arrays)

In [13]:
# =============================================================================
# 9. Novel Pipeline Combining Speaker Identification and SepFormer for
#    Joint Separation & Enhancement.
# =============================================================================
class SpeakerSeparationPipeline(nn.Module):
    def __init__(self, sepformer_wrapper, speaker_verification_model):
        """
        sepformer_wrapper: an instance of SepFormerWrapper for inference-only enhancement.
        speaker_verification_model: the speaker identification model.
        """
        super(SpeakerSeparationPipeline, self).__init__()
        # The SepFormer is used in inference mode only (it is frozen and not part of the backpropagation).
        self.sepformer = sepformer_wrapper  
        self.speaker_model = speaker_verification_model

    def forward(self, mixture):
        """
        mixture: Tensor of shape (1, T)
        Returns:
            enhanced_sources_tensors: list of tensors of enhanced speech signals.
            embeddings: list of speaker embeddings computed from each enhanced source.
        """
        # Convert the mixture tensor to a numpy array (SepFormerWrapper expects numpy input).
        mixture_np = mixture.squeeze(0).cpu().numpy()
        # Call the real pre-trained SepFormer model for speech enhancement.
        enhanced_sources = self.sepformer.separate(mixture_np, sample_rate=8000)
        enhanced_sources_tensors = []
        embeddings = []
        for src in enhanced_sources:
            # Convert each enhanced source back to a tensor.
            src_tensor = torch.tensor(src).float().to(next(self.speaker_model.parameters()).device).unsqueeze(0)
            enhanced_sources_tensors.append(src_tensor)
            # Compute the speaker embedding using the speaker identification model.
            emb = self.speaker_model(src_tensor)
            embeddings.append(emb)
        return enhanced_sources_tensors, embeddings

In [14]:
# =============================================================================
# 10. Training and Evaluation Functions for the New Pipeline
# =============================================================================
def train_pipeline(pipeline, train_loader, optimizer, criterion, id2label, device):
    pipeline.train()
    total_loss = 0
    for batch in train_loader:
        speaker_ids, mixture, source1, source2 = batch
        mixture_tensor = torch.tensor(mixture).float().to(device).unsqueeze(0)  # shape (1, T)
        optimizer.zero_grad()
        separated_sources, embeddings = pipeline(mixture_tensor)
        # Create ground truth labels for each separated source using speaker_ids
        # Here, we assume a mapping (id2label) exists; both speakers are used.
        labels = torch.tensor([id2label[speaker_ids[0]], id2label[speaker_ids[1]]]).to(device)
        # Concatenate embeddings (assume each embedding is 1xD)
        emb_concat = torch.cat(embeddings, dim=0)
        loss = criterion(emb_concat, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        del labels, emb_concat, loss, embeddings, separated_sources
        # Call garbage collection and empty the CUDA cache
        gc.collect()
        torch.cuda.empty_cache()
    return total_loss / len(train_loader)

def evaluate_sepformer_baseline(test_loader, sepformer_wrapper, sv_model, enrollment_db, device):
    """
    Evaluate the multi-speaker test set using the pre-trained SepFormer alone (baseline).
    For each mixture, the function:
      - Uses SepFormer to separate and enhance the mixture.
      - Computes SDR, SIR, SAR (via mir_eval) and PESQ.
      - Computes Rank-1 speaker identification accuracy by comparing
        each separated source embedding (from sv_model) to the enrollment database.
        
    Args:
        test_loader (DataLoader): DataLoader for the multi-speaker test set.
        sepformer_wrapper (SepFormerWrapper): Pre-trained SepFormer for enhancement.
        sv_model (nn.Module): Speaker verification model.
        enrollment_db (dict): Mapping from speaker_id to enrollment embedding.
        device (torch.device): Computation device.
        
    Returns:
        dict: A dictionary with keys "SDR", "SIR", "SAR", "PESQ", "Rank1_Identification_Accuracy".
    """
    sv_model.eval()
    all_sdr, all_sir, all_sar, all_pesq = [], [], [], []
    correct_count = 0
    total_count = 0
    
    for batch in test_loader:
        # Each batch: (speaker_ids, mixture, source1, source2)
        speaker_ids, mixture, source1, source2 = batch  # speaker_ids is a tuple/list, e.g., (spk1, spk2)
        # Convert mixture (assumed as a numpy array or list) to tensor
        mixture_tensor = torch.tensor(mixture).float().to(device).unsqueeze(0)  # shape: [1, T]
        # Convert mixture to numpy array for SepFormerWrapper (expects numpy input)
        mixture_np = mixture_tensor.squeeze(0).cpu().numpy()
        # Use pre-trained SepFormer to separate/enhance the mixture
        enhanced_sources = sepformer_wrapper.separate(mixture_np, sample_rate=8000)  # list of np arrays
        
        # Ground truth sources (assumed to be numpy arrays or lists)
        gt_sources = [np.array(source1), np.array(source2)]
        # Align lengths: truncate signals to the minimum length among estimates and ground truths
        min_len = min(len(gt_sources[0]), len(enhanced_sources[0]), len(gt_sources[1]), len(enhanced_sources[1]))
        gt_sources = [src[:min_len] for src in gt_sources]
        est_sources = [src[:min_len] for src in enhanced_sources]
        
        # Compute separation metrics using mir_eval
        try:
            sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(np.stack(gt_sources), np.stack(est_sources))
            all_sdr.append(np.mean(sdr))
            all_sir.append(np.mean(sir))
            all_sar.append(np.mean(sar))
        except Exception as e:
            print("Error computing mir_eval metrics:", e)
        
        # Compute PESQ for each source (using sample_rate=8000 and 'wb' mode)
        pesq_scores = []
        for i in range(len(gt_sources)):
            try:
                score = pesq(8000, gt_sources[i], est_sources[i], 'wb')
                pesq_scores.append(score)
            except Exception as e:
                print("Error computing PESQ:", e)
        if pesq_scores:
            all_pesq.append(np.mean(pesq_scores))
        else:
            all_pesq.append(0.0)
        
        # Speaker Identification: For each enhanced source, compute its embedding using sv_model
        enhanced_embeddings = []
        for est in est_sources:
            est_tensor = torch.tensor(est).float().to(device).unsqueeze(0)  # shape: [1, T]
            emb = sv_model(est_tensor)
            enhanced_embeddings.append(emb)
        
        # For each computed embedding, perform nearest neighbor search against the enrollment_db
        for emb, true_spk in zip(enhanced_embeddings, speaker_ids):
            emb = emb.squeeze(0)  # shape: [embedding_dim]
            max_sim = -1
            pred_spk = None
            for spk, enroll_emb in enrollment_db.items():
                sim = F.cosine_similarity(emb.unsqueeze(0), enroll_emb.unsqueeze(0)).item()
                if sim > max_sim:
                    max_sim = sim
                    pred_spk = spk
            if pred_spk == true_spk:
                correct_count += 1
            total_count += 1
    
    rank1_accuracy = (correct_count / total_count * 100) if total_count > 0 else 0.0
    results = {
        "SDR": np.mean(all_sdr) if all_sdr else None,
        "SIR": np.mean(all_sir) if all_sir else None,
        "SAR": np.mean(all_sar) if all_sar else None,
        "PESQ": np.mean(all_pesq) if all_pesq else None,
        "Rank1_Identification_Accuracy": rank1_accuracy
    }
    return results

In [15]:
def compute_enrollment_db(enrollment_data, speaker_model, device):
    """
    Computes enrollment embeddings for each speaker.
    
    Args:
        enrollment_data (dict): Mapping from speaker_id (str) to enrollment file path.
        speaker_model (nn.Module): Pre-trained/fine-tuned speaker verification model.
        device (torch.device): Device to run computations on.
        
    Returns:
        dict: Mapping from speaker_id to enrollment embedding (torch.Tensor).
    """
    enrollment_db = {}
    for spk, file_path in enrollment_data.items():
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.to(device)
        # Ensure waveform shape is [batch, time]
        if waveform.dim() == 3 and waveform.size(1) == 1:
            waveform = waveform.squeeze(1)
        # Compute embedding (unsqueeze for batch dimension)
        emb = speaker_model(waveform.unsqueeze(0))
        enrollment_db[spk] = emb.squeeze(0)  # store as [embedding_dim]
    return enrollment_db

def evaluate_pipeline(pipeline, test_loader, enrollment_db, device):
    """
    Evaluate the pipeline on the multi-speaker test dataset.
    Computes SDR, SIR, SAR using mir_eval, PESQ using the pesq package,
    and actual Rank‑1 identification accuracy by comparing each enhanced source
    embedding to an enrollment database.
    
    Args:
        pipeline (nn.Module): The combined speaker separation/enhancement pipeline.
        test_loader (DataLoader): DataLoader for the multi-speaker test set.
        enrollment_db (dict): Mapping from speaker_id (str) to enrollment embedding.
        device (torch.device): Device to run computations on.
        
    Returns:
        dict: Evaluation metrics (SDR, SIR, SAR, PESQ, Rank‑1 Identification Accuracy).
    """
    pipeline.eval()
    all_sdr, all_sir, all_sar, all_pesq = [], [], [], []
    correct_count = 0
    total_count = 0
    
    for batch in test_loader:
        # Assume batch returns: (speaker_ids, mixture, source1, source2)
        speaker_ids, mixture, source1, source2 = batch  # speaker_ids is a tuple/list, e.g., (spk1, spk2)
        # Convert mixture (assumed as numpy array or list) to tensor.
        mixture_tensor = torch.tensor(mixture).float().to(device).unsqueeze(0)
        
        with torch.no_grad():
            enhanced_sources_tensors, embeddings = pipeline(mixture_tensor)
        
        # Convert enhanced outputs to numpy arrays for mir_eval.
        est_sources_np = [src.squeeze(0).detach().cpu().numpy() for src in enhanced_sources_tensors]
        # Ground truth sources from dataset (convert to numpy arrays if not already)
        gt_sources = [np.array(source1), np.array(source2)]
        # Align lengths: truncate all signals to the minimum length.
        min_len = min(len(gt_sources[0]), len(est_sources_np[0]), len(gt_sources[1]), len(est_sources_np[1]))
        gt_sources = [src[:min_len] for src in gt_sources]
        est_sources_np = [src[:min_len] for src in est_sources_np]
        
        # Compute separation metrics using mir_eval
        try:
            sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(np.stack(gt_sources), np.stack(est_sources_np))
            all_sdr.append(np.mean(sdr))
            all_sir.append(np.mean(sir))
            all_sar.append(np.mean(sar))
        except Exception as e:
            print("Error computing mir_eval metrics:", e)
        
        # Compute PESQ (sample_rate=8000, mode 'wb' for wideband)
        pesq_scores = []
        for i in range(len(gt_sources)):
            try:
                score = pesq(8000, gt_sources[i], est_sources_np[i], 'wb')
                pesq_scores.append(score)
            except Exception as e:
                print("Error computing PESQ:", e)
        if pesq_scores:
            all_pesq.append(np.mean(pesq_scores))
        else:
            all_pesq.append(0.0)
        
        # Compute actual Rank-1 Identification Accuracy:
        # For each enhanced source embedding, compare to enrollment_db using cosine similarity.
        # Assume that ordering of embeddings corresponds to ordering in speaker_ids.
        for emb, true_spk in zip(embeddings, speaker_ids):
            emb = emb.squeeze(0)  # shape: [embedding_dim]
            max_sim = -1
            pred_spk = None
            for spk, enroll_emb in enrollment_db.items():
                sim = F.cosine_similarity(emb.unsqueeze(0), enroll_emb.unsqueeze(0)).item()
                if sim > max_sim:
                    max_sim = sim
                    pred_spk = spk
            if pred_spk == true_spk:
                correct_count += 1
            total_count += 1

    rank1_accuracy = (correct_count / total_count * 100) if total_count > 0 else 0.0

    results = {
        "SDR": np.mean(all_sdr) if all_sdr else None,
        "SIR": np.mean(all_sir) if all_sir else None,
        "SAR": np.mean(all_sar) if all_sar else None,
        "PESQ": np.mean(all_pesq) if all_pesq else None,
        "Rank1_Identification_Accuracy": rank1_accuracy
    }
    return results

In [16]:
# =============================================================================
# 11. Main Execution Pipeline
# =============================================================================

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [18]:
def trial_collate_fn(batch):
    """
    Custom collate function for VoxCelebTrialDataset.
    Each item in batch: (label, enroll_waveform, test_waveform, sr)
    Pads enrollment and test waveforms to the max length in the batch.
    """
    labels = []
    enroll_waveforms = []
    test_waveforms = []
    srs = []
    
    # Determine max lengths for enrollment and test waveforms separately
    max_len_enroll = max(item[1].size(1) for item in batch)
    max_len_test = max(item[2].size(1) for item in batch)
    
    for label, enroll_waveform, test_waveform, sr in batch:
        labels.append(label)
        # Pad enrollment waveform (assumed shape: [channels, time])
        if enroll_waveform.size(1) < max_len_enroll:
            pad_amt = max_len_enroll - enroll_waveform.size(1)
            enroll_waveform = F.pad(enroll_waveform, (0, pad_amt))
        enroll_waveforms.append(enroll_waveform)
        
        # Pad test waveform
        if test_waveform.size(1) < max_len_test:
            pad_amt = max_len_test - test_waveform.size(1)
            test_waveform = F.pad(test_waveform, (0, pad_amt))
        test_waveforms.append(test_waveform)
        
        srs.append(sr)
    
    # Stack padded tensors and convert labels to tensor
    labels = torch.tensor(labels)
    enroll_waveforms = torch.stack(enroll_waveforms)
    test_waveforms = torch.stack(test_waveforms)
    
    return labels, enroll_waveforms, test_waveforms, srs

In [19]:
# ----------------------------
# I. Speaker Verification Evaluation (Pre-trained)
# ----------------------------
sv_model = SpeakerVerificationModel(model_name='microsoft/wavlm-base-plus').to(device)
trial_dataset = VoxCelebTrialDataset(trial_file="vox1/trial_pairs.txt", root_dir="vox1/wav")


In [20]:
trial_loader = DataLoader(trial_dataset, batch_size=4, shuffle=False,collate_fn=trial_collate_fn)

In [21]:
sv_model_pretrained = copy.deepcopy(sv_model)

In [22]:
scores, labels_list = [], []
sv_model_pretrained.eval()
with torch.no_grad():
    for batch in trial_loader:  # Limit to first 2 batches for demonstration
        label, enroll_waveform, test_waveform, sr = batch
        enroll_waveform = enroll_waveform.to(device)
        test_waveform = test_waveform.to(device)
        enroll_emb = sv_model_pretrained(enroll_waveform)
        test_emb = sv_model_pretrained(test_waveform)
        cosine_sim = F.cosine_similarity(enroll_emb, test_emb)
        scores.extend(cosine_sim.cpu().numpy().tolist())
        labels_list.extend(label)
        del enroll_waveform, test_waveform, enroll_emb, test_emb, cosine_sim
        # Call garbage collection and empty the CUDA cache
        gc.collect()
        # Free unused memory after each batch
        torch.cuda.empty_cache()

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


In [25]:
eer = compute_eer(np.array(scores), np.array(labels_list))
tar = compute_tar_at_far(np.array(scores), np.array(labels_list), target_far=0.01)
compute_identification_accuracy(labels_list, scores, threshold=0.5)
# Print results
print(f"Pre-trained Speaker Verification: EER = {eer:.2f}%, TAR@1%FAR = {tar:.2f}%, Speaker Identification Accuracy = {compute_identification_accuracy(labels_list, scores, threshold=0.5):.2f}%")

Pre-trained Speaker Verification: EER = 49.28%, TAR@1%FAR = 2.79%, Speaker Identification Accuracy = 50.48%


In [26]:
import torch.nn.functional as F

def voxceleb_collate_fn(batch):
    """
    Custom collate function for VoxCelebDataset.
    Each item in the batch is (identity, waveform, sr).
    Pads each waveform to the maximum length in the batch.
    """
    identities = []
    waveforms = []
    srs = []
    
    # Determine maximum length across the batch.
    max_len = max(item[1].size(1) for item in batch)
    
    for identity, waveform, sr in batch:
        identities.append(identity)
        # Pad waveform on the time dimension if needed.
        if waveform.size(1) < max_len:
            pad_amt = max_len - waveform.size(1)
            waveform = F.pad(waveform, (0, pad_amt))
        waveforms.append(waveform)
        srs.append(sr)
    
    waveforms = torch.stack(waveforms, dim=0)
    return identities, waveforms, srs


In [27]:
import os
import glob

def create_enrollment_data(root_dir):
    """
    Creates a dictionary mapping speaker IDs to an enrollment audio file path.
    
    Assumes structure: root_dir/<speaker_id>/<subfolder>/<utterance_file>.wav.
    For each speaker, the function recursively searches through subdirectories and selects the first found .wav file.
    
    Args:
        root_dir (str): Path to the directory containing speaker folders.
        
    Returns:
        dict: Mapping from speaker_id (str) to enrollment audio file path (str).
    """
    enrollment_data = {}
    # List speaker directories in sorted order
    speaker_dirs = sorted(os.listdir(root_dir))
    for speaker in speaker_dirs:
        speaker_path = os.path.join(root_dir, speaker)
        if os.path.isdir(speaker_path):
            # Recursively search for .m4a files in all subfolders
            m4a_files = glob.glob(os.path.join(speaker_path, '**', '*.m4a'), recursive=True)
            if m4a_files:
                # Select the first m4a file as the enrollment utterance
                enrollment_data[speaker] = m4a_files[0]
    return enrollment_data


In [28]:
root_audio_dir = "vox2/aac"
enrollment_data = create_enrollment_data(root_audio_dir)

In [None]:
# print("Enrollment Data:")
# for spk, file_path in enrollment_data.items():
#     print(f"Speaker {spk}: {file_path}")

In [33]:
torch.cuda.empty_cache()

In [34]:
enrollment_db = compute_enrollment_db(enrollment_data, sv_model_pretrained, device)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
enrollment_db_pretrained = compute_enrollment_db(enrollment_data, sv_model_pretrained, device)

In [30]:
# ---------------------------------------------------------------------------
# II. Fine-tuning Speaker Verification with VoxCeleb2 using LoRA and ArcFace
# ---------------------------------------------------------------------------
identities = sorted(os.listdir("vox2/txt"))
train_ids = identities[:100]
test_ids = identities[100:]
print(f"Fine-tuning Speaker Verification with {len(train_ids)} train identities, testing on {len(test_ids)} identities.")
# Create a mapping from identity to label index
trim_transform = lambda x: trim_audio(x, sr=16000, max_duration=3.0)
train_dataset = VoxCelebDataset(root_dir="vox2/aac", identities=train_ids,transform=trim_transform, file_ext='.m4a')
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True,collate_fn=voxceleb_collate_fn)

# Create a mapping from identity to label index
id2label = {identity: idx for idx, identity in enumerate(train_ids)}

# Apply LoRA to the speaker verification model
apply_lora(sv_model, r=4, alpha=1.0)

# Instantiate ArcFace loss (using embedding dimension from model)
arcface_loss = ArcFaceLoss(embedding_size=sv_model.embedding_dim, num_classes=len(train_ids)).to(device)
optimizer = torch.optim.Adam(list(sv_model.parameters()) + list(arcface_loss.parameters()), lr=1e-4)

num_epochs = 10
for epoch in range(num_epochs):
    loss = train_speaker_verification(sv_model, train_loader, optimizer, arcface_loss, id2label, device)
    print(f"Fine-tuning Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

# Evaluate fine-tuned model on VoxCeleb1 trial pairs
scores_ft, labels_list_ft = [], []
sv_model.eval()
with torch.no_grad():
    for batch in trial_loader:
        label, enroll_waveform, test_waveform, sr = batch
        enroll_waveform = enroll_waveform.to(device)
        test_waveform = test_waveform.to(device)
        enroll_emb = sv_model(enroll_waveform)
        test_emb = sv_model(test_waveform)
        cosine_sim = F.cosine_similarity(enroll_emb, test_emb)
        scores_ft.extend(cosine_sim.cpu().numpy().tolist())
        labels_list_ft.extend(label)
        del enroll_waveform, test_waveform, enroll_emb, test_emb, cosine_sim
        # Call garbage collection and empty the CUDA cache
        gc.collect()
        # Free unused memory after each batch
        torch.cuda.empty_cache()


Fine-tuning Speaker Verification with 42 train identities, testing on 0 identities.


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
eer_ft = compute_eer(np.array(scores_ft), np.array(labels_list_ft))
tar_ft = compute_tar_at_far(np.array(scores_ft), np.array(labels_list_ft), target_far=0.01)
compute_identification_accuracy(labels_list_ft, scores_ft, threshold=0.5)
print(f"Fine-tuned Speaker Verification: EER = {eer_ft:.2f}%, TAR@1%FAR = {tar_ft:.2f}%, Fine Tuned Speaker Identification Accuracy = {compute_identification_accuracy(labels_list, scores, threshold=0.5):.2f}%")

In [None]:
enrollment_db_finetuned = compute_enrollment_db(enrollment_data, sv_model, device)

In [None]:
# ----------------------------
# III. Create Multi-Speaker Scenario Dataset (Mixing VoxCeleb2 Utterances)
# ----------------------------
identities_multi = sorted(os.listdir("vox2/txt"))
train_multi_ids = identities_multi[:50]
test_multi_ids = identities_multi[50:100]

train_multi_dataset = MultiSpeakerDataset(identity_list=train_multi_ids,
                                            metadata_dir="vox2/txt",
                                            audio_dir="vox2/aac", max_duration=3.0, target_sr=16000)
test_multi_dataset = MultiSpeakerDataset(identity_list=test_multi_ids,
                                            metadata_dir="vox2/txt",
                                            audio_dir="vox2/aac", max_duration=3.0, target_sr=16000)
print(f"Multi-Speaker Dataset: {len(train_multi_dataset)} training samples, {len(test_multi_dataset)} testing samples.")
train_multi_loader = DataLoader(train_multi_dataset, batch_size=4, shuffle=True)
test_multi_loader = DataLoader(test_multi_dataset, batch_size=4, shuffle=False)

In [None]:
# ----------------------------
# IV. Speech Separation & Enhancement using SepFormer
# ----------------------------
sepformer_wrapper = SepFormerWrapper(model_name="speechbrain/sepformer-whamr")

In [None]:
# ----------------------------
# Baseline Evaluation using pre-trained SepFormer
# ----------------------------
baseline_results = evaluate_sepformer_baseline(test_multi_loader, sepformer_wrapper, sv_model_pretrained, enrollment_db_pretrained, device)
print("Baseline Evaluation with pre-trained SepFormer and Speaker Verification:")
for metric, value in baseline_results.items():
    print(f"{metric}: {value:.2f}")

In [None]:
# ----------------------------
# V. Novel Pipeline: Combining Speaker Identification with SepFormer
# ----------------------------
pipeline = SpeakerSeparationPipeline(sepformer_wrapper=sepformer_wrapper,speaker_verification_model=sv_model).to(device)
pipeline_optimizer = torch.optim.Adam(pipeline.parameters(), lr=1e-4)
pipeline_criterion = ArcFaceLoss(embedding_size=sv_model.embedding_dim, num_classes=len(test_multi_ids)).to(device)

num_pipeline_epochs = 5
# For the pipeline training, we create a dummy id2label mapping for the test multi-speaker set.
id2label_multi = {identity: idx for idx, identity in enumerate(test_multi_ids)}
for epoch in range(num_pipeline_epochs):
    loss = train_pipeline(pipeline, train_multi_loader, pipeline_optimizer, pipeline_criterion, id2label_multi, device)
    print(f"Pipeline Training Epoch {epoch+1}/{num_pipeline_epochs}, Loss: {loss:.4f}")

# Evaluate the novel pipeline on the multi-speaker test set
pipeline_results = evaluate_pipeline(pipeline, test_multi_loader, enrollment_db_finetuned, device)
print("Final Evaluation Results for Novel Pipeline on Multi-Speaker Test Set:")
for metric, value in pipeline_results.items():
    print(f"{metric}: {value:.2f}")