In [1]:
import os
import subprocess
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from byol_pytorch import BYOL
import pytorch_lightning as pl

from torchvision import datasets

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
resnet = models.resnet50(weights=None)

In [3]:
BATCH_SIZE = 256
EPOCHS = 2
LR = 3e-4
NUM_GPUS = int(os.environ["SLURM_GPUS_ON_NODE"])
IMAGE_SIZE = 224
NUM_WORKERS = int(os.environ['SLURM_CPUS_PER_TASK'])

In [4]:
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()

In [5]:
class PredictWrapper(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = net
        
    def forward(self, images):
        return self.learner(images)
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        images, labels = batch
        return self.forward(images), labels

In [6]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
    transforms.ToTensor(),
])

ds_train = datasets.ImageNet(root='/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization', split='train', transform=transform)
train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

ds_test = datasets.ImageNet(root='/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization', split='val', transform=transform)
ds_test = torch.utils.data.Subset(ds_test, list(range(500))) 
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [7]:
def get_predictions(net, data_loader):
    trainer = pl.Trainer(devices=NUM_GPUS)
    
    predictions = trainer.predict(PredictWrapper(net), data_loader)
    all_features = []
    all_labels = []
    
    for embeddings, ground_truths in predictions:
        all_features.append(embeddings.view(embeddings.size(0), -1))
        all_labels.extend(ground_truths)
        
    features = torch.cat(all_features).numpy()
    labels = torch.tensor(all_labels).numpy()
    
    return features, labels

In [8]:
def visualize(net, data_loader):
    features, labels = get_predictions(net, data_loader)
    
    pca = PCA()
    pca_features = pca.fit_transform(features)
    
    pc1_variance = pca.explained_variance_ratio_[0]
    pc2_variance = pca.explained_variance_ratio_[1]

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(pca_features[:, 0], pca_features[:, 1], c=labels, cmap='plasma')
    plt.legend(handles=scatter.legend_elements()[0])
    plt.xlabel(f'PC1 (Variance: {pc1_variance:.3f})')
    plt.ylabel(f'PC2 (Variance: {pc2_variance:.3f})')
    plt.title('Model Features Projected to 2D using PCA')
    plt.show()

In [9]:
model = SelfSupervisedLearner(
    resnet,
    image_size=IMAGE_SIZE,
    hidden_layer='avgpool',
    projection_size=256,
    projection_hidden_size=4096,
    moving_average_decay=0.99
)

trainer = pl.Trainer(
    devices=NUM_GPUS,
    max_epochs=EPOCHS,
    accumulate_grad_batches=1,
    sync_batchnorm=True
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [10]:
visualize(resnet, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
[W socket.cpp:46

ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 173, in _wrapping_function
    results = function(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 902, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 942, in _run
    self.strategy.setup_environment()
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/pytorch_lightning/strategies/ddp.py", line 154, in setup_environment
    self.setup_distributed()
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/pytorch_lightning/strategies/ddp.py", line 203, in setup_distributed
    _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/lightning_fabric/utilities/distributed.py", line 291, in _init_dist_connection
    torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 86, in wrapper
    func_return = func(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 1177, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/torch/distributed/rendezvous.py", line 246, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eh0560/.conda/envs/byol/lib/python3.12/site-packages/torch/distributed/rendezvous.py", line 174, in _create_c10d_store
    return TCPStore(
           ^^^^^^^^^
torch.distributed.DistNetworkError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:28657 (errno: 98 - Address already in use). The server socket has failed to bind to 0.0.0.0:28657 (errno: 98 - Address already in use).


In [None]:
trainer.fit(model, train_loader)

In [None]:
visualize(resnet, test_loader)