# Siamese CNN training for speaker verification

## Imports and constants

In [1]:
import os
import glob

import torch
import torchaudio
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import timm
import gc

from torchaudio.datasets import VoxCeleb1Verification
from tqdm import tqdm
from torch import nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from IPython.display import display, Audio
import torch.nn.functional as F
from torch.cuda import empty_cache

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_SAVE_PATH = './models/efficientnet_1.pt'

## Data preparation

In [3]:
test_dataset = VoxCeleb1Verification('../data', download=False)

### Data split

We need to split data in the clever way, leaving some speakers for validation, and not using them as well as some speakers from pre-defined test.

In [4]:
# occs = dict()
# for _, _, _, _, f1, f2 in tqdm(test_dataset):
#     id1 = f1.split('-')[0]
#     id2 = f2.split('-')[0]
#     if id1 not in occs:
#         occs[id1] = 0
#     if id2 not in occs:
#         occs[id2] = 0
#     occs[id1] += 1
#     occs[id2] += 1

Speakers who are used in test:

In [5]:
# len(occs.keys())

In [6]:
# test_ids = sorted(list(occs.keys()))

Extracting validation speakers:

In [7]:
# val_ids = np.random.choice(list(set(os.listdir('../data/wav/')) - set(test_ids)), size=40, replace=False)
# val_ids  = sorted(list(val_ids))
# len(set(os.listdir('../data/wav/')) - set(test_ids) - set(val_ids))

Saving all the test and validation speakers to be reused throughout the project:

In [8]:
# with open('../data/test_ids.txt', 'w') as f:
#     f.write(str(test_ids))

# with open('../data/val_ids.txt', 'w') as f:
#     f.write(str(val_ids))

### Data split load

In [9]:
with open('../data/test_ids.txt', 'r') as f:
    test_ids = eval(f.read())

with open('../data/val_ids.txt', 'r') as f:
    val_ids = eval(f.read())

In [10]:
samples_files = glob.glob('../data/wav/**/**/*.wav')
samples_df = pd.DataFrame({'path': samples_files})
samples_df['path'] = samples_df['path'].str.replace('\\', '/')
samples_df['speaker_id'] = samples_df['path'].apply(lambda path: path.split('/')[-3])
samples_df['utterance_id'] = samples_df['path'].apply(lambda path: path.split('/')[-2])
samples_df['sample_id'] = samples_df['path'].apply(lambda path: path.split('/')[-1])

In [11]:
samples_df

Unnamed: 0,path,speaker_id,utterance_id,sample_id
0,../data/wav/id10001/1zcIwhmdeo4/00001.wav,id10001,1zcIwhmdeo4,00001.wav
1,../data/wav/id10001/1zcIwhmdeo4/00002.wav,id10001,1zcIwhmdeo4,00002.wav
2,../data/wav/id10001/1zcIwhmdeo4/00003.wav,id10001,1zcIwhmdeo4,00003.wav
3,../data/wav/id10001/7gWzIy6yIIk/00001.wav,id10001,7gWzIy6yIIk,00001.wav
4,../data/wav/id10001/7gWzIy6yIIk/00002.wav,id10001,7gWzIy6yIIk,00002.wav
...,...,...,...,...
153511,../data/wav/id11251/Tmh87G_cDZo/00004.wav,id11251,Tmh87G_cDZo,00004.wav
153512,../data/wav/id11251/WbB8m9-wlIQ/00001.wav,id11251,WbB8m9-wlIQ,00001.wav
153513,../data/wav/id11251/WbB8m9-wlIQ/00002.wav,id11251,WbB8m9-wlIQ,00002.wav
153514,../data/wav/id11251/XHCSVYEZvlM/00001.wav,id11251,XHCSVYEZvlM,00001.wav


In [12]:
train_df = samples_df.loc[~samples_df['speaker_id'].isin(test_ids) & ~samples_df['speaker_id'].isin(val_ids)]
test_df = samples_df.loc[samples_df['speaker_id'].isin(test_ids)]
val_df = samples_df.loc[samples_df['speaker_id'].isin(val_ids)]

Sanity check:

In [13]:
# for filepath in tqdm(samples_df['path']):
#     _, sr = torchaudio.load(filepath)
#     assert sr == 16000, sr

