## Setup Environment

In [None]:
!git clone https://github.com/microsoft/UniSpeech.git

In [None]:
!git clone https://github.com/pytorch/fairseq.git

In [None]:
!pip install --force pip==24.0

In [None]:
!pip install s3prl fire omegaconf==2.2.0

In [11]:
import os
os.chdir("/kaggle/working/fairseq")

In [None]:
!pip install --editable ./

In [7]:
!wget https://mm.kaist.ac.kr/datasets/voxceleb/meta/veri_test2.txt

--2025-04-06 15:50:29--  https://mm.kaist.ac.kr/datasets/voxceleb/meta/veri_test2.txt
Resolving mm.kaist.ac.kr (mm.kaist.ac.kr)... 143.248.39.47
Connecting to mm.kaist.ac.kr (mm.kaist.ac.kr)|143.248.39.47|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2331882 (2.2M) [text/plain]
Saving to: ‘veri_test2.txt’


2025-04-06 15:50:38 (309 KB/s) - ‘veri_test2.txt’ saved [2331882/2331882]



In [8]:
!mv /kaggle/working/fairseq/veri_test2.txt /kaggle/working/UniSpeech/downstreams/speaker_verification/veri_test2.txt

In [3]:
import pandas as pd
df= pd.read_csv('/kaggle/working/UniSpeech/downstreams/speaker_verification/veri_test2.txt', sep=" ", header=None)
df.columns

Index([0, 1, 2], dtype='int64')

In [33]:
os.chdir("/kaggle/working")

In [None]:
!python verification.py --model_name wavlm_base_plus --wav1 /kaggle/input/vox-celeb/vox_celeb/vox1/vox1_test_wav/wav/id10270/x6uYqmx31kE/00001.wav --wav2 /kaggle/input/vox-celeb/vox_celeb/vox1/vox1_test_wav/wav/id10270/8jEAjG6SegY/00008.wav --checkpoint /kaggle/input/wavelm_base_plus/pytorch/default/1/wavlm_base_plus_nofinetune.pth 

In [55]:
model = init_model(model_name="wavlm_base_plus", checkpoint="/kaggle/input/wavelm_base_plus/pytorch/default/1/wavlm_base_plus_nofinetune.pth")

Downloading: "https://github.com/s3prl/s3prl/zipball/main" to /root/.cache/torch/hub/main.zip
  torchaudio.set_audio_backend("sox_io")
Downloading: https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base_plus.pt
Destination: /root/.cache/s3prl/download/72cb34edf8a3724c720467cf40b77ad20b1b714b5f694e9db57f521467f9006b.wavlm_base_plus.pt
100%|██████████| 360M/360M [00:04<00:00, 85.9MB/s] 
  checkpoint = torch.load(ckpt)
  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)


## Speaker Separation on Mixed Dataset using Sepformer

In [None]:
!pip install speechbrain pydub pesq museval torchaudio

In [26]:
import os

aac_dir = "/kaggle/input/vox-celeb/vox2_test_aac/aac"
all_ids = sorted(os.listdir(aac_dir))
train_ids = all_ids[:50]
test_ids = all_ids[50:100]

In [27]:
import random
from pydub import AudioSegment
from pathlib import Path

def mix_utterances(speaker1_path, speaker2_path, output_path):
    audio1 = AudioSegment.from_file(speaker1_path)
    audio2 = AudioSegment.from_file(speaker2_path)

    # Ensure same duration
    min_len = min(len(audio1), len(audio2))
    mixed = audio1[:min_len].overlay(audio2[:min_len])

    mixed.export(output_path / "mixture.wav", format="wav")
    audio1[:min_len].export(output_path / "source1.wav", format="wav")
    audio2[:min_len].export(output_path / "source2.wav", format="wav")

In [56]:
import random
import json
from pathlib import Path

