In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import audioread

In [2]:
speakers_per_batch, utterances_per_speaker = 5, 10
embeds = torch.rand(size=(speakers_per_batch, utterances_per_speaker, 20))

In [21]:
similarity_weight = nn.Parameter(torch.tensor([10.]))
similarity_bias = nn.Parameter(torch.tensor([-5.]))

speakers_per_batch, utterances_per_speaker = embeds.shape[:2]

# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
# create inc centroid for each speaker in batch by finding mean of utterances AT EACH EMBEDDING IDX
# then norm
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)

# Exclusive centroids (1 per utterance)
# (mean of utterance at each embedding idx) - (each utterance)
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
centroids_excl /= (utterances_per_speaker - 1)
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)

# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
# We vectorize the computation for efficiency.
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                            speakers_per_batch)
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=int)
for j in range(speakers_per_batch):
    # each row in mask_matrix represents 1 speaker in batch
    mask = np.where(mask_matrix[j])[0] # indexes of 1s in mask_matrix
    # compute cosine sim via dot product
    sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) # otherwise
    sim_matrix[mask, :, j] *= similarity_weight
    sim_matrix[j, :, j] = (embeds[j] * centroids_incl[j] + similarity_bias).sum(dim=1) # i = k
    sim_matrix[j, :, j] *= similarity_weight

In [38]:
ce = torch.nn.CrossEntropyLoss()
x1 = torch.tensor([[-0.1, -0.2, -0.3, -0.4], [-0.1, -0.2, -0.3, -0.4]]).float()
x2 = torch.tensor([0,1])
ce(x1,x2)

tensor(1.2925)

In [None]:
torch.log(torch.exp())

In [22]:
sim_matrix[0]

tensor([[-981.6509,   17.1950,   17.3330,   17.0617,   17.2566],
        [-978.8950,   21.1419,   20.7072,   21.3241,   19.3361],
        [-979.3747,   20.3317,   20.5422,   20.6956,   20.0121],
        [-977.5345,   20.8301,   21.5688,   21.6655,   22.4786],
        [-975.2383,   23.2504,   24.1748,   24.3249,   23.6548],
        [-979.3206,   19.6952,   18.3881,   20.2115,   19.3078],
        [-976.4357,   23.6502,   21.1381,   22.7274,   22.1186],
        [-972.7823,   25.9305,   25.3570,   26.2298,   25.9515],
        [-974.3808,   25.1884,   25.2018,   24.9270,   25.1980],
        [-978.8405,   19.0207,   20.3922,   20.9197,   19.4741]],
       grad_fn=<SelectBackward0>)

In [18]:
sim_matrix[0]

tensor([[1.8349, 1.7195, 1.7333, 1.7062, 1.7257],
        [2.1105, 2.1142, 2.0707, 2.1324, 1.9336],
        [2.0625, 2.0332, 2.0542, 2.0696, 2.0012],
        [2.2465, 2.0830, 2.1569, 2.1665, 2.2479],
        [2.4762, 2.3250, 2.4175, 2.4325, 2.3655],
        [2.0679, 1.9695, 1.8388, 2.0211, 1.9308],
        [2.3564, 2.3650, 2.1138, 2.2727, 2.2119],
        [2.7218, 2.5930, 2.5357, 2.6230, 2.5952],
        [2.5619, 2.5188, 2.5202, 2.4927, 2.5198],
        [2.1160, 1.9021, 2.0392, 2.0920, 1.9474]])

In [9]:
sim_matrix[0]

tensor([[1.7839, 1.7195, 1.7333, 1.7062, 1.7257],
        [2.0339, 2.1142, 2.0707, 2.1324, 1.9336],
        [2.0011, 2.0332, 2.0542, 2.0696, 2.0012],
        [2.1588, 2.0830, 2.1569, 2.1665, 2.2479],
        [2.4046, 2.3250, 2.4175, 2.4325, 2.3655],
        [2.0107, 1.9695, 1.8388, 2.0211, 1.9308],
        [2.2771, 2.3650, 2.1138, 2.2727, 2.2119],
        [2.6675, 2.5930, 2.5357, 2.6230, 2.5952],
        [2.4705, 2.5188, 2.5202, 2.4927, 2.5198],
        [2.0493, 1.9021, 2.0392, 2.0920, 1.9474]])