In [14]:
# audio_lengths = []
# for filepath in tqdm(samples_df['path'][:2000]):
#     audio, _ = torchaudio.load(filepath)
#     audio_lengths.append(audio.size()[1])

Saving in HDF5 format (created with ChatGPT):

In [15]:
HDF5_FILE = "../data/dataset.hdf5"

In [16]:
# # Initialize HDF5 file
# with h5py.File(HDF5_FILE, "w") as hf:
#     # Iterate over speaker folders
#     for speaker_id in tqdm(os.listdir('..\\data\\wav\\'), desc='Speakers'):
#         speaker_group = hf.create_group(speaker_id)
#         speaker_path = os.path.join('..\\data\\wav\\', speaker_id)
#         # Iterate over utterance folders
#         for utterance_id in os.listdir(speaker_path):
#             utterance_group = speaker_group.create_group(utterance_id)
#             utterance_path = os.path.join(speaker_path, utterance_id)
#             # Iterate over sample files
#             for sample_id in os.listdir(utterance_path):
#                 sample_path = os.path.join(utterance_path, sample_id)
#                 if not os.path.isfile(sample_path):
#                     continue
#                 # Read the sample data
#                 data, sr = torchaudio.load(sample_path)
#                 # Create a dataset in the HDF5 file and write the sample data
#                 utterance_group.create_dataset(sample_id, data=data, compression="gzip")

In [17]:
def load_sample(speaker_id, utterance_id, sample_id):
    with h5py.File(HDF5_FILE, "r") as hf:
        try:
            sample_data = hf[speaker_id][utterance_id][sample_id][:]
            return torch.tensor(sample_data)
        except KeyError:
            print("Sample not found.")
            return None

Dataset for training:

In [18]:
class VoxCeleb1Triplet(torch.utils.data.Dataset):
    def __init__(self, train_df: pd.DataFrame, transforms = None, max_length: int = 240000):
        self.df = train_df
        self.transforms = transforms
        self.max_length = max_length
    
    def __len__(self):
        return len(self.df)
    
    def _crop_or_extend(self, sample):
        if sample.size()[1] > self.max_length:
            sample = sample[:, :self.max_length]
        else:
            sample = torch.cat((sample, torch.zeros((1, self.max_length - sample.size()[1]))), dim=1)
        return sample

    def __getitem__(self, id):
        selected_row = self.df.iloc[id]
        selected_speaker_id = selected_row['speaker_id']

        positive_row = self.df.loc[(self.df['speaker_id'] == selected_speaker_id) & (self.df.index != id)].sample(1).iloc[0]

        assert positive_row is not None, f"There are now samples for the same speaker {selected_speaker_id}, row {id}"

        negative_row = self.df.loc[(self.df['speaker_id'] != selected_speaker_id)].sample(1).iloc[0]

        assert negative_row is not None, f"There are no negative samples"

        anchor_audio, anchor_sr = torchaudio.load(selected_row['path'])
        pos_audio, pos_sr = torchaudio.load(positive_row['path'])
        neg_audio, neg_sr = torchaudio.load(negative_row['path'])
        assert anchor_sr == 16000 and pos_sr == 16000 and neg_sr == 16000
        # anchor_audio = load_sample(selected_row['speaker_id'], selected_row['utterance_id'], selected_row['sample_id'])
        # pos_audio = load_sample(positive_row['speaker_id'], positive_row['utterance_id'], positive_row['sample_id'])
        # neg_audio = load_sample(negative_row['speaker_id'], negative_row['utterance_id'], negative_row['sample_id'])

        if self.transforms is not None:
            anchor_audio = self.transforms(anchor_audio)
            pos_audio = self.transforms(pos_audio)
            neg_audio = self.transforms(neg_audio)
        
        anchor_audio = self._crop_or_extend(anchor_audio)
        pos_audio = self._crop_or_extend(pos_audio)
        neg_audio = self._crop_or_extend(neg_audio)
        
        return anchor_audio, pos_audio, neg_audio

In [19]:
# train_dataset = VoxCeleb1Triplet(train_df)

In [20]:
def speaker_id_to_int(speaker_id: str):
    return int(speaker_id[3:])

