In [1]:
import torch
import torch.nn.functional as F
from torchmetrics.functional import accuracy
from icecream import ic

In [7]:
output = torch.rand((18, 5))
output

tensor([[0.4140, 0.0044, 0.4308, 0.9380, 0.4694],
        [0.9582, 0.6054, 0.6241, 0.4901, 0.3987],
        [0.4205, 0.8947, 0.3452, 0.4983, 0.3301],
        [0.8823, 0.7114, 0.9835, 0.7917, 0.4293],
        [0.3558, 0.2678, 0.7717, 0.9408, 0.4591],
        [0.0773, 0.1100, 0.2650, 0.9001, 0.0201],
        [0.0532, 0.6323, 0.2431, 0.2888, 0.1640],
        [0.6956, 0.2009, 0.6033, 0.3993, 0.5911],
        [0.8188, 0.8890, 0.7809, 0.6861, 0.2592],
        [0.5465, 0.8402, 0.6783, 0.3311, 0.7982],
        [0.7412, 0.5911, 0.1430, 0.7581, 0.0299],
        [0.2052, 0.9731, 0.4294, 0.3211, 0.8431],
        [0.5247, 0.7191, 0.6029, 0.2741, 0.2892],
        [0.3945, 0.1238, 0.6583, 0.7615, 0.4105],
        [0.3897, 0.0051, 0.0924, 0.7370, 0.7422],
        [0.3087, 0.0970, 0.7310, 0.1690, 0.3536],
        [0.9116, 0.3506, 0.8710, 0.2031, 0.7993],
        [0.5961, 0.9943, 0.8310, 0.8022, 0.2497]])

In [15]:
trg = torch.randint(0, 5, (1,18))
trg

tensor([[2, 3, 4, 0, 4, 1, 4, 4, 1, 4, 2, 4, 4, 3, 0, 1, 3, 4]])

In [16]:
predictions = output.argmax(1)
predictions

tensor([3, 0, 1, 2, 3, 3, 1, 0, 1, 1, 3, 1, 1, 3, 4, 2, 0, 1])

In [19]:
torch.sum(predictions == trg)

tensor(2)

In [None]:
def _get_loss_acc(seq, seq_len, mask, trg, reduction):
    # at the end extract only seq_len elements
    new_mask = mask[:seq_len]
    new_seq = seq[:seq_len]
    new_trg = trg[:seq_len]
    # to do masked_select need to add new dimension to mask
    # then need to reshape the result of masked_select, since it flattens the tensor

    new_seq = new_seq.masked_select(new_mask.unsqueeze(1)).view(-1, seq.shape[1])
    new_trg = new_trg.masked_select(new_mask)
    loss = F.cross_entropy(new_seq, new_trg, reduction=self.reduction)
    # multiply acc by the number of elements in the sequence
    acc = accuracy(new_seq, new_trg) * new_seq.shape[0]
    return loss, acc, new_seq.shape[0]


def _compute_loss_acc(similarity_scores, frames_cnt, targets, batch_mask_indices, batch_masks, reduction):
    # targets.shape = [n_clusterings, batch, n_frames]
    # n_frames.shape = [batch, ]
    # batch_mask_indices is a list, batch_mask_indices[i] = torch tensor with indices where span mask was applied

    # cross_entropy recap
    #   The input is expected to contain raw, unnormalized scores for each class.
    #   input has to be a Tensor of size either (minibatch, C).
    total_mask_loss, total_unmask_loss = 0, 0
    total_mask_acc, total_unmask_acc, total_acc = 0, 0, 0

    # iterate over different clustering models
    for (k, scores), k_target in zip(similarity_scores.items(), targets):
        # scores.shape = [batch, n_frames, k]

        # metrics for a specific clustering model k
        clustering_mask_loss, clustering_unmask_loss = 0, 0
        clustering_mask_acc, clustering_unmask_acc, clustering_total_acc = 0, 0, 0
        cnt_mask, cnt_unmask, cnt_total = 0, 0, 0

        # iterate over sequences in the batch
        for seq_score, target, seq_len, mask in zip(scores, k_target, frames_cnt, batch_masks):
            # seq_score.shape = [n_frames, k]
            # target.shape = [n_frames]
            # index_mask.shape = [n_masked_frames] ... differs for each sequence, that is why processing each seq separately
            # seq_len is an int

            # cross entropy loss and acc over frames without mask
            unmask_loss, unmask_acc, unmask_size = _get_loss_acc(seq_score, seq_len, mask, target, reduction)
            clustering_unmask_loss += unmask_loss
            clustering_unmask_acc += unmask_acc
            cnt_unmask += unmask_size

            # cross entropy loss and acc over frames with mask
            mask_loss, mask_acc, mask_size = _get_loss_acc(seq_score, seq_len, ~mask, target, reduction)
            clustering_mask_loss += mask_loss
            clustering_mask_acc += mask_acc
            cnt_mask += mask_size

            # total accuracy
            clustering_total_acc += accuracy(seq_score[:seq_len], target[:seq_len]) * seq_len
            cnt_total += seq_len

        # average across batch
        total_mask_loss += clustering_mask_loss / scores.shape[0]
        total_unmask_loss += clustering_unmask_loss / scores.shape[0]

        total_mask_acc += clustering_mask_acc / cnt_mask
        total_unmask_acc += clustering_unmask_acc / cnt_unmask
        total_acc += clustering_total_acc / cnt_total
    # total_{mask,unmask}_{loss,acc} are sums of losses for different clustering models
    return total_mask_loss, total_unmask_loss, total_mask_acc / len(similarity_scores), total_unmask_acc / len(similarity_scores), total_acc / len(
        similarity_scores)