def create_mixtures(id_list, base_dir, out_dir, num_samples=100):
    speaker_map = {}
    mix_idx = 0  # Ensures sequential naming for skipped mixes

    for _ in range(num_samples):
        spk1, spk2 = random.sample(id_list, 2)

        spk1_files = list(Path(base_dir, spk1).rglob("*.m4a"))
        spk2_files = list(Path(base_dir, spk2).rglob("*.m4a"))

        if not spk1_files or not spk2_files:
            continue  # Skip if no valid audio files

        f1 = random.choice(spk1_files)
        f2 = random.choice(spk2_files)

        mix_id = f"mix_{mix_idx}"
        out_path = Path(out_dir, mix_id)
        out_path.mkdir(parents=True, exist_ok=True)

        # Call your mixing function (assumed defined elsewhere)
        mix_utterances(f1, f2, out_path)

        # Update speaker map
        speaker_map[mix_id] = [spk1, spk2]
        mix_idx += 1

    # Save speaker map JSON
    speaker_map_path = Path(out_dir) / "test_mix_speaker_map.json"
    with open(speaker_map_path, "w") as f:
        json.dump(speaker_map, f, indent=2)

    print(f"Created {mix_idx} mixtures and saved speaker map to {speaker_map_path}")


In [57]:
create_mixtures(train_ids, "/kaggle/input/vox-celeb/vox2_test_aac/aac", "/kaggle/working/train_mixes")
create_mixtures(test_ids, "/kaggle/input/vox-celeb/vox2_test_aac/aac", "/kaggle/working/test_mixes")

Created 100 mixtures and saved speaker map to /kaggle/working/train_mixes/test_mix_speaker_map.json
Created 100 mixtures and saved speaker map to /kaggle/working/test_mixes/test_mix_speaker_map.json


In [42]:
import os
from speechbrain.pretrained import SepformerSeparation as separator
from scipy.io import wavfile
import torchaudio
import numpy as np
from pesq import pesq
import museval
from tqdm import tqdm

# Load SepFormer
sepformer = separator.from_hparams(source="speechbrain/sepformer-whamr", savedir="/kaggle/working/tmpdir_sepformer")

# Your test mix directory
test_mix_root = "/kaggle/working/test_mixes"

# Storage for results
all_metrics = []

def evaluate(ref1_path, ref2_path, sep1_path, sep2_path):
    rate_ref, ref1 = wavfile.read(ref1_path)
    _, ref2 = wavfile.read(ref2_path)
    _, sep1 = wavfile.read(sep1_path)
    _, sep2 = wavfile.read(sep2_path)

    # Squeeze & align lengths
    ref1 = ref1.squeeze()[:min(len(ref1), len(ref2), len(sep1), len(sep2))]
    ref2 = ref2.squeeze()[:len(ref1)]
    sep1 = sep1.squeeze()[:len(ref1)]
    sep2 = sep2.squeeze()[:len(ref1)]

    refs = np.stack([ref1, ref2], axis=0).astype(np.float32)
    ests = np.stack([sep1, sep2], axis=0).astype(np.float32)

    sdr, sir, sar, _ = museval.metrics.bss_eval_sources(refs, ests)

    pesq1 = pesq(rate_ref, ref1.astype(np.int16), sep1.astype(np.int16), 'wb')
    pesq2 = pesq(rate_ref, ref2.astype(np.int16), sep2.astype(np.int16), 'wb')

    return {
        "SDR": sdr.tolist(),
        "SIR": sir.tolist(),
        "SAR": sar.tolist(),
        "PESQ": [pesq1, pesq2]
    }