In [21]:
class VoxCeleb1Unary(torch.utils.data.Dataset):
    def __init__(self, train_df: pd.DataFrame, transforms = None, max_length: int = 240000, supporting_count: int = 3):
        self.df = train_df
        self.transforms = transforms
        self.max_length = max_length
        self.supporting_count = supporting_count
    
    def __len__(self):
        return len(self.df)
    
    def _crop_or_extend(self, sample):
        if sample.size()[1] > self.max_length:
            sample = sample[:, :self.max_length]
        else:
            sample = torch.cat((sample, torch.zeros((1, self.max_length - sample.size()[1]))), dim=1)
        return sample

    def __getitem__(self, id):
        selected_row = self.df.iloc[id]
        selected_speaker_id = selected_row['speaker_id']
        positive_rows = self.df.loc[(self.df['speaker_id'] == selected_speaker_id) & (self.df.index != id)] \
            .sample(self.supporting_count)

        audio, sr = torchaudio.load(selected_row['path'])
        assert sr == 16000
        audios = [audio]
        for i in range(self.supporting_count):
            audio, sr = torchaudio.load(positive_rows.iloc[i]['path'])
            assert sr == 16000
            audios.append(audio)
        # audio = load_sample(selected_row['speaker_id'], selected_row['utterance_id'], selected_row['sample_id'])

        if self.transforms is not None:
            for i in range(self.supporting_count + 1):
                audios[i] = self.transforms(audios[i])
        
        for i in range(self.supporting_count + 1):
            audios[i] = self._crop_or_extend(audios[i])
        
        return torch.cat(audios), torch.ones(self.supporting_count + 1) * speaker_id_to_int(selected_speaker_id)

In [22]:
train_dataset = VoxCeleb1Unary(train_df)

In [23]:
train_dataset[0]

