# TrainRawNet2: On BD ASR Dataset

Necessary Imports:

In [1]:
import torch
from torch.utils import data

import torchaudio

import csv
import pandas as pd

import glob
from pathlib import Path

import os

Directories and lcoations:

In [2]:
# Directories are assumed to have a trailing '/' or '\\' in all the subsequent code

CURRENT_WORKING_DIRECTORY = "W:/SpeakerRecognitionResearch"

BANGLA_ASR_DATASET_DIRECTORY = "data/BanglaASR/WavFiles/"
BANGLA_ASR_TSV_LOCATION = "data/BanglaASR/utt_spk_text.tsv"

# To avoid file location related errors, we make sure "SpeakerRecognitionResearch" root folder is the current working directory.
os.chdir(CURRENT_WORKING_DIRECTORY)
os.getcwd()

'W:\\SpeakerRecognitionResearch'

In [3]:
MODEL_SAVE_DIRECTORY = "notebooks/TrainBdAsrOnRawNet2/out"
MODEL_LOSS_EER_OUTPUT_FILE = "notebooks/TrainBdAsrOnRawNet2/out/loss_eer.txt"

Constants:

In [4]:
# If sample_rate = 16K and number_of_samples = 32000, then each tensor will be equivalent to 2 seconds of data
SAMPLE_RATE = 16000
NUMBER_OF_SAMPLES = 32000

# Bangla ASR Dataset has around half of second of silence in the beginning
# This constant will be used to cut samples from the left of the audio
TRIM_AMOUNT_TIME = 0.5

In [5]:
# Device

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {}.".format(device))
if device=="cuda": print(torch.cuda.get_device_name(0))

Using cuda.
NVIDIA GeForce GTX 1050


## Load Dataset preparation files

These files were generated by the prepareDataset notebook.

In [6]:
TRAINSET_LIST_LOCATION = "notebooks/TrainBdAsrOnRawNet2/DatasetPreparations/trainset_list.txt"
EVALSET_LIST_LOCATION = "notebooks/TrainBdAsrOnRawNet2/DatasetPreparations/evalset_list.txt"
TESTSET_LIST_LOCATION = "notebooks/TrainBdAsrOnRawNet2/DatasetPreparations/testset_list.txt"
DEVSET_CLASS_ORDER_LOCATION = "notebooks/TrainBdAsrOnRawNet2/DatasetPreparations/devset_classes.txt"
TESTSET_CLASS_ORDER_LOCATION = "notebooks/TrainBdAsrOnRawNet2/DatasetPreparations/testset_classes.txt"

EVAL_TRIALS_LOCATION = "notebooks/TrainBdAsrOnRawNet2/DatasetPreparations/eval_trials.txt"

In [7]:
def load_list_from_file(file_loc, trial_file=False):
    list_ = []
    if not trial_file:
        with open(file_loc, "r") as file:
            for line in file.readlines():
                list_.append(line.strip())
    else:
        # The trial file has 3 elements in each line
        with open(file_loc, "r") as file:
            for line in file.readlines():
                expected, utt1, utt2 = line.strip().split(" ")
                list_.append( (expected, utt1, utt2) )
    return list_

In [8]:
trainset_stems = load_list_from_file(TRAINSET_LIST_LOCATION)
evalset_stems = load_list_from_file(EVALSET_LIST_LOCATION)
testset_stems = load_list_from_file(TESTSET_LIST_LOCATION)

devset_classes = load_list_from_file(DEVSET_CLASS_ORDER_LOCATION)
testset_classes = load_list_from_file(TESTSET_CLASS_ORDER_LOCATION)

eval_trials = load_list_from_file(EVAL_TRIALS_LOCATION, trial_file=True)

len(trainset_stems) + len(evalset_stems) + len(testset_stems)


218703

## Custom dataset for Bangla ASR

This custom dataset is written with the assumption that the Dataset has been already converted into wav format. Check evaluate_asr_ds.ipynb notebook for conversion method.

Strategy:
1. train_set will be strictly used for training.
2. eval_set will be used for evaluation during training.
3. test_set will be used to assess the model after training.

