In [1]:
import json
from pathlib import Path
from typing import List

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_is_available
from torch.nn.functional import normalize, one_hot, softmax
from tqdm.auto import tqdm

from nld.constant.defaults import *
from nld.nld.beta_mixture import fit_bmm
from nld.constant.config import DataConfig, TrainConfig
from nld.model.loss import AAMSoftmax, GE2ELoss, SubcenterArcMarginProduct
from nld.process_data.dataset import (VOX2_CLASS_NUM, SpeakerLabelDataset,
                                    SpeakerUtteranceDataset)
from nld.process_data.mislabel import find_mislabeled_json
from nld.utils import clean_memory

In [2]:
model_dir: Path = Path('data/training-models/Permute-75-CE-bs128-seed0')
selected_iteration: str = 'final'
vox1_mel_spectrogram_dir: Path = DEFAULT_VOX1_MEL_SPECTROGRAM_DIR
vox2_mel_spectrogram_dir: Path = DEFAULT_VOX2_MEL_SPECTROGRAM_DIR
mislabeled_json_dir: Path = DEFAULT_VOXCELEB_MISLABELED_JSON_DIR

In [3]:
device = torch.device('cuda' if cuda_is_available() else 'cpu')
train_config = TrainConfig.from_json(model_dir / 'config.json')
data_processing_config = DataConfig.from_json(
    vox2_mel_spectrogram_dir / 'data-processing-config.json'
)
mislabeled_json_file = find_mislabeled_json(
    mislabeled_json_dir, train_config.noise_type, train_config.noise_level
)

model = train_config.forge_model(
    data_processing_config.nmels, VOX2_CLASS_NUM,
).to(device).eval()
model.load_state_dict(torch.load(
    model_dir / f'model-{selected_iteration}.pth', map_location=device,
))
with open(vox2_mel_spectrogram_dir / 'speaker-label-to-id.json', 'r') as f:
    utterance_classes_num = len(json.load(f))
criterion = train_config.forge_criterion(utterance_classes_num).to(device).eval()
criterion.load_state_dict(torch.load(
    model_dir / f'loss-{selected_iteration}.pth', map_location=device,
))

label_dataset = SpeakerLabelDataset(
    vox1_mel_spectrogram_dir, vox2_mel_spectrogram_dir, mislabeled_json_file
)

In [4]:
clean_memory()
inconsistencies = []
noise = []

In [5]:
with torch.no_grad():
    if train_config.loss == 'GE2E':
        assert isinstance(criterion, GE2ELoss)
        w = criterion.w
        b = criterion.b

        norm_centroids: List[Tensor] = []
        for i in range(len(label_dataset)):
            mel, _, label = label_dataset[i]
            assert i == label
            mel: Tensor = mel.to(device)
            centroid = normalize(model(mel).mean(dim=0), dim=0)
            norm_centroids.append(centroid)
        norm_centroids: Tensor = torch.stack(norm_centroids)

        utterance_dataset = SpeakerUtteranceDataset(
            vox1_mel_spectrogram_dir, vox2_mel_spectrogram_dir, mislabeled_json_file,
        )
        for i in range(len(utterance_dataset)):
            # TODO use dataloader!
            mel, is_noisy, selected_id, _, _ = utterance_dataset[i]
            mel: Tensor = mel.to(device)
            norm_embedding: Tensor = normalize(model(mel), dim=0)
            all_similarities = w * (norm_centroids * norm_embedding).sum(dim=-1) + b
            y = one_hot(torch.tensor(selected_id), VOX2_CLASS_NUM)
            inconsistencies.append(torch.max(y * all_similarities).item())
            noise.append(is_noisy)

In [6]:
with torch.no_grad():
    if train_config.loss != 'GE2E':
        for i in tqdm(range(len(label_dataset))):
            mel, is_noisy, label = label_dataset[i]
            mel: Tensor = mel.to(device)
            y = one_hot(torch.tensor(label), VOX2_CLASS_NUM).to(device)
            model_output: Tensor = model(mel)
            noise.extend(is_noisy)
            if train_config.loss in ('AAM', 'AAMSC'):
                assert isinstance(criterion, (AAMSoftmax, SubcenterArcMarginProduct))
                model_output = criterion.directly_predict(model_output)
            model_output = softmax(model_output, dim=-1)
            inconsistencies.extend(
                ((1 - y) * model_output[j, ...]).max().item() for j in range(model_output.size(0))
            )
            if i == 20:
                break

  0%|          | 0/5994 [00:00<?, ?it/s]

In [7]:
sorted(zip(inconsistencies, noise), key=lambda t: t[0], reverse=True)

[(0.11985166370868683, True),
 (0.04880097135901451, False),
 (0.0340949185192585, True),
 (0.032978955656290054, True),
 (0.02815568633377552, True),
 (0.027473842725157738, True),
 (0.02450169436633587, False),
 (0.024380328133702278, True),
 (0.02141707018017769, True),
 (0.019786382094025612, True),
 (0.019293108955025673, False),
 (0.018080227077007294, True),
 (0.01711065135896206, True),
 (0.014976497739553452, True),
 (0.014643417671322823, True),
 (0.014039408415555954, False),
 (0.013740978203713894, True),
 (0.013395494781434536, False),
 (0.01334099005907774, True),
 (0.012772751972079277, True),
 (0.012347716838121414, True),
 (0.012092556804418564, False),
 (0.012067468836903572, False),
 (0.011944062076508999, True),
 (0.011915608309209347, True),
 (0.011692020110785961, True),
 (0.011670933105051517, True),
 (0.011637278832495213, False),
 (0.011515110731124878, True),
 (0.01140650175511837, True),
 (0.011370711959898472, True),
 (0.011288524605333805, True),
 (0.011278

In [8]:
bmm_model, bmm_model_max, bmm_model_min = fit_bmm(np.array(inconsistencies), max_iters=50)

  return _boost._beta_pdf(x, a, b)


In [12]:
bmm_model.weight

array([0.99623935, 0.00376065])