# TinyRawNet: A Student of RawNet Speaker Recognition Model

Necessary Imports:

In [130]:
import torch
from torch.utils import data

import torchaudio

import csv
import pandas as pd

import glob
from pathlib import Path

Directories and lcoations:

In [131]:
# Directories are assumed to have a trailing '/' or '\\' in all the subsequent code

CURRENT_WORKING_DIRECTORY = "W:/SpeakerRecognitionResearch"

BANGLA_ASR_DATASET_DIRECTORY = "data/BanglaASR/WavFiles/"
BANGLA_ASR_TSV_LOCATION = "data/BanglaASR/utt_spk_text.tsv"

# To avoid file location related errors, we make sure "SpeakerRecognitionResearch" root folder is the current working directory.
os.chdir(CURRENT_WORKING_DIRECTORY)
os.getcwd()

'W:\\SpeakerRecognitionResearch'

Constants:

In [132]:
# If sample_rate = 16K and number_of_samples = 32000, then each tensor will be equivalent to 2 seconds of data
SAMPLE_RATE = 16000
NUMBER_OF_SAMPLES = 32000

# Bangla ASR Dataset has around half of second of silence in the beginning
# This constant will be used to cut samples from the left of the audio
TRIM_AMOUNT_TIME = 0.5

## Custom dataset for Bangla ASR

This custom dataset is written with the assumption that the Dataset has been already converted into wav format. Check evaluate_asr_ds.ipynb notebook for conversion method.

In [133]:
class BanglaAsrDataset(data.Dataset):
    def __init__(self, dataset_dir, tsv_loc, target_sample_rate, target_num_samples, trim_amount_time):

        tsv_dataframe = pd.read_csv(tsv_loc, quoting=csv.QUOTE_NONE, sep='\t', header=None)

        # The TSV file contains speech annotations in the third column.
        # We don't need the annotations, so we drop the column
        tsv_dataframe = tsv_dataframe.iloc[:,:-1]

        self.dataset_dir = dataset_dir
        self.wav_to_spk_mapping = dict(sorted(tsv_dataframe.values.tolist()))
        self.wav_path_list = self._get_audio_path_list()
        self.target_sample_rate = target_sample_rate
        self.target_num_samples = target_num_samples
        self.trim_amount_time = trim_amount_time
        
    def _get_audio_path_list(self):
        pattern = '**/*.wav'
        files = glob.glob(self.dataset_dir + pattern , recursive=True)

        # Normalize the file paths. To get file paths with '/' or '\\' consistently depending on OS
        wav_list = [os.path.normpath(i) for i in files]
        return wav_list

    def _resample_to_target_sr(self, signal, sample_rate):
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
            signal = resampler(signal)
        return signal

    def _mix_down_to_mono(self, signal):
        if signal.shape[0] > 1:
            signal = torch.mean(siggnal, dim=0, keepdim=True)
        return signal

    def _trim(self, signal):
        total_samples = signal.shape[-1]

        # We cut a fixed amount on the left side if the signal is big enough
        trim_samples_amount = int(self.target_sample_rate * self.trim_amount_time)

        if total_samples >= trim_samples_amount + self.target_num_samples:
            signal = signal[: , trim_samples_amount:]
            total_samples = signal.shape[-1]

        # We cut from the right side if the signal is too big
        if total_samples > self.target_num_samples:
            signal = signal[:, :self.target_num_samples]
        
        # We add zero padding on the right if signal is too small
        if total_samples < self.target_num_samples:
            num_missing_samples = self.target_num_samples - total_samples
            last_dim_padding = (0, num_missing_samples)
            signal = torch.nn.functional.pad(signal, last_dim_padding)
            
        return signal

    def _normalize_like_sincnet(self, signal):
        return signal/torch.max(torch.abs(signal))

    def __len__(self):
        return len(self.wav_to_spk_mapping)

    def __getitem__(self, index):
        wav_path = self.wav_path_list[index]
        wav_name = Path(wav_path).stem
        label = self.wav_to_spk_mapping[wav_name]

        signal, sample_rate = torchaudio.load(wav_path)

        signal = self._resample_to_target_sr(signal, sample_rate)
        signal = self._mix_down_to_mono(signal)

        signal =  self._trim(signal)
        signal = self._normalize_like_sincnet(signal)

        return signal, label

In [134]:
bangla_asr_dataset = BanglaAsrDataset(
    dataset_dir=BANGLA_ASR_DATASET_DIRECTORY,
    tsv_loc = BANGLA_ASR_TSV_LOCATION,
    target_sample_rate=SAMPLE_RATE,
    target_num_samples = NUMBER_OF_SAMPLES,
    trim_amount_time = TRIM_AMOUNT_TIME
)

assert bangla_asr_dataset.wav_to_spk_mapping['000020a912'] == '16cfb' , "The dictionary returned wrong mapping"

In [135]:
bangla_asr_dataset[0], bangla_asr_dataset[0][0].shape

((tensor([[-0.0076, -0.0041,  0.0149,  ..., -0.1191, -0.1366, -0.1302]]),
  '16cfb'),
 torch.Size([1, 32000]))

800