In [None]:
import os
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from byol_pytorch import BYOL
import pytorch_lightning as pl

from torchvision import datasets

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

In [None]:
resnet = models.resnet50(weights=None)

In [None]:
BATCH_SIZE = 512
EPOCHS = 2
LR = 3e-4
NUM_GPUS = 1
IMAGE_SIZE = 224
NUM_WORKERS = int(os.environ['SLURM_CPUS_PER_TASK'])

In [None]:
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()

In [None]:
class GenericLightningWrapper(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = net
        
    def forward(self, images):
        return self.learner(images)

In [None]:
class CIFAR10_Wrapper(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        image, _ = self.original_dataset[idx]
        return image

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

ds_train = datasets.ImageNet(root="/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization", split="train")
wrapper_train = CIFAR10_Wrapper(ds_train)
train_loader = DataLoader(wrapper_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

ds_test = datasets.ImageNet(root="/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization", split="val")
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [None]:
def visualize(net, data_loader, classes):
    net.eval()
    all_features = []
    all_labels = []
    with torch.no_grad():
        for i, (images, labels) in enumerate(data_loader):
            output = net(images)
            all_features.append(output.view(output.size(0), -1))
            all_labels.extend(labels)
            
    features = torch.cat(all_features).numpy()
    labels = torch.tensor(all_labels).numpy()
    
    pca = PCA()
    pca_features = pca.fit_transform(features)
    
    print(pca.explained_variance_ratio_)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(pca_features[:, 0], pca_features[:, 1], c=labels, cmap='plasma')
    plt.legend(handles=scatter.legend_elements()[0], labels=classes)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.title('Model Features Projected to 2D using PCA')
    plt.show()

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
visualize(GenericLightningWrapper(resnet), test_loader, classes)

In [None]:
model = SelfSupervisedLearner(
    resnet,
    image_size=IMAGE_SIZE,
    hidden_layer='avgpool',
    projection_size=256,
    projection_hidden_size=4096,
    moving_average_decay=0.99
)

trainer = pl.Trainer(
    devices=NUM_GPUS,
    max_epochs=EPOCHS,
    accumulate_grad_batches=1,
    sync_batchnorm=True
)

In [None]:
trainer.fit(model, train_loader)

In [None]:
visualize(GenericLightningWrapper(resnet), test_loader, classes)