In [5]:
batch = torch.arange(2 * 20 * 3).view(2,20,3)
frames_cnt = torch.tensor([20, 14])
ic(batch)
ic(frames_cnt)


ic| batch: tensor([[[  0,   1,   2],
                    [  3,   4,   5],
                    [  6,   7,   8],
                    [  9,  10,  11],
                    [ 12,  13,  14],
                    [ 15,  16,  17],
                    [ 18,  19,  20],
                    [ 21,  22,  23],
                    [ 24,  25,  26],
                    [ 27,  28,  29],
                    [ 30,  31,  32],
                    [ 33,  34,  35],
                    [ 36,  37,  38],
                    [ 39,  40,  41],
                    [ 42,  43,  44],
                    [ 45,  46,  47],
                    [ 48,  49,  50],
                    [ 51,  52,  53],
                    [ 54,  55,  56],
                    [ 57,  58,  59]],
           
                   [[ 60,  61,  62],
                    [ 63,  64,  65],
                    [ 66,  67,  68],
                    [ 69,  70,  71],
                    [ 72,  73,  74],
                    [ 75,  76,  77],
                    [ 78,

tensor([20, 14])

In [47]:
batch = torch.tensor([
    [[1, 1, 1],
     [0, 1, 1],
     [0, 0, 1]],
    
    [[1, 1, 1],
     [1, 1, 0],
     [0, 0, 0]],
]) * 1.0

targets = torch.tensor([[
    [1, 1, 1],
    [0, 1, 1],
    [0, 0, 1],
    [1, 1, 0],
    [1/2, 1/2, 1/2]
]]) * 1.0

ic(batch.shape)
ic(targets.shape)


ic| batch.shape: torch.Size([2, 3, 3])
ic| targets.shape: torch.Size([1, 5, 3])


torch.Size([1, 5, 3])

In [48]:
for t in targets:
    ic(t.shape)
    sim = F.cosine_similarity(
        batch[:, :, None, :],
        t[None, None, :, :],
        dim=-1
    )
    ic(sim)
    ic(sim.shape)

ic| t.shape: torch.Size([5, 3])
ic| sim: tensor([[[1.0000, 0.8165, 0.5774, 0.8165, 1.0000],
                  [0.8165, 1.0000, 0.7071, 0.5000, 0.8165],
                  [0.5774, 0.7071, 1.0000, 0.0000, 0.5774]],
         
                 [[1.0000, 0.8165, 0.5774, 0.8165, 1.0000],
                  [0.8165, 0.5000, 0.0000, 1.0000, 0.8165],
                  [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
ic| sim.shape: torch.Size([2, 3, 5])


In [33]:
l = 3
p = 0.2
mask_value = torch.tensor(-1)
masks = []

for i, (features, length) in enumerate(zip(batch, frames_cnt)):
    masked_cnt = int(length * p)
    mask_starts = torch.randperm(length - l)[:masked_cnt]
    mask_starts = torch.sort(mask_starts).values
    index_mask = torch.stack([mask_starts + i for i in range(l)], dim=1).view(-1)
    
    ic(mask_starts)
    ic(index_mask)
    
    mask = torch.zeros(features.shape[0], dtype=torch.bool)
    mask.index_fill_(0, index_mask, True)
    ic(mask)
    ic(mask.shape)
    masked_features = features.masked_fill(mask.unsqueeze(1), mask_value)
    ic(masked_features)
    masks.append(mask)
    
    ic('----------------------------------')
masks = torch.stack(masks)
batch.masked_fill(masks.unsqueeze(-1), mask_value)

ic| mask_starts: tensor([ 3,  5,  9, 15])
ic| index_mask: tensor([ 3,  4,  5,  5,  6,  7,  9, 10, 11, 15, 16, 17])
ic| mask: tensor([False, False, False,  True,  True,  True,  True,  True, False,  True,
                   True,  True, False, False, False,  True,  True,  True, False, False])
ic| mask.shape: torch.Size([20])
ic| masked_features: tensor([[ 0,  1,  2],
                             [ 3,  4,  5],
                             [ 6,  7,  8],
                             [-1, -1, -1],
                             [-1, -1, -1],
                             [-1, -1, -1],
                             [-1, -1, -1],
                             [-1, -1, -1],
                             [24, 25, 26],
                             [-1, -1, -1],
                             [-1, -1, -1],
                             [-1, -1, -1],
                             [36, 37, 38],
                             [39, 40, 41],
                             [42, 43, 44],
                             [

tensor([[[  0,   1,   2],
         [  3,   4,   5],
         [  6,   7,   8],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ 24,  25,  26],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ 36,  37,  38],
         [ 39,  40,  41],
         [ 42,  43,  44],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ 54,  55,  56],
         [ 57,  58,  59]],

        [[ 60,  61,  62],
         [ 63,  64,  65],
         [ 66,  67,  68],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ 78,  79,  80],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ -1,  -1,  -1],
         [ 90,  91,  92],
         [ 93,  94,  95],
         [ 96,  97,  98],
         [ 99, 100, 101],
         [102, 103, 104],
         [105, 106, 107],
         [108, 109, 110],
         [111, 112, 113],
         [