In [13]:
import torch
import torch.distributed as dist
import torch.nn.functional as F


def masked_mse_loss(
    pred: torch.Tensor,
    target: torch.Tensor,
    mask: torch.Tensor,
    normalize_targets: bool = False,
):
    """MSE loss on masked patches

    Args:
        pred: B x num_patches x D tensor of predict patches
        target: B x num_patches x D tensor of target patch values
        mask: B x num_patches binary mask with masked patches marked with 1

    Return:
        loss: Masked mean square error loss
    """
    # Normalize target pixel values
    if normalize_targets:
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.0e-6) ** 0.5

    # Calculate MSE loss
    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # Per patch loss
    loss = (loss * mask).sum() / mask.sum()  # Mean of masked patches

    return loss


"""
Modified from: 
https://github.com/vturrisi/solo-learn/blob/main/solo/losses/simclr.py
https://github.com/vturrisi/solo-learn/blob/main/solo/utils/misc.py
"""


def info_nce_loss(z: torch.Tensor, temperature: float = 0.1) -> torch.Tensor:
    """Computes SimCLR's loss given batch of projected features z
    from different views, a positive boolean mask of all positives and
    a negative boolean mask of all negatives.

    Args:
        z (torch.Tensor): (2*B) x D tensor containing features from the views.

    Return:
        torch.Tensor: SimCLR loss.
    """

    z = F.normalize(z, dim=-1)
    gathered_z = gather(z)
    
    print(gathered_z.shape)

    sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature)
    

    indexes = torch.arange(z.size(0) // 2, device=z.device).repeat(2)
    gathered_indexes = gather(indexes)
    print(gathered_indexes)

    indexes = indexes.unsqueeze(0)
    gathered_indexes = gathered_indexes.unsqueeze(0)

    # positives
    pos_mask = indexes.t() == gathered_indexes
    pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0)
    
    print(pos_mask.int())

    # negatives
    neg_mask = indexes.t() != gathered_indexes
    
    print(neg_mask.int())

    pos = torch.sum(sim * pos_mask, 1)
    neg = torch.sum(sim * neg_mask, 1)
    loss = -(torch.mean(torch.log(pos / (pos + neg))))
    return loss


def get_rank():
    if dist.is_available() and dist.is_initialized():
        return dist.get_rank()
    return 0


class GatherLayer(torch.autograd.Function):
    """
    Gathers tensors from all process and supports backward propagation
    for the gradients across processes.
    """

    @staticmethod
    def forward(ctx, x):
        if dist.is_available() and dist.is_initialized():
            output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
            dist.all_gather(output, x)
        else:
            output = [x]
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        if dist.is_available() and dist.is_initialized():
            all_gradients = torch.stack(grads)
            dist.all_reduce(all_gradients)
            grad_out = all_gradients[get_rank()]
        else:
            grad_out = grads[0]
        return grad_out


def gather(X, dim=0):
    """Gathers tensors from all processes, supporting backward propagation."""
    return torch.cat(GatherLayer.apply(X), dim=dim)

In [14]:
embeddings = torch.randn(4, 128)
info_nce_loss(embeddings)

torch.Size([4, 128])
tensor([0, 1, 0, 1])
tensor([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 0, 0]], dtype=torch.int32)
tensor([[0, 1, 0, 1],
        [1, 0, 1, 0],
        [0, 1, 0, 1],
        [1, 0, 1, 0]], dtype=torch.int32)


tensor(1.0946)

In [18]:
def info_nce_loss(features):

        labels = torch.cat([torch.arange(4) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape
        
        print(similarity_matrix.shape)

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape
        
        print(similarity_matrix.shape)

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        
        print(positives)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
        
        print(negatives)

        logits = torch.cat([positives, negatives], dim=1)
        
        print(logits)
        
        labels = torch.zeros(logits.shape[0], dtype=torch.long)
        
        print(labels)

        logits = logits / 0.1
        
        return torch.nn.functional.cross_entropy(logits, labels)

In [20]:
info_nce_loss(torch.randn(8, 128))

torch.Size([8, 8])


IndexError: The shape of the mask [16, 16] at index 0 does not match the shape of the indexed tensor [8, 8] at index 0