In [79]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d

In [6]:
speakers_per_batch, utterances_per_speaker = 5, 10
embeds = torch.rand(size=(speakers_per_batch, utterances_per_speaker, 10))

In [40]:
def similarity_matrix(embeds):
    """
    Computes the similarity matrix according the section 2.1 of GE2E.

    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
    utterances_per_speaker, embedding_size)
    :return: the similarity matrix as a tensor of shape (speakers_per_batch,
    utterances_per_speaker, speakers_per_batch)
    """
    similarity_weight = nn.Parameter(torch.tensor([10.]))
    similarity_bias = nn.Parameter(torch.tensor([-5.]))

    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    
    # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
    centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
    centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)

    # Exclusive centroids (1 per utterance)
    centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
    centroids_excl /= (utterances_per_speaker - 1)
    centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)

    # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
    # product of these vectors (which is just an element-wise multiplication reduced by a sum).
    # We vectorize the computation for efficiency.
    sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                                speakers_per_batch)
    mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
    for j in range(speakers_per_batch):
        # each row in mask_matrix represents 1 speaker in batch
        mask = np.where(mask_matrix[j])[0] # indexes of 1s in mask_matrix
        # compute cosine sim via dot product
        sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
        sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
       
    sim_matrix = sim_matrix * similarity_weight + similarity_bias
    return sim_matrix

In [50]:
sim_matrix = similarity_matrix(embeds)
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 
                                    speakers_per_batch))

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)


In [52]:
sim_matrix.shape

torch.Size([50, 5])

In [60]:
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)

In [63]:
target = torch.from_numpy(ground_truth).long()
target.shape

torch.Size([50])

In [64]:
sim_matrix.shape

torch.Size([50, 5])

In [69]:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(sim_matrix, target)
loss

tensor(1.8022, grad_fn=<NllLossBackward0>)

In [70]:
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]

In [73]:
labels = np.array([inv_argmax(i) for i in ground_truth])
labels.shape

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]


(50, 5)

In [74]:
preds = sim_matrix.detach().cpu().numpy()

In [76]:
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())

In [80]:
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

In [81]:
eer

0.52