Dev indicates data used during the whole training process, which indicates both train and eval datasets.

Each set will have entirely different speakers (classes).

A stem is the name of the file. I.E: "a/b/c/bat.jpg" --stem--> bat

In [9]:
class BanglaAsrDataset(data.Dataset):
    def __init__(self, dataset_dir, wav_stem_list, tsv_loc, target_sample_rate, target_num_samples, trim_amount_time, device, is_evalset=False):

        # wav_stem_list will have a list of stems that will be included in this dataset

        tsv_dataframe = pd.read_csv(tsv_loc, quoting=csv.QUOTE_NONE, sep='\t', header=None)

        # The TSV file contains speech annotations in the third column.
        # We don't need the annotations, so we drop the column
        tsv_dataframe = tsv_dataframe.iloc[:,:-1]

        self.dataset_dir = dataset_dir
        self.stem_to_spk_mapping = dict(sorted(tsv_dataframe.values.tolist()))
        self.wav_path_list = self._stem_list_to_path_list(wav_stem_list)
        self.target_sample_rate = target_sample_rate
        self.target_num_samples = target_num_samples
        self.trim_amount_time = trim_amount_time
        self.device = device
        self.is_evalset=is_evalset

        # Set will ensure uniquness
        # sorted list will make the order consistent
        # Tuple will make sure the order doesn't change
        self.speakers_list = tuple(sorted(list(set([self.stem_to_spk_mapping[stem] for stem in wav_stem_list]))))
        
    def _get_audio_path_list(self, dataset_dir):
        
        pattern = '**/*.wav'
        files = glob.glob(self.dataset_dir + pattern , recursive=True)

        # Normalize the file paths. To get file paths with '/' or '\\' consistently depending on OS
        wav_list = [os.path.normpath(i) for i in files]

        return wav_list

    def _stem_list_to_path_list(self, stem_list):
        # Sets are faster to search
        stem_list = set(stem_list)
        all_path_list = self._get_audio_path_list(self.dataset_dir)
        path_list = [path for path in all_path_list if Path(path).stem in stem_list]
        return path_list


    def _resample_to_target_sr(self, signal, sample_rate):
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
            signal = resampler(signal)
        return signal

    def _mix_down_to_mono(self, signal):
        if signal.shape[0] > 1:
            signal = torch.mean(siggnal, dim=0, keepdim=True)
        return signal

    def _trim(self, signal):
        total_samples = signal.shape[-1]

        # We cut a fixed amount on the left side if the signal is big enough
        trim_samples_amount = int(self.target_sample_rate * self.trim_amount_time)

        if total_samples >= trim_samples_amount + self.target_num_samples:
            signal = signal[: , trim_samples_amount:]
            total_samples = signal.shape[-1]

        # We cut from the right side if the signal is too big
        if total_samples > self.target_num_samples:
            signal = signal[:, :self.target_num_samples]
        
        # We add zero padding on the right if signal is too small
        if total_samples < self.target_num_samples:
            num_missing_samples = self.target_num_samples - total_samples
            last_dim_padding = (0, num_missing_samples)
            signal = torch.nn.functional.pad(signal, last_dim_padding)
            
        return signal

    def _normalize_like_sincnet(self, signal):
        return signal/torch.max(torch.abs(signal))

    def __len__(self):
        return len(self.wav_path_list)

    def __getitem__(self, index):
        wav_path = self.wav_path_list[index]
        wav_name = Path(wav_path).stem
        label = self.stem_to_spk_mapping[wav_name]

        signal, sample_rate = torchaudio.load(wav_path)

        # moving to CPU/CUDA is now done in the training phase
        # signal = signal.to(self.device)

        signal = self._resample_to_target_sr(signal, sample_rate)
        signal = self._mix_down_to_mono(signal)

        signal =  self._trim(signal)
        signal = self._normalize_like_sincnet(signal)

        # signal = signal.squeeze(0)

        label_index = self.speakers_list.index(label)

        if self.is_evalset:
            return signal, label_index, wav_path
        else:
            return signal, label_index

        

