In [2]:
%cd /content/drive/MyDrive/753_hacker

/content/drive/MyDrive/753_hacker


Create numpy files from the .wav files in the dataset

In [None]:
!python data_preprocess.py

In [3]:
import os
import random
import time
import torch
from torch.utils.data import DataLoader
import librosa
import numpy as np
import torch
import torch.autograd as grad
import torch.nn.functional as F

In [9]:
from hparam import hparam as hp
from data_load import SpeakerDatasetTIMIT, SpeakerDatasetTIMITPreprocessed

from speech_embedder_net import SpeechEmbedder

import torch
import torch.nn as nn

get_centroid and get_cossim 

In [10]:
def get_utterance_centroids(embeddings):
    """
    Returns the centroids for each utterance of a speaker, where
    the utterance centroid is the speaker centroid without considering
    this utterance

    Shape of embeddings should be:
        (speaker_ct, utterance_per_speaker_ct, embedding_size)
    """
    sum_centroids = embeddings.sum(dim=1)
    # we want to subtract out each utterance, prior to calculating the
    # the utterance centroid
    sum_centroids = sum_centroids.reshape(
        sum_centroids.shape[0], 1, sum_centroids.shape[-1]
    )
    # we want the mean but not including the utterance itself, so -1
    num_utterances = embeddings.shape[1] - 1
    centroids = (sum_centroids - embeddings) / num_utterances
    return centroids

In [11]:
def get_centroids(embeddings):
    centroids = embeddings.mean(dim=1)
    return centroids



def get_cossim(embeddings, centroids):
    # number of utterances per speaker
    num_utterances = embeddings.shape[1]
    utterance_centroids = get_utterance_centroids(embeddings)

    # flatten the embeddings and utterance centroids to just utterance,
    # so we can do cosine similarity
    utterance_centroids_flat = utterance_centroids.view(
        utterance_centroids.shape[0] * utterance_centroids.shape[1],
        -1
    )
    embeddings_flat = embeddings.view(
        embeddings.shape[0] * num_utterances,
        -1
    )
    # the cosine distance between utterance and the associated centroids
    # for that utterance
    # this is each speaker's utterances against his own centroid, but each
    # comparison centroid has the current utterance removed
    cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat)

    # now we get the cosine distance between each utterance and the other speakers'
    # centroids
    # to do so requires comparing each utterance to each centroid. To keep the
    # operation fast, we vectorize by using matrices L (embeddings) and
    # R (centroids) where L has each utterance repeated sequentially for all
    # comparisons and R has the entire centroids frame repeated for each utterance
    centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1))
    embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1)
    embeddings_expand = embeddings_expand.view(
        embeddings_expand.shape[0] * embeddings_expand.shape[1],
        embeddings_expand.shape[-1]
    )
    cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand)
    cos_diff = cos_diff.view(
        embeddings.size(0),
        num_utterances,
        centroids.size(0)
    )
    # assign the cosine distance for same speakers to the proper idx
    same_idx = list(range(embeddings.size(0)))
    cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances)
    cos_diff = cos_diff + 1e-6
    return cos_diff

def calc_loss(sim_matrix):
    same_idx = list(range(sim_matrix.size(0)))
    pos = sim_matrix[same_idx, :, same_idx]
    neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_()
    per_embedding_loss = -1 * (pos - neg)
    loss = per_embedding_loss.sum()
    return loss, per_embedding_loss

In [31]:
batch_size = 4
num_workers=1
lr = 0.01
epochs=10
utt = 4

GE2E LOSS

In [14]:
class GE2ELoss(nn.Module):
    
    def __init__(self, device):
        super(GE2ELoss, self).__init__()
        self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
        self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
        self.device = device
        
    def forward(self, embeddings):
        torch.clamp(self.w, 1e-6)
        centroids = get_centroids(embeddings)
        cossim = get_cossim(embeddings, centroids)
        sim_matrix = self.w*cossim.to(self.device) + self.b
        loss, _ = calc_loss(sim_matrix)
        return loss

