## Setup Environment

In [None]:
!git clone https://github.com/microsoft/UniSpeech.git

In [None]:
!git clone https://github.com/pytorch/fairseq.git

In [None]:
!pip install --force pip==24.0

In [None]:
!pip install s3prl fire omegaconf==2.2.0

In [11]:
import os
os.chdir("/kaggle/working/fairseq")

In [None]:
!pip install --editable ./

In [7]:
!wget https://mm.kaist.ac.kr/datasets/voxceleb/meta/veri_test2.txt

--2025-04-06 15:50:29--  https://mm.kaist.ac.kr/datasets/voxceleb/meta/veri_test2.txt
Resolving mm.kaist.ac.kr (mm.kaist.ac.kr)... 143.248.39.47
Connecting to mm.kaist.ac.kr (mm.kaist.ac.kr)|143.248.39.47|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2331882 (2.2M) [text/plain]
Saving to: ‘veri_test2.txt’


2025-04-06 15:50:38 (309 KB/s) - ‘veri_test2.txt’ saved [2331882/2331882]



In [8]:
!mv /kaggle/working/fairseq/veri_test2.txt /kaggle/working/UniSpeech/downstreams/speaker_verification/veri_test2.txt

In [3]:
import pandas as pd
df= pd.read_csv('/kaggle/working/UniSpeech/downstreams/speaker_verification/veri_test2.txt', sep=" ", header=None)
df.columns

Index([0, 1, 2], dtype='int64')

In [33]:
os.chdir("/kaggle/working")

In [None]:
!python verification.py --model_name wavlm_base_plus --wav1 /kaggle/input/vox-celeb/vox_celeb/vox1/vox1_test_wav/wav/id10270/x6uYqmx31kE/00001.wav --wav2 /kaggle/input/vox-celeb/vox_celeb/vox1/vox1_test_wav/wav/id10270/8jEAjG6SegY/00008.wav --checkpoint /kaggle/input/wavelm_base_plus/pytorch/default/1/wavlm_base_plus_nofinetune.pth 

## Speaker Verification on VoxCeleb using pretrained WavLM-Base-Plus

In [52]:
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as trans
import soundfile as sf
import fire
from torchaudio.transforms import Resample



''' Res2Conv1d + BatchNorm1d + ReLU
'''


class Res2Conv1dReluBn(nn.Module):
    '''
    in_channels == out_channels == channels
    '''

    def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
        super().__init__()
        assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
        self.scale = scale
        self.width = channels // scale
        self.nums = scale if scale == 1 else scale - 1

        self.convs = []
        self.bns = []
        for i in range(self.nums):
            self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
            self.bns.append(nn.BatchNorm1d(self.width))
        self.convs = nn.ModuleList(self.convs)
        self.bns = nn.ModuleList(self.bns)

    def forward(self, x):
        out = []
        spx = torch.split(x, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            # Order: conv -> relu -> bn
            sp = self.convs[i](sp)
            sp = self.bns[i](F.relu(sp))
            out.append(sp)
        if self.scale != 1:
            out.append(spx[self.nums])
        out = torch.cat(out, dim=1)

        return out


''' Conv1d + BatchNorm1d + ReLU
'''


class Conv1dReluBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x)))


''' The SE connection of 1D case.
'''


class SE_Connect(nn.Module):
    def __init__(self, channels, se_bottleneck_dim=128):
        super().__init__()
        self.linear1 = nn.Linear(channels, se_bottleneck_dim)
        self.linear2 = nn.Linear(se_bottleneck_dim, channels)

    def forward(self, x):
        out = x.mean(dim=2)
        out = F.relu(self.linear1(out))
        out = torch.sigmoid(self.linear2(out))
        out = x * out.unsqueeze(2)

        return out


''' SE-Res2Block of the ECAPA-TDNN architecture.
'''


class SE_Res2Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
        super().__init__()
        self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
        self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)

        self.shortcut = None
        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
            )

    def forward(self, x):
        residual = x
        if self.shortcut:
            residual = self.shortcut(x)

        x = self.Conv1dReluBn1(x)
        x = self.Res2Conv1dReluBn(x)
        x = self.Conv1dReluBn2(x)
        x = self.SE_Connect(x)

        return x + residual