In [8]:
sim_matrix[[1, 2, 3, 4], :, 0]

tensor([[2.7666, 2.4374, 1.6308, 2.0274, 1.9173, 2.1072, 1.9042, 1.7557, 2.0510,
         2.2377],
        [2.5159, 2.4360, 2.3751, 2.4724, 1.6673, 1.7300, 2.0701, 2.3316, 2.0284,
         2.5851],
        [2.2612, 2.1291, 3.0185, 2.3079, 2.6390, 2.1932, 1.8777, 2.1220, 2.2076,
         1.8930],
        [2.1561, 1.8996, 2.0605, 2.7901, 2.6472, 2.0333, 2.0961, 2.5265, 2.5153,
         1.9770]])

# GE2E Sim Mat

* why is joon's sim mat different to corj's?
    * embeds is size (speakers/batch, utts/speaker, emb_size) (5, 10, 20) going in
    * corj sim mat out (5, 10, 5) 
    * joon sim mat out (5, 5, 5)

* how can I adjust joon's to match corj's?

* use this to check their ge2e losses

* use this to check their proto angular loss

In [181]:
def similarity_matrix(embeds):
    """
    Computes the similarity matrix according the section 2.1 of GE2E.

    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
    utterances_per_speaker, embedding_size)
    :return: the similarity matrix as a tensor of shape (speakers_per_batch,
    utterances_per_speaker, speakers_per_batch)
    """
    similarity_weight = nn.Parameter(torch.tensor([10.]))
    similarity_bias = nn.Parameter(torch.tensor([-5.]))

    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    
    # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
    # create inc centroid for each speaker in batch by finding mean of utterances AT EACH EMBEDDING IDX
    # then norm
    centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
    centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)

    # Exclusive centroids (1 per utterance)
    # (mean of utterance at each embedding idx) - (each utterance)
    centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
    centroids_excl /= (utterances_per_speaker - 1)
    centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)

    # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
    # product of these vectors (which is just an element-wise multiplication reduced by a sum).
    # We vectorize the computation for efficiency.
    sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                                speakers_per_batch)
    mask_matrix = 1 - np.eye(speakers_per_batch, dtype=int)
    for j in range(speakers_per_batch):
        # each row in mask_matrix represents 1 speaker in batch
        mask = np.where(mask_matrix[j])[0] # indexes of 1s in mask_matrix
        # compute cosine sim via dot product
        sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
        sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)

        #sim_matrix[:, :, j] = (embeds * centroids_incl[j]).sum(dim=2)

       
    sim_matrix = sim_matrix * similarity_weight + similarity_bias
    return sim_matrix # (5, 1, 20) * (5, 10, 20) -> (5, 10, 5)

In [175]:
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
centroids_incl.shape

torch.Size([5, 1, 20])

In [176]:
embeds.shape

torch.Size([5, 10, 20])

In [154]:
similarity_matrix(embeds)[0] # corj sim mat

torch.Size([5, 1, 20])

In [59]:
similarity_matrix(embeds)[0] # corj without excl

tensor([[16.6742, 16.2385, 16.0487, 16.3170, 16.7079],
        [21.3579, 20.6582, 21.1883, 20.3702, 20.4534],
        [22.1896, 22.0102, 21.6861, 20.5429, 22.0982],
        [18.3206, 17.0490, 16.0587, 16.7957, 16.6575],
        [20.8070, 19.3250, 19.8346, 20.5987, 19.8000],
        [18.2206, 17.1984, 17.2436, 17.2208, 16.6799],
        [24.2771, 23.3234, 23.0234, 23.6963, 24.1507],
        [21.3176, 19.9951, 21.1612, 20.0701, 20.4435],
        [22.0853, 21.8682, 21.5481, 21.7785, 22.6327],
        [16.3744, 16.5684, 16.7081, 17.0269, 16.4539]],
       grad_fn=<SelectBackward0>)

# GE2E Loss

