In [None]:
!pip install pytorch-lightning einops torchmetrics lovely-tensors lightly wandb timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.0.2-py3-none-any.whl (719 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m719.0/719.0 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting lovely-tensors
  Downloading lovely_tensors-0.1.15-py3-none-any.whl (17 kB)
Collecting lightly
  Downloading lightly-1.4.3-py3-none-any.whl (650 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m650.6/650.6 kB[0m [31m30.1 MB/s[0m eta [36m0:00:

In [None]:
from einops.layers.torch import Rearrange
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.data.collate import imagenet_normalize
from lightly.loss import NTXentLoss
from lightly.models import ResNetGenerator
from lightly.models.modules import SimCLRProjectionHead
import lovely_tensors as lt
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from sklearn.linear_model import LogisticRegression
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN, StanfordCars
import torchvision.transforms as T
import timm
from tqdm.auto import tqdm
import wandb

from copy import deepcopy
from functools import partial
import inspect
import os

lt.monkey_patch()
torch.backends.cudnn.deterministic = True

In [None]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Definitions

### MNIST and CIFAR10

#### Teachers

In [None]:
def create_mnist_cls_teacher():
  return nn.Sequential(
      Rearrange('b c h w -> b (c h w)'),
      nn.Linear(784, 1200),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(1200, 1200),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(1200, 10)
  )

def create_mnist_ae_teacher():
  return nn.Sequential(
      Rearrange('b c h w -> b (c h w)'),
      nn.Linear(784, 128),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(128, 64),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(64, 128),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(128, 784),
      nn.Tanh(),
      Rearrange('b (c h w) -> b c h w', c=1, h=28, w=28)
  )

def create_cifar10_cls_teacher():
  return ResNetGenerator('resnet-18')

def create_cifar10_simclr_teacher():
  return nn.Sequential(
    *list(ResNetGenerator('resnet-18').children())[:-1],
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten()
  )

#### Students

In [None]:
def create_mnist_cls_student():
  return nn.Sequential(
      Rearrange('b c h w -> b (c h w)'),
      nn.Linear(784, 32),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(32, 32),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(32, 10)
  )

def create_mnist_ae_student():
  return nn.Sequential(
      Rearrange('b c h w -> b (c h w)'),
      nn.Linear(784, 64),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(32, 64),
      nn.Dropout(0.5),
      nn.ReLU(),
      nn.Linear(64, 784),
      nn.Tanh(),
      Rearrange('b (c h w) -> b c h w', c=1, h=28, w=28)
  )

def create_cifar10_cls_student():
  return ResNetGenerator('resnet-9', width=0.5)

def create_cifar10_simclr_student():
  return nn.Sequential(
    *list(ResNetGenerator('resnet-9', width=0.5).children())[:-1],
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten()
  )

#### Lightning

In [None]:
class LitModel(pl.LightningModule):
  def __init__(self, model, num_classes, param_fn=lambda model : model.parameters(), lr=1e-3, momentum=0, lr_decay=None):
    super().__init__()
    self.model = model
    self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    self.param_fn = param_fn
    self.lr = lr
    self.momentum = momentum
    self.lr_decay = lr_decay
    self.save_hyperparameters(ignore=['model'])

  def forward(self, x):
    return self.model(x)

  def step(self, batch, stage):
    x, y = batch
    logits = self(x)
    loss = F.cross_entropy(logits, y)
    self.log(f'{stage}/loss', loss, prog_bar=True)
    return {'loss': loss, 'logits': logits, 'y': y}

  def training_step(self, batch, batch_idx):
    return self.step(batch, 'train')['loss']

  def test_validation_step(self, batch, stage):
    output = self.step(batch, stage)
    preds = torch.argmax(output['logits'], dim=1)
    accuracy = getattr(self, f'{stage}_accuracy')
    accuracy.update(preds, output['y'])
    self.log(f"{stage}/acc", accuracy, prog_bar=True)
    return output['loss']

  def validation_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'val')

  def test_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'test')

  def predict_step(self, batch, batch_idx):
    # can't reuse self.step() because predict doesn't support logging
    x, y = batch
    return {'logits': self(x), 'y': y}

  def configure_optimizers(self):
    optimizer = optim.SGD(self.param_fn(self.model), lr=self.lr, momentum=self.momentum)
    if self.lr_decay == 'linear':
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0, total_iters=self.trainer.max_epochs)
        scheduler_dict = {'scheduler': scheduler}
    elif self.lr_decay == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)
        scheduler_dict = {'scheduler': scheduler, 'monitor': 'val/loss', 'strict': True}
    elif callable(self.lr_decay):
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, self.lr_decay)
        scheduler_dict = {'scheduler': scheduler}
    elif self.lr_decay is None:
        return optimizer
    else:
        raise Exception(f'LR decay not supported for {self.lr_decay}')
    return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

In [None]:
class LitAE(pl.LightningModule):
  def __init__(self, model, lr=1e-3, weight_decay=0):
    super().__init__()
    self.model = model
    self.lr = lr
    self.weight_decay = weight_decay
    self.save_hyperparameters(ignore=['model'])

  def forward(self, x):
    return self.model(x)

  def step(self, batch, stage):
    x, _ = batch
    rec = self(x)
    loss = F.mse_loss(rec, x)
    self.log(f'{stage}_loss', loss, prog_bar=True)
    return loss

  def training_step(self, batch, batch_idx):
    return self.step(batch, 'train')

  def validation_step(self, batch, batch_idx):
    return self.step(batch, 'val')

  def test_step(self, batch, batch_idx):
    return self.step(batch, 'test')

  def configure_optimizers(self):
    optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0.9, weight_decay=self.weight_decay)
    return optimizer

In [None]:
class SimCLR(pl.LightningModule):
    def __init__(self, model, lr=None):
        super().__init__()
        self.backbone = model
        dim = self.backbone[-3][-1].bn2.num_features
        self.projection_head = SimCLRProjectionHead(dim, dim, 128)
        self.criterion = NTXentLoss()
        self.lr = lr
        self.save_hyperparameters()

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

    def training_step(self, batch, batch_index):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06 if self.lr is None else self.lr)
        return optim

In [None]:
class LitDistiller(pl.LightningModule):
  def __init__(
      self,
      teacher,
      student,
      d_weight=1,
      d_type='cos',
      ce_weight=0,
      eps=1e-8,
      dim=None,
      num_classes=0,
      non_linear_head=None,
      dropout_head=None,
      optim='sgd',
      lr=1e-1,
      momentum=0.9,
      anchors=None
  ):
    super().__init__()
    self.teacher = deepcopy(teacher).eval()
    self.student = student
    for p in self.teacher.parameters():
      p.requires_grad=False

    self.d_weight = d_weight
    self.ce_weight = ce_weight
    self.eps = eps
    self.register_buffer('anchors', anchors)
    self.lr = lr
    self.momentum = momentum
    self.optim = optim

    if d_type == 'cos':
      self.d_loss_fn = lambda student_rel, teacher_rel, cos: -(cos/ 2 + 0.5 + self.eps).log().mean()
    elif d_type == 'mse':
      self.d_loss_fn = lambda student_rel, teacher_rel, cos: F.mse_loss(student_rel, teacher_rel)
    elif d_type == 'sce':
      self.d_loss_fn = lambda student_rel, teacher_rel, cos: F.cross_entropy(student_rel / 2, (teacher_rel / 2).softmax(dim=-1))
    else:
      raise Exception(f'`d_type` must be in `("cos", "mse", "sce")`, got {d_type}')

    if dim and num_classes and ce_weight and (non_linear_head is not None) and (dropout_head is not None):
      assert non_linear_head in (True, False)
      assert dropout_head in (True, False)
      self.head = nn.Linear(dim, num_classes)
      if non_linear_head:
        self.head = nn.Sequential(
            *([nn.Dropout(0.5)] if dropout_head else []),
            nn.ReLU(),
            nn.Linear(dim, num_classes)
        )
      self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
      self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    elif dim is None and num_classes == 0 and ce_weight == 0 and non_linear_head is None:
      self.head = None
    else:
      raise Exception('`dim`, `num_classes`, `ce_weight`, `non_linear_head`, and `dropout_head` must be either be all set or none set')

    self.save_hyperparameters(ignore=['teacher', 'student'])

  def forward(self, x):
    # print(self.teacher.training)
    return {'teacher': self.teacher(x), 'student': self.student(x)}

  def step(self, batch, stage):
    ret_dict = {}

    # get teacher and student emebddings
    x, y = batch
    features = self(x)
    teacher_abs = features['teacher']
    student_abs = features['student']

    # auxiliary loss can be cross entropy with labels
    # only works when distilling class probabilities
    ce_loss = 0
    if self.ce_weight:
      student_logits = self.head(student_abs)
      ce_loss = F.cross_entropy(student_logits, y)
      self.log(f'{stage}/ce_loss', ce_loss, prog_bar=True, sync_dist=True)
      ret_dict.update({'student_logits': student_logits})

    # anchors are either the embeddings themselves or predefined anchors
    teacher_anchors, student_anchors = teacher_abs, student_abs
    if self.anchors is not None:
      teacher_anchors = self.teacher(self.anchors)
      student_anchors = self.student(self.anchors)

    # normalize anchors
    teacher_anchors = teacher_anchors / teacher_anchors.norm(dim=-1, keepdim=True)
    student_anchors = student_anchors / student_anchors.norm(dim=-1, keepdim=True)

    # normalize absolute representations
    teacher_abs = teacher_abs / teacher_abs.norm(dim=-1, keepdim=True)
    student_abs = student_abs / student_abs.norm(dim=-1, keepdim=True)

    # compute relative representations
    teacher_rel = teacher_abs.mm(teacher_anchors.T)
    student_rel = student_abs.mm(student_anchors.T)

    # actual distillation loss
    if self.d_weight:
      cos = F.cosine_similarity(student_rel, teacher_rel)
      d_loss = self.d_loss_fn(student_rel, teacher_rel, cos)
#       d_loss = -(cos/ 2 + 0.5 + self.eps).log().mean()
#       d_loss = F.mse_loss(student_rel, teacher_rel)
#       d_loss = F.cross_entropy(student_rel / 2, (teacher_rel / 2).softmax(dim=-1))
      self.log(f'{stage}/cos', cos.mean(), prog_bar=True, sync_dist=True)
      self.log(f'{stage}/d_loss', d_loss, prog_bar=True, sync_dist=True)
    else:
      d_loss = 0

    # loss = self.d_weight * d_loss + self.ce_weight * ce_loss + self.sce_weight * sce_loss
    loss = self.d_weight * d_loss + self.ce_weight * ce_loss
    self.log(f'{stage}/loss', loss, sync_dist=True)

    ret_dict.update({'loss': loss, 'features': features, 'y': y})
    return ret_dict

  def training_step(self, batch, batch_idx):
    # assert not self.teacher.training
    # for p in self.teacher.parameters():
    #   assert not p.requires_grad
    return self.step(batch, 'train')['loss']

  def test_validation_step(self, batch, stage):
    output = self.step(batch, stage)
    if self.head is not None:
      preds = torch.argmax(output['student_logits'], dim=1)
      accuracy = getattr(self, f'{stage}_accuracy')
      accuracy.update(preds, output['y'])
      self.log("val/acc", accuracy, prog_bar=True, sync_dist=True)
    return output['loss']

  def validation_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'val')

  def test_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'test')

  def configure_optimizers(self):
    if self.optim == 'sgd':
        optim_cls = optim.SGD
    elif self.optim == 'adam':
        optim_cls = optim.Adam
    optimizer = optim_cls(self.student.parameters(), lr=self.lr, momentum=self.momentum)
    return optimizer

  def on_train_start(self):
    self.teacher.eval()
    print(f'd_weight: {self.d_weight}\tce_weight: {self.ce_weight}')
    print(f'anchors: {self.anchors}')
    print(f'head: {self.head}')
    print(f'optim: {self.optim}')

