# Siamese CNN training for speaker verification

## Imports and constants

In [88]:
import os
import glob

import torch
import torchaudio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchaudio.datasets import VoxCeleb1Verification
from tqdm import tqdm
from torch import nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import timm

## Data preparation

In [2]:
test_dataset = VoxCeleb1Verification('../data', download=False)

### Data split

We need to split data in the clever way, leaving some speakers for validation, and not using them as well as some speakers from pre-defined test.

In [3]:
# occs = dict()
# for _, _, _, _, f1, f2 in tqdm(test_dataset):
#     id1 = f1.split('-')[0]
#     id2 = f2.split('-')[0]
#     if id1 not in occs:
#         occs[id1] = 0
#     if id2 not in occs:
#         occs[id2] = 0
#     occs[id1] += 1
#     occs[id2] += 1

Speakers who are used in test:

In [4]:
# len(occs.keys())

In [5]:
# test_ids = sorted(list(occs.keys()))

Extracting validation speakers:

In [6]:
# val_ids = np.random.choice(list(set(os.listdir('../data/wav/')) - set(test_ids)), size=40, replace=False)
# val_ids  = sorted(list(val_ids))
# len(set(os.listdir('../data/wav/')) - set(test_ids) - set(val_ids))

Saving all the test and validation speakers to be reused throughout the project:

In [7]:
# with open('../data/test_ids.txt', 'w') as f:
#     f.write(str(test_ids))

# with open('../data/val_ids.txt', 'w') as f:
#     f.write(str(val_ids))

### Data split load

In [8]:
with open('../data/test_ids.txt', 'r') as f:
    test_ids = eval(f.read())

with open('../data/val_ids.txt', 'r') as f:
    val_ids = eval(f.read())

In [9]:
samples_files = glob.glob('../data/wav/**/**/*.wav')
samples_df = pd.DataFrame({'path': samples_files})
samples_df['path'] = samples_df['path'].str.replace('\\', '/')
samples_df['speaker_id'] = samples_df['path'].apply(lambda path: path.split('/')[-3])
samples_df['utterance_id'] = samples_df['path'].apply(lambda path: path.split('/')[-2])

In [10]:
samples_df

Unnamed: 0,path,speaker_id,utterance_id
0,../data/wav/id10001/1zcIwhmdeo4/00001.wav,id10001,1zcIwhmdeo4
1,../data/wav/id10001/1zcIwhmdeo4/00002.wav,id10001,1zcIwhmdeo4
2,../data/wav/id10001/1zcIwhmdeo4/00003.wav,id10001,1zcIwhmdeo4
3,../data/wav/id10001/7gWzIy6yIIk/00001.wav,id10001,7gWzIy6yIIk
4,../data/wav/id10001/7gWzIy6yIIk/00002.wav,id10001,7gWzIy6yIIk
...,...,...,...
153511,../data/wav/id11251/Tmh87G_cDZo/00004.wav,id11251,Tmh87G_cDZo
153512,../data/wav/id11251/WbB8m9-wlIQ/00001.wav,id11251,WbB8m9-wlIQ
153513,../data/wav/id11251/WbB8m9-wlIQ/00002.wav,id11251,WbB8m9-wlIQ
153514,../data/wav/id11251/XHCSVYEZvlM/00001.wav,id11251,XHCSVYEZvlM


In [11]:
train_df = samples_df.loc[~samples_df['speaker_id'].isin(test_ids) & ~samples_df['speaker_id'].isin(val_ids)]
test_df = samples_df.loc[samples_df['speaker_id'].isin(test_ids)]
val_df = samples_df.loc[samples_df['speaker_id'].isin(val_ids)]

Sanity check:

In [12]:
# for filepath in tqdm(samples_df['path']):
#     _, sr = torchaudio.load(filepath)
#     assert sr == 16000, sr

In [29]:
audio_lengths = []
for filepath in tqdm(samples_df['path'][:2000]):
    audio, _ = torchaudio.load(filepath)
    audio_lengths.append(audio.size()[1])

100%|██████████| 2000/2000 [00:00<00:00, 2125.33it/s]


Dataset for training:

