In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from itertools import permutations

# P: number of permutation
# T: number of frames
# C: number of speakers (classes)
# B: mini-batch size

In [None]:
def pit_loss(pred, label):
    """
    Permutation-invariant training (PIT) cross entropy loss function.

    Args:
      pred:  (T,C)-shaped pre-activation values
      label: (T,C)-shaped labels in {0,1}

    Returns:
      min_loss: (1,)-shape mean cross entropy
      label_perms[min_index]: (T,C)-shaped permutated labels 
      sigma: (P,)-shaped permutation tuple
    """
    
    T = len(label)
    C = label.shape[-1]
    
    # 所有的可能的 permutation 构成的 list
    label_perms_indices = [list(p) for p in permutations(range(C))]
    P = len(label_perms_indices)
    perm_mat = torch.zeros(P, T, C, C)

    for i, p in enumerate(label_perms_indices):
        perm_mat[i, :, torch.arange(label.shape[-1]), p] = 1

    # 获得不同 permutation 下对应的 label
    x = torch.unsqueeze(torch.unsqueeze(label, 0), -1)  # (1, T, C, 1)
    y = torch.arange(P * T * C).view(P, T, C, 1)        # (P, T, C, 1)

    broadcast_label = torch.broadcast_tensors(x, y)[0]  # (P, T, C, 1)
    allperm_label = torch.matmul(
            perm_mat, broadcast_label
            ).squeeze(-1)                               # (P, T, C)

    # 对 pred 进行 P 次复制
    x = torch.unsqueeze(pred, 0)                        # (1, T, C)
    y = torch.arange(P * T).view(P, T, 1)               # (P, T, 1)
    broadcast_pred = torch.broadcast_tensors(x, y)[0]   # (P, T, C)

    # 计算不同 permutation 下的二元交叉熵损失
    # broadcast_pred: (P, T, C)
    # allperm_label: (P, T, C)
    losses = F.binary_cross_entropy_with_logits(
               broadcast_pred,
               allperm_label,
               reduction='none')
    mean_losses = torch.mean(torch.mean(losses, dim=1), dim=1)
    min_loss = torch.min(mean_losses) * len(label)
    min_index = torch.argmin(mean_losses)
    
    # sigma - 最优的 permutation
    sigma = list(permutations(range(label.shape[-1])))[min_index]

    return min_loss, allperm_label[min_index], sigma


In [None]:
T = 20
C = 2
pred = np.random.randn(T, C)
label = np.random.randn(T, C)
pred = torch.from_numpy(pred).to(torch.float32)
label = torch.from_numpy(label).to(torch.float32)

print("pred.shape       = {}".format(pred.shape))
print("label.shape      = {}".format(label.shape))

min_loss, perm_label, sigma = pit_loss(pred, label)
print("min_loss         = {:.2f}".format(min_loss))
print("perm_label.shape = {}".format(perm_label.shape))
print("sigma            = {}".format(sigma))

In [None]:
def batch_pit_loss(ys, ts, ilens=None):
    """
    PIT loss over mini-batch.

    Args:
      ys: B-length list of predictions
      ts: B-length list of labels

    Returns:
      loss: (1,)-shape mean cross entropy over mini-batch
      labels: B-length list of permuted labels
      sigmas: B-length list of permutation
    """
    if ilens is None:
        ilens = [t.shape[0] for t in ts]

    loss_w_labels_w_sigmas = [pit_loss(y[:ilen, :], t[:ilen, :])
                              for (y, t, ilen) in zip(ys, ts, ilens)]
    losses, labels, sigmas = zip(*loss_w_labels_w_sigmas)
    loss = torch.sum(torch.stack(losses))
    n_frames = np.sum([ilen for ilen in ilens])
    loss = loss / n_frames
    return loss, labels, sigmas

In [None]:
B = 64
T = 20
C = 2
pred = np.random.randn(B, T, C)
label = np.random.randn(B, T, C)
pred = torch.from_numpy(pred).to(torch.float32)
label = torch.from_numpy(label).to(torch.float32)

print("pred.shape       = {}".format(pred.shape))
print("label.shape      = {}".format(label.shape))

min_loss, labels, sigma = batch_pit_loss(pred, label)
print("min_loss         = {:.2f}".format(min_loss))
print("labels           = {} - {} x {}".format(type(labels), len(labels), labels[0].shape))
print("sigma            = {} - {} x {}".format(type(sigma), len(sigma), len(sigma[0])))