(tensor([[ 0.0703,  0.0703,  0.0916,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0151, -0.0161, -0.0156,  ..., -0.0096, -0.0074, -0.0042],
         [-0.0003, -0.0072,  0.0037,  ..., -0.0096, -0.0074, -0.0017],
         [-0.0203, -0.0190, -0.0181,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor([1., 1., 1., 1.]))

In [24]:
class VoxCeleb1Validation(torch.utils.data.Dataset):
    def __init__(self, val_df: pd.DataFrame, transforms = None, max_length: int = 240000):
        self.df = val_df
        self.transforms = transforms
        self.max_length = max_length
    
    def __len__(self):
        return len(self.df)
    
    def _crop_or_extend(self, sample):
        if sample.size()[1] > self.max_length:
            sample = sample[:, :self.max_length]
        else:
            sample = torch.cat((sample, torch.zeros((1, self.max_length - sample.size()[1]))), dim=1)
        return sample

    def __getitem__(self, id):
        selected_row = self.df.iloc[id]
        selected_speaker_id = selected_row['speaker_id']

        res_class = np.random.choice((0, 1))

        if res_class == 1:
            rel_row = self.df.loc[(self.df['speaker_id'] == selected_speaker_id) & (self.df.index != id)].sample(1).iloc[0]
        else:
            rel_row = self.df.loc[(self.df['speaker_id'] != selected_speaker_id)].sample(1).iloc[0]

        assert rel_row is not None, f"There are now samples for the row {id}"

        anchor_audio, anchor_sr = torchaudio.load(selected_row['path'])
        rel_audio, rel_sr = torchaudio.load(rel_row['path'])
        assert anchor_sr == 16000 and rel_sr == 16000
        # anchor_audio = load_sample(selected_row['speaker_id'], selected_row['utterance_id'], selected_row['sample_id'])
        # rel_audio = load_sample(rel_row['speaker_id'], rel_row['utterance_id'], rel_row['sample_id'])

        if self.transforms is not None:
            anchor_audio = self.transforms(anchor_audio)
            rel_audio = self.transforms(rel_audio)
        
        anchor_audio = self._crop_or_extend(anchor_audio)
        rel_audio = self._crop_or_extend(rel_audio)

        return anchor_audio, rel_audio, res_class

In [25]:
val_dataset = VoxCeleb1Validation(val_df)

In [26]:
val_dataset[0]

(tensor([[0.0692, 0.3032, 0.3854,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([[-0.0972, -0.0992, -0.0992,  ...,  0.0000,  0.0000,  0.0000]]),
 1)

In [27]:
class SiameseCNN(nn.Module):
    def __init__(self, backbone_name: str, backbone_pretrained: bool, res_dim: int, n_fft: int, hop_size: int, n_mels: int, mapper_dropout_p: float, power: float = 1.0, sr: int = 16000, logmel: bool = False):
        super().__init__()

        self.melspec = nn.Sequential(torchaudio.transforms.Spectrogram(
            n_fft=n_fft,
            hop_length=hop_size,
            power=power,
        ), torchaudio.transforms.MelScale(
            n_mels=n_mels,
            sample_rate=sr,
            n_stft=n_fft // 2 + 1,
            f_min=0,
        ))
        self.logmel = logmel

        self.backbone = timm.create_model(
            backbone_name,
            features_only=True,
            pretrained=backbone_pretrained,
            in_chans=1,
            exportable=True
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.mapper = nn.Sequential(
            nn.Dropout(p=mapper_dropout_p),
            nn.Linear(self.backbone.feature_info.channels()[-1], res_dim),
        )


    def forward(self, input):
        specs = self.melspec(input)
        if self.logmel:
            specs = torch.log10(torch.clamp(specs, min=torch.tensor(1e-3)))

        emb = self.backbone(specs)[-1]

        bs, ch, _, _ = emb.shape
        emb = self.pool(emb)
        emb = emb.view(bs, ch)

        emb = self.mapper(emb)
        return emb


In [28]:
model = SiameseCNN('tf_efficientnet_b0.in1k', True, 128, 1024, 512, 128, 0.25)

Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


In [None]:
class SiameseHead(nn.Module):
    def __init__(self, input_dim: int, drop_p=0.25):
        super().__init__()

        self.dropout = nn.Dropout(p=drop_p)
        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            self.dropout,
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            self.dropout,
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            self.dropout,
            nn.Linear(32, 2)
        )


    def forward(self, input):
        return self.layers(input)


In [None]:
model_head = SiameseHead(input_dim=256)

In [29]:
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

In [30]:
BATCH_SIZE = 16

In [31]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE*4, shuffle=False)

In [32]:
model = model.to(DEVICE)
model_head = model_head.to(DEVICE)

In [33]:
class TripletLoss(nn.Module):
    def __init__(self, margin: float = 1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    
    @staticmethod
    def calc_euclidean(x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = TripletLoss.calc_euclidean(anchor, positive)
        distance_negative = TripletLoss.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()

https://towardsdatascience.com/triplet-loss-advanced-intro-49a07b7d8905

In [55]:
class BatchAllTtripletLoss(nn.Module):
    """Uses all valid triplets to compute Triplet loss
    Args:
        margin: Margin value in the Triplet Loss equation
    """
    def __init__(self, margin=1.0, eps=1e-8):
        super().__init__()
        self.margin = margin
        self.eps = eps
        # self.threshold = self.margin*100
    
    def euclidean_distance_matrix(self, x):
        """Efficient computation of Euclidean distance matrix
        Args:
            x: Input tensor of shape (batch_size, embedding_dim)
            
        Returns:
            Distance matrix of shape (batch_size, batch_size)
        """
        # step 1 - compute the dot product

        # shape: (batch_size, batch_size)
        dot_product = torch.mm(x, x.t())

        # step 2 - extract the squared Euclidean norm from the diagonal

        # shape: (batch_size,)
        squared_norm = torch.diag(dot_product)

        # step 3 - compute squared Euclidean distances

        # shape: (batch_size, batch_size)
        distance_matrix = squared_norm.unsqueeze(0) - 2 * dot_product + squared_norm.unsqueeze(1)

        # get rid of negative distances due to numerical instabilities
        distance_matrix = F.relu(distance_matrix)

        # step 4 - compute the non-squared distances
        
        # handle numerical stability
        # derivative of the square root operation applied to 0 is infinite
        # we need to handle by setting any 0 to eps
        mask = (distance_matrix == 0.0).float()

        # use this mask to set indices with a value of 0 to eps
        distance_matrix = distance_matrix + mask * self.eps

        # now it is safe to get the square root
        distance_matrix = torch.sqrt(distance_matrix)

        # undo the trick for numerical stability
        distance_matrix = (1.0 - mask)*distance_matrix

        return distance_matrix

    def get_triplet_mask(self, labels):
        """compute a mask for valid triplets
        Args:
            labels: Batch of integer labels. shape: (batch_size,)
        Returns:
            Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
            A triplet is valid if:
            `labels[i] == labels[j] and labels[i] != labels[k]`
            and `i`, `j`, `k` are different.
        """
        # step 1 - get a mask for distinct indices

        # shape: (batch_size, batch_size)
        indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
        indices_not_equal = torch.logical_not(indices_equal)
        # shape: (batch_size, batch_size, 1)
        i_not_equal_j = indices_not_equal.unsqueeze(2)
        # shape: (batch_size, 1, batch_size)
        i_not_equal_k = indices_not_equal.unsqueeze(1)
        # shape: (1, batch_size, batch_size)
        j_not_equal_k = indices_not_equal.unsqueeze(0)
        # Shape: (batch_size, batch_size, batch_size)
        distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)

        # step 2 - get a mask for valid anchor-positive-negative triplets

        # shape: (batch_size, batch_size)
        labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
        # shape: (batch_size, batch_size, 1)
        i_equal_j = labels_equal.unsqueeze(2)
        # shape: (batch_size, 1, batch_size)
        i_equal_k = labels_equal.unsqueeze(1)
        # shape: (batch_size, batch_size, batch_size)
        valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))

        # step 3 - combine two masks
        mask = torch.logical_and(distinct_indices, valid_indices)

        return mask
    
    def forward(self, embeddings, labels):
        """computes loss value.
        Args:
        embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
        labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
        Returns:
        Scalar loss value.
        """
        # step 1 - get distance matrix
        # shape: (batch_size, batch_size)
        distance_matrix = self.euclidean_distance_matrix(embeddings)

        # step 2 - compute loss values for all triplets by applying broadcasting to distance matrix

        # shape: (batch_size, batch_size, 1)
        anchor_positive_dists = distance_matrix.unsqueeze(2)
        # shape: (batch_size, 1, batch_size)
        anchor_negative_dists = distance_matrix.unsqueeze(1)
        # get loss values for all possible n^3 triplets
        # shape: (batch_size, batch_size, batch_size)
        triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin

        # step 3 - filter out invalid or easy triplets by setting their loss values to 0

        # shape: (batch_size, batch_size, batch_size)
        mask = self.get_triplet_mask(labels)
        
        triplet_loss = triplet_loss * mask
        # easy triplets have negative loss values
        triplet_loss = F.relu(triplet_loss)

        triplet_loss, _ = torch.max(triplet_loss, dim=1)
        triplet_loss, _ = torch.max(triplet_loss, dim=1)

        # step 4 - compute scalar loss value by averaging positive losses
        num_positive_losses = (triplet_loss > self.eps).float().sum()
        triplet_loss = triplet_loss.sum() / (num_positive_losses + self.eps)

        # pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) & ~torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
        # pos_ids = pos_mask.nonzero()
        # pos_random_indices = pos_ids[torch.randperm(pos_ids.size(0))][:, 1]
        # pos_dists = distance_matrix[pos_mask][pos_random_indices]

        # neg_mask = labels.unsqueeze(0) != labels.unsqueeze(1)
        # neg_ids = neg_mask.nonzero()
        # neg_random_indices = neg_ids[torch.randperm(neg_ids.size(0))][:, 1]
        # neg_dists = distance_matrix[neg_mask][neg_random_indices]

        # self.threshold, optim_train_acc = 0.0, 0.0
        # for dist in torch.cat((pos_dists, neg_dists), dim=0):
        #     acc = (len(pos_dists[pos_dists < dist]) + len(neg_dists[neg_dists >= dist]))/(len(pos_dists) + len(neg_dists))
        #     if acc > optim_train_acc:
        #         optim_train_acc = acc
        #         self.threshold = float(dist)

        return triplet_loss #, (optim_train_acc, self.threshold)

In [56]:
EPOCHS = 1
EVAL_EVERY_STEPS = 20

In [57]:
optimizer = torch.optim.AdamW([
                {'params': model.mapper.parameters(), 'lr': 1e-3},
                {'params': model.backbone.parameters(), 'lr': 3e-4}
            ])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS * np.ceil(len(train_dataset) / BATCH_SIZE), eta_min=1e-6)
criterion = torch.jit.script(BatchAllTtripletLoss())
optimizer_head = torch.optim.AdamW([
                {'params': model_head.parameters(), 'lr': 3e-3},
            ])
scheduler_head = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_head, EPOCHS * np.ceil(len(train_dataset) / BATCH_SIZE), eta_min=1e-6)
criterion_head = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 4.0], dtype=torch.float))