In [None]:
class LitSPDistiller(pl.LightningModule):
  def __init__(
      self,
      teacher,
      student,
      lr=1e-3,
      momentum=0.9,
      d_weight=1,
      ce_weight=0,
      num_classes=0,
      dim=None
  ):
    super().__init__()

    # set up teacher and student
    self.teacher = deepcopy(teacher).eval()
    self.student = student
    for p in self.teacher.parameters():
      p.requires_grad=False

    # optimization
    self.lr = lr
    self.momentum = momentum

    # loss components
    self.d_weight = d_weight
    self.ce_weight = ce_weight

    # accuracy if distilling with class labels
    if ce_weight and dim is not None and num_classes:
      self.head = nn.Linear(dim, num_classes)
      self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
      self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    elif ce_weight == 0 and dim is None and num_classes == 0:
      self.head = None
    else:
      raise Exception('`dim`, `num_classes`, and `ce_weight` must be either be all set or none set')
    self.save_hyperparameters(ignore=['teacher', 'student'])


  def forward(self, x):
    # print(self.teacher.training)
    return {'teacher': self.teacher(x), 'student': self.student(x)}

  def step(self, batch, stage):
    ret_dict = {}

    # get teacher and student emebddings
    x, y = batch
    logits = self(x)
    teacher_abs = logits['teacher']
    student_abs = logits['student']

    ce_loss = 0
    if self.ce_weight > 0:
      student_logits = self.head(student_abs)
      ce_loss = F.cross_entropy(student_logits, y)
      self.log(f'{stage}/ce_loss', ce_loss, prog_bar=True)
      ret_dict.update({'student_logits': student_logits})

    # compute relative representations
    teacher_rel = teacher_abs.mm(teacher_abs.T)
    student_rel = student_abs.mm(student_abs.T)

    # l2 normalize relative representations
    teacher_rel = teacher_rel / teacher_rel.norm(dim=1, keepdim=True)
    student_rel = student_rel / student_rel.norm(dim=1, keepdim=True)

    d_loss = F.mse_loss(teacher_rel, student_rel)
    self.log(f'{stage}/d_loss', d_loss, prog_bar=True)

    loss = self.d_weight * d_loss + self.ce_weight * ce_loss
    self.log(f'{stage}/loss', loss, prog_bar=True)

    ret_dict.update({'loss': loss, 'y': y})
    return ret_dict

  def training_step(self, batch, batch_idx):
    # assert not self.teacher.training
    # for p in self.teacher.parameters():
    #   assert not p.requires_grad
    return self.step(batch, 'train')['loss']

  def test_validation_step(self, batch, stage):
    output = self.step(batch, stage)
    if self.head is not None:
      preds = torch.argmax(output['student_logits'], dim=1)
      accuracy = getattr(self, f'{stage}_accuracy')
      accuracy.update(preds, output['y'])
      self.log(f"{stage}/acc", accuracy, prog_bar=True)
    return output['loss']

  def validation_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'val')

  def test_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'test')

  def configure_optimizers(self):
    optimizer = optim.SGD(self.student.parameters(), lr=self.lr, momentum=self.momentum, nesterov=True)
    return optimizer

  def on_train_start(self):
    self.teacher.eval()
    print(f'd_weight: {self.d_weight}\tce_weight: {self.ce_weight}')
    print(f'head: {self.head}')

