In [63]:
import PIL
import os
import logging
import pickle as pk
from collections import defaultdict

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from patchify import patchify,unpatchify

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

PIL.Image.MAX_IMAGE_PIXELS = 933120000

import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.transforms import InterpolationMode
from skimage.measure import block_reduce
from torch.utils.data import Dataset,DataLoader

from torchvision.models.resnet import resnet50, ResNet50_Weights

### Computing Contrastive Loss

Really helped: https://gombru.github.io/2018/05/23/cross_entropy_loss/

The loss NT-XENT loss is a categorical-crossentropy like loss whereby:

$$
\ell(i,j) = -\log\frac{\exp(sim(\vec{z}_i, \vec{z}_j)/\tau)}{\sum_{k = 1, k \neq i}^{2N} \exp(sim(\vec{z}_i, \vec{z}_k)/\tau)}
$$
where:
$$
sim(\vec{u}, \vec{v}) = \vec{u}^T\vec{v}/\|\vec{u}\|\|\vec{v}\|
$$
Then, the total contrastive loss will be:
$$
\mathcal{L} = \frac{1}{2N}\sum_{k = 1}^N  [\ell(2k-1, 2k) + \ell(2k, 2k-1)]
$$

Here, $N$ represents the mini-batch size. In standard contrastive loss, for each mini-batch sample $\vec{x}_k, k \in [1,N]$, we sample 2 random transformations $\mathcal{T}_1, \mathcal{T}_2$, to derive augmented samples $\vec{z}_{2k-1}, \vec{z}_{2k}$.

We store the augmented samples in a matrix:

$$
M = \begin{pmatrix}
\vec{z}_1^T \\
\vec{z}_2^T \\
\vdots \\
\vec{z}_{2k-1}^T \\
\vec{z}_{2k}^T
\\
\vdots
\\
\vec{z}_{2N-1}^T \\
\vec{z}_{2N}^T
\end{pmatrix}
$$

We can then normalise the samples to have unit length (define $\vec{a}_k = \vec{z}_k/\|\vec{z}_k\|$):

$$
\bar{M} = \begin{pmatrix}
\vec{a}_1^T \\
\vec{a}_2^T \\
\vdots \\
\vec{a}_{2k-1}^T \\
\vec{a}_{2k}^T
\\
\vdots
\\
\vec{a}_{2N-1}^T \\
\vec{a}_{2N}^T
\end{pmatrix}
$$

Then, we can easily compute our similarity scores:
$$
M^* = \bar{M}\bar{M}^T 
=
\begin{pmatrix}
\vec{a}_1^T\vec{a}_1 & \vec{a}_1^T\vec{a}_2 & \ldots & \vec{a}_1^T\vec{a}_{2k} & \ldots & \vec{a}_1^T\vec{a}_{2N} \\
\vec{a}_2^T\vec{a}_1 & \vec{a}_2^T\vec{a}_2 & \ldots & \vec{a}_2^T\vec{a}_{2k} & \ldots & \vec{a}_2^T\vec{a}_{2N} \\
\vdots & \vdots & \ddots & \vdots & \ddots & \vdots \\
\vec{a}_{2N}^T\vec{a}_1 & \vec{a}_{2N}^T\vec{a}_2 & \ldots & \vec{a}_{2N}^T\vec{a}_{2k} & \ldots & \vec{a}_{2N}^T\vec{a}_{2N} \\
\end{pmatrix} \in \mathbb{R}^{2N \times 2N}
$$
such that:
$$
sim(\vec{z}_i, \vec{z}_j) = (M^*)_{ij}
$$

Now, we recall crossentropy loss for categories. Say we have a vector:
$$
\vec{x}_n = \begin{pmatrix}
x_n^1 \\
x_n^2 \\
\vdots \\
x_n^{D}
\end{pmatrix}
$$
and a total of $C$ categories (indexed $[0,C)$).

We can define the cross-entropy loss of $\vec{x}_n$ (assuming it is expected to belong to class $p \in [0,C)$) by:
$$
\ell_n = -\log \frac{\exp(x_n^p)}{\sum_{c = 1}^C \exp(x_n^c)}
$$
and for **all** the training samples (say we have $N$ many) we will have:
$$
\mathcal{L} = \sum_{n = 1}^N \ell_n
$$

The above is the PyTorch formulation, whereby we need to provide:
- a matrix of samples $M$
- a vector of labels $L$ (such that if $\vec{x}_n$ is in class $p$, then $L[n] = p$)
Hence, to compute NT-XENT efficinetly, we can use PyTorch's `CrossEntropyLoss` directly, after slightly tweking our matrix $M^*$.