''' Attentive weighted mean and standard deviation pooling.
'''


class AttentiveStatsPool(nn.Module):
    def __init__(self, in_dim, attention_channels=128, global_context_att=False):
        super().__init__()
        self.global_context_att = global_context_att

        # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
        if global_context_att:
            self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1)  # equals W and b in the paper
        else:
            self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1)  # equals W and b in the paper
        self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1)  # equals V and k in the paper

    def forward(self, x):

        if self.global_context_att:
            context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
            context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
            x_in = torch.cat((x, context_mean, context_std), dim=1)
        else:
            x_in = x

        # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
        alpha = torch.tanh(self.linear1(x_in))
        # alpha = F.relu(self.linear1(x_in))
        alpha = torch.softmax(self.linear2(alpha), dim=2)
        mean = torch.sum(alpha * x, dim=2)
        residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
        std = torch.sqrt(residuals.clamp(min=1e-9))
        return torch.cat([mean, std], dim=1)


class ECAPA_TDNN(nn.Module):
    def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
                 feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
        super().__init__()

        self.feat_type = feat_type
        self.feature_selection = feature_selection
        self.update_extract = update_extract
        self.sr = sr

        if feat_type == "fbank" or feat_type == "mfcc":
            self.update_extract = False

        win_len = int(sr * 0.025)
        hop_len = int(sr * 0.01)

        if feat_type == 'fbank':
            self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
                                                        hop_length=hop_len, f_min=0.0, f_max=sr // 2,
                                                        pad=0, n_mels=feat_dim)
        elif feat_type == 'mfcc':
            melkwargs = {
                'n_fft': 512,
                'win_length': win_len,
                'hop_length': hop_len,
                'f_min': 0.0,
                'f_max': sr // 2,
                'pad': 0
            }
            self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
                                              melkwargs=melkwargs)
        else:
            if config_path is None:
                self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
            if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
                self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
            if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
                self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False

            self.feat_num = self.get_feat_num()
            self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))

        if feat_type != 'fbank' and feat_type != 'mfcc':
            freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
            for name, param in self.feature_extract.named_parameters():
                for freeze_val in freeze_list:
                    if freeze_val in name:
                        param.requires_grad = False
                        break

        if not self.update_extract:
            for param in self.feature_extract.parameters():
                param.requires_grad = False

        self.instance_norm = nn.InstanceNorm1d(feat_dim)
        # self.channels = [channels] * 4 + [channels * 3]
        self.channels = [channels] * 4 + [1536]

        self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
        self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
        self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
        self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)

        # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
        cat_channels = channels * 3
        self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
        self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
        self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
        self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)


    def get_feat_num(self):
        self.feature_extract.eval()
        wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
        with torch.no_grad():
            features = self.feature_extract(wav)
        select_feature = features[self.feature_selection]
        if isinstance(select_feature, (list, tuple)):
            return len(select_feature)
        else:
            return 1

    def get_feat(self, x):
        if self.update_extract:
            x = self.feature_extract([sample for sample in x])
        else:
            with torch.no_grad():
                if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
                    x = self.feature_extract(x) + 1e-6  # B x feat_dim x time_len
                else:
                    x = self.feature_extract([sample for sample in x])

        if self.feat_type == 'fbank':
            x = x.log()

        if self.feat_type != "fbank" and self.feat_type != "mfcc":
            x = x[self.feature_selection]
            if isinstance(x, (list, tuple)):
                x = torch.stack(x, dim=0)
            else:
                x = x.unsqueeze(0)
            norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            x = (norm_weights * x).sum(dim=0)
            x = torch.transpose(x, 1, 2) + 1e-6

        x = self.instance_norm(x)
        return x

    def forward(self, x):
        x = self.get_feat(x)

        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)

        out = torch.cat([out2, out3, out4], dim=1)
        out = F.relu(self.conv(out))
        out = self.bn(self.pooling(out))
        out = self.linear(out)

        return out