In [None]:
class LitLPDistiller(pl.LightningModule):
  def __init__(
      self,
      teacher,
      student,
      lr,
      momentum=0.9,
      weight_decay=0,
      d_weight=1.5,
      k=5,
      normalizing_constant=1,
      ce_weight=1,
      num_classes=0,
      dim=None,
      sce_weight=2,
      temp=0.5,
      teacher_head=None
  ):
    super().__init__()
    self.teacher = deepcopy(teacher).eval()
    self.student = student
    for p in self.teacher.parameters():
      p.requires_grad=False

    # optimization
    self.lr = lr
    self.weight_decay = weight_decay
    self.momentum = momentum

    # loss components
    self.d_weight = d_weight
    self.ce_weight = ce_weight
    self.sce_weight = sce_weight

    # lp distillation
    self.k = k
    self.normalizing_constant = normalizing_constant

    # accuracy if distilling with class labels
    if (ce_weight or sce_weight) and dim is not None and num_classes:
      self.student_head = nn.Linear(dim, num_classes)

      if sce_weight and teacher_head:
        self.teacher_head = teacher_head
        for p in self.teacher_head.parameters():
          p.requires_grad=False
      elif not (sce_weight == 0 and teacher_head is None):
        raise Exception('`sce_weight` and `teacher_head` must be all set or none set')

      self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
      self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    elif ce_weight == 0 and sce_weight == 0 and dim is None and num_classes == 0 and teacher_head is None:
      self.student_head = None
      self.teacher_head = None
    else:
      raise Exception('`dim` and `num_classes` and either `ce_weight` or `sce_weight` must be either be all set or none set')

    # sce
    self.temp = temp

    self.save_hyperparameters(ignore=['teacher', 'student', 'teacher_head'])

  def forward(self, x):
    # print(self.teacher.training)
    return {'teacher': self.teacher(x), 'student': self.student(x)}

  def step(self, batch, stage):
    ret_dict = {}

    # get teacher and student emebddings
    x, y = batch
    logits = self(x)
    teacher_abs = logits['teacher']
    student_abs = logits['student']

    if self.student_head:
      student_logits = self.student_head(student_abs)
      ret_dict.update({'student_logits': student_logits})

    ce_loss = 0
    if self.ce_weight > 0:
      ce_loss = F.cross_entropy(student_logits, y)
      self.log(f'{stage}/ce_loss', ce_loss, prog_bar=True)

    sce_loss = 0
    if self.sce_weight > 0:
      teacher_logits = self.teacher_head(teacher_abs)
      sce_loss = F.cross_entropy(student_logits, teacher_logits.softmax(dim=1))
      self.log(f'{stage}/sce_loss', sce_loss, prog_bar=True)

    d_loss = 0
    if self.d_weight:
        # compute teacher_map
        teacher_rel = torch.cdist(teacher_abs, teacher_abs).pow(2)
        knn, knn_ids = teacher_rel.sort()
        knn, knn_ids = knn[:, 1:1+self.k], knn_ids[:, 1:1+self.k]
        knn = (-knn / self.normalizing_constant ** 2).exp()
        teacher_rel = torch.zeros_like(teacher_rel).scatter_(1, knn_ids, knn)

        # compute student map
        student_rel = torch.cdist(student_abs, student_abs).pow(2)

        d_loss = (teacher_rel * student_rel).sum() / self.k
        self.log(f'{stage}/d_loss', d_loss)

    loss = self.d_weight * d_loss + self.ce_weight * ce_loss + self.sce_weight * sce_loss
    self.log(f'{stage}/loss', loss, prog_bar=True)

    ret_dict.update({'loss': loss, 'y': y})
    return ret_dict

  def training_step(self, batch, batch_idx):
    # assert not self.teacher.training
    # for p in self.teacher.parameters():
    #   assert not p.requires_grad
    return self.step(batch, 'train')['loss']

  def test_validation_step(self, batch, stage):
    output = self.step(batch, stage)
    if self.student_head is not None:
      preds = torch.argmax(output['student_logits'], dim=1)
      accuracy = getattr(self, f'{stage}_accuracy')
      accuracy.update(preds, output['y'])
      self.log(f"{stage}/acc", accuracy, prog_bar=True)
    return output['loss']

  def validation_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'val')

  def test_step(self, batch, batch_idx):
    self.test_validation_step(batch, 'test')

  def configure_optimizers(self):
    optimizer = optim.SGD(self.student.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
    return optimizer

  def on_train_start(self):
    self.teacher.eval()
    if self.teacher_head is not None:
      self.teacher_head.eval()
    print(f'd_weight: {self.d_weight}\tce_weight: {self.ce_weight}\tsce_weight: {self.sce_weight}')
    print(f'student_head: {self.student_head}')
    print(f'teacher_head: {self.teacher_head}')

In [None]:
class LitStandardDistiller(pl.LightningModule):
  def __init__(
      self,
      teacher,
      student,
      lr,
      temp=1,
      d_weight=1,
      ce_weight=1,
      num_classes=0,
      teacher_head=None
  ):
    super().__init__()
    self.teacher = deepcopy(teacher).eval()
    self.student = student
    for p in self.teacher.parameters():
      p.requires_grad=False

    # optimization
    self.lr = lr

    # loss components
    self.d_weight = d_weight
    self.ce_weight = ce_weight

    # distillation
    self.temp = temp

    # accuracy
    self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    self.save_hyperparameters(ignore=['teacher', 'student', 'teacher_head'])

  def forward(self, x):
    # print(self.teacher.training)
    return {'teacher': self.teacher(x), 'student': self.student(x)}

  def step(self, batch, stage):
    ret_dict = {}

    # get teacher and student emebddings
    x, y = batch
    logits = self(x)
    teacher_logits = logits['teacher']
    student_logits = logits['student']
    ret_dict.update({'student_logits': student_logits})

    ce_loss = 0
    if self.ce_weight > 0:
      ce_loss = F.cross_entropy(student_logits, y)
      self.log(f'{stage}/ce_loss', ce_loss, prog_bar=True)

    d_loss = 0
    if self.sce_weight > 0:
      d_loss = F.cross_entropy(student_logits / self.temp, (teacher_logits / self.temp).softmax(dim=1)) /(self.temp ** 2)
      self.log(f'{stage}/d_loss', sce_loss, prog_bar=True)

    loss = self.d_weight * d_loss + self.ce_weight * ce_loss
    self.log(f'{stage}/loss', loss, prog_bar=True)

    ret_dict.update({'loss': loss, 'y': y})
    return ret_dict

  def training_step(self, batch, batch_idx):
    # assert not self.teacher.training
    # for p in self.teacher.parameters():
    #   assert not p.requires_grad
    return self.step(batch, 'train')['loss']

  def test_validation_step(self, batch, stage):
    output = self.step(batch, stage)
    if self.student_head is not None:
      preds = torch.argmax(output['student_logits'], dim=1)
      accuracy = getattr(self, f'{stage}_accuracy')
      accuracy.update(preds, output['y'])
      self.log("val/acc", accuracy, prog_bar=True)
    return output['loss']

  def validation_step(self, batch, batch_idx):
    return self.test_validation_step(batch, 'val')

  def test_step(self, batch, batch_idx):
    self.test_validation_step(batch, 'test')

  def configure_optimizers(self):
    optimizer = optim.SGD(self.student.parameters(), lr=self.lr, momentum=0.9)
    return optimizer

  def on_train_start(self):
    self.teacher.eval()
    print(f'd_weight: {self.d_weight}\tce_weight: {self.ce_weight}')
    print(f'student_head: {self.student_head}')
    print(f'teacher_head: {self.teacher_head}')

#### Data

In [None]:
class GenericDataModule(pl.LightningDataModule):

  def __init__(self, ds_class, data_dir='./', batch_size=512, timm_name=None):
    super().__init__()
    self.ds_class = ds_class
    self.data_dir = data_dir
    self.batch_size = batch_size
    if timm_name is not None:
      model = timm.create_model(timm_name, num_classes=0)
      data_config = timm.data.config.resolve_data_config(model.pretrained_cfg)
      self.train_transform = timm.data.create_transform(**data_config, is_training=True)
      self.test_transform = timm.data.create_transform(**data_config, is_training=False)
    else:
      self.train_transform = T.Compose(
          [
              T.ToTensor(),
              T.Normalize(0.5, 0.5),
              *([T.RandomResizedCrop(224)] if ds_class in (StanfordCars,) else []),
              *([T.RandomHorizontalFlip()] if ds_class in (CIFAR10, CIFAR100,) else [])
          ]
      )
      self.test_transform = T.Compose(
          [
              T.ToTensor(),
              T.Normalize(0.5, 0.5),
              *([T.CenterCrop(224)] if ds_class in (StanfordCars,) else [])
          ]
      )
    parameters = inspect.signature(self.ds_class).parameters
    # kaggle being kaggle doesn't make the parameters show up right
    if 'train' in parameters or ds_class in [MNIST, CIFAR10, CIFAR100]:
      self.train_split_kwargs = {'train': True}
      self.test_split_kwargs = {'train': False}
    elif 'split' in parameters or ds_class in [SVHN, StanfordCars]:
      self.train_split_kwargs = {'split': 'train'}
      self.test_split_kwargs = {'split': 'test'}
    else:
      raise Exception(f'Account for dataset {ds_class}')


  def prepare_data(self):
    self.ds_class(self.data_dir, download=True, **self.train_split_kwargs)
    self.ds_class(self.data_dir, download=True, **self.test_split_kwargs)

  def setup(self, stage=None):
    if stage == "fit" or stage is None:
      ds_full = self.ds_class(self.data_dir, transform=self.train_transform, **self.train_split_kwargs)
      self.ds_train, _ = random_split(ds_full, [0.9, 0.1], generator=torch.Generator().manual_seed(42))

    if stage in ("fit", "validate") or stage is None:
      ds_full = self.ds_class(self.data_dir, transform=self.test_transform, **self.test_split_kwargs)
      _, self.ds_val = random_split(ds_full, [0.9, 0.1], generator=torch.Generator().manual_seed(42))

    if stage == "test" or stage is None:
      self.ds_test = self.ds_class(self.data_dir, transform=self.test_transform, **self.test_split_kwargs)

  def train_dataloader(self):
      return DataLoader(
          self.ds_train,
          batch_size=self.batch_size,
          shuffle=True,
          generator=torch.Generator().manual_seed(12345678),
      )

  def val_dataloader(self):
      return DataLoader(self.ds_val, batch_size=self.batch_size, shuffle=False)

  def test_dataloader(self):
      return DataLoader(self.ds_test, batch_size=self.batch_size, shuffle=False)

In [None]:
class SimCLRCIFAR10DataModule(pl.LightningDataModule):
  def __init__(self, data_dir='./', batch_size=512):
    super().__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size

  def prepare_data(self):
    CIFAR10(self.data_dir, train=True, download=True)

  def setup(self, stage=None):
    self.ssl_ds_train = LightlyDataset.from_torch_dataset(
        CIFAR10(self.data_dir, download=True, train=True)
    )

  def train_dataloader(self):
    collate_fn = SimCLRCollateFunction(
      input_size=32,
      gaussian_blur=0.,
    )
    return DataLoader(
      self.ssl_ds_train,
      batch_size=self.batch_size,
      collate_fn=collate_fn,
      shuffle=True,
      drop_last=True,
      num_workers=2,
      generator=torch.Generator().manual_seed(12345678)
    )

In [None]:
mnist_64 = {'ds_class': MNIST, 'batch_size': 64}
mnist_1024 = {'ds_class': MNIST, 'batch_size': 1024}
cifar10_64 = {'ds_class': CIFAR10, 'batch_size': 64}
cifar10_1024 = {'ds_class': CIFAR10, 'batch_size': 1024}
cifar10_256 = {'ds_class': CIFAR10, 'batch_size': 256}

### Trainer

In [None]:
create_trainer = partial(
    pl.Trainer,
    accelerator='auto',
    max_epochs=10,
    deterministic=True
)

In [None]:
def quick_train(
    model_init,
    lit_model_cls,
    dm_init,
    lit_model_kwargs={},
    dm_kwargs={},
    trainer_kwargs={},
    seed=12345678,
    project=None,
    name=None,
    id=None,
    log_model=False
):
  fit_kwargs = {}
  if project is not None and name is not None:
    resume_kwargs = {'id': id, 'resume': 'allow'} if id is not None else {}
    wandb.finish()
    if id is not None and log_model:
        api = wandb.Api()
        artifact = api.artifact(f'patrickramosobf/{project}/model-{id}:latest')
        fit_kwargs['ckpt_path'] = artifact.download() + '/model.ckpt'
    wandb_logger = WandbLogger(
        project=project,
        name=name,
        log_model=log_model,
        # entity=,
        **resume_kwargs
    )
    trainer_kwargs['logger'] = wandb_logger
  trainer_kwargs['callbacks'] = [LearningRateMonitor()]
  pl.seed_everything(seed)
  dm = dm_init(**dm_kwargs)
  model = model_init()
  lit_model = lit_model_cls(model, **lit_model_kwargs)
  trainer = create_trainer(**trainer_kwargs)
  trainer.fit(lit_model, dm, **fit_kwargs)
  return lit_model

def quick_distill(
    student_init,
    dm_init,
    dm_kwargs={},
    distill_cls=LitDistiller,
    distill_kwargs={},
    trainer_kwargs={},
    student_preprocess=lambda model: model,
    seed=12345678,
    project=None,
    name=None,
    id=None,
    log_model=False
):
  fit_kwargs = {}
  if project is not None and name is not None:
    resume_kwargs = {'id': id, 'resume': 'allow'} if id is not None else {}
    wandb.finish()
    if id is not None and log_model:
        api = wandb.Api()
        artifact = api.artifact(f'patrickramosobf/{project}/model-{id}:latest')
        fit_kwargs['ckpt_path'] = artifact.download() + '/model.ckpt'
    wandb_logger = WandbLogger(
        project=project,
        name=name,
        log_model=log_model,
        # entity=,
        **resume_kwargs
    )
    trainer_kwargs['logger'] = wandb_logger
  pl.seed_everything(seed)
  dm = dm_init(**dm_kwargs)
  distilled = student_preprocess(student_init())
  lit_distiller = distill_cls(
      student=distilled, **distill_kwargs
  )
  trainer = create_trainer(**trainer_kwargs)
  trainer.fit(lit_distiller, dm, **fit_kwargs)
  return lit_distiller

def quick_fc_probe(
    model,
    dm_init,
    lit_model_kwargs={},
    dm_kwargs={},
    model_preprocess=lambda model: model,
    seed=12345678,
):
  '''linear probe eval with trainable fully-connected layer'''
  pl.seed_everything(seed)
  dm = dm_init(**dm_kwargs)
  linear = model_preprocess(deepcopy(model))
  linear.extend([nn.ReLU(), nn.Linear(linear[-1].out_features, 10)])
  linear.eval()
  for p in linear[:-2].parameters():
    p.requires_grad = False
  lit_linear = LitModel(
      linear, 10, param_fn=lambda model: model[-1].parameters(), **lit_model_kwargs
  )
  trainer = create_trainer()
  trainer.fit(lit_linear, dm)
  return lit_linear

def quick_sk_probe(
    model,
    dm_init,
    dm_kwargs={},
    sk_kwargs={},
    model_preprocess=lambda model: model,
    seed=12345678,
    project=None,
    id=None
):
  '''linear probe eval with logistic regression model'''
  use_wandb = project is not None and id is not None
  if use_wandb:
    wandb.init(
        project=project,
        # entity=,
        id=id,
        resume='allow'
    )
  pl.seed_everything(seed)
  dm = dm_init(**dm_kwargs)
  encoder = model_preprocess(deepcopy(model))
  lit_encoder = LitModel(
      encoder, 10, param_fn=lambda model: model[-1].parameters()
  )
  trainer = create_trainer(strategy='dp' if not COLAB else 'auto')
  dm.prepare_data()
  dm.setup()
  preds = trainer.predict(lit_encoder, dm.train_dataloader())
  print(preds)
  embs = torch.cat([pred['logits'] for pred in preds])
  labels = torch.cat([pred['y'] for pred in preds])
  log_reg = (
      LogisticRegression(max_iter=5000, random_state=seed, verbose=True, **sk_kwargs)
      .fit(embs, labels)
  )
  val_preds = trainer.predict(lit_encoder, dm.val_dataloader())
  val_embs = torch.cat([pred['logits'] for pred in val_preds])
  val_labels = torch.cat([pred['y'] for pred in val_preds])
  val_acc = log_reg.score(val_embs, val_labels)
  print(val_acc)
  if use_wandb:
    wandb.log({'val/acc': val_acc})
  return log_reg

def quick_fn_probe(
    quick_fn,
    probe_fn,
    fn_kwargs={},
    probe_kwargs={},
    extract_model_fn=lambda lit_model: lit_model.model,
):
  lit_model = quick_fn(**fn_kwargs)
  if fn_kwargs.get('project') and fn_kwargs.get('name'):
    probe_kwargs.update({'project': fn_kwargs['project'], 'id': wandb.run.id})
  lit_linear = probe_fn(extract_model_fn(lit_model), **probe_kwargs)
  return lit_model, lit_linear

def quick_sk_test(
    model,
    linear,
    dm_init,
    dm_kwargs={},
    model_preprocess=lambda model: model,
    project=None,
    id=None
):
  '''test function for logistic regression model'''
  use_wandb = project is not None and id is not None
  if use_wandb:
    wandb.init(
        project=project,
        # entity=,
        id=id,
        resume='allow'
    )
  dm = dm_init(**dm_kwargs)
  encoder = deepcopy(model_preprocess(model))
  lit_encoder = LitModel(encoder, 10) # don't worry, final dim is not 10
  trainer = create_trainer()
  dm.prepare_data()
  dm.setup()
  test_preds = trainer.predict(lit_encoder, dm.test_dataloader())
  test_embs = torch.cat([pred['logits'] for pred in test_preds])
  test_labels = torch.cat([pred['y'] for pred in test_preds])
  test_acc = linear.score(test_embs, test_labels)
  print(test_acc)
  if use_wandb:
    wandb.log({'test/acc': test_acc})
  return test_acc

def quick_lit_test(
    lit_model,
    dm_init,
    dm_kwargs,
    model_preprocess=lambda model: model,
    project=None,
    id=None,
):
  '''test function for model'''
  use_wandb = project is not None and id is not None
  trainer_kwargs = {}
  if use_wandb:
    wandb.init(
        project=project,
        # entity=,
        id=id,
        resume='allow'
    )
  dm = dm_init(**dm_kwargs)
  trainer = create_trainer(**trainer_kwargs)
  test_acc = trainer.test(lit_model, dm)[0]['test/acc']
  if use_wandb:
    wandb.log({'test/acc': test_acc})
  print(test_acc)
  return test_acc

In [None]:
extract_student = lambda lit_model: lit_model.student

In [None]:
def strip_resnet_cls_head(resnet):
  return nn.Sequential(
      *list(resnet.children())[:-1],
      nn.AdaptiveAvgPool2d(1),
      nn.Flatten()
  )

In [None]:
def experiment(
    seeds,
    teacher_kwargs,
    small_kwargs,
    distilled_kwargs,
    extract_teacher_fn,
    test_dm_init,
    test_dm_kwargs={},
    test_fn=None,
    project=None
):
  # initialize empty results
  teacher_results = []
  small_results = []
  distilled_results = []

  # dm and trainer for testing just in case
  dm = test_dm_init(**test_dm_kwargs)
  dm.prepare_data()
  dm.setup()
  trainer = create_trainer()

  # set wandb details
  # if project is `None`, no logging anyway, so safe to input experiment
  for kwargs, name in zip(
      (teacher_kwargs, small_kwargs, distilled_kwargs),
      ('teacher', 'baseline', 'student')
  ):
    if 'fn_kwargs' in kwargs:
      kwargs = kwargs['fn_kwargs']
    kwargs.update({'project': project, 'name': name})

  # one trial for each seed
  for seed in seeds:

    # teacher
    if teacher_kwargs.get('quick_fn') is not None:
      # ssl: train and probe
      teacher_kwargs['fn_kwargs'].update({'seed': seed})
      if teacher_kwargs.get('probe_kwargs') is None:
        teacher_kwargs['probe_kwargs'] = {'seed': seed}
      else:
        teacher_kwargs['probe_kwargs'].update({'seed': seed})
      lit_teacher, teacher_linear = quick_fn_probe(**teacher_kwargs)
      if isinstance(teacher_linear, LogisticRegression):
        results = test_fn(
            teacher_kwargs.get(
                'extract_model_fn', lambda lit_model: lit_model.model
            )(lit_teacher),
            teacher_linear
        )
      else:
        pass # not using fc probe rn anyway
    else:
      # sl: just train
      lit_teacher = quick_train(**teacher_kwargs, seed=seed)
      results = trainer.test(lit_teacher, dm)[0]['test/acc']
    teacher_results.append(results)

    # student (no distillation)
    if small_kwargs.get('quick_fn') is not None:
      # ssl: train and probe
      small_kwargs['fn_kwargs'].update({'seed': seed})
      if small_kwargs.get('probe_kwargs') is None:
        small_kwargs['probe_kwargs'] = {'seed': seed}
      else:
        small_kwargs['probe_kwargs'].update({'seed': seed})
      lit_small, small_linear = quick_fn_probe(**small_kwargs)
      if isinstance(small_linear, LogisticRegression):
        results = test_fn(
            small_kwargs.get(
                'extract_model_fn', lambda lit_model: lit_model.model
            )(lit_small),
            small_linear
        )
      else:
        pass # not using fc probe rn anyway
    else:
      # sl: just train
      lit_small = quick_train(**small_kwargs, seed=seed)
      results = trainer.test(lit_small, dm)[0]['test/acc']
    small_results.append(results)

    # student (with distillation)
    # always train and probe
    distilled_kwargs['fn_kwargs'].update({'seed': seed})
    distilled_kwargs['fn_kwargs']['distill_kwargs'].update({'teacher': extract_teacher_fn(lit_teacher)})
    if distilled_kwargs.get('probe_kwargs') is None:
      distilled_kwargs['probe_kwargs'] = {'seed': seed}
    else:
      distilled_kwargs['probe_kwargs'].update({'seed': seed})
    lit_distilled, distilled_linear = quick_fn_probe(**distilled_kwargs)
    if isinstance(distilled_linear, LogisticRegression):
      results = test_fn(lit_distilled.student, distilled_linear)
    else:
      pass # not using fc probe rn anyway
    distilled_results.append(results)

  param_counts = [
      param_count(model)
      for model
      in (
          lit_distilled.teacher,
          distilled_kwargs['quick_fn'].keywords.get(
              'student_preprocess',
              distilled_kwargs['fn_kwargs'].get(
                  'student_preprocess',
                  lambda model: model
              )
          )(small_kwargs.get('extract_model_fn', lambda lit_model: lit_model.model)(lit_small)),
          lit_distilled.student
      )
  ]

  return param_counts, teacher_results, small_results, distilled_results

### Utils

In [None]:
def same_params(model_0, model_1):
  return all([(p_0 == p_1).all() for (p_0, p_1) in zip(model_0.parameters(), model_1.parameters())])

In [None]:
def param_count(model):
  return sum([p.numel() for p in model.parameters()])

In [None]:
@torch.no_grad()
def extract_embeddings(models, dataloader):
  if isinstance(models, nn.Module):
    models = [models]
  embeddings = []
  trainer = create_trainer()
  for model in models:
    embeddings.append(torch.cat([
        output['logits']
        for output in trainer.predict(LitModel(model, num_classes=10), dataloader) # ignore `num_classes`
    ]))
  return embeddings

In [None]:
def compare_cos_sim(model_0, model_1, dataloader):
  all_logits_0, all_logits_1 = extract_embeddings(
      [model_0, model_1], dataloader
  )

  all_logits_0 = all_logits_0 / all_logits_0.norm(dim=1, keepdim=True)
  all_logits_1 = all_logits_1 / all_logits_1.norm(dim=1, keepdim=True)

  all_logits_0 = all_logits_0.mm(all_logits_0.T)
  all_logits_1 = all_logits_1.mm(all_logits_1.T)

  cos = F.cosine_similarity(all_logits_0, all_logits_1).mean().item()
  return cos

## Experiments

### Self-Supervised

#### MNIST

In [None]:
mnist_ssl_quick_train = partial(
    quick_train,
    lit_model_cls=LitAE,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': MNIST, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 20}
)

mnist_ssl_quick_distill = partial(
    quick_distill,
    student_init=create_mnist_ae_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': MNIST, 'batch_size': 128},
    student_preprocess=lambda model: model[:5],
    trainer_kwargs={'max_epochs': 20}
)

mnist_ssl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs=mnist_1024,
    model_preprocess=lambda model: model[:5],
)

