In [None]:
import time

def loadbar(iteration, total, prefix="", suffix="", decimal=0, 
            length=100, fill="=", extras=""):
    per_val = iteration*100/float(total)
    
    percent = ("{0:." + str(decimal) + "f}").format(per_val)   
    cur_percent = ( ' ' * (3-len(str(round(per_val)))) + percent)
    
    filledLen = int(length * iteration//total)
    if per_val == 100:
        bar = fill * filledLen + "." * (length - filledLen)
    else:
        bar = fill * filledLen + ">" + "." * (length - filledLen - 1)
        
    print(f"\r{prefix} [{bar}] {cur_percent}% {suffix}", end="\r")
    if iteration == total: 
        print(f"\r{prefix} [{bar}] {cur_percent}% {suffix} {extras}", end="\n")
        
    time.sleep(0.1)

In [2]:
# Libraries
import os
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from torch.nn.utils import clip_grad_norm_
from scipy.optimize import brentq

from sklearn.model_selection import train_test_split

In [3]:
# Global Const
PAR_N_FRAMES = 160
EM_SIZE = 256
MEL_N_CHANNELS = 40

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOSS_DEVICE = torch.device("cpu")

In [4]:
# Globals Vars
nspeaker = 8
nutter = 10

nworker = 0
# nworker = os.cpu_count()

hlsize = 256
nlayer = 3
learnrate =1e-4

epochs = 5

limt = 30
limv = 20

In [6]:
# Utils
class RandomCycler:  
    def __init__(self, source):
        self.all_items = list(source)
        self.next_items = []
    
    def sample(self, count: int):
        shuffle = lambda l: random.sample(l, len(l))
        
        out = []
        while count > 0:
            if count >= len(self.all_items):
                out.extend(shuffle(list(self.all_items)))
                count -= len(self.all_items)
                continue
            n = min(count, len(self.next_items))
            out.extend(self.next_items[:n])
            count -= n
            self.next_items = self.next_items[n:]
            if len(self.next_items) == 0:
                self.next_items = shuffle(list(self.all_items))
        return out
    
    def __next__(self):
        return self.sample(1)[0]

class Utterance:
    def __init__(self, uid, frames):
        self.id = uid
        self.frames = frames

    def random_partial(self, n_frames=PAR_N_FRAMES):
        if self.frames.shape[0] == n_frames:
            start = 0
        else:
            start = np.random.randint(0, self.frames.shape[0] - n_frames)
        end = start + n_frames
        return self.frames[start:end], (start, end)   
    
class Speaker:
    def __init__(self, sid, data):
        self.id = sid
        self.data = data
        self.utterances = None
        self.utterance_cycler = None
        
    def _load_utterances(self):
        self.utterances = [Utterance(idx, arr) for idx, arr in self.data]
        self.utterance_cycler = RandomCycler(self.utterances)
               
    def random_partial(self, count, n_frames=PAR_N_FRAMES):
        if self.utterances is None: self._load_utterances()
        utterances = self.utterance_cycler.sample(count)
        return [(u,) + u.random_partial(n_frames) for u in utterances]

class SpeakerBatch:
    def __init__(self, speakers, utterances_per_speaker, n_frames=PAR_N_FRAMES):
        self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
        self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])   
        
class SpeakerVerificationDataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe      
        self.speakers = [Speaker(x, self.df[self.df.speaker_id==x][["file_id", "frames"]].to_numpy()) for x in self.df.speaker_id.unique()]
        self.speaker_cycler = RandomCycler(self.speakers)

    def __len__(self):
        return int(1e10)
        
    def __getitem__(self, index):
        return next(self.speaker_cycler)
    
class SpeakerVerificationDataLoader(DataLoader):
    def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, 
                 batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, 
                 worker_init_fn=None):
        self.utterances_per_speaker = utterances_per_speaker

        super().__init__(
            dataset=dataset, 
            batch_size=speakers_per_batch, 
            shuffle=False, 
            sampler=sampler, 
            batch_sampler=batch_sampler, 
            num_workers=num_workers,
            collate_fn=self.collate, 
            pin_memory=pin_memory, 
            drop_last=False, 
            timeout=timeout, 
            worker_init_fn=worker_init_fn
        )

    def collate(self, speakers):
        return SpeakerBatch(speakers, self.utterances_per_speaker, n_frames=PAR_N_FRAMES)

In [7]:
# Model
class SpeakerEncoder(nn.Module):
    def __init__(
        self, 
        hidden_size,
        num_layers,
        learning_rate,
        ):
        super().__init__()
        self.loss_device = LOSS_DEVICE
        
        self.lstm = nn.LSTM(input_size=MEL_N_CHANNELS,
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True).to(DEVICE)
        self.linear = nn.Linear(in_features=hidden_size, 
                                out_features=EM_SIZE).to(DEVICE)
        self.relu = torch.nn.ReLU().to(DEVICE)
        
        self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(self.loss_device)
        self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(self.loss_device)

        self.loss_fn = nn.CrossEntropyLoss().to(self.loss_device)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        
    def do_gradient_ops(self):
        self.similarity_weight.grad *= 0.01
        self.similarity_bias.grad *= 0.01
            
        clip_grad_norm_(self.parameters(), 3, norm_type=2)
    
    def forward(self, utterances, hidden_init=None):
        out, (hidden, cell) = self.lstm(utterances, hidden_init)
        
        embeds_raw = self.relu(self.linear(hidden[-1]))
        
        return embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)        
    
    def similarity_matrix(self, embeds):
        speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
        
        centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
        centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)

        centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
        centroids_excl /= (utterances_per_speaker - 1)
        centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)

        sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                                 speakers_per_batch).to(self.loss_device)
        mask_matrix = 1 - np.eye(speakers_per_batch, dtype=int)
        for j in range(speakers_per_batch):
            mask = np.where(mask_matrix[j])[0]
            sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
            sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
        
        return sim_matrix * self.similarity_weight + self.similarity_bias
    
    def loss(self, embeds):
        speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
        
        sim_matrix = self.similarity_matrix(embeds)
        sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 
                                         speakers_per_batch))
        ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
        target = torch.from_numpy(ground_truth).long().to(self.loss_device)
        loss = self.loss_fn(sim_matrix, target)
        
        with torch.no_grad():
            inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=int)[0]
            labels = np.array([inv_argmax(i) for i in ground_truth])
            preds = sim_matrix.detach().cpu().numpy()

            fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
            eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
            
        return loss, eer
    
