In [16]:
import numpy
import torch
from pytorch_metric_learning.losses import BaseMetricLossFunction

In [17]:
class deep_clustering_loss(BaseMetricLossFunction):
    def compute_loss(self, embedding, tgt_index, binary_mask=None):

        spk_cnt = len(tgt_index.unique())

        batch, bins, frames = tgt_index.shape
        if binary_mask is None:
            binary_mask = torch.ones(batch, bins * frames, 1)
        binary_mask = binary_mask.float()
        if len(binary_mask.shape) == 3:
            binary_mask = binary_mask.view(batch, bins * frames, 1)
        # If boolean mask, make it float.
        binary_mask = binary_mask.to(tgt_index.device)

        # Fill in one-hot vector for each TF bin
        tgt_embedding = torch.zeros(batch, bins * frames, spk_cnt, device=tgt_index.device)
        tgt_embedding.scatter_(2, tgt_index.view(batch, bins * frames, 1), 1)

        # Compute VAD-weighted DC loss
        tgt_embedding = tgt_embedding * binary_mask
        embedding = embedding * binary_mask
        est_proj = torch.einsum("ijk,ijl->ikl", embedding, embedding)
        true_proj = torch.einsum("ijk,ijl->ikl", tgt_embedding, tgt_embedding)
        true_est_proj = torch.einsum("ijk,ijl->ikl", embedding, tgt_embedding)
        # Equation (1) in [1]
        cost = batch_matrix_norm(est_proj) + batch_matrix_norm(true_proj)
        cost = cost - 2 * batch_matrix_norm(true_est_proj)
        # Divide by number of active bins, for each element in batch
        return cost / torch.sum(binary_mask, dim=[1, 2])


def batch_matrix_norm(matrix, norm_order=2):
    keep_batch = list(range(1, matrix.ndim))
    return torch.norm(matrix, p=norm_order, dim=keep_batch) ** norm_order

In [18]:
loss= deep_clustering_loss()