mnist_ssl_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs=mnist_1024,
    model_preprocess=lambda model: model[:5]
)

In [None]:
lit_teacher, teacher_linear = quick_fn_probe(
    mnist_ssl_quick_train,
    mnist_ssl_quick_sk_probe,
    {'model_init': create_mnist_ae_teacher, 'lit_model_kwargs': {'lr': 1e-1}}
)

In [None]:
lit_teacher, teacher_linear = quick_fn_probe(
    mnist_ssl_quick_train,
    mnist_ssl_quick_sk_probe,
    {'model_init': create_mnist_ae_teacher, 'lit_model_kwargs': {'lr': 1e-1}, 'project': 'rel-rep-dist-mnist-ssl-viz', 'name': 'teacher', 'log_model': True}
)

In [None]:
lit_student, student_linear = quick_fn_probe(
    mnist_ssl_quick_train,
    mnist_ssl_quick_sk_probe,
    {'model_init': create_mnist_ae_student, 'lit_model_kwargs': {'lr': 1e-8}, 'project': 'rel-rep-dist-mnist-ssl-viz', 'name': 'baseline', 'log_model': True}
)

In [None]:
lit_distilled, distilled_linear = quick_fn_probe(
    mnist_ssl_quick_distill,
    mnist_ssl_quick_sk_probe,
    {'distill_kwargs': {'teacher': lit_teacher.model[:5], 'lr': 1e-1},  'project': 'rel-rep-dist-mnist-ssl-viz', 'name': 'student', 'log_model': True},
    extract_model_fn=extract_student
)

In [None]:
param_counts, teacher_results, small_results, distilled_results = experiment(
    seeds=[42, 43, 44],
    teacher_kwargs={
        'quick_fn': mnist_ssl_quick_train,
        'probe_fn': mnist_ssl_quick_sk_probe,
        'fn_kwargs': {'model_init': create_mnist_ae_teacher, 'lit_model_kwargs': {'lr': 1e-1}}
    },
    small_kwargs={
        'quick_fn': mnist_ssl_quick_train,
        'probe_fn': mnist_ssl_quick_sk_probe,
        'fn_kwargs': {'model_init': create_mnist_ae_student, 'lit_model_kwargs': {'lr': 1e-8}}
    },
    distilled_kwargs={
        'quick_fn': mnist_ssl_quick_distill,
        'probe_fn': mnist_ssl_quick_sk_probe,
        'fn_kwargs': {'distill_kwargs': {'lr': 1e-1}},
        'extract_model_fn': extract_student
    },
    extract_teacher_fn=lambda lit_teacher: lit_teacher.model[:5],
    test_dm_init=GenericDataModule,
    test_dm_kwargs=mnist_1024,
    test_fn=mnist_ssl_sk_test,
    project='rel-rep-dist-mnist-ssl'
)

In [None]:
# param count vs test acc
accs = [np.mean([results]) for results in (teacher_results, small_results, distilled_results)]

labels = ['AE-64', 'AE-32 (no distillation)', 'AE-32 (with distillation)']
for count, acc, label in zip(param_counts, accs, labels):
  plt.scatter(count, acc, label=label)
plt.xlabel('parameter count')
plt.ylabel('test accuracy')
plt.legend()
plt.title('MNIST')
plt.show()

#### CIFAR10

In [None]:
cifar10_ssl_quick_train = partial(
    quick_train,
    lit_model_cls=SimCLR,
    dm_init=SimCLRCIFAR10DataModule,
    dm_kwargs={'batch_size': 256},
    trainer_kwargs={'max_epochs': 10}
)

cifar10_ssl_quick_distill = partial(
    quick_distill,
    student_init=create_cifar10_simclr_student,
    dm_init=GenericDataModule,
    dm_kwargs=cifar10_64,
    trainer_kwargs={'max_epochs': 10}
)

cifar10_ssl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs=cifar10_1024,
)

cifar10_ssl_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs=cifar10_1024,
)

In [None]:
param_counts, teacher_results, small_results, distilled_results = experiment(
    seeds=[42, 43, 44],
    teacher_kwargs={
        'quick_fn': cifar10_ssl_quick_train,
        'probe_fn': cifar10_ssl_quick_sk_probe,
        'fn_kwargs': {'model_init': create_cifar10_simclr_teacher, 'lit_model_kwargs': {'lr': 1e-1}, 'log_model': True},
        'extract_model_fn': lambda lit_model: lit_model.backbone
    },
    small_kwargs={
        'quick_fn': cifar10_ssl_quick_train,
        'probe_fn': cifar10_ssl_quick_sk_probe,
        'fn_kwargs': {'model_init': create_cifar10_simclr_student, 'lit_model_kwargs': {'lr': 1e-1}, 'trainer_kwargs': {'max_steps': 10}},
        'extract_model_fn': lambda lit_model: lit_model.backbone
    },
    distilled_kwargs={
        'quick_fn': cifar10_ssl_quick_distill,
        'probe_fn': cifar10_ssl_quick_sk_probe,
        'fn_kwargs': {'distill_kwargs': {'lr': 1e-1}, 'log_model': True},
        'extract_model_fn': extract_student
    },
    extract_teacher_fn=lambda lit_teacher: lit_teacher.backbone,
    test_dm_init=GenericDataModule,
    test_dm_kwargs=cifar10_1024,
    test_fn=cifar10_ssl_sk_test,
    project='rel-rep-dist-cifar-ssl-tune'
)

In [None]:
# param count vs test acc
accs = [np.mean([results]) for results in (teacher_results, small_results, distilled_results)]

labels = ['ResNet-18', 'ResNet-9×0.5 (no distillation)', 'ResNet-9×0.5 (with distillation)']
for count, acc, label in zip(param_counts, accs, labels):
  plt.scatter(count, acc, label=label)
plt.xlabel('parameter count')
plt.ylabel('test accuracy')
plt.legend()
plt.title('CIFAR10')
plt.show()

### Supervised

#### MNIST

In [None]:
mnist_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': MNIST, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 20}
)

mnist_sl_quick_distill = partial(
    quick_distill,
    student_init=create_mnist_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': MNIST, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 20},
    student_preprocess=lambda model: model[:5],
)

mnist_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs=mnist_1024,
    model_preprocess=lambda model: model[:5],
)

mnist_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs=mnist_1024,
    model_preprocess=lambda model: model[:5],
)

mnist_sl_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs=mnist_1024,
    model_preprocess=lambda model: model[:5]
)

In [None]:
param_counts, teacher_results, small_results, distilled_results = experiment(
    seeds=[42, 43, 44],
    teacher_kwargs={
        'model_init': create_mnist_cls_teacher,
        'lit_model_kwargs': {'num_classes': 10, 'lr': 1e-1},
        **mnist_sl_quick_train.keywords
    },
    small_kwargs={
        'model_init': create_mnist_cls_student,
        'lit_model_kwargs': {'num_classes': 10, 'lr': 1e-1},
        **mnist_sl_quick_train.keywords
    },
    distilled_kwargs={
        'quick_fn': mnist_sl_quick_distill,
        'probe_fn': mnist_sl_quick_sk_probe,
        'fn_kwargs': {
          'distill_kwargs': {
              'lr': 1e-1,
              'd_weight': 1,
              'dim': 32,
              'num_classes': 10,
              'ce_weight': 1,
              'non_linear_head': True,
              'dropout_head': True
          }
        },
        'extract_model_fn': extract_student
    },
    extract_teacher_fn=lambda lit_teacher: lit_teacher.model[:5],
    test_dm_init=GenericDataModule,
    test_dm_kwargs=mnist_1024,
    test_fn=mnist_sl_sk_test,
    project='rel-rep-dist-mnist-sl'
)

In [None]:
# param count vs test acc
accs = [np.mean(results) for results in (teacher_results, small_results, distilled_results)]

labels = ['MLP-64', 'MLP-32 (no distillation)', 'MLP-32 (with distillation)']
for count, acc, label in zip([param_counts[i] for i in (0, 1, 1)], accs, labels):
  plt.scatter(count, acc, label=label)
plt.xlabel('parameter count')
plt.ylabel('test accuracy')
plt.legend()
plt.title('MNIST')
plt.show()

#### CIFAR10

In [None]:
cifar10_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR10, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 10}
)

cifar10_sl_quick_distill = partial(
    quick_distill,
    student_init=create_cifar10_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR10, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 10},
    student_preprocess=strip_resnet_cls_head
)