In [19]:
def train():
    device = torch.device("cuda")
    train_dataset = SpeakerDatasetTIMITPreprocessed()
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) 
    
    embedder_net = SpeechEmbedder().to(device)
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([
                    {'params': embedder_net.parameters()},
                    {'params': ge2e_loss.parameters()}
                ], lr=lr)
    
    os.makedirs('speech_id_checkpoint', exist_ok=True)
    
    embedder_net.train()
    iteration = 0
    for e in range(epochs):
        print('epoch - '+str(e)+'started')
        total_loss = 0
        for batch_id, mel_db_batch in enumerate(train_loader): 
            mel_db_batch = mel_db_batch.to(device)
            
            mel_db_batch = torch.reshape(mel_db_batch, (batch_size*utt, mel_db_batch.size(2), mel_db_batch.size(3)))
            perm = random.sample(range(0, batch_size*utt), batch_size*utt)
            unperm = list(perm)
            for i,j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()
            
            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(embeddings, (batch_size,utt, embeddings.size(1)))
            
            #get loss, call backward, step optimizer
            loss = ge2e_loss(embeddings) #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()
            
            total_loss = total_loss + loss
            iteration += 1
            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1,
                        batch_id+1, len(train_dataset)//batch_size, iteration,loss, total_loss / (batch_id + 1))
                print(mesg)
                # if hp.train.log_file is not None:
                #     with open(hp.train.log_file,'a') as f:
                #         f.write(mesg)
                    
        if (e + 1) % 10 == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth"
            ckpt_model_path = os.path.join('speech_id_checkpoint', ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()

    #save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model"
    save_model_path = os.path.join('speech_id_checkpoint', save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)
    
    print("\nDone, trained model saved at", save_model_path)


In [32]:
def test(model_path):
    
    test_dataset = SpeakerDatasetTIMITPreprocessed()
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    
    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()
    
    avg_EER = 0
    for e in range(1):
        batch_avg_EER = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            assert utt % 2 == 0
            print(len(torch.split(mel_db_batch, int(mel_db_batch.size(1)/2), dim=1)))
            enrollment_batch, verification_batch = torch.split(mel_db_batch, int(mel_db_batch.size(1)/2), dim=1)
            
            enrollment_batch = torch.reshape(enrollment_batch, (batch_size*utt//2, enrollment_batch.size(2), enrollment_batch.size(3)))
            verification_batch = torch.reshape(verification_batch, (batch_size*utt//2, verification_batch.size(2), verification_batch.size(3)))
            
            perm = random.sample(range(0,verification_batch.size(0)), verification_batch.size(0))
            unperm = list(perm)
            for i,j in enumerate(perm):
                unperm[j] = i
                
            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]
            
            enrollment_embeddings = torch.reshape(enrollment_embeddings, (batch_size,utt//2, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(verification_embeddings, (batch_size,utt//2, verification_embeddings.size(1)))
            
            enrollment_centroids = get_centroids(enrollment_embeddings)
            
            sim_matrix = get_cossim(verification_embeddings, enrollment_centroids)
            
            # calculating EER
            diff = 1; EER=0; EER_thresh = 0; EER_FAR=0; EER_FRR=0
            
            for thres in [0.01*i+0.5 for i in range(50)]:
                sim_matrix_thresh = sim_matrix>thres
                
                FAR = (sum([sim_matrix_thresh[i].float().sum()-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(batch_size))])
                /(batch_size-1.0)/(float(utt/2))/batch_size)
    
                FRR = (sum([utt/2-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))])
                /(float(utt/2))/batch_size)
                
                # Save threshold when FAR = FRR (=EER)
                if diff> abs(FAR-FRR):
                    diff = abs(FAR-FRR)
                    EER = (FAR+FRR)/2
                    EER_thresh = thres
                    EER_FAR = FAR
                    EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)"%(EER,EER_thresh,EER_FAR,EER_FRR))
        avg_EER += batch_avg_EER/(batch_id+1)
    print("\n EER across {0} epochs: {1:.4f}".format(1, avg_EER))


In [20]:
train()

epoch - 0started
Sat May  7 08:34:15 2022	Epoch:1[30/141],Iteration:30	Loss:26.4649	TLoss:26.3836	

Sat May  7 08:34:26 2022	Epoch:1[60/141],Iteration:60	Loss:14.6437	TLoss:24.6774	

Sat May  7 08:34:37 2022	Epoch:1[90/141],Iteration:90	Loss:13.0071	TLoss:23.0980	

Sat May  7 08:34:47 2022	Epoch:1[120/141],Iteration:120	Loss:20.6213	TLoss:22.3128	

epoch - 1started
Sat May  7 08:35:06 2022	Epoch:2[30/141],Iteration:171	Loss:19.0303	TLoss:19.2272	

Sat May  7 08:35:16 2022	Epoch:2[60/141],Iteration:201	Loss:15.8807	TLoss:19.5838	

Sat May  7 08:35:27 2022	Epoch:2[90/141],Iteration:231	Loss:21.3074	TLoss:19.1919	

Sat May  7 08:35:38 2022	Epoch:2[120/141],Iteration:261	Loss:11.0907	TLoss:18.9132	

epoch - 2started
Sat May  7 08:35:56 2022	Epoch:3[30/141],Iteration:312	Loss:14.8836	TLoss:17.3539	

Sat May  7 08:36:07 2022	Epoch:3[60/141],Iteration:342	Loss:20.1234	TLoss:18.4257	

Sat May  7 08:36:18 2022	Epoch:3[90/141],Iteration:372	Loss:12.0562	TLoss:18.6345	

Sat May  7 08:36:28 2022	E

In [33]:
test('./speech_id_checkpoint/final_epoch_10_batch_id_141.model')

3


ValueError: ignored