def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
    return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
                      feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)


def init_model(model_name, checkpoint=None):
    config_path = None
    model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
    if checkpoint is not None:
        state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict['model'], strict=False)
    return model

### Utils

In [53]:
import os
import torch
import torchaudio
from torchaudio.transforms import Resample

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_audio(wav_path, target_sr=16000):
    audio, sr = torchaudio.load(wav_path)
    if sr != target_sr:
        audio = Resample(orig_freq=sr, new_freq=target_sr)(audio)
    return audio.squeeze(0)  # shape: (samples,)


def load_batch(wav_paths, target_sr=16000, max_len_sec=None):
    batch = []
    for path in wav_paths:
        audio = load_audio(path, target_sr)
        if max_len_sec:
            max_len = int(target_sr * max_len_sec)
            audio = audio[:max_len]  # truncate to fixed length
        batch.append(audio)
    return batch


def verify_batch(model, wav_paths, sr=16000, max_len_sec=None):
    model.eval()
    model.to(device)
    
    # Load and batch audio
    wavs = load_batch(wav_paths, target_sr=sr, max_len_sec=max_len_sec)
    wavs = [w.to(device) for w in wavs]

    with torch.no_grad():
        embeddings = model(wavs)
        embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings.cpu()

In [54]:
import gc
import torch

# Clear cache
torch.cuda.empty_cache()

# Collect garbage
gc.collect()

# Optionally clear any tensors stored on GPU
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) and obj.is_cuda:
            del obj
    except:
        pass

# Final memory cleanup
torch.cuda.empty_cache()


  return isinstance(obj, torch.Tensor)


In [55]:
model = init_model(model_name="wavlm_base_plus", checkpoint="/kaggle/input/wavelm_base_plus/pytorch/default/1/wavlm_base_plus_nofinetune.pth")

Downloading: "https://github.com/s3prl/s3prl/zipball/main" to /root/.cache/torch/hub/main.zip
  torchaudio.set_audio_backend("sox_io")
Downloading: https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base_plus.pt
Destination: /root/.cache/s3prl/download/72cb34edf8a3724c720467cf40b77ad20b1b714b5f694e9db57f521467f9006b.wavlm_base_plus.pt
100%|██████████| 360M/360M [00:04<00:00, 85.9MB/s] 
  checkpoint = torch.load(ckpt)
  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)


In [52]:
import pandas as pd
import torch
import os
import torch.nn.functional as F
from tqdm import tqdm

# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_path = "/kaggle/input/vox-celeb/vox_celeb/vox1/vox1_test_wav/wav"

# --- Load trial file ---
df = pd.read_csv("/kaggle/working/UniSpeech/downstreams/speaker_verification/veri_test2.txt", sep=" ", header=None)
df.columns = ['label', 'wav1', 'wav2']
df['wav1_path'] = df['wav1'].apply(lambda x: os.path.join(base_path, x))
df['wav2_path'] = df['wav2'].apply(lambda x: os.path.join(base_path, x))

# --- Get unique paths ---
unique_files = pd.unique(df[['wav1_path', 'wav2_path']].values.ravel())

# --- Init model ---
model = init_model(model_name="wavlm_base_plus", checkpoint="/kaggle/input/wavelm_base_plus/pytorch/default/1/wavlm_base_plus_nofinetune.pth")
model.eval().to(device)

# --- Compute embeddings with batching ---
embedding_dict = {}
batch_size = 4  # Safe choice for limited VRAM

with torch.no_grad():
    for i in tqdm(range(0, len(unique_files), batch_size)):
        batch_paths = unique_files[i:i + batch_size]
        
        # Load and pad audio to max length in batch
        batch_audio = [load_audio(p) for p in batch_paths]
        max_len = max(w.shape[0] for w in batch_audio)
        batch_padded = [F.pad(w, (0, max_len - w.shape[0])) for w in batch_audio]
        batch_tensor = torch.stack(batch_padded).to(device)

        # Inference
        batch_emb = model(batch_tensor)
        batch_emb = F.normalize(batch_emb, p=2, dim=1).cpu()

        for path, emb in zip(batch_paths, batch_emb):
            embedding_dict[path] = emb

        # Free memory
        del batch_tensor, batch_emb
        torch.cuda.empty_cache()

