In [None]:
import os
import subprocess
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from byol_pytorch import BYOL
import pytorch_lightning as pl

from torchvision import datasets

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

In [None]:
resnet = models.resnet50(weights=None)

In [None]:
BATCH_SIZE = 256
EPOCHS = 2
LR = 3e-4
NUM_GPUS = int(os.environ["SLURM_GPUS_ON_NODE"])
IMAGE_SIZE = 224
NUM_WORKERS = int(os.environ['SLURM_CPUS_PER_TASK'])

In [None]:
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()

In [None]:
class PredictWrapper(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = net
        
    def forward(self, images):
        return self.learner(images)
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        images, labels = batch
        return self.forward(images), labels

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
    transforms.ToTensor(),
])

ds_train = datasets.ImageNet(root='/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization', split='train', transform=transform)
train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

ds_test = datasets.ImageNet(root='/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization', split='val', transform=transform)
ds_test = torch.utils.data.Subset(ds_test, list(range(500))) 
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [None]:
def get_predictions(net, data_loader):
    trainer = pl.Trainer(devices=NUM_GPUS)
    
    predictions = trainer.predict(PredictWrapper(net), data_loader)
    all_features = []
    all_labels = []
    
    for embeddings, ground_truths in predictions:
        all_features.append(embeddings.view(embeddings.size(0), -1))
        all_labels.extend(ground_truths)
        
    features = torch.cat(all_features).numpy()
    labels = torch.tensor(all_labels).numpy()
    
    return features, labels

In [None]:
def visualize(net, data_loader):
    features, labels = get_predictions(net, data_loader)
    
    pca = PCA()
    pca_features = pca.fit_transform(features)
    
    pc1_variance = pca.explained_variance_ratio_[0]
    pc2_variance = pca.explained_variance_ratio_[1]

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(pca_features[:, 0], pca_features[:, 1], c=labels, cmap='plasma')
    plt.legend(handles=scatter.legend_elements()[0])
    plt.xlabel(f'PC1 (Variance: {pc1_variance:.3f})')
    plt.ylabel(f'PC2 (Variance: {pc2_variance:.3f})')
    plt.title('Model Features Projected to 2D using PCA')
    plt.show()

In [None]:
model = SelfSupervisedLearner(
    resnet,
    image_size=IMAGE_SIZE,
    hidden_layer='avgpool',
    projection_size=256,
    projection_hidden_size=4096,
    moving_average_decay=0.99
)

In [None]:
trainer = pl.Trainer(
    devices=NUM_GPUS,
    max_epochs=EPOCHS,
    accumulate_grad_batches=1,
    sync_batchnorm=True
)

In [None]:
visualize(resnet, test_loader)

In [None]:
trainer.fit(model, train_loader)

In [None]:
visualize(resnet, test_loader)