cifar10_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs=cifar10_1024,
)

cifar10_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs=cifar10_1024,
)

cifar10_sl_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs=cifar10_1024,
)

In [None]:
param_counts, teacher_results, small_results, distilled_results = experiment(
    seeds=[42, 43, 44],
    teacher_kwargs={
        'model_init': create_cifar10_cls_teacher,
        'lit_model_kwargs': {'lr': 1e-1, 'num_classes': 10},
        **cifar10_sl_quick_train.keywords
    },
    small_kwargs={
        'model_init': create_cifar10_cls_student,
        'lit_model_kwargs': {'lr': 1e-1, 'num_classes': 10},
        **cifar10_sl_quick_train.keywords
    },
    distilled_kwargs={
        'quick_fn': cifar10_sl_quick_distill,
        'probe_fn': cifar10_sl_quick_sk_probe,
        'fn_kwargs': {
            'distill_kwargs': {
                'lr': 1e-1,
                'd_weight': 1,
                'dim': 256,
                'num_classes': 10,
                'ce_weight': 1,
                'non_linear_head': False,
                'dropout_head': False
            }
        },
        'extract_model_fn': extract_student
    },
    extract_teacher_fn=lambda lit_teacher: nn.Sequential(*list(lit_teacher.model.children())[:-1], nn.AdaptiveAvgPool2d(1), nn.Flatten()),
    test_dm_init=GenericDataModule,
    test_dm_kwargs=cifar10_1024,
    test_fn=cifar10_sl_sk_test,
    project='rel-rep-dist-cifar-sl'
)

In [None]:
# param count vs test acc
accs = [np.mean(results) for results in (teacher_results, small_results, distilled_results)]

labels = ['ResNet-18', 'ResNet-9×0.5 (no distillation)', 'ResNet-9×0.5 (with distillation)']
for count, acc, label in zip(param_counts, accs, labels):
  plt.scatter(count, acc, label=label)
plt.xlabel('parameter count')
plt.ylabel('test accuracy')
plt.legend()
plt.title('CIFAR10')
plt.show()

### Anchor selection

In [None]:
dm = GenericDataModule(MNIST, batch_size=128)
dm.prepare_data()
dm.setup()

In [None]:
anc_sel_quick_train = partial(
    quick_train,
    lit_model_cls=LitAE,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': MNIST, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 20}
)

anc_sel_quick_distill = partial(
    quick_distill,
    student_init=create_mnist_ae_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': MNIST, 'batch_size': 128},
    student_preprocess=lambda model: model[:5],
    trainer_kwargs={'max_epochs': 20}
)

In [None]:
def anchor_experiment(seeds, project=None):
  distilled_results = []
  distilled_rand_results = []
  distilled_per_class_results = []
  distilled_best_per_class_results = []

  dm = GenericDataModule(MNIST, batch_size=128)
  dm.prepare_data()
  dm.setup()

  for seed in seeds:
    # teacher
    lit_teacher, teacher_linear = quick_fn_probe(
        anc_sel_quick_train,
        mnist_ssl_quick_sk_probe,
        {'model_init': create_mnist_ae_teacher, 'lit_model_kwargs': {'lr': 1e-1}, 'seed': seed, 'project': project, 'name': 'teacher'},
        {'seed': seed}
    )

    # in batch
    lit_distilled, distilled_linear = quick_fn_probe(
        anc_sel_quick_distill,
        mnist_ssl_quick_sk_probe,
        {'distill_kwargs': {'teacher': lit_teacher.model[:5], 'lr': 1e-1},  'seed': seed, 'project': project, 'name': 'in-batch'},
        {'seed': seed},
        extract_model_fn=lambda lit_model: lit_model.student
    )
    id = {'id': wandb.run.id} if project is not None else {}
    distilled_results.append(mnist_ssl_sk_test(lit_distilled.student, distilled_linear, project=project, **id))

    # rand
    anchors, _ = random_split(dm.ds_train, [128, len(dm.ds_train) - 128], generator=torch.Generator().manual_seed(seed))
    anchors = torch.stack([a for (a, _) in anchors]) # iteration, yeah, but it's only 128 items and done once

    lit_distilled_rand, distilled_rand_linear = quick_fn_probe(
        anc_sel_quick_distill,
        mnist_ssl_quick_sk_probe,
        {'distill_kwargs': {'teacher': lit_teacher.model[:5], 'lr': 1e-1, 'anchors': anchors}, 'seed': seed, 'project': project, 'name': 'random'},
        {'seed': seed},
        extract_model_fn=lambda lit_model: lit_model.student
    )
    id = {'id': wandb.run.id} if project is not None else {}
    distilled_rand_results.append(mnist_ssl_sk_test(lit_distilled_rand.student, distilled_rand_linear, project=project, **id))

    # per class
    ds_train = dm.ds_train
    sorter = torch.randperm(len(ds_train), generator=torch.Generator().manual_seed(seed))
    ds_train_data = ds_train.dataset.data[ds_train.indices][sorter]
    ds_train_labels = ds_train.dataset.targets[ds_train.indices][sorter]
    anchors = torch.cat([ds_train_data[ds_train_labels == i][:13] for i in range(10)])
    anchors = ((anchors / 255 - 0.5) / 0.5).unsqueeze(1)

    lit_distilled_per_class, distilled_per_class_linear = quick_fn_probe(
        anc_sel_quick_distill,
        mnist_ssl_quick_sk_probe,
        {'distill_kwargs': {'teacher': lit_teacher.model[:5], 'lr': 1e-1, 'anchors': anchors}, 'seed': seed, 'project': project, 'name': 'random-per-class'},
        {'seed': seed},
        extract_model_fn=lambda lit_model: lit_model.student
    )
    id = {'id': wandb.run.id} if project is not None else {}
    distilled_per_class_results.append(mnist_ssl_sk_test(lit_distilled_per_class.student, distilled_per_class_linear, project=project, **id))

    # best per class
    preds = create_trainer().predict(
        LitModel(deepcopy(lit_teacher.model[:5]), 10),
        dm.train_dataloader()
    )
    logits = torch.tensor(teacher_linear.decision_function(torch.cat([pred['logits'] for pred in preds]))).softmax(dim=-1)
    labels = torch.cat([pred['y'] for pred in preds])
    pred_labels = logits.argmax(dim=-1)
    correct_pred_filter = (pred_labels == labels)

    ds_train = dm.ds_train
    ds_train_data = ds_train.dataset.data[ds_train.indices]
    ds_train_labels = ds_train.dataset.targets[ds_train.indices]

    anchors = torch.cat([
        ds_train_data[correct_pred_filter & (labels == i)] \
        [logits[correct_pred_filter & (labels == i)][:, i].argsort(descending=True)[:13]]
        for i in range(10)
    ])
    anchors = ((anchors / 255 - 0.5) / 0.5).unsqueeze(1)

    lit_distilled_best_per_class, distilled_best_per_class_linear = quick_fn_probe(
        anc_sel_quick_distill,
        mnist_ssl_quick_sk_probe,
        {'distill_kwargs': {'teacher': lit_teacher.model[:5], 'lr': 1e-1, 'anchors': anchors}, 'seed': seed, 'project': project, 'name': 'best-per-class'},
        {'seed': seed},
        extract_model_fn=lambda lit_model: lit_model.student
    )
    id = {'id': wandb.run.id} if project is not None else {}
    # distilled_best_per_class_results.append(mnist_ssl_sk_test(model=lit_distilled_best_per_class.student, linear=distilled_best_per_class_linear, project=project, *id))
    # can't get the partial function to work when using wandb, so call original `sk_test` function
    distilled_best_per_class_results.append(sk_test(
        model=lit_distilled_best_per_class.student,
        linear=distilled_best_per_class_linear,
        dm_init=GenericDataModule,
        dm_kwargs=mnist_1024,
        model_preprocess=lambda model: model[:5],
        project=project,
        id=id.get('id')
    ))
  return distilled_results, distilled_rand_results, distilled_per_class_results, distilled_best_per_class_results

In [None]:
in_batch, rand, rand_per_class, best_per_class = anchor_experiment([42], project='rel-rep-dist-anc-sel')

In [None]:
[np.mean(scores) for scores in (in_batch, rand, rand_per_class, best_per_class)]

### Comparison to other methods

#### SL CIFAR10

In [None]:
lit_teacher = cifar10_sl_quick_train(
    create_cifar10_cls_teacher,
    lit_model_kwargs={'lr': 1e-1, 'num_classes': 10},
    trainer_kwargs={'max_epochs': 50},
    project='rel-rep-dist-comp-cifar10-tune',
    name='teacher',
    log_model=True
)


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [None]:
quick_lit_test(
    lit_teacher,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR10, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar10-tune',
)

