# Minimal Risk Decomposition Code

This notebook contains a minimal pipeline for computing the risk decomposition from [...]. The focus is on simplicity and understandibility.

**Make sure that you use a GPU** (on COLAB: runtime -> change runtime type -> Hardware accelerator: GPU)

## Environment

In [1]:
!pip install torch torchvision tqdm pytorch-lightning pandas sklearn git+https://github.com/openai/CLIP.git --quiet

## Standard: model and data
First we will download the desired pretrained model. The following command returns the compressor as well as the transform that should be applied to the images before compression. 

In [5]:
import torch

# loads the desired pretrained model and preprocessing pipeline
name = "clip_rn50" # example
model, preprocessor = torch.hub.load('YannDubs/SSL-Risk-Decomposition:main', name, trust_repo=True)

# loads all results and hyperparameters
results_df = torch.hub.load('YannDubs/SSL-Risk-Decomposition:main', "results_df")
z_dim = results_df.hyperparameters.z_dim[name.lower()]

Using cache found in /Users/yanndubois/.cache/torch/hub/YannDubs_SSL-Risk-Decomposition_main
Using cache found in /Users/yanndubois/.cache/torch/hub/YannDubs_SSL-Risk-Decomposition_main


In [6]:
from torchvision.datasets import Food101
DATA_DIR = "data/"

# Load some data to compress and apply transformation
train = Food101(DATA_DIR, download=True, split="train", transform=preprocessor)
test = Food101(DATA_DIR, download=True, split="test", transform=preprocessor)

## Computing the SSL risk decomposition 


In [7]:
def compute_risk_components(model_ssl, D_train, D_test, model_sup=None):
    """Computes the SSL risk decomposition for `model_ssl` using a given training and testing set.
    
    If we are given a supervised `model_sup` of the same architecture as model_ssl, we compute the 
    approximation error. Else we merge it with usability error given that approx error is neglectable.
    """
    errors = dict()
        
    D_comp, D_sub = data_split(D_train, n=len(D_test))
    
    r_A_F = train_eval_model(model_ssl, D_train, D_train)
    r_A_S = train_eval_model(model_ssl, D_comp, D_sub)
    r_U_S = train_eval_model(model_ssl, D_train, D_test)
    
    if model_sup is not None:
        errors["approx"] = eval_model(model_sup, D_train)
        errors["usability"] = r_A_F - out["approx"]
    else:
        errors["usability"] = r_A_F # merges both errors but approx is neglectable
        
    out["probe_gen"] = r_A_S - r_A_F
    out["encoder_gen"] = r_U_S - r_A_S 
    return out

This is a specific implementation for the evaluation, training, and data. Those functions should can easily be modified for different choices (eg to use sklearn, tune the probe, preprocess the data...)

In [8]:
from torch.utils.data import DataLoader, Subset
import os
import pdb
from sklearn.model_selection import train_test_split

def eval_model(model, D_test):
    """Trains a model on D_train and evaluates it on D_test"""
    loader_test = DataLoader(D_test, batch_size=1024, shuffle=False, 
                             num_workers=os.cpu_count(), pin_memory=True)
    
    trainer = pl.Trainer(accelerator="auto", logger=False, enable_checkpointing=False)
    logs = trainer.test(dataloaders=loader_test, ckpt_path=None, model=model)
    pdb.set_trace()

    return logs["acc"]

def train_eval_model(encoder, D_train, D_test, Probe=torch.nn.Linear, z_dim=z_dim):
    """Trains a model (encoder and probe) on D_train and evaluates it on D_test"""
    model = Model(encoder, Probe(z_dim, len(train.classes)))
    
    loader_train = DataLoader(D_train, batch_size=512, shuffle=True, 
                              num_workers=os.cpu_count(), pin_memory=True)
    
    trainer = pl.Trainer(logger=False, accelerator="auto", enable_checkpointing=False, 
                         max_epochs=model.max_epochs)
    trainer.fit(model, train_dataloaders=loader_train)
    
    logs = trainer.test(dataloaders=loader_test, ckpt_path=None, model=model)
    
    pdb.set_trace()
    pass

def data_split(D, n, seed=123):
    """Split a dataset into a set of size n and its complement"""
    
    complement_idcs, subset_idcs = train_test_split(
        range(len(D)), stratify=D._labels, test_size=n, random_state=seed
    )
    
    return Subset(D, indices=complement_idcs), Subset(D, indices=subset_idcs)

In [9]:
import pytorch_lightning as pl

class Model(pl.LightningModule):
    """Encoder and Predictor."""

    def __init__(self, encoder, probe, max_epochs=100):
        super().__init__()
        self.encoder = encoder.eval()
        self.probe = probe #nn.Linear(z_dim, n_classes)
        self.max_epochs = max_epochs

    def forward(self, x):
        with torch.no_grad():
            z = self.encoder(x)
        return self.probe(z)

    def step(self, batch, name):
        x, y = batch
        Y_hat = self(x)
        acc = accuracy(Y_hat.argmax(dim=-1), y)
        self.log(f"{name}/acc", acc)
        return F.cross_entropy(Y_hat, y.squeeze().long())

    def training_step(self, batch, batch_idx):
        loss = self.step(batch, "train")
        return loss
    
    def test_step(self, batch, batch_idx):
        loss = self.step(batch, "test")
        return loss

    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.probe.parameters(), 
                                     lr=1e-3, weight_decay=1e-6)
            
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.max_epochs)

        return [optimizer], [scheduler]



In [None]:
compute_risk_components(model, train, test)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type           | Params
-------------------------------------------
0 | encoder | ModifiedResNet | 40.4 M
1 | probe   | Linear         | 206 K 
-------------------------------------------
40.6 M    Trainable params
0         Non-trainable params
40.6 M    Total params
162.488   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]