In [1]:
!pip install pytorch-lightning==1.7.2 scikit-learn==1.0.2 lightning-bolts=0.5.0

You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.[0m


Basic variables depending on whether using GPU or not.

In [2]:
import torch

data_dir = "data/"
if torch.cuda.is_available():
    device, precision, gpus = "cuda", 16, 1
else:
    device, precision, gpus = "cpu", 32, 0

## CISSL

In [8]:
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

class CISSL(pl.LightningModule):
    """CISSL objective."""

    def __init__(self, proj_dim=128, temperature=0.07, max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        
        # ENCODER
        self.encoder = resnet18()
        self.encoder.fc = nn.Identity() # remove last linear layer
        z_dim=512
        
        # TEACHER PROJECTION HEAD
        # more expressive is better => MLP
        n_hidden = 2048
        bottleneck_size = 512 # adds a bottleneck to avoid linear layer with many parameters
        self.projector = nn.Sequential(
            nn.Linear(z_dim, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden, bottleneck_size),
            nn.BatchNorm1d(bottleneck_size),
            nn.ReLU(inplace=True),
            nn.Linear(bottleneck_size, proj_dim),
            nn.BatchNorm1d(proj_dim) # for contrastive learning batchnorm is typically added
        )
        
        # STUDENT PROJECTION HEAD: 
        # needs to be linear (batchnorm is)
        self.predictor = nn.Sequential(nn.Linear(z_dim, proj_dim),
                                       nn.BatchNorm1d(proj_dim))

    def forward(self, batch):
        x1, x2 = batch
        bs, device = x1.size(0), x1.device
        
        # shape: [2*batch_size, z_dim]
        z = self.encoder(torch.cat([x1, x2], dim=0))
        
        # shape: [2*batch_size, proj_dim]
        # normalize to use cosine similarity
        z_student = F.normalize(self.predictor(z), dim=1, p=2)
        z_teacher = F.normalize(self.projector(z), dim=1, p=2)
        
        # shape: [2*batch_size, 2*batch_size]
        logits = (z_student @ z_teacher.T).float() / self.hparams.temperature
        
        # there are two positives for each example x1: x1 and x2
        # note: SimCLR removes x1-x1 as those are typically equal. 
        # But not for CISSL due to asymmetric proj heads
        # => computes cross entropy between predicted proba of positives and 0.5 for each positive
        predicted_log_q = logits.log_softmax(-1)
        select_positives = torch.eye(bs, device=device).bool().repeat(2, 2) 
        cross_entropy = - predicted_log_q[select_positives].view(bs*2, 2).sum(1) / 2 
        
        return cross_entropy.mean()

    def training_step(self, batch, batch_idx):
        return self(batch) 

    def predict_step(self, batch, batch_idx):
        x, y = batch
        return self.encoder(x).cpu().numpy(), y.cpu().numpy()

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=2e-3, weight_decay=1e-6)
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]

## Data
Downloads and prepare the necessary data.

In [9]:
from torchvision import transforms

imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

pretrain_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=96, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

In [10]:
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
import os

class PretrainSTL10(STL10):
    def __init__(self, data_dir, pretrain_transforms, download=True):
        super().__init__(data_dir, download=download, split="unlabeled", transform=None)
        self.pretrain_transforms = pretrain_transforms
        
    def __getitem__(self, index):
        x,y = super().__getitem__(index)
        x1 = self.pretrain_transforms(x)
        x2 = self.pretrain_transforms(x)
        return x1, x2

data_pretrain = PretrainSTL10(data_dir, pretrain_transforms=pretrain_transforms, download=True)
data_train = STL10(data_dir,  split="train", transform=val_transforms, download=True)
data_test = STL10(data_dir, split="test", transform=val_transforms, download=True)

loader_pretrain = DataLoader(data_pretrain, batch_size=256, shuffle=True,
                               num_workers=os.cpu_count(), pin_memory=True)
loader_train = DataLoader(data_train, batch_size=512, shuffle=False, num_workers=os.cpu_count(),
                          pin_memory=True, drop_last=False)
loader_test = DataLoader(data_train, batch_size=512, shuffle=False, num_workers=os.cpu_count(),
                         pin_memory=True, drop_last=False)

Files already downloaded and verified


## Train

In [11]:
EPOCHS=100
cissl= CISSL(max_epochs=EPOCHS)
trainer = pl.Trainer(gpus=gpus, precision=precision, max_epochs=EPOCHS, logger=False, enable_checkpointing=False)
trainer.fit(cissl, train_dataloaders=loader_pretrain)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | ResNet     | 11.2 M
1 | projector | Sequential | 2.2 M 
2 | predictor | Sequential | 65.9 K
-----------------------------------------
13.4 M    Trainable params
0         Non-trainable params
13.4 M    Total params
26.826    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

`Trainer.fit` stopped: `max_epochs=100` reache

## Evaluate

In [12]:
import numpy as np

Z_train, Y_train = zip(*trainer.predict(dataloaders=loader_train, model=cissl))
Z_test, Y_test = zip(*trainer.predict(dataloaders=loader_test, model=cissl))

Z_train = np.concatenate(Z_train, axis=0)
Y_train = np.concatenate(Y_train, axis=0)
Z_test = np.concatenate(Z_test, axis=0)
Y_test = np.concatenate(Y_test, axis=0)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 391it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 391it [00:00, ?it/s]

In [13]:
# Downstream evaluation. Accuracy: 95.36%
from sklearn.svm import LinearSVC
import tqdm

best_acc = 0
for C in tqdm.tqdm(np.logspace(-3,0,base=10,num=7)):
    clf = LinearSVC(C=C)
    clf.fit(Z_train, Y_train)
    acc = clf.score(Z_test, Y_test)
    best_acc = max(best_acc, acc)
print(f"Downstream STL10 accuracy: {best_acc*100:.2f}%") 

Downstream STL10 accuracy: 95.36%