In [58]:
best_val_f1 = 0

In [59]:
model.train()
for epoch in tqdm(range(EPOCHS), desc="Epochs"):
    running_loss = []
    for step, (audios, labels) in enumerate(tqdm(train_loader, desc="Training")):
        audios = audios.view((BATCH_SIZE*4, 1, 240000))
        labels = labels.view((BATCH_SIZE*4))
        audios = audios.to(DEVICE)
        labels = labels.to(DEVICE)
        
        optimizer.zero_grad()
        out = model(audios)
        
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss.append(loss.cpu().detach().numpy())

        out = out.detach()
        B, C = out.size()

        expanded_out1 = out.unsqueeze(0).expand(B, B, C)
        expanded_out2 = out.unsqueeze(1).expand(B, B, C)

        concatenated_pairs = torch.cat((expanded_out1, expanded_out2), dim=-1)

        concatenated_pairs = concatenated_pairs.view(-1, 2*C)

        labels1 = labels.unsqueeze(0).expand(B, B)  # Shape: (1, B, C)
        labels2 = labels.unsqueeze(1).expand(B, B)  # Shape: (B, 1, C)

        # Concatenate embeddings of all pairs
        concatenated_labels = (labels1 == labels2).reshape(B * B)

        optimizer_head.zero_grad()
        out_head = model_head(concatenated_pairs)
        
        loss_head = criterion_head(out_head, concatenated_labels)
        loss_head.backward()
        optimizer_head.step()
        scheduler_head.step()

        if step % EVAL_EVERY_STEPS == 0:
            model.eval()
            model_head.eval()

            pred_head = torch.argmax(out_head, dim=1).cpu().detach().numpy()

            train_acc = accuracy_score(concatenated_labels, pred_head)
            train_prec = precision_score(concatenated_labels, pred_head)
            train_rec = recall_score(concatenated_labels, pred_head)
            train_f1 = f1_score(concatenated_labels, pred_head)

            with torch.no_grad():
                pred_list = []
                labels_list = []
                for val_step, (anchor_audio, rel_audio, labels) in enumerate(val_loader):
                    anchor_audio = anchor_audio.to(DEVICE)
                    rel_audio = rel_audio.to(DEVICE)
                    
                    optimizer.zero_grad()
                    anchor_out = model(anchor_audio)
                    rel_out = model(rel_audio)

                    labels_list.append(labels.cpu().numpy())

                    concatenated_pairs = torch.cat((anchor_out, rel_out), dim=-1)
                    out_head = model_head(concatenated_pairs)

                    pred_list.append(torch.argmax(out_head, dim=-1).cpu().numpy())
                
                preds = np.concatenate(pred_list)
                labels = np.concatenate(labels_list)
                val_acc = accuracy_score(labels, preds)
                val_prec = precision_score(labels, preds)
                val_rec = recall_score(labels, preds)
                val_f1 = f1_score(labels, preds)

            model.train()

            print(f"Step: {step+1} - Loss: {running_loss[-1]:.4f}, Train Acc: {optim_train_acc:.4f}, Optim distance threshold: {thresh:.4f}")
            print(f"\tVal Acc: {val_acc:.4f}, Val precision: {val_prec:.4f}, Val recall: {val_rec:.4f}, Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), MODEL_SAVE_PATH)
                print('\tSaving the model...')

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Step: 1 - Loss: 4.2141, Train Acc: 0.9524, Optim distance threshold: 7.7487
	Val Acc: 0.5198, Val precision: 0.9111, Val recall: 0.0358, Val F1: 0.0689
	Saving the model...




Step: 21 - Loss: 2.9065, Train Acc: 0.9524, Optim distance threshold: 5.8084
	Val Acc: 0.5005, Val precision: 1.0000, Val recall: 0.0099, Val F1: 0.0196




Step: 41 - Loss: 2.2163, Train Acc: 0.9583, Optim distance threshold: 5.5616
	Val Acc: 0.5036, Val precision: 1.0000, Val recall: 0.0091, Val F1: 0.0180




Step: 61 - Loss: 1.9381, Train Acc: 0.9539, Optim distance threshold: 4.1400
	Val Acc: 0.5083, Val precision: 1.0000, Val recall: 0.0087, Val F1: 0.0173




Step: 81 - Loss: 1.7792, Train Acc: 0.9571, Optim distance threshold: 3.2815
	Val Acc: 0.5194, Val precision: 1.0000, Val recall: 0.0134, Val F1: 0.0264




Step: 101 - Loss: 1.6140, Train Acc: 0.9489, Optim distance threshold: 2.4967
	Val Acc: 0.5081, Val precision: 0.9500, Val recall: 0.0083, Val F1: 0.0165




Step: 121 - Loss: 1.3736, Train Acc: 0.9539, Optim distance threshold: 1.6884
	Val Acc: 0.5068, Val precision: 0.9767, Val recall: 0.0181, Val F1: 0.0356




Step: 141 - Loss: 1.2795, Train Acc: 0.9576, Optim distance threshold: 1.5252
	Val Acc: 0.5328, Val precision: 0.7774, Val recall: 0.1036, Val F1: 0.1828




	Saving the model...




Step: 161 - Loss: 1.2793, Train Acc: 0.9539, Optim distance threshold: 1.1166
	Val Acc: 0.5814, Val precision: 0.7295, Val recall: 0.2614, Val F1: 0.3848
	Saving the model...




Step: 181 - Loss: 1.2022, Train Acc: 0.9554, Optim distance threshold: 0.9760
	Val Acc: 0.6375, Val precision: 0.6078, Val recall: 0.7437, Val F1: 0.6689
	Saving the model...




Step: 201 - Loss: 1.2049, Train Acc: 0.9524, Optim distance threshold: 0.8161
	Val Acc: 0.5777, Val precision: 0.5432, Val recall: 0.9397, Val F1: 0.6884
	Saving the model...




Step: 221 - Loss: 1.1589, Train Acc: 0.9524, Optim distance threshold: 0.7493
	Val Acc: 0.5198, Val precision: 0.5056, Val recall: 0.9960, Val F1: 0.6707




Step: 241 - Loss: 1.1546, Train Acc: 0.9554, Optim distance threshold: 0.6431
	Val Acc: 0.5073, Val precision: 0.5013, Val recall: 1.0000, Val F1: 0.6678




Step: 261 - Loss: 1.1306, Train Acc: 0.9524, Optim distance threshold: 0.5697
	Val Acc: 0.5047, Val precision: 0.5041, Val recall: 0.9996, Val F1: 0.6702




Step: 281 - Loss: 1.1193, Train Acc: 0.9452, Optim distance threshold: 0.5305
	Val Acc: 0.4878, Val precision: 0.4878, Val recall: 1.0000, Val F1: 0.6557




Step: 301 - Loss: 1.1017, Train Acc: 0.9568, Optim distance threshold: 0.4333
	Val Acc: 0.5090, Val precision: 0.5089, Val recall: 1.0000, Val F1: 0.6745




Step: 321 - Loss: 1.1056, Train Acc: 0.9539, Optim distance threshold: 0.4592
	Val Acc: 0.4938, Val precision: 0.4938, Val recall: 1.0000, Val F1: 0.6612


Training:   4%|▍         | 340/9002 [1:56:30<49:28:21, 20.56s/it]
Epochs:   0%|          | 0/1 [1:56:30<?, ?it/s]


KeyboardInterrupt: 

In [None]:
gc.collect()
empty_cache()

Add:
- [x] Checkpoints
- [ ] Augmentations
- [ ] Use better models
- [ ] Online triplet loss