In [210]:
def cor_ge2e_loss(embeds):
    """
    Computes the softmax loss according the section 2.1 of GE2E.
    
    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
    utterances_per_speaker, embedding_size)
    :return: the loss and the EER for this batch of embeddings.
    """
    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    ce_loss = nn.CrossEntropyLoss()

    # Loss
    sim_matrix = similarity_matrix(embeds)
    sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,  
                                        speakers_per_batch))
    temp = sim_matrix
    ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
    target = torch.from_numpy(ground_truth).long()
    loss = ce_loss(sim_matrix, target)
    
    # EER (not backpropagated)
    with torch.no_grad():
        inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=int)[0]
        labels = np.array([inv_argmax(i) for i in ground_truth])
        preds = sim_matrix.detach().cpu().numpy()

        # Snippet from https://yangcha.github.io/EER-ROC/
        fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        
    return loss, eer

In [13]:
def joon_ge2e_loss(x):
    
    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    w = nn.Parameter(torch.tensor(10.0))
    b = nn.Parameter(torch.tensor(-5.0))
    criterion  = nn.CrossEntropyLoss()

    assert x.size()[1] >= 2

    gsize = x.size()[1] # utterances per speaker
    centroids = torch.mean(x, 1) # inc centroids without norm
    stepsize = x.size()[0] # speakers per batch

    cos_sim_matrix = []

    for ii in range(0,gsize): 
        idx = [*range(0,gsize)]
        idx.remove(ii)
        exc_centroids = torch.mean(x[:,idx,:], 1)
        cos_sim_diag    = F.cosine_similarity(x[:,ii,:],exc_centroids)
        cos_sim         = F.cosine_similarity(x[:,ii,:].unsqueeze(-1),centroids.unsqueeze(-1).transpose(0,2))
        cos_sim[range(0,stepsize),range(0,stepsize)] = cos_sim_diag
        cos_sim_matrix.append(torch.clamp(cos_sim,1e-6))
    temp = cos_sim_matrix
    cos_sim_matrix = torch.stack(cos_sim_matrix,dim=1)
    
    torch.clamp(w, 1e-6)
    cos_sim_matrix = torch.tensor(cos_sim_matrix)
    cos_sim_matrix = cos_sim_matrix * w + b
    

    ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
    target = torch.from_numpy(ground_truth).long()

    label   = torch.from_numpy(np.asarray(range(0,stepsize)))
    nloss = criterion(cos_sim_matrix.view(-1,stepsize), target)

    # with torch.no_grad():
    #     inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=int)[0]
    #     labels = np.array([inv_argmax(i) for i in ground_truth])
    #     preds = sim_matrix.detach().cpu().numpy()

    #     # Snippet from https://yangcha.github.io/EER-ROC/
    #     fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
    #     eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

    return temp

# Angular Proto Loss

In [None]:
# embed (5, 10, 20)
# cent (5, 1, 20)

In [199]:
def ap_similarity_matrix(embeds):

    w = nn.Parameter(torch.tensor(10.0))
    b = nn.Parameter(torch.tensor(-5.0))
    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]

    centroids = torch.mean(embeds[:,1:,:], dim=1, keepdim=True) # eq 6
    centroids /= (utterances_per_speaker - 1) # eq 6
    centroids = centroids.clone()/(torch.norm(centroids, dim=2, keepdim=True) + 1e-5) # normalise vector

    query = embeds[:,0,:].unsqueeze(1) # should already be normalised from forward pass
    #query = query.clone()/(torch.norm(query, dim=2, keepdim=True) + 1e-5) # normalise vector

    # compute sed btwn every query and every speaker emb centroid
    sim_matrix = torch.zeros(speakers_per_batch, 1, speakers_per_batch)
    for i in range(speakers_per_batch):
        sim_matrix[:, :, i] = (query * centroids[i]).sum(dim=2)
    
    sim_matrix = w*sim_matrix + b
    return sim_matrix

