In [1]:
!pip install pytorch-lightning==1.7.2 scikit-learn==1.0.2 --no-dependencies tensorboard

You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.[0m


Basic variables depending on whether using GPU or not.

In [2]:
import torch

data_dir = "data/"
if torch.cuda.is_available():
    device, precision, gpus = "cuda", 16, 1
else:
    device, precision, gpus = "cpu", 32, 0

## DISSL

In [3]:
import pytorch_lightning as pl
import torch.nn as nn
from torchvision.models import resnet18
from torch.distributions import Categorical
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

class DISSL(pl.LightningModule):
    """DISSL objective."""

    def __init__(self, 
                 n_equivalence_classes=16384, 
                 lambda_maximality=2.3, 
                 beta_det_inv=0.8, 
                 max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        
        # ENCODER
        self.encoder = resnet18()
        self.encoder.fc = nn.Identity() # remove last linear layer
        z_dim=512
        
        # TEACHER PROJECTION HEAD
        # more expressive is better => MLP
        n_hidden = 2048
        bottleneck_size = 512 # adds a bottleneck to avoid linear layer with many parameters
        self.projector = nn.Sequential(
            nn.Linear(z_dim, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden, bottleneck_size),
            nn.BatchNorm1d(bottleneck_size),
            nn.ReLU(inplace=True),
            nn.Linear(bottleneck_size, self.hparams.n_equivalence_classes)
        )
        
        # STUDENT PROJECTION HEAD: 
        # needs to be linear (could also add batchnorm as this is linear)
        self.predictor = nn.Linear(z_dim, self.hparams.n_equivalence_classes)

    def forward(self, batch):
        x1, x2 = batch
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)
        return (self.asymmetric_loss(z1,z2) + 
                self.asymmetric_loss(z2,z1)) / 2
    
    def asymmetric_loss(self, z1, z2, temperature_teach=0.5):
        """Computes the asymmetric DISSL loss where MC expectations over z1."""
        logits_t1 = self.projector(z1).float() / temperature_teach
        logits_t2 = self.projector(z2).float() / temperature_teach
        logits_s = self.predictor(z2).float() 
        
        # q(\hat{M}|X). batch shape: [batch_size] ; event shape: []
        q_Mlx = Categorical(logits=logits_t1)
        
        # MAXIMALITY. -H[\hat{M}]
        q_M = Categorical(probs=q_Mlx.probs.mean(0))
        mxml = -q_M.entropy() # you want to max entropy
        
        # INVARIANCE and DETERMINISM. E_{q(M|X)}[log q(M|\tilde{X})]
        det_inv = (q_Mlx.probs * logits_t2.log_softmax(-1)).sum(-1).mean()
        
        # DISTILLATION. E_{q(M|X)}[log s(M|\tilde{X})]
        dstl = (q_Mlx.probs * logits_s.log_softmax(-1)).sum(-1).mean()
        
        self.log_dict({"H[M]": -mxml, "CE": -det_inv}, prog_bar=True) 

        return self.hparams.lambda_maximality * mxml - self.hparams.beta_det_inv * det_inv - dstl

    def training_step(self, batch, batch_idx):
        return self(batch) 

    def predict_step(self, batch, batch_idx):
        x, y = batch
        return self.encoder(x).cpu().numpy(), y.cpu().numpy()

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=2e-3, weight_decay=1e-6)
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]

## Data
Downloads and prepare the necessary data.

In [4]:
from torchvision import transforms

imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

pretrain_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=96, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

In [5]:
from torchvision.datasets import STL10
from torch.utils.data import DataLoader

class PretrainSTL10(STL10):
    def __init__(self, data_dir, pretrain_transforms, download=True):
        super().__init__(data_dir, download=download, split="unlabeled", transform=None)
        self.pretrain_transforms = pretrain_transforms
        
    def __getitem__(self, index):
        x,y = super().__getitem__(index)
        x1 = self.pretrain_transforms(x)
        x2 = self.pretrain_transforms(x)
        return x1, x2

data_pretrain = PretrainSTL10(data_dir, pretrain_transforms=pretrain_transforms, download=True)
data_train = STL10(data_dir,  split="train", transform=val_transforms, download=True)
data_test = STL10(data_dir, split="test", transform=val_transforms, download=True)

loader_pretrain = DataLoader(data_pretrain, batch_size=256, shuffle=True,
                               num_workers=8, pin_memory=True)
loader_train = DataLoader(data_train, batch_size=512, shuffle=False, num_workers=8, pin_memory=True, drop_last=False)
loader_test = DataLoader(data_train, batch_size=512, shuffle=False, num_workers=8, pin_memory=True, drop_last=False)

Files already downloaded and verified


## Train

In [6]:
EPOCHS=100 
dissl= DISSL(max_epochs=EPOCHS)
trainer = pl.Trainer(gpus=gpus, precision=precision, max_epochs=EPOCHS, logger=False, enable_checkpointing=False)
trainer.fit(dissl, train_dataloaders=loader_pretrain)

  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | ResNet     | 11.2 M
1 | projector | Sequential | 10.5 M
2 | predictor | Linear     | 8.4 M 
-----------------------------------------
30.1 M    Trainable params
0         Non-trainable params
30.1 M    Total params
60.183    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


## Evaluate

In [7]:
from pl_bolts.datamodules import SklearnDataModule

In [9]:
import numpy as np

Z_train, Y_train = zip(*trainer.predict(dataloaders=loader_train, model=dissl))
Z_test, Y_test = zip(*trainer.predict(dataloaders=loader_test, model=dissl))

Z_train = np.concatenate(Z_train, axis=0)
Y_train = np.concatenate(Y_train, axis=0)
Z_test = np.concatenate(Z_test, axis=0)
Y_test = np.concatenate(Y_test, axis=0)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 391it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 391it [00:00, ?it/s]

In [10]:
# Downstream evaluation. Accuracy: 95.44% 
from sklearn.svm import LinearSVC

best_acc = 0
for C in np.logspace(-3,0,base=10,num=7):
    clf = LinearSVC(C=C, dual=False)
    clf.fit(Z_train, Y_train)
    acc = clf.score(Z_test, Y_test)
    best_acc = max(best_acc, acc)
print(f"Downstream STL10 accuracy: {best_acc*100:.2f}%") 

Downstream STL10 accuracy: 95.44%