In [50]:
class VoxCeleb1Triplet(torch.utils.data.Dataset):
    def __init__(self, train_df: pd.DataFrame, transforms = None, max_length: int = 240000):
        self.df = train_df
        self.transforms = transforms
        self.max_length = max_length
    
    def __len__(self):
        return len(self.df)
    
    def _crop_or_extend(self, sample):
        if sample.size()[1] > self.max_length:
            sample = sample[:, :self.max_length]
        else:
            sample = torch.cat((sample, torch.zeros((1, self.max_length - sample.size()[1]))), dim=1)
        return sample

    def __getitem__(self, id):
        selected_row = self.df.iloc[id]
        selected_speaker_id = selected_row['speaker_id']

        positive_row = self.df.loc[(self.df['speaker_id'] == selected_speaker_id) & (self.df.index != id)].sample(1).iloc[0]

        assert positive_row is not None, f"There are now samples for the same speaker {selected_speaker_id}, row {id}"

        negative_row = self.df.loc[(self.df['speaker_id'] != selected_speaker_id)].sample(1).iloc[0]

        assert negative_row is not None, f"There are no negative samples"

        anchor_audio, anchor_sr = torchaudio.load(selected_row['path'])
        pos_audio, pos_sr = torchaudio.load(positive_row['path'])
        neg_audio, neg_sr = torchaudio.load(negative_row['path'])
        assert anchor_sr == 16000 and pos_sr == 16000 and neg_sr == 16000

        if self.transforms is not None:
            anchor_audio = self.transforms(anchor_audio)
            pos_audio = self.transforms(pos_audio)
            neg_audio = self.transforms(neg_audio)
        
        anchor_audio = self._crop_or_extend(anchor_audio)
        pos_audio = self._crop_or_extend(pos_audio)
        neg_audio = self._crop_or_extend(neg_audio)
        
        return anchor_audio, pos_audio, neg_audio

In [51]:
train_dataset = VoxCeleb1Triplet(train_df)

In [52]:
train_dataset[0]