Let do a final transformation to $M^*$, by replacing its diagonal by -1000:

$$
M' = 
\begin{pmatrix}
-1000 & \vec{a}_1^T\vec{a}_2 & \ldots & \vec{a}_1^T\vec{a}_{2k} & \ldots & \vec{a}_1^T\vec{a}_{2N} \\
\vec{a}_2^T\vec{a}_1 & -1000 & \ldots & \vec{a}_2^T\vec{a}_{2k} & \ldots & \vec{a}_2^T\vec{a}_{2N} \\
\vdots & \vdots & \ddots & \vdots & \ddots & \vdots \\
\vec{a}_{2N}^T\vec{a}_1 & \vec{a}_{2N}^T\vec{a}_2 & \ldots & \vec{a}_{2N}^T\vec{a}_{2k} & \ldots & -1000 \\
\end{pmatrix} \in \mathbb{R}^{2N \times 2N}
$$

$\ell(i,j)$ then corresponds to feeding $(M')_{i,-}$ to `CrossEntropyLoss` alongside the label $L[i] = j$. By replacing the diagonal with -1000, we ensure that $exp(\vec{a}_i^T\vec{a}_i) \approx 0$, so it won't count towards the denominator of the cross-entropy function.

In [4]:
a = torch.randint(0,10,(3,3)).float()
print(a)
#b = a / nn.functional.normalize(a, axis = 1)
#a.T @ a
b = nn.functional.normalize(a, dim = 1)
print(b)
c = b @ b.T
print(c)

for i in range(3):
    print(torch.norm(b[i,:]))

tensor([[8., 7., 5.],
        [3., 0., 3.],
        [7., 7., 9.]])
tensor([[0.6810, 0.5959, 0.4256],
        [0.7071, 0.0000, 0.7071],
        [0.5232, 0.5232, 0.6727]])
tensor([[1.0000, 0.7825, 0.9544],
        [0.7825, 1.0000, 0.8456],
        [0.9544, 0.8456, 1.0000]])
tensor(1.)
tensor(1.)
tensor(1.)


In [55]:
a = torch.randint(0,10,(4,4))
b = -1000*torch.ones(4,1)
torch.cat((a,b), dim = 1)

tensor([[    4.,     0.,     9.,     2., -1000.],
        [    9.,     3.,     3.,     7., -1000.],
        [    4.,     2.,     8.,     2., -1000.],
        [    3.,     8.,     1.,     4., -1000.]])

In [71]:
def contrastive_loss_2(z_batch, tau):
    N = len(z_batch)
    
    # normalise to have unit length rows
    norm_z_batch = F.normalize(z_batch)
    
    # compute similarity & apply factor of tau
    sim_batch = (norm_z_batch @ norm_z_batch.T)/tau
    
    # remove diagonal
    sim_batch = sim_batch.flatten()[1:].view(N-1, N+1)[:,:-1].reshape(N, N-1)
    sim_batch = torch.cat((sim_batch, -1000*torch.ones(N,1)), dim = 1)
    print(sim_batch)
    
    # generate labels
    # here we assume that z_batch[2k-1] and z_batch[2k] are a positive pair of samples
    labels = torch.Tensor([k if k%2 == 0 else k-1 for k in range(0,N)]).long()
    print(labels)
    
    # return the NT-XENT loss
    return 1/N * F.cross_entropy(sim_batch, labels, reduction = "sum")

In [84]:
def contrastive_loss(z_batch, tau):
    N = len(z_batch)
    
    # normalise to have unit length rows
    norm_z_batch = F.normalize(z_batch)
    
    # compute similarity & apply factor of tau
    sim_batch = (norm_z_batch @ norm_z_batch.T)/tau
    
    # fill the diagonal with -1000, to make sure it is never considered in the cross entropy computations
    sim_batch.fill_diagonal_(-1000)
    
    # generate labels
    # here we assume that z_batch[2k-1] and z_batch[2k] are a positive pair of samples
    labels = torch.Tensor([k+1 if k%2 == 0 else k-1 for k in range(0,N)]).long()
    
    # return the NT-XENT loss
    return 1/N * F.cross_entropy(sim_batch, labels, reduction = "sum")

In [59]:
rand_batch = torch.randint(-10,10,(6,5)).float()
print(rand_batch)
contrastive_loss(rand_batch, tau = 1)

tensor([[  6., -10.,   9.,   0.,   6.],
        [ -4.,  -9.,   4.,   3.,   3.],
        [  3.,  -1.,  -2., -10.,  -4.],
        [ -6.,  -8.,  -7.,   3.,   1.],
        [  2.,   0.,   5.,  -4.,   0.],
        [  3.,  -6.,  -5.,  -7.,  -4.]])
tensor([[ 6.5915e-01, -7.7196e-02, -6.4816e-02,  5.3421e-01,  4.8698e-02,
         -1.0000e+03],
        [ 6.5915e-01, -4.0613e-01,  5.5431e-01,  0.0000e+00, -8.2716e-02,
         -1.0000e+03],
        [-7.7196e-02, -4.0613e-01, -2.0867e-01,  4.7068e-01,  8.3789e-01,
         -1.0000e+03],
        [-6.4816e-02,  5.5431e-01, -2.0867e-01, -6.9750e-01,  2.7302e-01,
         -1.0000e+03],
        [ 5.3421e-01,  0.0000e+00,  4.7068e-01, -6.9750e-01,  1.1547e-01,
         -1.0000e+03],
        [ 4.8698e-02, -8.2716e-02,  8.3789e-01,  2.7302e-01,  1.1547e-01,
         -1.0000e+03]])
tensor([1, 0, 3, 2, 5, 4])


tensor(168.3250)

In [68]:
contrastive_loss(rand_batch, tau = 1)

tensor([[-1.0000e+03,  6.5915e-01, -7.7196e-02, -6.4816e-02,  5.3421e-01,
          4.8698e-02],
        [ 6.5915e-01, -1.0000e+03, -4.0613e-01,  5.5431e-01,  0.0000e+00,
         -8.2716e-02],
        [-7.7196e-02, -4.0613e-01, -1.0000e+03, -2.0867e-01,  4.7068e-01,
          8.3789e-01],
        [-6.4816e-02,  5.5431e-01, -2.0867e-01, -1.0000e+03, -6.9750e-01,
          2.7302e-01],
        [ 5.3421e-01,  0.0000e+00,  4.7068e-01, -6.9750e-01, -1.0000e+03,
          1.1547e-01],
        [ 4.8698e-02, -8.2716e-02,  8.3789e-01,  2.7302e-01,  1.1547e-01,
         -1.0000e+03]])
tensor([1, 0, 3, 2, 5, 4])


tensor(1.6296)

In [72]:
contrastive_loss_2(rand_batch, tau = 1)

tensor([[ 6.5915e-01, -7.7196e-02, -6.4816e-02,  5.3421e-01,  4.8698e-02,
         -1.0000e+03],
        [ 6.5915e-01, -4.0613e-01,  5.5431e-01,  0.0000e+00, -8.2716e-02,
         -1.0000e+03],
        [-7.7196e-02, -4.0613e-01, -2.0867e-01,  4.7068e-01,  8.3789e-01,
         -1.0000e+03],
        [-6.4816e-02,  5.5431e-01, -2.0867e-01, -6.9750e-01,  2.7302e-01,
         -1.0000e+03],
        [ 5.3421e-01,  0.0000e+00,  4.7068e-01, -6.9750e-01,  1.1547e-01,
         -1.0000e+03],
        [ 4.8698e-02, -8.2716e-02,  8.3789e-01,  2.7302e-01,  1.1547e-01,
         -1.0000e+03]])
tensor([0, 0, 2, 2, 4, 4])


tensor(1.6296)

In [86]:
test_vec = torch.Tensor([
[1.1, 1.1, 1.1],
[1.1, 1.1, 1.1],
[-0.3, 0.2, -0.6],
[-0.3, 0.2, -0.6],
[22, 23, 24],
[22, 23, 24]
])

test_vec_2 = torch.Tensor([
[1.1, 1.1, 1.1],
[-0.3, 0.2, -0.6],
[22, 23, 24],
[1.1, 1.1, 1.1],
[-0.3, 0.2, -0.6],
[22, 23, 24]
])

test_vec_3 = torch.Tensor([
[1.1, 1.1, 1.1],
[-1.1, -1.1, -1.1],
[-0.3, 0.2, -0.6],
[0.3, -0.2, 0.6],
[-22, -23, -24],
[22, 23, 24]
])

print(contrastive_loss(test_vec, tau = 1))
print(contrastive_loss(test_vec_2, tau = 1))
print(contrastive_loss(test_vec_3, tau = 1))

tensor(1.0177)
tensor(2.0729)
tensor(2.7130)


In [78]:
projs = rand_batch
b = projs.shape[0]//2
n = b * 2

logits = projs @ projs.t()

mask = torch.eye(n).bool()
logits = logits[~mask].reshape(n, n - 1)

labels = torch.cat(((torch.arange(b) + b - 1), torch.arange(b)), dim=0)

loss = nn.functional.cross_entropy(logits, labels, reduction='sum')
print(loss/n)

tensor(67.1667)


In [79]:
print(labels)

tensor([2, 3, 4, 0, 1, 2])


In [64]:
class MapCLNN(nn.Module):
    def __init__(self, positive_samples):
        super(MapCLNN, self).__init__()
        
        self.MAX_PIXEL_VALUE = 255
        self.RESNET_DIM = 224
        
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        
        # dictionary containing patch alignments
        # key is a tuple, representing the patch coordinates of the patch
        # value is a list, containing MapPatch instances
        # MapPatch instances with the same key can be thought of as positive samples for the algorithm
        self.positive_samples = positive_samples
    
    def img_to_resnet(self, img, dim = 224):
        """
        Convert image into the desired format for ResNet.
        The image must have width and height of at least self.RESNET_DIM, with RGB values between 0 and 1.
        Moreover, it must be normalised, by using a mean of [0.485, 0.456, 0.406] and a standard deviation of [0.229, 0.224, 0.225]
        --------------------------------------------------------------------------------------------------------------------------------
        :param img: a numpy nd.array, with 3 colour channels (this must be stored in the last dimensions), which has to be fed to ResNet
        :param dim: the desired dimension of the image (if we want to resize img before feeding it to ResNet).
                    This should be at least self.RESTNET_DIM.
        --------------------------------------------------------------------------------------------------------------------------------
        :return a Tensor, with the first dimension corresponding to the RGB channels, and normalised to be used by ResNet.
        """
        
        assert dim >= self.RESNET_DIM, f"Provided dimension {dim} is less than the required for RESNET ({self.RESNET_DIM})"
        
        # put the colour channel in front
        norm_img = np.moveaxis(img, -1, 0)
        
        # normalise into range [0,1]
        norm_img = torch.from_numpy(norm_img)/255
        
        # resize
        if img.shape[0] < self.RESNET_DIM or img.shape[1] < self.RESNET_DIM:
            norm_img = T.Resize(self.RESNET_DIM)(norm_img)
        else:
            if dim is not None:
                norm_img = T.Resize(dim)(norm_img)     
        
        # normalise mean and variance
        mean = torch.Tensor([0.485, 0.456, 0.406])
        std = torch.Tensor([0.229, 0.224, 0.225])
        
        return T.Normalize(mean = mean, std = std)(norm_img)
    
    def sim(self, z_i, z_j):
        """
        Computes cosine similarity between 2 vectors z_i, z_j.
        """
        return torch.dot(z_i,z_j)/(torch.linalg.norm(z_i) * torch.linalg.norm(z_j))
    
    def exp_sim(self, z_i, z_j, tau):
        """
        Computes a temperature-scaled, exponential similarity between 2 vectors z_i, z_j
        """
        return torch.exp(self.sim(z_i, z_j)/tau)
    
    def sample_contrastive_loss(self, z_i, z_j, z_batch, tau):
        """
        Computes the contrastive loss for a pair of positive samples z_i, z_j i.
        -----------------------------------------------------------------------------
        :param z_i,z_j: 1D Tensors; these are positive samples whose loss we compute.
        :param z_batch: a 2D Tensor, giving the batch over which we compute the loss. 
                        Contains the negative samples for z_i,z_j.
        :param tau: a float, the "temperature"
        """
        similarity = self.exp_sim(z_i, z_j, tau)
        all_dissimilarities = [self.exp_sim(z_i, z_k, tau) for z_k in z_batch if not torch.equal(z_i,z_k)]
        dissimilarity = torch.sum(torch.Tensor(all_dissimilarities))
        
        return -torch.log(similarity/dissimilarity)
    
    def contrastive_loss(self, z_batch, tau):
        N = len(z_batch)
        all_losses = [self.sample_contrastive_loss(z_batch[2*k], z_batch[2*k+1], z_batch, tau) +
                      self.sample_contrastive_loss(z_batch[2*k+1], z_batch[2*k], z_batch, tau)
                      for k in range(1,N//2)]
        
        return 1/N * torch.sum(torch.Tensor(all_losses))
    
    def forward(self,x):
        pass
    
    def compile_model(self):
        pass
    
    def train(self):
        pass

In [66]:
mclnn = MapCLNN("a")
mclnn.contrastive_loss(logits, tau = 1)

tensor(1.0475)