# Loop through test samples
for mix_dir in tqdm(sorted(os.listdir(test_mix_root))):
    mix_path = os.path.join(test_mix_root, mix_dir)
    mixture_file = os.path.join(mix_path, "mixture.wav")
    source1_file = os.path.join(mix_path, "source1.wav")
    source2_file = os.path.join(mix_path, "source2.wav")

    # Resample to 8kHz if needed
    mixture, sr = torchaudio.load(mixture_file)
    if sr != 8000:
        resampler = torchaudio.transforms.Resample(sr, 8000)
        mixture = resampler(mixture)
        torchaudio.save(mixture_file.replace(".wav", "_8k.wav"), mixture, 8000)
        mixture_file = mixture_file.replace(".wav", "_8k.wav")

    # Perform separation
    out = sepformer.separate_file(path=mixture_file)
    est_sources = out[0].transpose(0, 1)  # shape: [2, time]

    sep1_path = os.path.join(mix_path, "sep1.wav")
    sep2_path = os.path.join(mix_path, "sep2.wav")
    torchaudio.save(sep1_path, est_sources[0].unsqueeze(0), 8000)
    torchaudio.save(sep2_path, est_sources[1].unsqueeze(0), 8000)

    # Evaluate
    metrics = evaluate(source1_file, source2_file, sep1_path, sep2_path)
    all_metrics.append({
        "mix": mix_dir,
        **metrics
    })

# Print or save results
import pandas as pd
df = pd.DataFrame(all_metrics)
print(df.describe())  # Show summary stats (mean, std, etc.)


  state_dict = torch.load(path, map_location=device)
100%|██████████| 100/100 [15:21<00:00,  9.21s/it]

           mix                                           SDR  \
count      100                                           100   
unique     100                                           100   
top     mix_99  [[-16.98153023043959], [-20.36353317291269]]   
freq         1                                             1   

                                                SIR  \
count                                           100   
unique                                          100   
top     [[4.165183636201211], [1.0840597732567738]]   
freq                                              1   

                                                  SAR  \
count                                             100   
unique                                            100   
top     [[-15.539044871167116], [-17.83029501124999]]   
freq                                                1   

                                           PESQ  
count                                       100  
unique           




In [49]:
df

Unnamed: 0,mix,SDR,SIR,SAR,PESQ
0,mix_0,"[[-17.069044008415528], [-19.075894028940937]]","[[2.6752860304896875], [0.5818220955975363]]","[[-15.147255625523883], [-16.299522242165526]]","[1.0247418880462646, 1.0264379978179932]"
1,mix_1,"[[-18.337136202879897], [-13.457945461510212]]","[[1.6300755293341884], [3.05708682162188]]","[[-16.021859436201428], [-11.614584925789359]]","[1.0536715984344482, 1.0284733772277832]"
2,mix_10,"[[-15.139187247421532], [-20.111132596237624]]","[[5.166741657040355], [0.029240042038632418]]","[[-13.944691368552215], [-17.07317500983563]]","[1.0532764196395874, 1.0292764902114868]"
3,mix_11,"[[-21.24172847078424], [-18.964291162477803]]","[[0.3822814849974416], [2.126926014614753]]","[[-18.388380939904827], [-16.854616881689264]]","[1.0228641033172607, 1.0225387811660767]"
4,mix_12,"[[-20.961478873540358], [-18.56597937643512]]","[[-1.0836117722560068], [1.4838763634284122]]","[[-17.330764316931273], [-16.191400666490225]]","[1.0745307207107544, 1.0445177555084229]"
...,...,...,...,...,...
95,mix_95,"[[-22.887926256862134], [-17.8028707112885]]","[[-0.9532596572558356], [2.8820649519534]]","[[-19.34698726246016], [-15.961538827352156]]","[1.0292237997055054, 1.0773890018463135]"
96,mix_96,"[[-19.140684863118963], [-17.847187670473037]]","[[2.640839407758043], [2.7941847598760337]]","[[-17.224182282632636], [-15.975405833827896]]","[1.256512999534607, 1.0657751560211182]"
97,mix_97,"[[-20.76550731026028], [-19.145009551979907]]","[[-0.5534130170461288], [3.6685769006009012]]","[[-17.428134097432483], [-17.569849087999106]]","[1.025382399559021, 1.0720341205596924]"
98,mix_98,"[[-20.35340465330187], [-17.79552302211158]]","[[0.7223739964436847], [1.1420638474621916]]","[[-17.655256138210184], [-15.262998529353855]]","[1.1043754816055298, 1.234655499458313]"