In [10]:
trainset = BanglaAsrDataset(
    dataset_dir=BANGLA_ASR_DATASET_DIRECTORY,
    wav_stem_list=trainset_stems,
    tsv_loc = BANGLA_ASR_TSV_LOCATION,
    target_sample_rate=SAMPLE_RATE,
    target_num_samples = NUMBER_OF_SAMPLES,
    trim_amount_time = TRIM_AMOUNT_TIME,
    device = device
)

In [12]:
evalset = BanglaAsrDataset(
    dataset_dir=BANGLA_ASR_DATASET_DIRECTORY,
    wav_stem_list=evalset_stems,
    tsv_loc = BANGLA_ASR_TSV_LOCATION,
    target_sample_rate=SAMPLE_RATE,
    target_num_samples = NUMBER_OF_SAMPLES,
    trim_amount_time = TRIM_AMOUNT_TIME,
    device = device,
    is_evalset=True
)

In [13]:
trainset[0]

(tensor([[ 0.0045,  0.0053,  0.0050,  ..., -0.0040, -0.0044, -0.0065]]), 250)

## Model !

In [14]:
from torch import nn
from torchsummary import summary

import numpy as np
import math

import torch.nn.functional as F
from tqdm import tqdm

## FRM

In [15]:
class FRM(nn.Module):
    def __init__(self, nb_dim, do_add = True, do_mul = True):
        super(FRM, self).__init__()
        self.fc = nn.Linear(nb_dim, nb_dim)
        self.sig = nn.Sigmoid()
        self.do_add = do_add
        self.do_mul = do_mul
    def forward(self, x):
        y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
        
        y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)

        if self.do_mul: x = x * y
        if self.do_add: x = x + y
        return x

## Residual Block wFRM

In [16]:
class Residual_block_wFRM(nn.Module):
    def __init__(self, nb_filts, first = False):
        super(Residual_block_wFRM, self).__init__()
        self.first = first
        if not self.first:
            self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0])
        self.lrelu = nn.LeakyReLU()
        self.lrelu_keras = nn.LeakyReLU(negative_slope=0.3)
        
        self.conv1 = nn.Conv1d(in_channels = nb_filts[0],
            out_channels = nb_filts[1],
            kernel_size = 3,
            padding = 1,
            stride = 1)
        self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1])
        self.conv2 = nn.Conv1d(in_channels = nb_filts[1],
            out_channels = nb_filts[1],
            padding = 1,
            kernel_size = 3,
            stride = 1)
        
        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0],
                out_channels = nb_filts[1],
                padding = 0,
                kernel_size = 1,
                stride = 1)
            
        else:
            self.downsample = False
        self.mp = nn.MaxPool1d(3)
        self.frm = FRM(
            nb_dim = nb_filts[1],
            do_add = True,
            do_mul = True)
        
    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.lrelu_keras(out)
        else:
            out = x
            
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.lrelu_keras(out)
        out = self.conv2(out)
        
        if self.downsample:
            identity = self.conv_downsample(identity)
            
        out += identity
        out = self.mp(out)
        out = self.frm(out)
        return out

## LayerNorm

In [17]:
class LayerNorm(nn.Module):

    def __init__(self, features, eps=1e-6):
        super(LayerNorm,self).__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

## SincConv Fast