In [206]:
def angular_proto_loss(embeds):

    assert embeds.size()[1] >= 2

    criterion = nn.CrossEntropyLoss()

    # # (5,1,20)
    # out_anchor      = torch.mean(x[:,1:,:], dim=1, keepdim=True) # mean each utterance for each embed idx except first utt
    # out_positive    = x[:,0,:] # first embed
    # stepsize        = out_anchor.size()[0]

    #############
    # >> JOON <<
    # cos_sim_matrix  = F.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
    # torch.clamp(w, 1e-6)
    # cos_sim_matrix = cos_sim_matrix * w + b
    # label   = torch.from_numpy(np.asarray(range(0,stepsize))).long()
    #############

    ##############
    # >> CORJ <<
    ap_sm = ap_similarity_matrix(embeds)
    ap_sm = ap_sm.reshape((speakers_per_batch, speakers_per_batch))
    ground_truth = np.arange(speakers_per_batch)
    target = torch.from_numpy(ground_truth).long()
    ##############

    loss = criterion(ap_sm, target)

    # EER
    with torch.no_grad():
        inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=int)[0]
        labels = np.array([inv_argmax(i) for i in ground_truth])
        preds = ap_sm.detach().cpu().numpy()

        # Snippet from https://yangcha.github.io/EER-ROC/
        fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

    return loss, eer

In [207]:
angular_proto_loss(embeds)

(tensor(1.6395, grad_fn=<NllLossBackward0>), 0.4500000000011706)

In [211]:
cor_ge2e_loss(embeds)

(tensor(1.7346, grad_fn=<NllLossBackward0>), 0.49999999999997546)

In [182]:
similarity_matrix(embeds).shape

torch.Size([5, 10, 5])

In [None]:
# (5, 1, 20) * (5, 10, 20) -> (5, 10, 5)
# (5, 1, 20) * (5,  1, 20) -> (5,  1, 5)

In [158]:
angular_proto_loss(embeds) # joon sim mat

tensor([[3.6840, 3.6257, 3.5540, 3.6594, 3.9523],
        [3.3781, 3.5920, 3.2014, 3.6058, 3.6591],
        [2.9513, 3.8061, 3.3136, 3.4410, 3.1781],
        [3.6365, 3.7020, 3.6398, 3.7833, 4.0437],
        [3.4237, 4.0259, 3.9089, 3.9259, 3.9273]], grad_fn=<AddBackward0>)

In [73]:
angular_proto_loss(embeds) # corj sim mat

tensor(1.2143, grad_fn=<NllLossBackward0>)

# Proto Loss

In [None]:
# embed (5, 10, 20)
# cent (5, 1, 20)

In [232]:
def proto_loss(embeds, label=None):
    """Computes angular variant of prototypical loss.

    Args:
        embeds : 
        label (_type_, optional): _description_. Defaults to None.
    """
    assert embeds.size()[1] >= 2
    criterion = nn.CrossEntropyLoss()
    pwd = nn.PairwiseDistance(p=2)

    out_anchor = torch.mean(embeds[:,1:,:], dim=1,) # eq 6 - find centroids
    out_positive = embeds[:,0,:] # query - 0th utt embed array from each speaker
    step = out_anchor.size()[0]

    output  = -1 * (pwd(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))**2) # eq 7 
    label   = torch.from_numpy(np.asarray(range(0,step))).long()
    nloss   = criterion(output, label)

    return nloss

In [230]:
def proto_sed(embeds):
    # squared euclidian distance
    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
    pwd = nn.PairwiseDistance(p=2)

    centroids = torch.mean(embeds[:,1:,:], dim=1, keepdim=True) # dont need to norm for sed
    query = embeds[:,0,:].unsqueeze(1)

    # compute sed btwn every query and every speaker emb centroid
    out = torch.zeros(speakers_per_batch, 1, utterances_per_speaker)
    for i in range(speakers_per_batch):
        out[:, :, i] = -1* (pwd(centroids[i], query)**2)

    return out

In [231]:
proto_sed(embeds)

tensor([[[-1.6533, -1.6840, -1.5919, -1.6411, -1.2058,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[-1.9505, -1.6500, -1.6480, -1.6320, -1.3953,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[-2.4848, -1.4254, -1.7047, -1.8469, -1.9867,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[-2.1348, -2.0365, -2.2512, -1.9185, -1.5980,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[-2.2121, -1.3681, -1.5904, -1.5039, -1.4993,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]]])

In [233]:
proto_loss(embeds)

tensor(3.1325)