In [10]:
import os
import argparse
import multiprocessing
from pathlib import Path
from PIL import Image

import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from byol_pytorch import BYOL
import pytorch_lightning as pl

from torchvision.datasets import CIFAR10

In [11]:
resnet = models.resnet50(pretrained=True)

In [12]:
BATCH_SIZE = 4096
EPOCHS = 1000
LR = 3e-4
NUM_GPUS = 1
IMAGE_SIZE = 32
NUM_WORKERS = int(os.environ["SLURM_CPUS_PER_TASK"])

In [13]:
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()

In [14]:
class CIFAR10_Wrapper(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        image, _ = self.original_dataset[idx]
        return image

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

ds_train = CIFAR10(root='/scratch/gpfs/eh0560/data', train=True, download=False, transform=transform)
wrapper_train = CIFAR10_Wrapper(ds_train)
train_loader = DataLoader(wrapper_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

ds_test = CIFAR10(root='/scratch/gpfs/eh0560/data', train=False, download=False, transform=transform)
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [16]:
model = SelfSupervisedLearner(
    resnet,
    image_size=IMAGE_SIZE,
    hidden_layer='avgpool',
    projection_size=256,
    projection_hidden_size=4096,
    moving_average_decay=0.99
)

trainer = pl.Trainer(
    devices=NUM_GPUS,
    max_epochs=EPOCHS,
    accumulate_grad_batches=1,
    sync_batchnorm=True
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
for i, (images, labels) in enumerate(ds_test):
    

3
8
8
0
6
6
1
6
3
1
0
9
5
7
9
8
5
7
8
6
7
0
4
9
5
2
4
0
9
6
6
5
4
5
9
2
4
1
9
5
4
6
5
6
0
9
3
9
7
6
9
8
0
3
8
8
7
7
4
6
7
3
6
3
6
2
1
2
3
7
2
6
8
8
0
2
9
3
3
8
8
1
1
7
2
5
2
7
8
9
0
3
8
6
4
6
6
0
0
7
4
5
6
3
1
1
3
6
8
7
4
0
6
2
1
3
0
4
2
7
8
3
1
2
8
0
8
3
5
2
4
1
8
9
1
2
9
7
2
9
6
5
6
3
8
7
6
2
5
2
8
9
6
0
0
5
2
9
5
4
2
1
6
6
8
4
8
4
5
0
9
9
9
8
9
9
3
7
5
0
0
5
2
2
3
8
6
3
4
0
5
8
0
1
7
2
8
8
7
8
5
1
8
7
1
3
0
5
7
9
7
4
5
9
8
0
7
9
8
2
7
6
9
4
3
9
6
4
7
6
5
1
5
8
8
0
4
0
5
5
1
1
8
9
0
3
1
9
2
2
5
3
9
9
4
0
3
0
0
9
8
1
5
7
0
8
2
4
7
0
2
3
6
3
8
5
0
3
4
3
9
0
6
1
0
9
1
0
7
9
1
2
6
9
3
4
6
0
0
6
6
6
3
2
6
1
8
2
1
6
8
6
8
0
4
0
7
7
5
5
3
5
2
3
4
1
7
5
4
6
1
9
3
6
6
9
3
8
0
7
2
6
2
5
8
5
4
6
8
9
9
1
0
2
2
7
3
2
8
0
9
5
8
1
9
4
1
3
8
1
4
7
9
4
2
7
0
7
0
6
6
9
0
9
2
8
7
2
2
5
1
2
6
2
9
6
2
3
0
3
9
8
7
8
8
4
0
1
8
2
7
9
3
6
1
9
0
7
3
7
4
5
0
0
2
9
3
4
0
6
2
5
3
7
3
7
2
5
3
1
1
4
9
9
5
7
5
0
2
2
2
9
7
3
9
4
3
5
4
6
5
6
1
4
3
4
4
3
7
8
3
7
8
0
5
7
6
0
5
4
8
6
8
5
5
9
9
9
5
0
1
0
8
1
1
8
0
2
2
0


2
4
3
8
3
4
7
0
2
8
4
8
3
8
8
8
3
3
5
7
7
0
4
1
5
9
7
0
6
8
4
9
0
1
8
9
6
3
9
2
4
4
0
3
3
5
4
5
1
1
8
2
2
9
3
7
8
9
2
3
1
7
3
2
3
0
1
9
5
5
3
4
5
2
0
0
3
1
3
4
7
4
2
8
4
8
9
9
4
2
4
3
6
4
6
4
6
6
3
8
7
4
8
0
5
9
3
9
0
7
5
1
9
7
9
1
8
4
2
1
6
4
3
0
1
4
8
6
1
7
1
3
5
4
3
3
9
7
8
2
5
5
4
5
4
5
7
2
5
5
1
7
7
8
3
6
0
2
5
2
3
7
4
2
1
6
8
6
8
3
2
7
7
9
7
1
4
7
4
6
1
7
3
0
0
6
6
8
6
6
0
3
4
7
4
9
4
9
9
3
4
1
4
0
3
7
1
2
2
8
4
7
8
5
5
6
5
6
0
6
4
9
3
2
7
3
8
3
9
4
1
6
9
9
4
3
9
8
1
6
9
5
9
0
9
7
2
2
4
2
6
4
8
1
9
5
6
5
1
8
6
6
7
6
5
0
2
7
6
1
2
3
1
6
4
9
0
3
9
1
0
5
3
1
6
9
8
9
0
6
1
6
2
3
5
6
9
0
7
2
4
0
3
6
8
6
9
1
9
0
6
4
5
9
5
4
1
9
2
0
7
3
5
1
8
3
0
5
2
8
8
3
9
6
0
3
5
3
4
1
5
0
7
3
9
4
5
3
1
4
2
4
9
9
7
1
7
0
1
2
1
5
3
8
4
1
5
1
9
9
7
0
1
7
6
2
6
5
0
3
1
3
8
9
9
0
8
8
7
9
2
9
1
3
6
7
6
0
8
6
2
4
2
3
5
5
4
1
6
7
1
6
1
7
1
6
2
3
0
4
2
9
7
5
5
6
0
8
8
1
1
3
5
6
5
3
8
7
5
0
7
7
5
7
1
9
9
1
9
9
3
5
5
3
0
3
0
6
6
5
6
8
9
4
5
1
1
2
7
9
0
1
2
5
4
4
7
4
8
3
3
5
3
5
7
0
1
2
9
8
1
7
3
5
9
1
9
1
1
3


0
5
6
0
4
3
2
8
8
0
6
9
5
2
8
7
0
6
5
9
7
2
3
6
9
6
2
2
4
1
0
5
0
8
9
3
5
9
3
8
1
6
3
7
5
6
2
0
2
8
2
8
7
7
8
1
0
8
9
7
0
3
8
0
5
9
5
8
4
2
0
9
2
2
4
4
9
2
2
2
5
1
3
2
0
0
4
0
6
5
8
0
5
8
6
4
8
5
2
9
7
9
7
1
0
1
9
6
9
2
7
9
4
4
0
6
2
4
1
3
7
2
8
5
9
0
3
2
3
2
7
6
3
2
5
9
0
5
9
9
8
7
7
4
8
6
5
2
3
1
0
4
1
8
8
4
9
4
4
3
3
9
2
0
1
1
8
4
4
8
3
2
9
5
7
6
2
5
4
7
3
3
9
0
1
5
9
3
7
6
0
4
2
2
5
6
3
8
9
5
6
1
4
5
4
6
7
2
1
0
2
0
4
9
8
8
9
1
1
5
0
0
8
7
1
7
4
5
4
3
3
2
0
6
6
0
1
3
9
8
3
7
8
9
4
8
9
0
9
7
1
6
2
9
5
6
9
3
9
8
7
7
1
6
5
3
1
3
1
2
7
1
8
2
0
9
7
9
8
8
6
7
3
7
1
3
9
0
9
3
6
7
2
7
3
0
5
9
7
5
5
0
6
5
1
8
2
7
5
9
0
0
0
8
8
7
3
7
8
9
3
7
9
7
8
7
9
8
5
4
8
3
7
6
3
8
2
1
9
5
7
3
9
5
5
8
7
3
5
3
5
9
7
6
7
3
6
4
3
9
4
2
1
9
6
0
2
6
7
4
7
9
0
7
4
3
5
3
1
1
2
6
8
2
1
7
8
5
9
6
1
1
5
0
6
0
9
2
6
5
8
9
5
5
6
2
9
1
5
8
8
7
1
7
3
5
4
9
7
5
2
9
9
4
7
4
1
3
8
7
9
0
4
5
7
5
2
8
7
6
9
6
9
3
8
5
6
6
9
5
7
8
0
5
0
7
4
8
2
5
1
3
2
2
6
2
1
7
4
6
3
1
3
7
2
1
3
7
0
8
4
4
5
7
9
5
4
3
9
6
8
2
3
3
1
6
1
7
0
3


In [None]:
trainer.fit(model, train_loader)

In [8]:
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 