In [18]:
class SincConv_fast(nn.Module):
    """Sinc-based convolution
    Parameters
    ----------
    in_channels : `int`
        Number of input channels. Must be 1.
    out_channels : `int`
        Number of filters.
    kernel_size : `int`
        Filter length.
    sample_rate : `int`, optional
        Sample rate. Defaults to 16000.
    Usage
    -----
    See `torch.nn.Conv1d`
    Reference
    ---------
    Mirco Ravanelli, Yoshua Bengio,
    "Speaker Recognition from raw waveform with SincNet".
    https://arxiv.org/abs/1808.00158
    """

    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10 ** (mel / 2595) - 1)

    def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1,
                 stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50):

        super(SincConv_fast,self).__init__()

        if in_channels != 1:
            #msg = (f'SincConv only support one input channel '
            #       f'(here, in_channels = {in_channels:d}).')
            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
            raise ValueError(msg)

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size%2==0:
            self.kernel_size=self.kernel_size+1
            
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        self.sample_rate = sample_rate
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz

        # initialize filterbanks such that they are equally spaced in Mel scale
        low_hz = 30
        high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)

        mel = np.linspace(self.to_mel(low_hz),
                          self.to_mel(high_hz),
                          self.out_channels + 1)
        hz = self.to_hz(mel)
        

        # filter lower frequency (out_channels, 1)
        self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))

        # filter frequency band (out_channels, 1)
        self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))

        # Hamming window
        #self.window_ = torch.hamming_window(self.kernel_size)
        n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window
        self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);

        # (1, kernel_size/2)
        n = (self.kernel_size - 1) / 2.0
        self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes

    def forward(self, waveforms):
        """
        Parameters
        ----------
        waveforms : `torch.Tensor` (batch_size, 1, n_samples)
            Batch of waveforms.
        Returns
        -------
        features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
            Batch of sinc filters activations.
        """

        self.n_ = self.n_.to(waveforms.device)

        self.window_ = self.window_.to(waveforms.device)

        low = self.min_low_hz  + torch.abs(self.low_hz_)
        
        high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2)
        band=(high-low)[:,0]
        
        f_times_t_low = torch.matmul(low, self.n_)
        f_times_t_high = torch.matmul(high, self.n_)

        band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. 
        band_pass_center = 2*band.view(-1,1)
        band_pass_right= torch.flip(band_pass_left,dims=[1])
        
        
        band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)

        
        band_pass = band_pass / (2*band[:,None])
        

        self.filters = (band_pass).view(
            self.out_channels, 1, self.kernel_size)

        return F.conv1d(waveforms, self.filters, stride=self.stride,
                        padding=self.padding, dilation=self.dilation,
                         bias=None, groups=1)

## RawNet2

In [19]:
class RawNet2(nn.Module):
    def __init__(self, d_args):
        super(RawNet2, self).__init__()

        self.ln = LayerNorm(d_args['nb_samp'])
        self.first_conv = SincConv_fast(in_channels = d_args['in_channels'],
            out_channels = d_args['filts'][0],
            kernel_size = d_args['first_conv']
            )
        
        self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0])
        self.lrelu = nn.LeakyReLU()
        self.lrelu_keras = nn.LeakyReLU(negative_slope = 0.3)
        
        self.block0 = nn.Sequential(Residual_block_wFRM(nb_filts = d_args['filts'][1], first = True))
        self.block1 = nn.Sequential(Residual_block_wFRM(nb_filts = d_args['filts'][1]))
 
        self.block2 = nn.Sequential(Residual_block_wFRM(nb_filts = d_args['filts'][2]))
        d_args['filts'][2][0] = d_args['filts'][2][1]
        self.block3 = nn.Sequential(Residual_block_wFRM(nb_filts = d_args['filts'][2]))
        self.block4 = nn.Sequential(Residual_block_wFRM(nb_filts = d_args['filts'][2]))
        self.block5 = nn.Sequential(Residual_block_wFRM(nb_filts = d_args['filts'][2]))
        self.avgpool = nn.AdaptiveAvgPool1d(1)

        self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1])
        self.gru = nn.GRU(input_size = d_args['filts'][2][-1],
            hidden_size = d_args['gru_node'],
            num_layers = d_args['nb_gru_layer'],
            batch_first = True)

        
        self.fc1_gru = nn.Linear(in_features = d_args['gru_node'],
            out_features = d_args['nb_fc_node'])
        self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'],
            out_features = d_args['nb_classes'],
            bias = True)
        
        self.sig = nn.Sigmoid()
        
    def forward(self, x, y = 0, is_test=False):
        #follow sincNet recipe
        nb_samp = x.shape[0]
        len_seq = x.shape[1]

        x = self.ln(x)
        x=x.view(nb_samp,1,len_seq)
        x = F.max_pool1d(torch.abs(self.first_conv(x)), 3)
        x = self.first_bn(x)
        x = self.lrelu_keras(x)
        
        x = self.block0(x)
        x = self.block1(x)

        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)

        x = self.bn_before_gru(x)
        x = self.lrelu_keras(x)
        x = x.permute(0, 2, 1)  #(batch, filt, time) >> (batch, time, filt)
        self.gru.flatten_parameters()
        x, _ = self.gru(x)
        x = x[:,-1,:]
        code = self.fc1_gru(x)
        if is_test: return code
        
        code_norm = code.norm(p=2,dim=1, keepdim=True) / 10.
        code = torch.div(code, code_norm)
        out = self.fc2_gru(code)
        return out