[34m[1mwandb[0m: Currently logged in as: [33mpatrickramosobf[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.8812999725341797


0.8812999725341797

In [None]:
lit_rr_distilled, rr_distilled_linear = quick_fn_probe(
    cifar10_sl_quick_distill,
    cifar10_sl_quick_sk_probe,
    {
        'distill_kwargs': {
          'teacher': strip_resnet_cls_head(lit_teacher.model),
          'lr': 1e-1,
          'ce_weight': 1,
          'dim': 256,
          'num_classes': 10,
          'non_linear_head': False,
          'dropout_head': False
        },
        'trainer_kwargs': {'max_epochs': 50},
        'project': 'rel-rep-dist-comp-cifar10-tune',
        'name': 'rel-rep momentum=0.9',
        'log_model': True
    },
    extract_model_fn=extract_student,
)

[34m[1mwandb[0m: Currently logged in as: [33mpatrickramosobf[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12876494.02it/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type               | Params
-----------------------------------------------------
0 | teacher       | Sequential         | 11.2 M
1 | student       | Sequential         | 1.2 M 
2 | head          | Linear             | 2.6 K 
3 | val_accuracy  | MulticlassAccuracy | 0     
4 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
1.2 M     Trainable params
11.2 M    Non-trainable params
12.4 M    Total params
49.591    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

d_weight: 1	ce_weight: 1
anchors: None
head: Linear(in_features=256, out_features=10, bias=True)
optim: sgd


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/ce_loss,█▅▅▄▃▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/cos,▁▃▃▄▆▆▆▆▇▆▇▇▇▇▇▇▇▇█▇████████████████████
train/d_loss,█▆▆▅▃▃▃▃▂▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▅▅▄▃▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/acc,▁▃▄▅▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇█████████████████████
val/ce_loss,█▆▄▃▃▃▁▂▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
val/cos,▁▄▅▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇█████████████████████
val/d_loss,█▅▄▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
train/ce_loss,0.00075
train/cos,0.99358
train/d_loss,0.00322
train/loss,0.00396
trainer/global_step,17599.0
val/acc,0.866
val/ce_loss,0.55821
val/cos,0.99277
val/d_loss,0.00362


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

[{'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 15.435] μ=2.007 σ=1.790, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.369 σ=2.951}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 14.842] μ=2.000 σ=1.765, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.310 σ=2.954}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 13.477] μ=1.997 σ=1.758, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.510 σ=2.847}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 13.555] μ=1.986 σ=1.783, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.530 σ=2.904}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 13.351] μ=1.991 σ=1.756, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.690 σ=2.850}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 13.266] μ=1.994 σ=1.779, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.419 σ=2.822}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 13.822] μ=1.998 σ=1.784, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=4.729 σ=2.818}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 14.407] μ=2.005 σ=1.791, 'y': tensor[1024] i6

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:   12.0s finished
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0.861


In [None]:
quick_sk_test(
    lit_rr_distilled.student,
    rr_distilled_linear,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR10, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar10-tune',
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
val/acc,▁

0,1
epoch,49.0
train/ce_loss,0.00075
train/cos,0.99358
train/d_loss,0.00322
train/loss,0.00396
trainer/global_step,17599.0
val/acc,0.861
val/ce_loss,0.55821
val/cos,0.99277
val/d_loss,0.00362


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0.8621


0.8621

In [None]:
lit_sp_distilled = cifar10_sl_quick_distill(
    distill_cls=LitSPDistiller,
    distill_kwargs={
        'teacher': strip_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'momentum': 0.9,
        'd_weight': 1,
        'ce_weight': 1,
        'num_classes': 10,
        'dim': 256
    },
    trainer_kwargs={'max_epochs': 50},
    project='rel-rep-dist-comp-cifar10-tune',
    name='sim-pre log model',
    log_model=True
)

[34m[1mwandb[0m: Downloading large artifact model-vukvzo6p:latest, 52.09MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.5


In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_sp_distilled.student, lit_sp_distilled.head), num_classes=10),
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR10, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar10-tune',
)

In [None]:
lit_lp_distilled = cifar10_sl_quick_distill(
    distill_cls=LitLPDistiller,
    distill_kwargs={
        'teacher': strip_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'normalizing_constant': 1,
        'dim': 256,
        'num_classes': 10,
        'teacher_head': lit_teacher.model.linear
    },
    trainer_kwargs={'max_epochs': 50},
    project='rel-rep-dist-comp-cifar10-tune',
    name='loc-pre reprod',
    log_model=True
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test/acc,▁

0,1
epoch,49.0
test/acc,0.8621
train/ce_loss,0.00075
train/cos,0.99358
train/d_loss,0.00322
train/loss,0.00396
trainer/global_step,17599.0
val/acc,0.861
val/ce_loss,0.55821
val/cos,0.99277


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type               | Params
-----------------------------------------------------
0 | teacher       | Sequential         | 11.2 M
1 | student       | Sequential         | 1.2 M 
2 | student_head  | Linear             | 2.6 K 
3 | teacher_head  | Linear             | 5.1 K 
4 | val_accuracy  | MulticlassAccuracy | 0     
5 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
1.2 M     Trainable params
11.2 M    Non-trainable params
12.4 M    Total params
49.612    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

d_weight: 1.5	ce_weight: 1	sce_weight: 2
student_head: Linear(in_features=256, out_features=10, bias=True)
teacher_head: Linear(in_features=512, out_features=10, bias=True)


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_lp_distilled.student, lit_lp_distilled.student_head), num_classes=10),
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR10, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar10-tune',
)

VBox(children=(Label(value='0.000 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
epoch,▁
test/acc,▁
test/loss,▁
trainer/global_step,▁

0,1
epoch,0.0
test/acc,0.8585
test/ce_loss,0.5551
test/d_loss,0.00025
test/loss,0.5551
train/ce_loss,0.00087
train/d_loss,0.00043
train/loss,0.0013
trainer/global_step,0.0
val/acc,0.8585


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.8555999994277954


0.8555999994277954

#### SL CIFAR100


In [None]:
def create_cifar100_cls_teacher():
  return ResNetGenerator('resnet-18', num_classes=100)

def create_cifar100_cls_student():
  return ResNetGenerator('resnet-9', num_classes=100, width=0.5)

In [None]:
cifar100_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 10}
)

cifar100_sl_quick_distill = partial(
    quick_distill,
    student_init=create_cifar100_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 10},
    student_preprocess=strip_resnet_cls_head
)

cifar100_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 1024},
)

cifar100_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 1024},
)

cifar100_sl_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 1024},
)

In [None]:
lit_teacher = cifar100_sl_quick_train(
  create_cifar100_cls_teacher,
  lit_model_kwargs={'lr': 1e-1, 'num_classes': 100},
  trainer_kwargs={'max_epochs': 50},
  project='rel-rep-dist-comp-cifar100-tune',
  name='teacher aug decay 0.2 every 10',
  log_model=True
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-3sdxkb6b:v0/model.ckpt`


In [None]:
quick_lit_test(
    lit_teacher,
    GenericDataModule,
    {'ds_class': CIFAR100, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar100-tune',
)

0,1
test/acc,▁
val/acc,▁

0,1
epoch,0.0
test/acc,0.8556
test/loss,0.59668
trainer/global_step,0.0
val/acc,0.853


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.607200026512146


0.607200026512146

In [None]:
lit_rr_distilled, rr_distilled_linear = quick_fn_probe(
    cifar100_sl_quick_distill,
    cifar100_sl_quick_sk_probe,
    {
        'distill_kwargs': {
          'teacher': strip_resnet_cls_head(lit_teacher.model),
          'lr': 1e-1,
          'ce_weight': 1,
          'dim': 256,
          'num_classes': 100,
          'non_linear_head': False,
          'dropout_head': False,
        },
        'trainer_kwargs': {'max_epochs': 50},
        'project': 'rel-rep-dist-comp-cifar100-tune',
        'name': 'rel-rep momentum=0.9',
        'log_model': True
    },
    extract_model_fn=extract_student,
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,0.0
test/acc,0.8556
test/loss,0.59668
trainer/global_step,0.0
val/acc,0.853


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./cifar-100-python.tar.gz



  0%|          | 0/169001437 [00:00<?, ?it/s][A
  0%|          | 32768/169001437 [00:00<18:14, 154359.71it/s][A
  0%|          | 65536/169001437 [00:00<18:12, 154565.37it/s][A
  0%|          | 98304/169001437 [00:00<17:39, 159464.87it/s][A
  0%|          | 229376/169001437 [00:00<08:15, 340527.37it/s][A
  0%|          | 393216/169001437 [00:00<04:34, 613803.09it/s][A
  0%|          | 491520/169001437 [00:01<04:17, 654601.60it/s][A
  0%|          | 688128/169001437 [00:01<03:02, 919925.05it/s][A
  1%|          | 884736/169001437 [00:01<02:27, 1137728.98it/s][A
  1%|          | 1048576/169001437 [00:01<02:13, 1256816.27it/s][A
  1%|          | 1245184/169001437 [00:01<02:00, 1395666.05it/s][A
  1%|          | 1409024/169001437 [00:01<01:54, 1458050.29it/s][A
  1%|          | 1605632/169001437 [00:01<01:48, 1539593.33it/s][A
  1%|          | 1769472/169001437 [00:01<01:47, 1557515.28it/s][A
  1%|          | 1966080/169001437 [00:01<01:43, 1614486.81it/s][A
  1%|▏         |

Extracting ./cifar-100-python.tar.gz to ./
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type               | Params
-----------------------------------------------------
0 | teacher       | Sequential         | 11.2 M
1 | student       | Sequential         | 1.2 M 
2 | head          | Linear             | 25.7 K
3 | val_accuracy  | MulticlassAccuracy | 0     
4 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
1.3 M     Trainable params
11.2 M    Non-trainable params
12.4 M    Total params
49.684    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

d_weight: 1	ce_weight: 1
anchors: None
head: Linear(in_features=256, out_features=100, bias=True)
optim: sgd


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/ce_loss,█▇▆▅▄▄▄▃▃▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/cos,▁▄▆▇▇███████████████████████████████████
train/d_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▇▆▅▄▄▄▃▃▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/acc,▁▂▄▅▆▇▇▇█▇█▇▇▇█▇▇██▇▇███████████████████
val/ce_loss,█▇▅▃▂▂▁▁▁▁▁▁▂▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃
val/cos,▁▄▆▇▇███████████████████████████████████
val/d_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
train/ce_loss,0.0087
train/cos,0.99313
train/d_loss,0.00344
train/loss,0.01214
trainer/global_step,17599.0
val/acc,0.558
val/ce_loss,2.17087
val/cos,0.99376
val/d_loss,0.00313


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

[{'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 33.696] μ=4.227 σ=3.699, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=48.304 σ=28.728}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 30.454] μ=4.222 σ=3.711, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=48.074 σ=28.846}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 32.674] μ=4.240 σ=3.731, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=48.260 σ=28.277}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 30.153] μ=4.238 σ=3.701, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=50.643 σ=29.031}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 29.378] μ=4.222 σ=3.676, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=50.910 σ=28.361}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 34.463] μ=4.229 σ=3.704, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=50.051 σ=29.236}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 31.942] μ=4.228 σ=3.708, 'y': tensor[1024] i64 8Kb x∈[0, 99] μ=49.771 σ=29.620}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 36.472] μ=4.233 σ=3.713,

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  4.6min finished
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0.569


In [None]:
cifar100_sl_sk_test(
    lit_rr_distilled.student,
    rr_distilled_linear,
    project='rel-rep-dist-comp-cifar100-tune',
)

0,1
val/acc,▁

0,1
epoch,49.0
train/ce_loss,0.0087
train/cos,0.99313
train/d_loss,0.00344
train/loss,0.01214
trainer/global_step,17599.0
val/acc,0.569
val/ce_loss,2.17087
val/cos,0.99376
val/d_loss,0.00313


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0.5748


0.5748

In [None]:
lit_sp_distilled = cifar100_sl_quick_distill(
    distill_cls=LitSPDistiller,
    distill_kwargs={
        'teacher': strip_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'momentum': 0.9,
        'd_weight': 1,
        'ce_weight': 1,
        'num_classes': 100,
        'dim': 256
    },
    trainer_kwargs={'max_epochs': 50},
    project='rel-rep-dist-comp-cifar100-tune',
    name='sim-pre no labels',
    log_model=True
)

[34m[1mwandb[0m: Downloading large artifact model-b6uxagxx:latest, 52.18MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.1
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-b6uxagxx:v0/model.ckpt`


In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_sp_distilled.student, lit_sp_distilled.head), num_classes=100),
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar100-tune',
)

VBox(children=(Label(value='0.000 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
test/acc,▁

0,1
epoch,49.0
test/acc,0.5336
train/ce_loss,0.59305
train/cos,0.99242
train/d_loss,0.0038
train/loss,0.59685
trainer/global_step,8799.0
val/acc,0.538
val/ce_loss,2.00035
val/cos,0.99246


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.5634999871253967


0.5634999871253967

In [None]:
lit_lp_distilled = cifar100_sl_quick_distill(
    distill_cls=LitLPDistiller,
    distill_kwargs={
        'teacher': strip_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'normalizing_constant': 1,
        'dim': 256,
        'num_classes': 100,
        'teacher_head': lit_teacher.model.linear
    },
    trainer_kwargs={'max_epochs': 50},
    project='rel-rep-dist-comp-cifar100-tune',
    name='loc-pre',
    log_model=True
)

[34m[1mwandb[0m: Downloading large artifact model-zerfvbaq:latest, 52.38MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-zerfvbaq:v0/model.ckpt`


In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_lp_distilled.student, lit_lp_distilled.student_head), num_classes=100),
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': CIFAR100, 'batch_size': 128},
    project='rel-rep-dist-comp-cifar100-tune',
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test/acc,▁

0,1
epoch,19.0
test/acc,0.94491
train/ce_loss,0.00219
train/d_loss,0.0
train/loss,0.02081
train/sce_loss,0.00931
trainer/global_step,5159.0
val/acc,0.9424
val/ce_loss,0.19615
val/d_loss,0.0


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.5703999996185303


0.5703999996185303

#### SL SVHN

In [None]:
def create_svhn_cls_teacher():
  return ResNetGenerator('resnet-18', num_classes=10)

def create_svhn_cls_student():
  return ResNetGenerator('resnet-9', num_classes=10, width=0.5)

svhn_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': SVHN, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 50}
)

svhn_sl_quick_distill = partial(
    quick_distill,
    student_init=create_svhn_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': SVHN, 'batch_size': 128},
    trainer_kwargs={'max_epochs': 50},
    student_preprocess=strip_resnet_cls_head
)

svhn_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': SVHN, 'batch_size': 1024},
)

svhn_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': SVHN, 'batch_size': 1024},
)

svhn_sl_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': SVHN, 'batch_size': 1024},
)

svhn_sl_quick_lit_test = partial(
    quick_lit_test,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': SVHN, 'batch_size': 1024},
)

In [None]:
lit_teacher = svhn_sl_quick_train(
  create_svhn_cls_teacher,
  lit_model_kwargs={'lr': 1e-1, 'num_classes': 10},
  trainer_kwargs={'max_epochs': 20},
  project='rel-rep-dist-comp-svhn-tune',
  name='teacher',
  log_model=True
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-wnd42ihu:v0/model.ckpt`


In [None]:
svhn_sl_quick_lit_test(
    lit_teacher,
    project='rel-rep-dist-comp-svhn-tune'
)

[34m[1mwandb[0m: Currently logged in as: [33mpatrickramosobf[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Using downloaded and verified file: ./train_32x32.mat
Using downloaded and verified file: ./test_32x32.mat


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.9466041922569275


0.9466041922569275

In [None]:
lit_rr_distilled, rr_distilled_linear = quick_fn_probe(
    svhn_sl_quick_distill,
    svhn_sl_quick_sk_probe,
    {
        'distill_kwargs': {
          'teacher': strip_resnet_cls_head(lit_teacher.model),
          'lr': 1e-1,
          'ce_weight': 1,
          'dim': 256,
          'num_classes': 10,
          'non_linear_head': False,
          'dropout_head': False,
        },
        'trainer_kwargs': {'max_epochs': 20},
        'project': 'rel-rep-dist-comp-svhn-tune',
        'name': 'rel-rep momentum=0.9',
        'log_model': True
    },
    extract_model_fn=extract_student,
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test/acc,▁

0,1
epoch,49.0
test/acc,0.5748
train/ce_loss,0.0087
train/cos,0.99313
train/d_loss,0.00344
train/loss,0.01214
trainer/global_step,17599.0
val/acc,0.569
val/ce_loss,2.17087
val/cos,0.99376


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./train_32x32.mat



  0%|          | 0/182040794 [00:00<?, ?it/s][A
  0%|          | 32768/182040794 [00:00<15:25, 196689.78it/s][A
  0%|          | 65536/182040794 [00:00<15:54, 190612.73it/s][A
  0%|          | 98304/182040794 [00:00<16:04, 188730.13it/s][A
  0%|          | 131072/182040794 [00:00<16:10, 187497.88it/s][A
  0%|          | 229376/182040794 [00:00<09:38, 314400.24it/s][A
  0%|          | 327680/182040794 [00:01<07:37, 396762.22it/s][A
  0%|          | 425984/182040794 [00:01<06:47, 445297.95it/s][A
  0%|          | 557056/182040794 [00:01<05:41, 531712.85it/s][A
  0%|          | 688128/182040794 [00:01<05:08, 588648.39it/s][A
  0%|          | 819200/182040794 [00:01<04:52, 620491.53it/s][A
  1%|          | 950272/182040794 [00:01<04:32, 664647.37it/s][A
  1%|          | 1114112/182040794 [00:02<04:09, 725558.29it/s][A
  1%|          | 1277952/182040794 [00:02<03:49, 786083.63it/s][A
  1%|          | 1441792/182040794 [00:02<03:38, 828011.13it/s][A
  1%|          | 1605632/1

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./test_32x32.mat



  0%|          | 0/64275384 [00:00<?, ?it/s][A
  0%|          | 32768/64275384 [00:00<05:30, 194379.79it/s][A
  0%|          | 65536/64275384 [00:00<09:02, 118316.75it/s][A
  0%|          | 98304/64275384 [00:00<07:33, 141449.85it/s][A
  0%|          | 163840/64275384 [00:00<04:54, 217694.85it/s][A
  0%|          | 229376/64275384 [00:01<03:59, 267109.96it/s][A
  1%|          | 327680/64275384 [00:01<02:58, 358473.77it/s][A
  1%|          | 458752/64275384 [00:01<02:13, 477750.77it/s][A
  1%|          | 688128/64275384 [00:01<01:27, 723206.34it/s][A
  1%|▏         | 950272/64275384 [00:01<01:05, 962965.76it/s][A
  2%|▏         | 1343488/64275384 [00:01<00:47, 1329098.52it/s][A
  3%|▎         | 1867776/64275384 [00:02<00:33, 1843834.79it/s][A
  4%|▍         | 2588672/64275384 [00:02<00:24, 2521789.95it/s][A
  6%|▌         | 3571712/64275384 [00:02<00:17, 3441699.47it/s][A
  7%|▋         | 4816896/64275384 [00:02<00:13, 4530618.49it/s][A
  9%|▉         | 5963776/64275384 

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

d_weight: 1	ce_weight: 1
anchors: None
head: Linear(in_features=256, out_features=10, bias=True)
optim: sgd


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train/ce_loss,█▃▂▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/cos,▁▃▅▅▆▆▆▆▇▇▇▇▆▆▇▇▇▇▇▇▇█████▇▇▇▇██████████
train/d_loss,█▆▄▄▃▃▃▃▂▂▂▂▃▃▂▂▂▂▂▂▂▁▁▁▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/acc,▁▂▅▆▆▅▅▆▇▇▇▇▇▆▇▇███▇
val/ce_loss,█▆▃▂▂▂▃▂▂▁▁▁▂▂▁▁▁▁▁▁
val/cos,▁▃▅▆▆▆▅▆▆▆▇▇▇▆▇▇████
val/d_loss,█▆▄▃▃▃▄▃▃▃▂▂▂▃▂▂▁▁▁▁

0,1
epoch,19.0
train/ce_loss,0.00315
train/cos,0.99631
train/d_loss,0.00185
train/loss,0.005
trainer/global_step,10319.0
val/acc,0.94775
val/ce_loss,0.19683
val/cos,0.9958
val/d_loss,0.0021


INFO:lightning_fabric.utilities.seed:Global seed set to 12345678
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Using downloaded and verified file: ./train_32x32.mat
Using downloaded and verified file: ./test_32x32.mat


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

[{'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 8.976] μ=1.613 σ=1.386, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.759 σ=2.685}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 10.491] μ=1.625 σ=1.383, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.904 σ=2.730}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 8.844] μ=1.616 σ=1.392, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.706 σ=2.666}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 10.522] μ=1.617 σ=1.383, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.809 σ=2.687}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 9.286] μ=1.614 σ=1.381, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.952 σ=2.702}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 9.452] μ=1.624 σ=1.399, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.780 σ=2.658}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 9.500] μ=1.624 σ=1.385, 'y': tensor[1024] i64 8Kb x∈[0, 9] μ=3.809 σ=2.679}, {'logits': tensor[1024, 256] n=262144 (1Mb) x∈[0., 10.210] μ=1.620 σ=1.390, 'y': tensor[1024] i64 8Kb

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  1.0min finished
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0.9485209373799462


In [None]:
svhn_sl_sk_test(
    lit_rr_distilled.student,
    rr_distilled_linear,
    project='rel-rep-dist-comp-svhn-tune',
)

0,1
val/acc,▁

0,1
epoch,19.0
train/ce_loss,0.00315
train/cos,0.99631
train/d_loss,0.00185
train/loss,0.005
trainer/global_step,10319.0
val/acc,0.94852
val/ce_loss,0.19683
val/cos,0.9958
val/d_loss,0.0021


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Using downloaded and verified file: ./train_32x32.mat
Using downloaded and verified file: ./test_32x32.mat


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0.9454901659496004


0.9454901659496004

In [None]:
lit_sp_distilled = svhn_sl_quick_distill(
    distill_cls=LitSPDistiller,
    distill_kwargs={
        'teacher': strip_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'momentum': 0.9,
        'd_weight': 1,
        'ce_weight': 1,
        'num_classes': 10,
        'dim': 256
    },
    trainer_kwargs={'max_epochs': 20},
    project='rel-rep-dist-comp-svhn-tune',
    name='sim-pre',
    log_model=True
)

[34m[1mwandb[0m: Downloading large artifact model-3qw3gnhe:latest, 52.09MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:2.4
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-3qw3gnhe:v0/model.ckpt`


In [None]:
svhn_sl_quick_lit_test(
    LitModel(nn.Sequential(lit_sp_distilled.student, lit_sp_distilled.head), num_classes=10),
    project='rel-rep-dist-comp-svhn-tune',
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test/acc,▁

0,1
epoch,19.0
lr-SGD,0.1
test/acc,0.9466
train/loss,8e-05
trainer/global_step,5159.0
val/acc,0.94585
val/loss,0.21582


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")


Using downloaded and verified file: ./train_32x32.mat
Using downloaded and verified file: ./test_32x32.mat


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.9456053972244263


0.9456053972244263

In [None]:
lit_lp_distilled = svhn_sl_quick_distill(
    distill_cls=LitLPDistiller,
    distill_kwargs={
        'teacher': strip_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'normalizing_constant': 1,
        'dim': 256,
        'num_classes': 10,
        'teacher_head': lit_teacher.model.linear
    },
    trainer_kwargs={'max_epochs': 20},
    project='rel-rep-dist-comp-svhn-tune',
    name='loc-pre fixed loss fn',
    log_model=True
)

[34m[1mwandb[0m: Downloading large artifact model-422tt7sb:latest, 52.12MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.6
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-422tt7sb:v0/model.ckpt`


In [None]:
svhn_sl_quick_lit_test(
    LitModel(nn.Sequential(lit_lp_distilled.student, lit_lp_distilled.student_head), num_classes=10),
    project='rel-rep-dist-comp-svhn-tune',
)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test/acc,▁

0,1
epoch,19.0
test/acc,0.94561
train/ce_loss,0.00222
train/d_loss,0.00023
train/loss,0.00246
trainer/global_step,5159.0
val/acc,0.94931
val/ce_loss,0.18626
val/d_loss,0.00026
val/loss,0.18652


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Using downloaded and verified file: ./train_32x32.mat
Using downloaded and verified file: ./test_32x32.mat


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

0.944913923740387


0.944913923740387

#### Stanford Cars

In [None]:
!mkdir -p stanford_cars

!wget https://web.archive.org/web/20230405013536/https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz -P stanford_cars
!tar -xzf stanford_cars/car_devkit.tgz -C stanford_cars

!wget https://web.archive.org/web/20230405013536/http://ai.stanford.edu/~jkrause/car196/cars_train.tgz -P stanford_cars
!tar -xzf stanford_cars/cars_train.tgz -C stanford_cars

!wget https://web.archive.org/web/20230405013536/http://ai.stanford.edu/~jkrause/car196/cars_test.tgz -P stanford_cars
!tar -xzf stanford_cars/cars_test.tgz -C stanford_cars

!wget https://web.archive.org/web/20230405013536/http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat -P stanford_cars

In [None]:
def strip_timm_resnet_cls_head(model):
    return nn.Sequential(*list(model.children())[:-1])

create_cars_cls_teacher = lambda : timm.create_model('resnet34', pretrained=True, num_classes=196)
create_cars_cls_student = lambda : timm.create_model('resnet18', pretrained=True, num_classes=196)

In [None]:
cars_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': StanfordCars, 'timm_name': 'resnet34', 'batch_size': 128},
    trainer_kwargs={'max_epochs': 20}
)

cars_sl_quick_distill = partial(
    quick_distill,
    student_init=create_cars_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': StanfordCars, 'timm_name': 'resnet18', 'batch_size': 128}, # pretty sure transforms are the same for both resnet34 and reset18
    trainer_kwargs={'max_epochs': 20},
    student_preprocess=strip_timm_resnet_cls_head
)

cars_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': StanfordCars, 'timm_name': 'resnet18', 'batch_size': 128},
)

cars_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': StanfordCars, 'timm_name': 'resnet18', 'batch_size': 128},
)

In [None]:
lit_teacher = cars_sl_quick_train(
    create_cars_cls_teacher,
    lit_model_kwargs={'lr': 1e-1, 'num_classes': 196},
    project='rel-rep-dist-comp-cars-tune',
    name='teacher resnet-34',
    log_model=True
)

In [None]:
quick_lit_test(
    lit_teacher,
    GenericDataModule,
    {'ds_class': StanfordCars, 'timm_name': 'resnet34', 'batch_size': 128},
    project='rel-rep-dist-comp-cars-tune',
)

In [None]:
lit_rr_distilled, rr_distilled_linear = quick_fn_probe(
    cars_sl_quick_distill,
    cars_sl_quick_sk_probe,
    {
        'distill_kwargs': {
          'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
          'lr': 1e-1,
          'dim': 512,
          'num_classes': 196,
          'ce_weight': 1,
          'non_linear_head': False,
          'dropout_head': False
        },
        'project': 'rel-rep-dist-comp-cars-tune',
        'name': 'rel-rep momentum=0.9',
        'log_model': True
    },
    extract_model_fn=extract_student,
)

In [None]:
quick_sk_test(
    lit_rr_distilled.student,
    rr_distilled_linear,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': StanfordCars, 'timm_name': 'resnet34', 'batch_size': 128},
    project='rel-rep-dist-comp-cars-tune',
)

In [None]:
lit_sp_distilled = cars_sl_quick_distill(
    distill_cls=LitSPDistiller,
    distill_kwargs={
        'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'momentum': 0.9,
        'd_weight': 1,
        'ce_weight': 1,
        'num_classes': 196,
        'dim': 512
    },
    project='rel-rep-dist-comp-cars-tune',
    name='sim-pre correct teacher head strip',
    log_model=True
)

In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_sp_distilled.student, lit_sp_distilled.head), num_classes=196),
    GenericDataModule,
    {'ds_class': StanfordCars, 'timm_name': 'resnet34', 'batch_size': 128},
    project='rel-rep-dist-comp-cars-tune',
)

In [None]:
lit_lp_distilled = cars_sl_quick_distill(
    distill_cls=LitLPDistiller,
    distill_kwargs={
        'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'normalizing_constant': 1,
        'dim': 512,
        'num_classes': 196,
        'teacher_head': lit_teacher.model.fc
    },
    project='rel-rep-dist-comp-cars-tune',
    name='loc-pre',
    log_model=True
)

In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_lp_distilled.student, lit_lp_distilled.student_head), num_classes=196),
    GenericDataModule,
    {'ds_class': StanfordCars, 'timm_name': 'resnet34', 'batch_size': 128},
    project='rel-rep-dist-comp-cars-tune',
)

#### Oxford Pet

In [None]:
def strip_timm_resnet_cls_head(model):
    return nn.Sequential(*list(model.children())[:-1])

create_pet_cls_teacher = lambda : timm.create_model('resnet34', pretrained=True, num_classes=37)
create_pet_cls_student = lambda : timm.create_model('resnet18', pretrained=True, num_classes=37)

pet_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': OxfordIIITPet, 'timm_name': 'resnet34', 'batch_size': 128},
    trainer_kwargs={'max_epochs': 20}
)

pet_sl_quick_distill = partial(
    quick_distill,
    student_init=create_pet_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128}, # pretty sure transforms are the same for both resnet34 and reset18
    trainer_kwargs={'max_epochs': 20},
    student_preprocess=strip_timm_resnet_cls_head
)

pet_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128},
)

pet_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128},
)

pet_sl_quick_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128}
)

In [None]:
lit_teacher = pet_sl_quick_train(
    create_pet_cls_teacher,
    lit_model_kwargs={'lr': 1e-1, 'num_classes': 37},
    project='rel-rep-dist-comp-pet-tune',
    name='teacher resnet-34',
    log_model=True
)

In [None]:
quick_lit_test(
    lit_teacher,
    GenericDataModule,
    {'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128},
    project='rel-rep-dist-comp-pet-tune',
)

In [None]:
lit_rr_distilled, rr_distilled_linear = quick_fn_probe(
    pet_sl_quick_distill,
    pet_sl_quick_sk_probe,
    {
        'distill_kwargs': {
          'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
          'lr': 1e-1,
          'dim': 512,
          'num_classes': 37,
          'ce_weight': 1,
          'non_linear_head': False,
          'dropout_head': False
        },
        'project': 'rel-rep-dist-comp-pet-tune',
        'name': 'rel-rep momentum=0.9',
        'log_model': True
    },
    extract_model_fn=extract_student,
)

In [None]:
pet_sl_quick_sk_test(
    lit_rr_distilled.student,
    rr_distilled_linear,
    project='rel-rep-dist-comp-pet-tune',
)

In [None]:
lit_sp_distilled = pet_sl_quick_distill(
    distill_cls=LitSPDistiller,
    distill_kwargs={
        'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'momentum': 0.9,
        'd_weight': 1,
        'ce_weight': 1,
        'num_classes': 37,
        'dim': 512
    },
    project='rel-rep-dist-comp-pet-tune',
    name='sim-pre',
    log_model=True
)

In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_sp_distilled.student, lit_sp_distilled.head), num_classes=37),
    GenericDataModule,
    {'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128},
    project='rel-rep-dist-comp-pet-tune',
)

In [None]:
lit_lp_distilled = pet_sl_quick_distill(
    distill_cls=LitLPDistiller,
    distill_kwargs={
        'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'normalizing_constant': 1,
        'dim': 512,
        'num_classes': 37,
        'teacher_head': lit_teacher.model.fc
    },
    project='rel-rep-dist-comp-pet-tune',
    name='loc-pre',
    log_model=True
)

In [None]:
quick_lit_test(
    LitModel(nn.Sequential(lit_lp_distilled.student, lit_lp_distilled.student_head), num_classes=37),
    GenericDataModule,
    {'ds_class': OxfordIIITPet, 'timm_name': 'resnet18', 'batch_size': 128},
    project='rel-rep-dist-comp-pet-tune',
)

#### Oxford Flowers

In [None]:
def strip_timm_resnet_cls_head(model):
    return nn.Sequential(*list(model.children())[:-1])

create_flowers_cls_teacher = lambda : timm.create_model('resnet34', pretrained=True, num_classes=102)
create_flowers_cls_student = lambda : timm.create_model('resnet18', pretrained=True, num_classes=102)

flowers_sl_quick_train = partial(
    quick_train,
    lit_model_cls=LitModel,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': Flowers102, 'timm_name': 'resnet34', 'batch_size': 128},
    trainer_kwargs={'max_epochs': 50}
)

flowers_sl_quick_distill = partial(
    quick_distill,
    student_init=create_flowers_cls_student,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': Flowers102, 'timm_name': 'resnet18', 'batch_size': 128}, # pretty sure transforms are the same for both resnet34 and reset18
    trainer_kwargs={'max_epochs': 50},
    student_preprocess=strip_timm_resnet_cls_head
)

flowers_sl_quick_fc_probe = partial(
    quick_fc_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': Flowers102, 'timm_name': 'resnet18', 'batch_size': 128},
)

flowers_sl_quick_sk_probe = partial(
    quick_sk_probe,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': Flowers102, 'timm_name': 'resnet18', 'batch_size': 128},
)

flowers_sl_quick_sk_test = partial(
    quick_sk_test,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': Flowers102, 'timm_name': 'resnet18', 'batch_size': 128}
)

flowers_sl_quick_lit_test = partial(
    quick_lit_test,
    dm_init=GenericDataModule,
    dm_kwargs={'ds_class': Flowers102, 'timm_name': 'resnet18', 'batch_size': 128}
)

In [None]:
lit_teacher = flowers_sl_quick_train(
    create_flowers_cls_teacher,
    lit_model_kwargs={'lr': 1e-1, 'num_classes': 102},
    project='rel-rep-dist-comp-flowers-tune',
    name='teacher',
    log_model=True,
    trainer_kwargs={'max_epochs': 50}
)

[34m[1mwandb[0m: Downloading large artifact model-a46txjt3:latest, 81.54MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.1


In [None]:
flowers_sl_quick_lit_test(
    lit_teacher,
    project='rel-rep-dist-comp-flowers-tune',
)

In [None]:
lit_rr_distilled, rr_distilled_linear = quick_fn_probe(
    flowers_sl_quick_distill,
    flowers_sl_quick_sk_probe,
    {
        'distill_kwargs': {
          'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
          'lr': 1e-1,
          'dim': 512,
          'num_classes': 102,
          'ce_weight': 1,
          'non_linear_head': False,
          'dropout_head': False
        },
        'project': 'rel-rep-dist-comp-flowers-tune',
        'name': 'rel-rep momentum=0.9',
        'log_model': True
    },
    extract_model_fn=extract_student,
)

[34m[1mwandb[0m: Currently logged in as: [33mpatrickramosobf[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flowers-102/102flowers.tgz


  0%|          | 0/344862509 [00:00<?, ?it/s]

Extracting flowers-102/102flowers.tgz to flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flowers-102/imagelabels.mat


  0%|          | 0/502 [00:00<?, ?it/s]

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flowers-102/setid.mat


  0%|          | 0/14989 [00:00<?, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

  "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks."


d_weight: 1	ce_weight: 1
anchors: None
head: Linear(in_features=512, out_features=102, bias=True)
optim: sgd


Training: 0it [00:00, ?it/s]

d_weight: 1	ce_weight: 1
anchors: None
head: Linear(in_features=512, out_features=102, bias=True)
optim: sgd




Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/ce_loss,█▄▁▁
train/cos,▄▁▇█
train/d_loss,▅█▂▁
train/loss,█▄▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/acc,▁▃▃▄▆▆▆▇▇▇▇▇█▇▇█████▇██▇▇▇█▇▇▇▇▇██▇███▇█
val/ce_loss,█▆▅▄▃▃▂▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁
val/cos,▁▃▄▅▆▇▇▆▇▇▇▇▇▇▇█▇█▇▇▇███▇▇█▇▇▇███▇▇█████
val/d_loss,█▆▅▄▃▂▂▃▂▂▂▂▂▂▂▁▂▁▂▂▂▁▁▁▂▂▁▂▂▂▁▁▁▂▂▁▁▁▁▁

0,1
epoch,49.0
train/ce_loss,0.18839
train/cos,0.95122
train/d_loss,0.02473
train/loss,0.21313
trainer/global_step,199.0
val/acc,0.81373
val/ce_loss,0.86992
val/cos,0.94076
val/d_loss,0.03014


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")


Predicting: 0it [00:00, ?it/s]

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =        52326     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  4.24573D+03    |proj g|=  5.90870D+01


 This problem is unconstrained.



At iterate   50    f=  2.31280D+01    |proj g|=  1.15666D-02

At iterate  100    f=  2.28216D+01    |proj g|=  2.81813D-03

At iterate  150    f=  2.28173D+01    |proj g|=  2.77656D-03

At iterate  200    f=  2.28168D+01    |proj g|=  7.17425D-04

At iterate  250    f=  2.28166D+01    |proj g|=  3.70045D-04

At iterate  300    f=  2.28165D+01    |proj g|=  6.83117D-04

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
52326    334    347      1     0     0   1.230D-04   2.282D+01
  F =   22.816246120842088     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    8.1s finished


Predicting: 0it [00:00, ?it/s]

0.8725490196078431


In [None]:
flowers_sl_quick_sk_test(
    lit_rr_distilled.student,
    rr_distilled_linear,
    project='rel-rep-dist-comp-flowers-tune',
)

0,1
val/acc,▁

0,1
epoch,49.0
train/ce_loss,0.18839
train/cos,0.95122
train/d_loss,0.02473
train/loss,0.21313
trainer/global_step,199.0
val/acc,0.87255
val/ce_loss,0.86992
val/cos,0.94076
val/d_loss,0.03014


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")


Predicting: 0it [00:00, ?it/s]

0.8461538461538461


0.8461538461538461

In [None]:
lit_sp_distilled = flowers_sl_quick_distill(
    distill_cls=LitSPDistiller,
    distill_kwargs={
        'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'momentum': 0.9,
        'd_weight': 1,
        'ce_weight': 1,
        'num_classes': 102,
        'dim': 512
    },
    project='rel-rep-dist-comp-flowers-tune',
    name='sim-pre',
    log_model=True
)

In [None]:
flowers_sl_quick_lit_test(
    LitModel(nn.Sequential(lit_sp_distilled.student, lit_sp_distilled.head), num_classes=102),
    project='rel-rep-dist-comp-flowers-tune',
)

In [None]:
lit_lp_distilled = flowers_sl_quick_distill(
    distill_cls=LitLPDistiller,
    distill_kwargs={
        'teacher': strip_timm_resnet_cls_head(lit_teacher.model),
        'lr': 1e-1,
        'normalizing_constant': 1,
        'dim': 512,
        'num_classes': 102,
        'teacher_head': lit_teacher.model.fc
    },
    project='rel-rep-dist-comp-flowers-tune',
    name='loc-pre',
    log_model=True
)

In [None]:
flowers_sl_quick_lit_test(
    LitModel(nn.Sequential(lit_lp_distilled.student, lit_lp_distilled.student_head), num_classes=102),
    project='rel-rep-dist-comp-flowers-tune',
)