## Simple Example

In [90]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

### Prepare the data

In [60]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [61]:
# Load the MNIST dataset

dataset = MNIST('./data', train=True, download=True, transform=ToTensor())
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True)

In [82]:
# inspect the data

image, label = dataset[0]
image.shape  # torch.Size([1, 28, 28])

torch.Size([1, 28, 28])

### Create a simple CNN

In [63]:
import torch.nn as nn
import torch.nn.functional as F

In [208]:
class Conv2dNet(nn.Module):
    """Simple 2D ConvNet that takes in 1x28x28 MNIST images and projects them with into an 8-dim
    latent space."""

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
          # ConvBlock 1
          nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, padding=3//2),
          nn.LayerNorm((2, 28, 28)),  # normaise over each sample
          nn.ReLU(),

          # ConvBlock 2
          nn.Conv2d(2, 4, 7, padding=7//2),
          nn.LayerNorm((4, 28, 28)),
          nn.ReLU(),

          # Linear Projection, output is of shape (batch_size, 8)
          nn.Flatten(),
          nn.Linear(4 * 28 * 28, 8)
        ).to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    

class ContrastiveLoss(nn.Module):
    """A simple contrastive loss that uses the cosine of the angle between two vectors as the
    similarity metric.
    
    A larger angle means a larger loss. So this loss encourages two vectors to have a smaller 
    cosine angle between them."""

    def __init__(self):
        super().__init__()

    def forward(self, x, y) -> torch.Tensor:
        """x, y should both be tensors of shape (batch_size, embedding_dim)"""

        assert (
            x.shape == y.shape
        ), f'x (shape {x.shape}) and y (shape {y.shape}) must have the same shape'

        # The cosine of the angle between two vectors is the same as the dot product between
        # two normalised vectors
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)

        # This matmul will compute the dot product across every element in x, against every element
        # in y. The leading diagonal sim[i][i] will all be 1.
        sim = x @ y.T  # shape (batch_size, batch_size)

        # we can now compute the loss as the mean of the off diagonal elemebts
        batch_size = x.shape[0]
        mask = torch.eye(batch_size, device=x.device).bool()
        sim = sim.masked_fill(mask, 0)
        loss = sim.sum() / (batch_size * (batch_size - 1))
        
        return loss
