In [1]:
import os
import random
from typing import Dict
from numbers import Real

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from hparam import Hparam
from data_load import SpeakerDatasetPreprocessed
from speech_embedder_net import SpeechEmbedder, SpeechEmbedder_Softmax, GE2ELoss_, AAMSoftmax,SubcenterArcMarginProduct, get_cossim
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
import numpy as np
import scipy.linalg
import scipy.stats
from tqdm import tqdm
import pandas as pd
from sklearn.mixture import GaussianMixture
import seaborn as sns

from utils import fit_bmm

# Preparatory Work

In [2]:
# Set all random seed to 1 to ensure consistent experiment output
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

In [3]:
# Load Hyperparameter configuration files
hp = Hparam(file='config/config.yaml')
model_path = hp.nld.model_path
device = torch.device(hp.device)
print(f'{hp.stage = }, {hp.nld.noise_type = }')
print(f'{hp.nld.model_path = }')

hp.stage = 'nld', hp.nld.noise_type = 'Permute'
hp.nld.model_path = '/home/yrb/code/speechbrain/data/models/Permute/GE2E/20%/m8_bs128/ckpt_epoch_200.pth'


In [4]:
def get_n_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    return params

def extract_emb(embedder_net, batch):
    if batch.ndim == 4:
        batch = batch.reshape(-1, batch.size(2), batch.size(3))
    if embedder_net.__class__.__name__ == 'SpeechEmbedder_Softmax':
        embeddings = embedder_net.get_embedding(batch)
    else:
        embeddings = embedder_net(batch)
    return embeddings

# TODO finish all function docstrings
def get_hyperparameters_from_path(model_absolute_path: str) -> Dict[str, Real]:
    directory_name = os.path.dirname(model_absolute_path)
    last_folder_name = directory_name.split('/')[-1]
    ret: Dict[str, Real] = dict()
    for param in last_folder_name.split('_'):
        for i, char in enumerate(param):
            if not char.isalpha():
                key = param[:i]
                value = float(param[i:])
                if value.is_integer():
                    value = int(value)
                ret[key] = value
                break
    return ret


def get_loss_type(model_absolute_path: str):
    for i in ['/Softmax/', '/GE2E/', '/AAM/', '/AAMSC/']:
        if i in model_absolute_path:
            if i == '/Softmax/':
                return 'Softmax'
            elif i == '/GE2E/':
                return 'GE2E'
            elif i == '/AAM/':
                return 'AAM'
            else:
                return 'AAMSC'
    return NotImplemented


def get_criterion(device, model_path):
    loss_type = get_loss_type(model_path)
    # TODO read hyperparameter from `get_hyperparameters_from_path`
    if loss_type == 'Softmax':
        criterion = torch.nn.NLLLoss()
    elif loss_type == 'GE2E':
        criterion = GE2ELoss_(init_w=10.0, init_b=-5.0, loss_method='softmax').to(device)
    elif loss_type == 'AAM':
        criterion = AAMSoftmax(hp.model.proj, 5994, scale=hp.train.s, margin=hp.train.m, easy_margin=True).to(device)
    elif loss_type == 'AAMSC':
        criterion = SubcenterArcMarginProduct(hp.model.proj, 5994, s=hp.train.s, m=hp.train.m, K=hp.train.K).to(device)
    else:
        raise ValueError('Unknown loss')
    return criterion, loss_type

In [5]:
try:
    embedder_net = SpeechEmbedder(hp).to(device)
    embedder_net.load_state_dict(torch.load(model_path))
except:
    embedder_net = SpeechEmbedder_Softmax(hp=hp, num_classes=5994).to(device)
    embedder_net.load_state_dict(torch.load(model_path))

criterion = get_criterion(hp.device, model_path)[0]
criterion.load_state_dict(torch.load(model_path.replace('ckpt_epoch', 'ckpt_criterion_epoch')))
embedder_net.eval()
criterion.eval()

print(f"Number of params: {get_n_params(embedder_net)}")

Initialised GE2E
Number of params: 12134656


# Distance Ranking

## $\mathbf{c}_{s}=\frac{1}{\sum_{u} 1} \sum_{u} \mathbf{f}_{s, u}$
## $\mathbf{d(f_{s, u}, c_s)} = 1 - \dfrac{\mathbf{f}_{s, u} \cdot \mathbf{c}_s}{\Vert \mathbf{f}_{s, u} \Vert _2 \cdot \Vert \mathbf{c}_s \Vert _2}$