## Training

In [20]:
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d

In [21]:
model_dict = {}
model_dict['nb_classes'] = len(trainset.speakers_list)
model_dict['first_conv'] = 251
model_dict['in_channels'] = 1
model_dict['filts'] = [128, [128,128], [128,128], [256,256]]
model_dict['m_blocks'] = [2, 4]
model_dict['nb_fc_att_node'] =[1]
model_dict['nb_fc_node'] = 1024
model_dict['gru_node'] = 1024
model_dict['nb_gru_layer'] = 1
model_dict['nb_samp'] = NUMBER_OF_SAMPLES

model_dict['lr_decay'] = "keras"
model_dict['do_lr_decay'] = True

In [22]:
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
AMSGRAD = True
EPOCHS = 5
BATCH_SIZE = 16

# Higher number may cause errors in notebook
NUMBER_OF_WORKERS = 0

In [23]:
trainset_loader = data.DataLoader(trainset,
            batch_size = BATCH_SIZE, 
            shuffle = False,
            drop_last = False,
            num_workers = NUMBER_OF_WORKERS)

evalset_loader = data.DataLoader(evalset,
            batch_size = 1, 
            shuffle = False,
            drop_last = False,
            num_workers = NUMBER_OF_WORKERS)

In [24]:
torch.cuda.empty_cache()


### Batch explained:

If batch size = 4

One batch = [ tensor([[[x,x,x]]], [[x,x,x]], [[x,x,x]], [[x,x,x]]])    , (label1, label2, label3, label4) ]

In [25]:
model = RawNet2(model_dict)

model.to(device)

