In [28]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d

In [105]:
speakers_per_batch, utterances_per_speaker = 5, 10
embeds = torch.rand(size=(speakers_per_batch, utterances_per_speaker, 20))

# GE2E Sim Mat

In [8]:
def similarity_matrix(embeds):
    """
    Computes the similarity matrix according the section 2.1 of GE2E.

    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
    utterances_per_speaker, embedding_size)
    :return: the similarity matrix as a tensor of shape (speakers_per_batch,
    utterances_per_speaker, speakers_per_batch)
    """
    similarity_weight = nn.Parameter(torch.tensor([10.]))
    similarity_bias = nn.Parameter(torch.tensor([-5.]))

    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    
    # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
    centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
    centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)

    # Exclusive centroids (1 per utterance)
    centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
    centroids_excl /= (utterances_per_speaker - 1)
    centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)

    # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
    # product of these vectors (which is just an element-wise multiplication reduced by a sum).
    # We vectorize the computation for efficiency.
    sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                                speakers_per_batch)
    mask_matrix = 1 - np.eye(speakers_per_batch, dtype=int)
    for j in range(speakers_per_batch):
        # each row in mask_matrix represents 1 speaker in batch
        mask = np.where(mask_matrix[j])[0] # indexes of 1s in mask_matrix
        # compute cosine sim via dot product
        sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
        sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
       
    sim_matrix = sim_matrix * similarity_weight + similarity_bias
    return sim_matrix

# GE2E Loss

In [86]:
def cor_ge2e_loss(embeds):
    """
    Computes the softmax loss according the section 2.1 of GE2E.
    
    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
    utterances_per_speaker, embedding_size)
    :return: the loss and the EER for this batch of embeddings.
    """
    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    ce_loss = nn.CrossEntropyLoss()

    # Loss
    sim_matrix = similarity_matrix(embeds)
    temp = sim_matrix
    sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,  
                                        speakers_per_batch))
    ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
    target = torch.from_numpy(ground_truth).long()
    loss = ce_loss(sim_matrix, target)
    
    # EER (not backpropagated)
    with torch.no_grad():
        inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=int)[0]
        labels = np.array([inv_argmax(i) for i in ground_truth])
        preds = sim_matrix.detach().cpu().numpy()

        # Snippet from https://yangcha.github.io/EER-ROC/
        fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        
    return temp

In [103]:
embeds.size()[1]

torch.Size([5, 10, 10])

In [108]:
def joon_ge2e_loss(x):
    
    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    w = nn.Parameter(torch.tensor(10.0))
    b = nn.Parameter(torch.tensor(-5.0))
    criterion  = nn.CrossEntropyLoss()

    assert x.size()[1] >= 2

    gsize = x.size()[1] # utterances per speaker
    centroids = torch.mean(x, 1) # inc centroids without norm
    stepsize = x.size()[0] # speakers per batch

    cos_sim_matrix = []

    for ii in range(0,gsize): 
        idx = [*range(0,gsize)]
        idx.remove(ii)
        exc_centroids = torch.mean(x[:,idx,:], 1)
        cos_sim_diag    = F.cosine_similarity(x[:,ii,:],exc_centroids)
        cos_sim         = F.cosine_similarity(x[:,ii,:].unsqueeze(-1),centroids.unsqueeze(-1).transpose(0,2))
        cos_sim[range(0,stepsize),range(0,stepsize)] = cos_sim_diag
        cos_sim_matrix.append(torch.clamp(cos_sim,1e-6))
    temp = cos_sim_matrix
    cos_sim_matrix = torch.stack(cos_sim_matrix,dim=1)
    
    torch.clamp(w, 1e-6)
    cos_sim_matrix = torch.tensor(cos_sim_matrix)
    cos_sim_matrix = cos_sim_matrix * w + b
    

    ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
    target = torch.from_numpy(ground_truth).long()

    label   = torch.from_numpy(np.asarray(range(0,stepsize)))
    nloss = criterion(cos_sim_matrix.view(-1,stepsize), target)

    # with torch.no_grad():
    #     inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=int)[0]
    #     labels = np.array([inv_argmax(i) for i in ground_truth])
    #     preds = sim_matrix.detach().cpu().numpy()

    #     # Snippet from https://yangcha.github.io/EER-ROC/
    #     fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
    #     eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

    return label, target

In [118]:
out_anchor = torch.mean(embeds[:,1:,:], 1)
out_pos = embeds[:,0,:]
step = out_anchor.size()[0]
cos_sim_matrix  = F.cosine_similarity(out_pos.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
cos_sim_matrix

tensor([[0.8616, 0.8481, 0.8584, 0.8956, 0.8388],
        [0.8352, 0.8375, 0.8238, 0.8065, 0.8069],
        [0.8888, 0.8414, 0.8727, 0.8883, 0.9196],
        [0.7981, 0.7690, 0.7929, 0.8008, 0.7798],
        [0.8443, 0.8539, 0.8443, 0.8443, 0.8568]])

In [121]:
out_anchor.shape, out_pos.shape

(torch.Size([5, 20]), torch.Size([5, 20]))

In [109]:
joon_ge2e_loss(embeds)

  cos_sim_matrix = torch.tensor(cos_sim_matrix)


(tensor([0, 1, 2, 3, 4], dtype=torch.int32),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
         4, 4]))

In [None]:
# 1.6649

In [39]:
gemat = similarity_matrix(embeds)

# Angular Proto Loss

In [41]:
def angular_proto_loss(embeds, init_w = 10.0, init_b = -5.0, label=None):
    """Computes angular variant of prototypical loss.

    Args:
        embeds : 
        label (_type_, optional): _description_. Defaults to None.
    """
    assert embeds.size()[1] >= 2
    ce_loss = nn.CrossEntropyLoss()
    w = nn.Parameter(torch.tensor(init_w))
    b = nn.Parameter(torch.tensor(init_b))

    out_anchor = torch.mean(embeds[:,1:,:], 1)
    out_pos = embeds[:,0,:]
    step = out_anchor.size()[0]

    sim_matrix = F.cosine_similarity(out_pos.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
    torch.clamp(w, 1e-6)
    sim_matrix = sim_matrix * w + b

    return sim_matrix

In [44]:
sim_matrix = angular_proto_loss(embeds)
sim_matrix

tensor([[3.4136, 4.2625, 3.2234, 4.1107, 4.1484],
        [3.7012, 4.0986, 3.7119, 3.8583, 3.7990],
        [3.0190, 2.9903, 3.0917, 2.7785, 2.8525],
        [3.6497, 3.8180, 3.3977, 3.7257, 3.7587],
        [3.5589, 3.3426, 3.7194, 3.9845, 3.8502]], grad_fn=<AddBackward0>)

In [40]:
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)