In [1]:
import os
# os.chdir(os.path.dirname(os.path.abspath(__file__))) # change to current file path
import random
import time, shutil
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from hparam import hparam as hp
from data_load import SpeakerDatasetPreprocessed
from speech_embedder_net import SpeechEmbedder, SpeechEmbedder_Softmax, GE2ELoss, GE2ELoss_, get_centroids, get_cossim
from torch.utils.tensorboard import SummaryWriter
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
import numpy as np
from numpy.linalg import solve
import scipy.linalg
import scipy.stats
from tqdm import tqdm
from utils import compute_eer

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

def get_n_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    return params

# Threshold based

In [3]:
model_path = '/home/yrb/code/speechbrain/data/models/Permute/Softmax/SoftmaxExperiment1/Chance=0.2/ckpt_epoch_2_batch_id_45598.pth'

device = torch.device(hp.device)
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

test_dataset = SpeakerDatasetPreprocessed()

test_loader = DataLoader(test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers, drop_last=True)

try:
    embedder_net = SpeechEmbedder().to(device)
    embedder_net.load_state_dict(torch.load(model_path))
except:
    embedder_net = SpeechEmbedder_Softmax(num_classes=5994).to(device)
    embedder_net.load_state_dict(torch.load(model_path))
embedder_net.eval()

print("Number of params: ", get_n_params(embedder_net))

Loading spkr2utter and spkr2utter_mislabel...


100%|██████████| 13706/13706 [00:00<00:00, 1586771.11it/s]


Number of params:  13687102


In [None]:
avg_EER = 0
    
ypreds = []
ylabels = []

for e in range(hp.test.epochs):
    batch_avg_EER = 0
    for batch_id, (mel_db_batch, labels, is_noisy, utterance_ids) in enumerate(tqdm(test_loader)):
        assert hp.test.M % 2 == 0

        utterance_ids = np.array(utterance_ids).T
        mel_db_batch = mel_db_batch.to(device)
        enrollment_batch, verification_batch = torch.split(mel_db_batch, int(mel_db_batch.size(1)/2), dim=1)
        enrollment_batch = torch.reshape(enrollment_batch, (hp.test.N*hp.test.M//2, enrollment_batch.size(2), enrollment_batch.size(3)))
        verification_batch = torch.reshape(verification_batch, (hp.test.N*hp.test.M//2, verification_batch.size(2), verification_batch.size(3)))
        
        perm = torch.randperm(verification_batch.size(0))
        unperm = torch.argsort(perm)
            
        verification_batch = verification_batch[perm]
        # get embedder_net attribute
        if embedder_net.__class__.__name__ == 'SpeechEmbedder_Softmax':
            enrollment_embeddings = embedder_net.get_embedding(enrollment_batch)
            verification_embeddings = embedder_net.get_embedding(verification_batch)
        else:
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)

        verification_embeddings = verification_embeddings[unperm]
        
        enrollment_embeddings = torch.reshape(enrollment_embeddings, (hp.test.N, hp.test.M//2, enrollment_embeddings.size(1)))
        verification_embeddings = torch.reshape(verification_embeddings, (hp.test.N, hp.test.M//2, verification_embeddings.size(1)))
        
        enrollment_centroids = get_centroids(enrollment_embeddings)
        veri_embed = torch.cat([verification_embeddings[:,0],verification_embeddings[:,1],verification_embeddings[:,2]])
        
        veri_embed_norm = veri_embed/torch.norm(veri_embed, dim = 1).unsqueeze(-1)
        enrl_embed = torch.cat([enrollment_centroids]*1)
        enrl_embed_norm = enrl_embed/torch.norm(enrl_embed, dim = 1).unsqueeze(-1)
        sim_mat = torch.matmul(veri_embed_norm, enrl_embed_norm.transpose(-1, -2)).data.cpu().numpy()
        truth = np.ones_like(sim_mat)*(-1)
        for i in range(truth.shape[0]):
            truth[i, i%10] = 1
        ypreds.append(sim_mat.flatten())
        ylabels.append(truth.flatten())

    eer, thresh = compute_eer(ypreds, ylabels)
    print("eer:", eer, "threshold:", thresh)

# O2U-Net