# --- Score trials ---
scores = []
labels = []

for _, row in tqdm(df.iterrows(), total=len(df)):
    emb1 = embedding_dict[row['wav1_path']]
    emb2 = embedding_dict[row['wav2_path']]
    sim = F.cosine_similarity(emb1.unsqueeze(0), emb2.unsqueeze(0)).item()
    scores.append(sim)
    labels.append(int(row['label']))

# --- Output ---
df['score'] = scores
print(df[['label', 'score']].head())


Using cache found in /root/.cache/torch/hub/s3prl_s3prl_main
  checkpoint = torch.load(ckpt)
  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
100%|██████████| 1177/1177 [08:23<00:00,  2.34it/s]
100%|██████████| 37611/37611 [00:04<00:00, 7662.46it/s]

   label     score
0      1  0.636597
1      0  0.355888
2      1  0.564463
3      0  0.071327
4      1  0.636857





In [56]:
import numpy as np
output_df = pd.read_csv('/kaggle/working/final_output.csv')

best_acc = 0
best_thresh = 0
thresholds = np.linspace(0, 1, 1001)  # thresholds from 0.000 to 1.000

for thresh in thresholds:
    preds = (output_df['score'] >= thresh).astype(int)
    acc = (preds == output_df['label']).mean()
    if acc > best_acc:
        best_acc = acc
        best_thresh = thresh

print(f"Best Threshold: {best_thresh:.4f}")
print(f"Accuracy at Best Threshold: {best_acc * 100:.2f}%")

# Optionally, apply the best threshold to get predictions
output_df['predicted_label'] = (output_df['score'] >= best_thresh).astype(int)


Best Threshold: 0.3610
Accuracy at Best Threshold: 91.99%


In [57]:
output_df.to_csv('/kaggle/working/final_output_with_prediction.csv')

## LoRA Finetune

In [1]:
import os

vox2_path = "/kaggle/input/vox-celeb/vox2_test_aac"  
all_speakers = (os.listdir(vox2_path))
print(all_speakers)



['aac']


In [None]:
!pip install peft

In [13]:
import os
import torch
from glob import glob
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torch.nn.functional as F

class VoxCeleb2AACDataset(Dataset):
    def __init__(self, root_dir, id_list, max_len=4, sr=16000):
        self.samples = []
        self.labels = []
        self.spk_to_id = {spk: i for i, spk in enumerate(sorted(id_list))}
        self.max_len = max_len
        self.sr = sr

        for spk in id_list:
            spk_dir = os.path.join(root_dir, spk)
            audio_files = glob(f"{spk_dir}/*/*.m4a")  # one level deeper
            for audio_path in audio_files:
                self.samples.append(audio_path)
                self.labels.append(self.spk_to_id[spk])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path = self.samples[idx]
        label = self.labels[idx]

        waveform, sr = torchaudio.load(audio_path)
        if sr != self.sr:
            waveform = torchaudio.functional.resample(waveform, sr, self.sr)

        # Trim or pad to fixed length (4 seconds)
        target_len = self.max_len * self.sr
        if waveform.shape[1] > target_len:
            start = torch.randint(0, waveform.shape[1] - target_len, (1,))
            waveform = waveform[:, start:start + target_len]
        else:
            waveform = F.pad(waveform, (0, target_len - waveform.shape[1]))

        return waveform.squeeze(0), label

In [14]:
from torch.utils.data import random_split

# Define paths
root_dir = "/kaggle/input/vox-celeb/vox2_test_aac/aac"
all_speakers = sorted(os.listdir(root_dir))

# Split: First 100 → train, remaining 18 → test
train_speakers = all_speakers[:100]
test_speakers = all_speakers[100:]