In [None]:
ypreds = []
ylabels = []


cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
nld_dataset = SpeakerDatasetPreprocessed(hp)
nld_loader = DataLoader(nld_dataset, batch_size=hp.nld.N, shuffle=False, num_workers=hp.nld.num_workers, drop_last=False)
for batch_id, (mel_db_batch, labels, is_noisy, utterance_ids) in enumerate(tqdm(nld_loader)):
    utterance_ids = np.array(utterance_ids).T
    mel_db_batch = mel_db_batch.to(device)
    
    embeddings = extract_emb(embedder_net, mel_db_batch) 
    embeddings = torch.reshape(embeddings, (hp.nld.N, hp.nld.M, embeddings.size(1))) # (1, M, 256)
    centroid = embeddings.mean(dim=1, keepdim=True)
    embeddings = embeddings / embeddings.norm(dim=2, keepdim=True) 
    centroid = centroid / centroid.norm(dim=2, keepdim=True)
    cos_sim = get_cossim(embeddings, centroid, cos)
    ypreds.extend((1 - cos_sim).reshape(-1).cpu().detach().numpy().tolist())
    ylabels.extend(is_noisy.reshape(-1).cpu().detach().numpy().tolist())
    if batch_id == 500:
        break

In [None]:
# select top noise level % from ypreds
noise_level = hp.nld.noise_level
print("noise level: ", hp.nld.noise_level)
ypreds = np.array(ypreds)
ylabels = np.array(ylabels)

def compute_precision(ypreds, ylabels, noise_level):
    selected = np.argsort(ypreds)[-int(len(ypreds) * noise_level / 100):]
    selected_ypreds = ypreds[selected]
    selected_ylabels = ylabels[selected]
    # compute precision
    return selected_ylabels.sum() / len(selected_ylabels)

print("top noise level precision: ", compute_precision(ypreds, ylabels, noise_level))

In [None]:
# create df from ypreds and ylabels
df = pd.DataFrame({"Distance": ypreds, "isNoisy": ylabels})
print("plot distance distribution")
import seaborn as sns
sns.set()
# ax = sns.displot(ypreds)
p = sns.displot(df, x="Distance", hue="isNoisy")
p.fig.set_dpi(100)

## predict with BMM

In [None]:
bmm_model, bmm_model_max, bmm_model_min = fit_bmm(ypreds, max_iters=50)
bmm_model.plot()

# Noise level estimation
if hp.nld.noise_level >=70:
    estimated_noise_level = bmm_model.weight[0]
else:
    estimated_noise_level = bmm_model.weight[1]
print("Estimated noise level: ", estimated_noise_level)
print("Noise level: ", ylabels.sum() / len(ylabels) * 100)
print("top estimated noise level precision: ", compute_precision(ypreds, ylabels, estimated_noise_level))
print("top noise level precision: ", compute_precision(ypreds, ylabels, hp.nld.noise_level))

In [None]:
ypreds_bmm = bmm_model.predict(ypreds)
if hp.nld.noise_level >=70:
    ypreds_bmm = 1 - ypreds_bmm
bmm_precision = (ypreds_bmm * ylabels).sum() / ypreds_bmm.sum()
print("BMM precision: ", bmm_precision)

## Predict with GMM

In [None]:
gmm_model = GaussianMixture(n_components=2, random_state=1)
gmm_model.fit(ypreds.reshape(-1, 1))


# Noise level estimation
ypreds_gmm = gmm_model.predict(ypreds.reshape(-1, 1))
if gmm_model.means_[0][0] > gmm_model.means_[1][0]:
    ypreds_gmm = 1 - ypreds_gmm
estimated_noise_level = ypreds_gmm.sum() / len(ypreds) * 100
print("Estimated noise level: ", estimated_noise_level)
print("Noise level: ", ylabels.sum() / len(ylabels) * 100)
print("top estimated noise level precision: ", compute_precision(ypreds, ylabels, estimated_noise_level))

In [None]:
gmm_precision = (ypreds_gmm * ylabels).sum() / ypreds_gmm.sum()
print("GMM precision: ", gmm_precision)

# Loss Ranking

# Confidence Ranking

$\mathbf{I}_c = \max((1 - \mathbf{y}) * \text{\textbf{CLS}}(\mathbf{x}; \theta))$