In [50]:
# Flatten lists into columns
df[['SDR_1', 'SDR_2']] = pd.DataFrame(df['SDR'].tolist(), index=df.index)
df[['SIR_1', 'SIR_2']] = pd.DataFrame(df['SIR'].tolist(), index=df.index)
df[['SAR_1', 'SAR_2']] = pd.DataFrame(df['SAR'].tolist(), index=df.index)
df[['PESQ_1', 'PESQ_2']] = pd.DataFrame(df['PESQ'].tolist(), index=df.index)

# Drop original list-columns
df = df.drop(columns=['SDR', 'SIR', 'SAR', 'PESQ'])

# Now describe again
print(df.describe())


           PESQ_1      PESQ_2
count  100.000000  100.000000
mean     1.063066    1.064700
std      0.085249    0.090504
min      1.019078    1.017695
25%      1.026018    1.024959
50%      1.036226    1.034923
75%      1.062487    1.071523
max      1.682800    1.639908


In [58]:
import os
import json
import torch
import torchaudio
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load WavLM-based model
model = init_model(model_name="wavlm_base_plus", checkpoint="/kaggle/input/wavelm_base_plus/pytorch/default/1/wavlm_base_plus_nofinetune.pth")
model.to(device)
model.eval()

# Load audio with resampling and convert to mono
def load_audio(path, target_sr=16000):
    waveform, sr = torchaudio.load(path)  # [1, T] or [2, T]
    if sr != target_sr:
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
    if waveform.shape[0] > 1:  # stereo to mono
        waveform = waveform.mean(dim=0)
    return waveform.squeeze(0)  # [T]

# Cosine similarity between two embeddings
def cosine_sim(a, b):
    return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()

# Load speaker map
with open("/kaggle/working/test_mixes/test_mix_speaker_map.json", "r") as f:
    speaker_map = json.load(f)

correct = 0
total = 0
mix_root = "/kaggle/working/test_mixes"

for mix_id in tqdm(sorted(speaker_map.keys())):
    mix_dir = os.path.join(mix_root, mix_id)
    audio_paths = [
        os.path.join(mix_dir, "source1.wav"),
        os.path.join(mix_dir, "source2.wav"),
        os.path.join(mix_dir, "sep1.wav"),
        os.path.join(mix_dir, "sep2.wav")
    ]

    # Load and pad to max length
    waveforms = [load_audio(p) for p in audio_paths]
    max_len = max(w.shape[0] for w in waveforms)
    padded = [F.pad(w, (0, max_len - w.shape[0])) for w in waveforms]
    batch_tensor = torch.stack(padded).to(device)  # Shape: [4, T]

    # Get embeddings
    with torch.no_grad():
        embeddings = model(batch_tensor)
        embeddings = F.normalize(embeddings, p=2, dim=1).cpu()

    ref1_emb, ref2_emb, sep1_emb, sep2_emb = embeddings

    # Compare embeddings
    scores = {
        "sep1_ref1": cosine_sim(sep1_emb, ref1_emb),
        "sep1_ref2": cosine_sim(sep1_emb, ref2_emb),
        "sep2_ref1": cosine_sim(sep2_emb, ref1_emb),
        "sep2_ref2": cosine_sim(sep2_emb, ref2_emb),
    }

    # Assign speakers
    sep1_pred = speaker_map[mix_id][0] if scores["sep1_ref1"] > scores["sep1_ref2"] else speaker_map[mix_id][1]
    sep2_pred = speaker_map[mix_id][1] if scores["sep2_ref2"] > scores["sep2_ref1"] else speaker_map[mix_id][0]

    if sep1_pred == speaker_map[mix_id][0]:
        correct += 1
    if sep2_pred == speaker_map[mix_id][1]:
        correct += 1

    total += 2

# Final Rank-1 accuracy
acc = correct / total
print(f"Rank-1 Identification Accuracy: {acc * 100:.2f}%")


Using cache found in /root/.cache/torch/hub/s3prl_s3prl_main
  checkpoint = torch.load(ckpt)
  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
100%|██████████| 100/100 [00:20<00:00,  4.99it/s]

Rank-1 Identification Accuracy: 47.00%



