# Minimal Risk Decomposition Code

This notebook contains a minimal pipeline for computing the risk decomposition from [...]. The focus is on simplicity and understandibility.

**Make sure that you use a GPU** (on COLAB: runtime -> change runtime type -> Hardware accelerator: GPU)

## Environment

In [1]:
!pip install torch torchvision tqdm pytorch-lightning pandas sklearn git+https://github.com/openai/CLIP.git --quiet

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.8 -m pip install --upgrade pip[0m


## Pretrain
First we will download the desired pretrained model. The following command returns the compressor as well as the transform that should be applied to the images before compression. 

In [14]:
import torch

# loads the desired pretrained model and preprocessing pipeline
name = "dino_rn50" # example
model, preprocessor = torch.hub.load('YannDubs/SSL-Risk-Decomposition:main', name, trust_repo=True)

# loads all results and hyperparameters
results_df = torch.hub.load('YannDubs/SSL-Risk-Decomposition:main', "results_df")

Using cache found in /root/.cache/torch/hub/YannDubs_SSL-Risk-Decomposition_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /root/.cache/torch/hub/YannDubs_SSL-Risk-Decomposition_main


In [15]:
from torchvision.datasets import Food101
DATA_DIR = "data/"

# Load some data to compute the SSL risk decomposition. 
# This should be the data on which the model was pretrained (ie ImageNet) but requires downloading => let's use Food as an example
train = Food101(DATA_DIR, download=True, split="train", transform=preprocessor)
test = Food101(DATA_DIR, download=True, split="test", transform=preprocessor)

## Computing the SSL risk decomposition 


In [16]:
def compute_risk_components(model_ssl, data_train, data_test, model_sup=None, n_sub=10000, **kwargs):
    """Computes the SSL risk decomposition for `model_ssl` using a given training and testing set.
    
    If we are given a supervised `model_sup` of the same architecture as model_ssl, we compute the 
    approximation error. Else we merge it with usability error given that approx error is neglectable.
    """
    errors = dict()
    
    # featurize data to make probing much faster. Optional.
    D_train = featurize_data(model_ssl, data_train)
    D_test = featurize_data(model_ssl, data_test)
    
    D_comp, D_sub = data_split(D_train, n=n_sub)
    
    r_A_F = train_eval_probe(D_train, D_train, **kwargs)
    r_A_S = train_eval_probe(D_comp, D_sub, **kwargs)
    r_U_S = train_eval_probe(D_train, D_test, **kwargs)
    
    if model_sup is not None:
        D_train_sup = featurize_data(model_sup, data_train)
        errors["approx"] = train_eval_probe(D_train_sup, D_train_sup, **kwargs)
        errors["usability"] = r_A_F - errors["approx"]
    else:
        errors["usability"] = r_A_F # merges both errors but approx is neglectable
        
    errors["probe_gen"] = r_A_S - r_A_F
    errors["encoder_gen"] = r_U_S - r_A_S 
    errors["agg_risk"] = r_U_S
    return errors

The above function is the general risk decomposition that is agnostic to the specific implementation of the the linear probing and data.
Below we give a specific implementation using Pytorch. Those functions should can easily be modified for different choices (eg to use sklearn or tune the probe)

In [20]:
import tqdm
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import os

def featurize_data(model, dataset):
    """Featurize a dataset using the model."""
    model = model.eval().cuda().half()
    with torch.no_grad():
        Z, Y = [], []
        for x, y in tqdm.tqdm(DataLoader(dataset, batch_size=512, num_workers=8)):
            Z += [model(x.to("cuda").half()).cpu().numpy()]
            Y += [y.cpu().numpy()]
    return SklearnDataset(np.concatenate(Z), np.concatenate(Y))


def train_eval_probe(D_train, D_test, max_epochs=100, batch_size=4096, n_workers=os.cpu_count(), lr=1e-3):
    """Trains a model (encoder and probe) on D_train and evaluates it on D_test"""
    probe = LogisticRegression(in_dim=len(D_train[0][0]), out_dim=len(train.classes), 
                               max_epochs=max_epochs, lr=batch_size/256*lr)
    loader_train = DataLoader(D_train, batch_size=batch_size, shuffle=True, 
                              num_workers=n_workers, pin_memory=True)
    trainer = pl.Trainer(logger=False, accelerator="auto", enable_checkpointing=False, 
                         max_epochs=max_epochs, precision=16)
    trainer.fit(probe, train_dataloaders=loader_train)
    
    loader_test = DataLoader(D_test, batch_size=batch_size*2, shuffle=False, 
                             num_workers=n_workers, pin_memory=True)
    logs = trainer.test(dataloaders=loader_test, ckpt_path=None, model=probe)[0]
    return logs["err"]

def data_split(D, n, seed=123):
    """Split a dataset into a set of size n and its complement"""
    complement_idcs, subset_idcs = train_test_split(range(len(D)), stratify=D.Y, test_size=n, random_state=seed)
    return Subset(D, indices=complement_idcs), Subset(D, indices=subset_idcs)

In [21]:
class SklearnDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X, self.Y = X.astype(np.float32), y.astype(np.int64)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


class LogisticRegression(pl.LightningModule):

    def __init__(self, in_dim, out_dim, max_epochs=100, lr=1e-3):
        super().__init__()
        self.probe = torch.nn.Linear(in_dim, out_dim)
        self.max_epochs = max_epochs
        self.lr = lr

    def forward(self, x):
        return self.probe(x)

    def step(self, batch):
        x, y = batch
        Y_hat = self(x)
        acc = (Y_hat.argmax(dim=-1) == y).sum() / y.shape[0] 
        self.log(f"err", 1-acc, prog_bar=True)
        return F.cross_entropy(Y_hat, y.squeeze().long())

    def training_step(self, batch, batch_idx):
        return self.step(batch)
    
    def test_step(self, batch, batch_idx):
        return self.step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.probe.parameters(), lr=self.lr, weight_decay=1e-7)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.max_epochs)
        return [optimizer], [scheduler]

In [22]:
errors = compute_risk_components(model, train, test, max_epochs=100)

100%|██████████| 148/148 [02:26<00:00,  1.01it/s]
100%|██████████| 50/50 [00:54<00:00,  1.08s/it]
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | probe | Linear | 206 K 
---------------------------------
206 K     Trainable params
0         Non-trainable params
206 K     Total params
0.414     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           err              0.02899010293185711
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | probe | Linear | 206 K 
---------------------------------
206 K     Trainable params
0         Non-trainable params
206 K     Total params
0.414     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           err              0.2696000039577484
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | probe | Linear | 206 K 
---------------------------------
206 K     Trainable params
0         Non-trainable params
206 K     Total params
0.414     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           err              0.22605940699577332
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [29]:
pd.Series(errors) * 100

usability       2.899010
probe_gen      24.060990
encoder_gen    -4.354060
agg_risk       22.605941
dtype: float64

Note that the estimate of encoder generalization here is not meaningfull because the model was not pretrained on the dataset we are using.
Despite this issue we see that the risk components are surprisingly similar given that we used a different dataset for computing and did not hyperparameter tune the probe. Results for ImageNet and tuned:

In [30]:
results_df.loc[name, "risk_decomposition"]

agg_risk     25.828001
approx        0.845089
enc_gen          3.336
probe_gen    21.420243
usability     0.226668
Name: dino_rn50, dtype: object