RawNet2(
  (ln): LayerNorm()
  (first_conv): SincConv_fast()
  (first_bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lrelu): LeakyReLU(negative_slope=0.01)
  (lrelu_keras): LeakyReLU(negative_slope=0.3)
  (block0): Sequential(
    (0): Residual_block_wFRM(
      (lrelu): LeakyReLU(negative_slope=0.01)
      (lrelu_keras): LeakyReLU(negative_slope=0.3)
      (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (mp): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
      (frm): FRM(
        (fc): Linear(in_features=128, out_features=128, bias=True)
        (sig): Sigmoid()
      )
    )
  )
  (block1): Sequential(
    (0): Residual_block_wFRM(
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [26]:
def keras_lr_decay(step, decay = 0.0001):
    return 1./(1. + decay * step)

In [27]:
params = [
    {
        'params': [
            param for name, param in model.named_parameters()
            if 'bn' not in name
        ]
    },
    {
        'params': [
            param for name, param in model.named_parameters()
            if 'bn' in name
        ],
        'weight_decay':
        0
    },
]

criterion = {}
criterion['cce'] = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params,
            lr = LEARNING_RATE,
            weight_decay = WEIGHT_DECAY,
            amsgrad = AMSGRAD)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: keras_lr_decay(step))

In [28]:
def cos_sim(a,b):
    return np.dot(a,b) / (np.linalg.norm(a) * np.linalg.norm(b))

In [68]:
def get_stem_to_path_dict(stem_list, wav_paths_list):
    # Sets are faster to search
    stem_set = set(stem_list)
    
    stem_to_path_d = {}
    for wav_path in wav_paths_list:
        wav_stem = Path(wav_path).stem
        if wav_stem in stem_set:
            stem_to_path_d[wav_stem] = wav_path

    return stem_to_path_d          

In [69]:
def get_wav_list(dataset_dir):
    # Given a directory, return path list of all wav files
    pattern = '**/*.wav'
    files = glob.glob(dataset_dir + pattern , recursive=True)

    # Normalize the file paths. To get file paths with '/' or '\\' consistently depending on OS
    wav_list = [os.path.normpath(i) for i in files]
    return wav_list

In [70]:
stem_to_path_dict = get_stem_to_path_dict(evalset_stems, get_wav_list(BANGLA_ASR_DATASET_DIRECTORY))

for key in stem_to_path_dict:
    print(stem_to_path_dict[key])
    break

data\BanglaASR\WavFiles\asr_bengali_0\asr_bengali\data\00\000020a912.wav


In [71]:
def evaluate_model(model, evalset_loader, device, eval_trials, stem_to_path_dict):

    # The number of validation trials will be twice the number of data in val_dataset
    # One target trial, another non target trial

    model.eval()
    with torch.set_grad_enabled(False):
        #1st, extract speaker embeddings.
        wav_path_to_embeddings_dict = {}

        with tqdm(total = len(evalset_loader), ncols = 70) as pbar:
            for m_batch in evalset_loader:

                l_code = []
                for batch in m_batch[0]:
                    batch = batch.to(device)
                    code = model(x = batch, is_test=True)
                    l_code.extend(code.cpu().numpy())
                
                wav_stem = Path(m_batch[2][0]).stem
                wav_path = stem_to_path_dict[wav_stem]
                wav_path_to_embeddings_dict[wav_path] = np.mean(l_code, axis=0)
                pbar.update(1)
        
        # print("Key", normalized_wav_path)
        # 2nd, calculate EER
        y_score = [] # score for each sample
        y = [] # label for each sample 
        
        for trial in eval_trials:
            trg, utt_a, utt_b = trial
            y.append(int(trg))

            y_score.append(cos_sim(wav_path_to_embeddings_dict[utt_a], wav_path_to_embeddings_dict[utt_b]))

        fpr, tpr, _ = roc_curve(y, y_score, pos_label=1)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

In [72]:
def train_model(model, db_gen, evalset_loader, optimizer, epoch, args, device, lr_scheduler, criterion, eval_trials, stem_to_path_dict):
    # Defining it up here to make it accessable outside loop. I.E: For saving to file
    loss = 0.0
    # model.train()
    # with tqdm(total = len(db_gen), ncols = 70) as pbar:
    #     for m_batch, m_label in db_gen:
            
    #         m_batch = m_batch.squeeze(1)

    #         m_batch, m_label = m_batch.to(device), m_label.to(device)

    #         output = model(m_batch, m_label)
            
    #         cce_loss = criterion['cce'](output, m_label)
    #         loss = cce_loss

    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #         pbar.set_description('epoch: %d, cce:%.3f'%(epoch, cce_loss))
    #         pbar.update(1)
    #         if args['do_lr_decay']:
    #             if model_dict['lr_decay'] == 'keras': lr_scheduler.step()
    eer = evaluate_model(model, evalset_loader, device, eval_trials, stem_to_path_dict)
    print("Validation set EER:", eer)

    # Save EER and model to file
    with open(MODEL_LOSS_EER_OUTPUT_FILE, "a+") as file:
        file.write("Epoch: {}, Loss: {}, EER: {}\n".format(epoch, loss, eer))
    
    torch.save(model.state_dict(), MODEL_SAVE_DIRECTORY+"model_epoch_"+str(epoch)+".pth")

In [73]:
for epoch in range(EPOCHS):
    train_model(model = model,
        db_gen = trainset_loader,
        args = model_dict,
        evalset_loader=evalset_loader,
        optimizer = optimizer,
        lr_scheduler = lr_scheduler,
        criterion = criterion,
        device = device,
        epoch = epoch,
        eval_trials=eval_trials,
        stem_to_path_dict=stem_to_path_dict)

 14%|███▉                        | 2491/17662 [00:59<06:04, 41.62it/s]


KeyboardInterrupt: 