def sync(device: torch.device):
    if device.type == "cuda": torch.cuda.synchronize(device)

In [8]:
df = pd.DataFrame.from_records(np.load("data/encoder_librispeech_valid.npz", allow_pickle=True)["data"], columns=["file_id", "speaker_id", "frames"])

df = df.merge(df.groupby("speaker_id")["file_id"].count().reset_index().rename(columns={"file_id": "file_count"}), on="speaker_id")

df = df[df["file_count"]>=nutter]

df.head()

Unnamed: 0,file_id,speaker_id,frames,file_count
0,5536-43363-0006,5536,"[[5.1189186e-06, 8.7852095e-06, 1.9026568e-06,...",59
1,5536-43363-0010,5536,"[[2.8548186e-06, 2.4248714e-06, 3.9052784e-06,...",59
2,5536-43363-0007,5536,"[[1.979813e-05, 1.698165e-05, 6.14011e-06, 1.5...",59
3,5536-43363-0014,5536,"[[1.6595812e-05, 7.0819824e-06, 5.835367e-06, ...",59
4,5536-43363-0001,5536,"[[8.575194e-06, 3.4151853e-06, 1.0975212e-06, ...",59


In [9]:
X_train, X_valid = train_test_split(df, test_size=0.2, random_state=42)
X_train.shape, X_valid.shape

((2144, 4), (536, 4))

In [10]:
train_dl = SpeakerVerificationDataLoader(SpeakerVerificationDataset(X_train), nspeaker, nutter, num_workers=nworker)
valid_dl = SpeakerVerificationDataLoader(SpeakerVerificationDataset(X_valid), nspeaker, nutter, num_workers=nworker)

In [12]:
class EncoderTrainer:
    def __init__(self, model):
        self.model = model
    
    def common_step(self, batch_idx, batch, stage=None):
        inputs = torch.from_numpy(batch.data).to(DEVICE)
        sync(DEVICE)

        embeds = self.model(inputs)
        sync(DEVICE)

        embeds_loss = embeds.view((nspeaker, nutter, -1)).to(LOSS_DEVICE)
        return self.model.loss(embeds_loss)
    
    def training_step(self, batch_idx, batch):
        loss, eer = self.common_step(batch_idx, batch, "train")

        self.model.zero_grad()
        loss.backward()

        self.model.do_gradient_ops()
        self.model.optimizer.step()
        return loss, eer
    
    def validation_step(self, batch_idx, batch):
        return self.common_step(batch_idx, batch, "val")
    
    def validation_epoch(self, valid_dl):
        res = {
            "loss": 0,
            "eer": 0
        }
        bsize = len(valid_dl) if limv == -1 else limv
        for idx, batch in enumerate(valid_dl):
            if idx > bsize-1: break
            loss, eer = self.validation_step(idx, batch)
            res["loss"] += loss
            res["eer"] += eer
        return res["loss"]/bsize, res["eer"]/bsize
    
    def fit(self, epochs, train_dl, valid_dl=None):
        res = {
            "loss": 0,
            "eer": 0
        }
        
        bsize = len(train_dl) if limt == -1 else limt
        for epoch in range(epochs):
            i = 0
            
            cur_epoch = ( '0' * (len(str(epochs))-len(str(epoch+1))) + str(epoch+1))
            cur_batch = ( '0' * (len(str(bsize))-len(str(i+1))) + str(i+1))
            
            loadbar(i, bsize, f"Epoch [{cur_epoch}/{epochs}] {cur_batch}/{bsize}", length=50)             
            for idx, batch in enumerate(train_dl):
                
                if idx > bsize-1: break
                
                loss, eer = self.training_step(idx, batch)
                s = f"- loss: {loss:0.4f} - eer: {eer:0.4f}"
                
                res["loss"] += loss
                res["eer"] += eer
                
                cur_batch = ( '0' * (len(str(bsize))-len(str(i+1))) + str(i+1))
                p = f"Epoch [{cur_epoch}/{epochs}] {cur_batch}/{bsize}"
                
                if (i+1 == bsize) and (valid_dl!=None):
                    s = f"- loss: {res['loss']/bsize:0.4f} - eer: {res['loss']/bsize:0.4f}"
                    
                    loss, eer = self.validation_epoch(valid_dl)
                    e =f" - val_loss: {loss:0.4f} - val_eer: {eer:0.4f}"
                    
                    loadbar(i+1, bsize, p, s, extras=e, length=50)
                else:
                    loadbar(i+1, bsize, p, s, length=50)
                    
                i+=1
                
#                 break
#             break

model = SpeakerEncoder(hlsize, nlayer, learnrate)
trainer = EncoderTrainer(model)

trainer.fit(epochs, train_dl)

# trainer.fit(epochs, train_dl, valid_dl) 

