In [None]:
import os
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from byol_pytorch import BYOL
import pytorch_lightning as pl

from torchvision.datasets import CIFAR10

In [None]:
resnet = models.resnet50(pretrained=True)

In [None]:
BATCH_SIZE = 4096
EPOCHS = 1000
LR = 3e-4
NUM_GPUS = 1
IMAGE_SIZE = 32
NUM_WORKERS = int(os.environ["SLURM_CPUS_PER_TASK"])

In [None]:
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()

In [None]:
class CIFAR10_Wrapper(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        image, _ = self.original_dataset[idx]
        return image

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

ds_train = CIFAR10(root='/scratch/gpfs/eh0560/data', train=True, download=False, transform=transform)
wrapper_train = CIFAR10_Wrapper(ds_train)
train_loader = DataLoader(wrapper_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

ds_test = CIFAR10(root='/scratch/gpfs/eh0560/data', train=False, download=False, transform=transform)
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [None]:
model = SelfSupervisedLearner(
    resnet,
    image_size=IMAGE_SIZE,
    hidden_layer='avgpool',
    projection_size=256,
    projection_hidden_size=4096,
    moving_average_decay=0.99
)

trainer = pl.Trainer(
    devices=NUM_GPUS,
    max_epochs=EPOCHS,
    accumulate_grad_batches=1,
    sync_batchnorm=True
)

In [None]:
for i, (images, labels) in enumerate(ds_test):
    pass

In [None]:
trainer.fit(model, train_loader)

In [None]:
resnet