# Dataset init
train_dataset = VoxCeleb2AACDataset(root_dir, train_speakers)
test_dataset = VoxCeleb2AACDataset(root_dir, test_speakers)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ArcFaceLoss(nn.Module):
    def __init__(self, emb_dim, num_classes, scale=30.0, margin=0.5):
        super().__init__()
        self.W = nn.Parameter(torch.randn(num_classes, emb_dim))
        nn.init.xavier_uniform_(self.W)
        self.s = scale
        self.m = margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

    def forward(self, embeddings, labels):
        W = F.normalize(self.W, dim=1)
        x = F.normalize(embeddings, dim=1)
        cos_theta = torch.matmul(x, W.t()).clamp(-1, 1)
        target_logit = cos_theta[torch.arange(len(labels)), labels]

        sin_theta = torch.sqrt(1.0 - target_logit ** 2)
        cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m

        mask = cos_theta > self.th
        final_target_logit = torch.where(mask[torch.arange(len(labels)), labels],
                                         cos_theta_m,
                                         target_logit - self.mm)

        cos_theta[torch.arange(len(labels)), labels] = final_target_logit
        return F.cross_entropy(self.s * cos_theta, labels)

In [20]:
for name, module in base_model.named_modules():
    if "attention" in name.lower():
        print(name)

encoder.layers.0.attention
encoder.layers.0.attention.k_proj
encoder.layers.0.attention.v_proj
encoder.layers.0.attention.q_proj
encoder.layers.0.attention.out_proj
encoder.layers.0.attention.gru_rel_pos_linear
encoder.layers.0.attention.rel_attn_embed
encoder.layers.1.attention
encoder.layers.1.attention.k_proj
encoder.layers.1.attention.v_proj
encoder.layers.1.attention.q_proj
encoder.layers.1.attention.out_proj
encoder.layers.1.attention.gru_rel_pos_linear
encoder.layers.2.attention
encoder.layers.2.attention.k_proj
encoder.layers.2.attention.v_proj
encoder.layers.2.attention.q_proj
encoder.layers.2.attention.out_proj
encoder.layers.2.attention.gru_rel_pos_linear
encoder.layers.3.attention
encoder.layers.3.attention.k_proj
encoder.layers.3.attention.v_proj
encoder.layers.3.attention.q_proj
encoder.layers.3.attention.out_proj
encoder.layers.3.attention.gru_rel_pos_linear
encoder.layers.4.attention
encoder.layers.4.attention.k_proj
encoder.layers.4.attention.v_proj
encoder.layers.4.at

In [21]:
from transformers import WavLMModel, WavLMConfig
from peft import get_peft_model, LoraConfig, TaskType

# Load base WavLM
model_name = "microsoft/wavlm-base-plus"
base_model = WavLMModel.from_pretrained(model_name)

lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)


base_model = get_peft_model(base_model, lora_config)

In [22]:
class SpeakerModel(nn.Module):
    def __init__(self, base_model, emb_dim=192, num_classes=100):
        super().__init__()
        self.backbone = base_model
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.project = nn.Linear(self.backbone.config.hidden_size, emb_dim)
        self.arcface = ArcFaceLoss(emb_dim, num_classes)

    def forward(self, wavs, labels=None):
        outputs = self.backbone(wavs, output_hidden_states=True)
        x = outputs.last_hidden_state.transpose(1, 2)  # B x C x T
        pooled = self.pool(x).squeeze(-1)  # B x C
        emb = self.project(pooled)

        if labels is not None:
            return self.arcface(emb, labels)
        return emb

In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

# Your SpeakerModel wraps WavLM and a classification head with ArcFace
num_classes = len(train_dataset.spk_to_id)
model = SpeakerModel(base_model, emb_dim=192, num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

for epoch in range(10):
    model.train()
    total_loss = 0.0

    for wavs, labels in train_loader:
        # Ensure mono
        if wavs.ndim == 3:
            wavs = wavs.mean(dim=1)

        # Resample if necessary (optional if already 16kHz)
        # wavs = torchaudio.transforms.Resample(orig_freq, 16000)(wavs)

        wavs = wavs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        loss = model(wavs, labels)  # Model handles forward + ArcFace
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")