# SimCLR implementation #

Implementation following: https://theaisummer.com/simclr/

In [1]:
!pip install torch torchvision pytorch-lightning lightning-bolts wandb



In [2]:
# Log in to your W&B account
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33malontke[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
import os

import torch
import torchvision.models as models
from torchvision.datasets import EuroSAT
from torch.utils.data import DataLoader
from torch.multiprocessing import cpu_count
from pytorch_lightning.callbacks import GradientAccumulationScheduler, ModelCheckpoint
from pytorch_lightning import Trainer

from ssl_remote_sensing.pretext_tasks.simclr.utils import reproducibility
from ssl_remote_sensing.pretext_tasks.simclr.training import SimCLRTraining
from ssl_remote_sensing.pretext_tasks.simclr.augmentation import Augment

In [None]:
means = [87.81586935763889, 96.97416420717593, 103.98142336697049]
stds = [51.67849701591506, 34.908630837585186, 29.465280593587384]

In [9]:
# Machine setup
available_gpus = torch.cuda.device_count()
save_model_path = os.path.join(os.getcwd(), "saved_models/")
print("available_gpus:", available_gpus)

# Model Setup
class Hparams:
    def __init__(self):
        self.epochs = 10  # number of training epochs
        self.seed = 1234  # randomness seed
        self.cuda = False  # use nvidia gpu
        self.img_size = 64  # image shape
        self.save = "./saved_models/"  # save checkpoint
        self.gradient_accumulation_steps = 1  # gradient accumulation steps
        self.batch_size = 64
        self.lr = 1e-3
        self.embedding_size = 128  # papers value is 128
        self.temperature = 0.5  # 0.1 or 0.5
        self.weight_decay = 1e-6


train_config = Hparams()

# Run setup
filename = "SimCLR_ResNet18_adam"
save_name = filename + ".ckpt"
resume_from_checkpoint = False
wandb.init(project="ssl-remote-sensing", config=train_config.__dict__)


reproducibility(train_config)

model = SimCLRTraining(
    config=train_config,
    model=models.resnet18(weights=None),
    feat_dim=512,
    norm_means=means,
    norm_stds=stds,
)

transform = Augment(train_config.img_size, norm_means=means, norm_stds=stds)

dataset = EuroSAT("./", transform=transform, download=True)
data_loader = DataLoader(
    dataset=dataset, batch_size=train_config.batch_size, num_workers=cpu_count()
)

# Needed to get simulate a large batch size
accumulator = GradientAccumulationScheduler(scheduling={0: 1})

checkpoint_callback = ModelCheckpoint(
    filename=filename,
    dirpath=save_model_path,
    every_n_epochs=2,
    save_last=True,
    save_top_k=2,
    monitor="InfoNCE loss_epoch",
    mode="min",
)

if resume_from_checkpoint:
    trainer = Trainer(
        callbacks=[accumulator, checkpoint_callback],
        gpus=available_gpus,
        max_epochs=train_config.epochs,
        resume_from_checkpoint=train_config.checkpoint_path,
    )
else:
    trainer = Trainer(
        callbacks=[accumulator, checkpoint_callback],
        gpus=available_gpus,
        max_epochs=train_config.epochs,
    )

trainer.fit(model, data_loader)
trainer.save_checkpoint(save_name)
print(f"Best model is stored under {checkpoint_callback.best_model_path}")

available_gpus: 0


DIM MLP input: 512


  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name  | Type          | Params
----------------------------------------
0 | model | AddProjection | 11.5 M
1 | loss  | InfoNceLoss   | 0     
----------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.024    Total estimated model params size (MB)
  rank_zero_warn(


Optimizer Adam, Learning Rate 0.001, Effective batch size 64


Training: 0it [00:00, ?it/s]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/alexanderlontke/.conda/envs/ssl-remote-sensing/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/alexanderlontke/.conda/envs/ssl-remote-sensing/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'Augment' on <module '__main__' (built-in)>
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
