# Reptrix Tutorial 

In this tutorial, we show how to use reptrix to assess representations of a model pretrained using self-supervised learning. 

Here we use the [STL-10](https://cs.stanford.edu/~acoates/stl10/) dataset, which contains 10 classes of images. 

We will use the [SimCLR](https://arxiv.org/abs/2002.05709), [Barlow Twins](https://arxiv.org/abs/2103.03230), and [BYOL](https://arxiv.org/abs/2006.07733) models as examples.

To assess the quality of the learned representations, we will use various metrics, including:

- **Alpha**: This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the alpha exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative.

- **RankMe**: This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.

- **Lidar**: 

We will compute these metrics using the Reptrix library, which provides a convenient interface for representation analysis. Let's dive into the code and explore the evaluation process in detail.



## Import everything we need

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

Function to get the features of from the pretrained encoder model using a dataset. 

In [28]:
from tqdm import tqdm
def get_features(encoder_function, dataloader):
    # Loop over the dataset and collect the representations
    all_features = []

    # Loop over the dataset and collect the representations
    for i, data in tqdm(enumerate(dataloader, 0)):
        inputs, labels = data
        with torch.no_grad():
            features = encoder_function(inputs)
            all_features.append(features)
            
    # Concatenate all the features
    all_features = torch.cat(all_features, dim=0)
    return all_features

## Get the STL-10 dataset and the pretrained models

In [9]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2242, 0.2215, 0.2239))
])

# Get the STL10 test dataset to measure the quality of the representations learned by the model
testset = torchvision.datasets.STL10(root='./data', split='test', download=False, transform=torchvision.transforms.ToTensor())

# Define a dataloader to load the test dataset
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)

### 1. Encoder trained with Barlow Twins

In this section, we will evaluate the representations learned by an encoder trained with the Barlow Twins method.

In [10]:
# Define a resnet encoder using pytorch where you can load the weights from a pre-trained model
# We will use encoder that is trained using Barlow Twins
encoder = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
# Remove the final fully connected layer so that the model outputs the 2048 feature vector
encoder = torch.nn.Sequential(*(list(encoder.children())[:-1]))

Using cache found in /home/mila/a/arnab.mondal/.cache/torch/hub/facebookresearch_barlowtwins_main


In [12]:
# Set the model to evaluation mode
encoder.eval()
    
all_representations = get_features(encoder, testloader)

In [14]:
from reptrix import alpha, rankme

metric_alpha = alpha.get_alpha(all_representations)[0]
metric_rankme = rankme.get_rankme(all_representations)
print(f'Values of different metrics: Alpha: {metric_alpha}, Rankme: {metric_rankme}')



### 2. Vit Encoder trained with DINO

In [24]:
# Define a vit encoder using pytorch where you can load the weights from a pre-trained model
# We will use encoder that is trained using Dino
encoder = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')




Using cache found in /home/mila/a/arnab.mondal/.cache/torch/hub/facebookresearch_dino_main


In [33]:
# Set the model to evaluation mode
encoder.eval()
    
all_representations = get_features(encoder.forward, testloader)

print('Shape of the representations:', all_representations.shape) 

In [36]:
metric_alpha = alpha.get_alpha(all_representations)[0]
metric_rankme = rankme.get_rankme(all_representations)
print(f'Values of different metrics: Alpha: {metric_alpha}, Rankme: {metric_rankme}')

Values of different metrics: Alpha: 0.8250624571053066, Rankme: 152.54171752929688