In [6]:
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
nld_dataset = SpeakerDatasetPreprocessed(hp)
nld_loader = DataLoader(nld_dataset, batch_size=hp.nld.N, shuffle=False, num_workers=0, drop_last=False)

Loading spkr2utter and spkr2utter_mislabel...


100%|██████████| 2918286/2918286 [00:08<00:00, 333395.30it/s]


Calculate the confidence ranking with CE loss

In [None]:
i_c_collection_ce = np.array([])
is_noisy_collection_ce = np.array([])

In [None]:
for batch_id, (mel_db_batch, labels, is_noisy, utterance_ids) in enumerate(tqdm(nld_loader)):
    # utterance_ids = np.array(utterance_ids).T

    y = F.one_hot(labels.flatten(), len(nld_dataset)).to(device) # tensor size: (M, 5994)

    mel_db_batch = mel_db_batch.to(device)
    if mel_db_batch.ndim == 4:
        mel_db_batch = mel_db_batch.reshape(-1, mel_db_batch.size(2), mel_db_batch.size(3))

    embeddings = embedder_net.get_embedding(mel_db_batch)
    # embeddings = torch.reshape(embeddings, (hp.nld.N, hp.nld.M, embeddings.size(1))) # (1, M, 256)

    confidence = embedder_net.get_confidence(embeddings)

    i_c = ((1 - y) * confidence).max(dim=1)

    # is_noisy_flatten = is_noisy.flatten()
    # i_c_not_noisy = i_c[is_noisy_flatten == 0]
    # i_c_noisy = i_c[is_noisy_flatten == 1]

    i_c_collection_ce = np.hstack([i_c_collection_ce, i_c.values.cpu().detach().numpy()])
    is_noisy_collection_ce = np.hstack([is_noisy_collection_ce, is_noisy.flatten().cpu().detach().numpy()])

    if batch_id == 50:
        break

In [None]:
i_c_noisy_ce = i_c_collection_ce[is_noisy_collection_ce == 1]
i_c_not_noisy_ce = i_c_collection_ce[is_noisy_collection_ce == 0]
print(f'{i_c_noisy_ce.mean() = }, {i_c_noisy_ce.var() = }')
print(f'{i_c_not_noisy_ce.mean() = }, {i_c_not_noisy_ce.var() = }')

Calculate the confidence ranking using GE2E

In [7]:
i_c_collection_ge2e = np.array([])
is_noisy_collection_ge2e = np.array([])

In [8]:
for batch_id, (mel_db_batch, labels, is_noisy, utterance_ids) in enumerate(tqdm(nld_loader)):
    y = F.one_hot(labels.flatten(), len(nld_dataset)).to(device) # tensor size: (M, 5994)

    mel_db_batch = mel_db_batch.to(device)
    if mel_db_batch.ndim == 4:
        mel_db_batch = mel_db_batch.reshape(-1, mel_db_batch.size(2), mel_db_batch.size(3))

    ge2e_centroid = mel_db_batch.mean(dim=1)

    embeddings = embedder_net.get_embedding(mel_db_batch)
    embeddings = embeddings.reshape((hp.nld.N, hp.nld.M, embeddings.size(1)))
    confidence = criterion.get_confidence(embeddings)

    i_c = ((1 - y) * confidence).max(dim=1)

    i_c_collection_ge2e = np.hstack([i_c_collection_ge2e, i_c.values.cpu().detach().numpy()])
    is_noisy_collection_ge2e = np.hstack([is_noisy_collection_ge2e, is_noisy.flatten().cpu().detach().numpy()])

  0%|          | 0/1499 [32:59<?, ?it/s]


TypeError: only integer tensors of a single element can be converted to an index

In [None]:
i_c_noisy_ge2e = i_c_collection_ge2e[is_noisy_collection_ge2e == 1]
i_c_not_noisy_ge2e = i_c_collection_ge2e[is_noisy_collection_ge2e == 0]

# Plot PR curve and ROC curve

In [None]:
import matplotlib
import matplotlib.pyplot as plt

plt.figure("P-R Curve")
plt.title('Precision/Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
precision, recall, thresholds = precision_recall_curve(ylabels, ypreds)
plt.plot(recall,precision)
plt.show()

In [None]:
# plot roc curve and compute auc
fpr, tpr, thresholds = roc_curve(ylabels, ypreds)
auc = scipy.integrate.trapz(tpr, fpr)
plt.figure("ROC Curve")
plt.title('ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(fpr, tpr)
plt.show()
print("AUC: ", auc)