(tensor([[0.0703, 0.0703, 0.0916,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([[-0.0039,  0.0002, -0.0114,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor([[0.0653, 0.0603, 0.0573,  ..., 0.0000, 0.0000, 0.0000]]))

In [99]:
class VoxCeleb1Validation(torch.utils.data.Dataset):
    def __init__(self, val_df: pd.DataFrame, transforms = None, max_length: int = 240000):
        self.df = val_df
        self.transforms = transforms
        self.max_length = max_length
    
    def __len__(self):
        return len(self.df)
    
    def _crop_or_extend(self, sample):
        if sample.size()[1] > self.max_length:
            sample = sample[:, :self.max_length]
        else:
            sample = torch.cat((sample, torch.zeros((1, self.max_length - sample.size()[1]))), dim=1)
        return sample

    def __getitem__(self, id):
        selected_row = self.df.iloc[id]
        selected_speaker_id = selected_row['speaker_id']

        res_class = np.random.choice((0, 1))

        if res_class == 1:
            rel_row = self.df.loc[(self.df['speaker_id'] == selected_speaker_id) & (self.df.index != id)].sample(1).iloc[0]
        else:
            rel_row = self.df.loc[(self.df['speaker_id'] != selected_speaker_id)].sample(1).iloc[0]

        assert rel_row is not None, f"There are now samples for the row {id}"

        anchor_audio, anchor_sr = torchaudio.load(selected_row['path'])
        rel_audio, rel_sr = torchaudio.load(rel_row['path'])
        assert anchor_sr == 16000 and rel_sr == 16000

        if self.transforms is not None:
            anchor_audio = self.transforms(anchor_audio)
            rel_audio = self.transforms(rel_audio)
        
        anchor_audio = self._crop_or_extend(anchor_audio)
        rel_audio = self._crop_or_extend(rel_audio)

        return anchor_audio, rel_audio, res_class

In [100]:
val_dataset = VoxCeleb1Validation(val_df)

In [101]:
val_dataset[0]

(tensor([[0.0692, 0.3032, 0.3854,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([[0.0109, 0.0094, 0.0081,  ..., 0.0000, 0.0000, 0.0000]]),
 0)

In [71]:
class SiameseCNN(nn.Module):
    def __init__(self, backbone_name: str, backbone_pretrained: bool, res_dim: int, n_fft: int, hop_size: int, n_mels: int, mapper_dropout_p: float, power: float = 1.0, sr: int = 16000, logmel: bool = False):
        super().__init__()

        self.melspec = nn.Sequential(torchaudio.transforms.Spectrogram(
            n_fft=n_fft,
            hop_length=hop_size,
            power=power,
        ), torchaudio.transforms.MelScale(
            n_mels=n_mels,
            sample_rate=sr,
            n_stft=n_fft // 2 + 1,
            f_min=0,
        ))
        self.logmel = logmel

        self.backbone = timm.create_model(
            backbone_name,
            features_only=True,
            pretrained=backbone_pretrained,
            in_chans=1,
            exportable=True
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.mapper = nn.Sequential(
            nn.Dropout(p=mapper_dropout_p),
            nn.Linear(self.backbone.feature_info.channels()[-1], res_dim),
        )


    def forward(self, input):
        specs = self.melspec(input)
        if self.logmel:
            specs = torch.log10(torch.clamp(specs, min=torch.tensor(1e-3)))

        emb = self.backbone(specs)[-1]

        bs, ch, _, _ = emb.shape
        emb = self.pool(emb)
        emb = emb.view(bs, ch)

        emb = self.mapper(emb)
        return emb


In [89]:
model = SiameseCNN('tf_efficientnet_b0.in1k', True, 128, 1024, 512, 128, 0.25)

Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


In [90]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [91]:
BATCH_SIZE = 32

In [103]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [104]:
model = model.to(DEVICE)

In [105]:
class TripletLoss(nn.Module):
    def __init__(self, margin: float = 1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    
    @staticmethod
    def calc_euclidean(x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = TripletLoss.calc_euclidean(anchor, positive)
        distance_negative = TripletLoss.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()

In [106]:
EPOCHS = 1
EVAL_EVERY_STEPS = 20

In [107]:
optimizer = torch.optim.AdamW([
                {'params': model.mapper.parameters(), 'lr': 1e-3},
                {'params': model.backbone.parameters(), 'lr': 3e-4}
            ])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS * np.ceil(len(train_dataset) / BATCH_SIZE), eta_min=1e-6)
criterion = torch.jit.script(TripletLoss())

In [108]:
# best_val_f1 = 0

In [113]:
model.train()
for epoch in tqdm(range(EPOCHS), desc="Epochs"):
    running_loss = []
    for step, (anchor_audio, positive_audio, negative_audio) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
        anchor_audio = anchor_audio.to(DEVICE)
        positive_audio = positive_audio.to(DEVICE)
        negative_audio = negative_audio.to(DEVICE)
        
        optimizer.zero_grad()
        anchor_out = model(anchor_audio)
        positive_out = model(positive_audio)
        negative_out = model(negative_audio)
        
        loss = criterion(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        running_loss.append(loss.cpu().detach().numpy())
        
        if step % EVAL_EVERY_STEPS == 0:
            with torch.no_grad():
                pos_dists = TripletLoss.calc_euclidean(anchor_out, positive_out)
                neg_dists = TripletLoss.calc_euclidean(anchor_out, negative_out)

            optim_thresh, optim_train_acc = 0, 0
            for dist in torch.cat((pos_dists, neg_dists), dim=0):
                acc = (len(pos_dists[pos_dists < dist]) + len(neg_dists[neg_dists >= dist]))/(len(pos_dists) + len(neg_dists))
                if acc > optim_train_acc:
                    optim_train_acc = acc
                    optim_thresh = dist
            
            model.eval()

            with torch.no_grad():
                pred_list = []
                labels_list = []
                for val_step, (anchor_audio, rel_audio, labels) in enumerate(val_loader):
                    anchor_audio = anchor_audio.to(DEVICE)
                    rel_audio = rel_audio.to(DEVICE)
                    
                    optimizer.zero_grad()
                    anchor_out = model(anchor_audio)
                    rel_out = model(rel_audio)

                    dists = TripletLoss.calc_euclidean(anchor_out, rel_out)
                    labels_list.append(labels.cpu().numpy())
                    pred_list.append((dists < optim_thresh).cpu().numpy())
                    # if val_step >= 10:
                    #     break
                
                preds = np.concatenate(pred_list)
                labels = np.concatenate(labels_list)
                val_acc = accuracy_score(labels, preds)
                val_prec = precision_score(labels, preds)
                val_rec = recall_score(labels, preds)
                val_f1 = f1_score(labels, preds)

            model.train()

            print(f"Step: {step+1} - Loss: {running_loss[-1]:.4f}, Train Acc: {optim_train_acc:.4f}, Optim distance threshold: {optim_thresh:.4f}")
            print(f"\tVal Acc: {val_acc:.4f}, Val precision: {val_prec:.4f}, Val recall: {val_rec:.4f}, Val F1: {val_f1:.4f}")

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Step: 1 - Loss: 9.4519, Train Acc: 0.5781, Optim distance threshold: 120.9663
	Val Acc: 0.5540, Val precision: 0.5417, Val recall: 0.8619, Val F1: 0.6652




Step: 6 - Loss: 5.5284, Train Acc: 0.6406, Optim distance threshold: 82.2975
	Val Acc: 0.5540, Val precision: 0.5277, Val recall: 0.8314, Val F1: 0.6456




Step: 11 - Loss: 12.5828, Train Acc: 0.5625, Optim distance threshold: 56.2323
	Val Acc: 0.5597, Val precision: 0.5765, Val recall: 0.2917, Val F1: 0.3874




Step: 16 - Loss: 4.1824, Train Acc: 0.6875, Optim distance threshold: 75.3693
	Val Acc: 0.5739, Val precision: 0.5517, Val recall: 0.7356, Val F1: 0.6305




Step: 21 - Loss: 10.2566, Train Acc: 0.5938, Optim distance threshold: 81.2796
	Val Acc: 0.5511, Val precision: 0.5083, Val recall: 0.9387, Val F1: 0.6595




Step: 26 - Loss: 5.2370, Train Acc: 0.5938, Optim distance threshold: 62.1200
	Val Acc: 0.6080, Val precision: 0.6340, Val recall: 0.6474, Val F1: 0.6406




Step: 31 - Loss: 3.9192, Train Acc: 0.6406, Optim distance threshold: 64.0749
	Val Acc: 0.6705, Val precision: 0.6717, Val recall: 0.7228, Val F1: 0.6963




Step: 36 - Loss: 6.0486, Train Acc: 0.6094, Optim distance threshold: 65.6375
	Val Acc: 0.6477, Val precision: 0.6107, Val recall: 0.8371, Val F1: 0.7062




Step: 41 - Loss: 4.5701, Train Acc: 0.6562, Optim distance threshold: 72.3815
	Val Acc: 0.5881, Val precision: 0.5284, Val recall: 0.9255, Val F1: 0.6727




Step: 46 - Loss: 9.1276, Train Acc: 0.5625, Optim distance threshold: 68.3054
	Val Acc: 0.6023, Val precision: 0.5625, Val recall: 0.9205, Val F1: 0.6983




Step: 51 - Loss: 3.3480, Train Acc: 0.6094, Optim distance threshold: 59.6725
	Val Acc: 0.6364, Val precision: 0.5939, Val recall: 0.8757, Val F1: 0.7078




Step: 56 - Loss: 2.9017, Train Acc: 0.6562, Optim distance threshold: 58.8730
	Val Acc: 0.5824, Val precision: 0.5374, Val recall: 0.9349, Val F1: 0.6825




Step: 61 - Loss: 1.9734, Train Acc: 0.7344, Optim distance threshold: 63.0717
	Val Acc: 0.5682, Val precision: 0.5224, Val recall: 0.9819, Val F1: 0.6820




Step: 66 - Loss: 5.5519, Train Acc: 0.6250, Optim distance threshold: 55.3447
	Val Acc: 0.6193, Val precision: 0.5677, Val recall: 0.8882, Val F1: 0.6927




Step: 71 - Loss: 3.2953, Train Acc: 0.7031, Optim distance threshold: 53.1777
	Val Acc: 0.6335, Val precision: 0.5802, Val recall: 0.9659, Val F1: 0.7249


Epochs:   0%|          | 0/1 [33:26<?, ?it/s]


KeyboardInterrupt: 

Add:
- Checkpoints
- Augmentations
- Use better models