# SimCLR implementation #

Implementation following: https://theaisummer.com/simclr/

In [1]:
!pip install ssl_remote_sensing@git+https://github.com/AlexanderLontke/ssl-remote-sensing.git@feature/pipeline

In [1]:
# Log in to your W&B account
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33malontke[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import os

import torch
from pytorch_lightning.callbacks import GradientAccumulationScheduler, ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from ssl_remote_sensing.pretext_tasks.simclr.utils import reproducibility
from ssl_remote_sensing.pretext_tasks.simclr.training import SimCLRTraining
from ssl_remote_sensing.pretext_tasks.simclr.augmentation import Augment
from ssl_remote_sensing.pretext_tasks.simclr.config import get_simclr_config
from ssl_remote_sensing.data.get_eurosat import get_eurosat_normalizer, get_eurosat_dataloader

In [5]:
# Machine setup
available_gpus = torch.cuda.device_count()
save_model_path = os.path.join(os.getcwd(), "saved_models/")
print("available_gpus:", available_gpus)

# Model Setup
train_config = get_simclr_config()
train_config.batch_size = 2250
train_config.epochs = 40

# Run setup
filename = f"SimCLR_ResNet18_adam_bs{train_config.batch_size}"
save_name = filename + ".ckpt"
resume_from_checkpoint = False
reproducibility(train_config)

model = SimCLRTraining(
    config=train_config,
    feat_dim=512,
)

# Setup data loading and augments
eurosat_normalizer = get_eurosat_normalizer()
transform = Augment(train_config.img_size, normalizer=eurosat_normalizer)
data_loader = get_eurosat_dataloader(
    root="./",
    transform=transform,
    batchsize=train_config.batch_size,
    numworkers=os.cpu_count(),
    split=False,
)


# Needed to get simulate a large batch size
accumulator = GradientAccumulationScheduler(scheduling={0: 1})

checkpoint_callback = ModelCheckpoint(
    filename=filename,
    dirpath=save_model_path,
    every_n_epochs=2,
    save_last=True,
    save_top_k=2,
    monitor="train/InfoNCE",
    mode="min",
)

# Setup WandB logging
wandb_logger = WandbLogger(
    project="ssl-remote-sensing-simclr",
    config=train_config.__dict__
)
shared_trainer_kwargs = {
    "callbacks": [accumulator, checkpoint_callback],
    "max_epochs": train_config.epochs,
    "logger": wandb_logger,
    "log_every_n_steps": 1,
    "accelerator": "gpu",
}
if resume_from_checkpoint:
    trainer = Trainer(
        **shared_trainer_kwargs,
        resume_from_checkpoint=train_config.checkpoint_path,
    )
else:
    trainer = Trainer(
         **shared_trainer_kwargs,
    )

trainer.fit(model, data_loader)
trainer.save_checkpoint(save_name)
wandb.save(checkpoint_callback.best_model_path)
wandb.finish()
print(f"Best model is stored under {checkpoint_callback.best_model_path}")

available_gpus: 0
[LOG] Total number of images: 27000
[LOG] Batch size is 64
[LOG] Total images in the train set is: 27000
[LOG] Total number of batches in the trainloader: 422


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type          | Params
----------------------------------------
0 | model | AddProjection | 11.5 M
1 | loss  | InfoNCE       | 0     
----------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.024    Total estimated model params size (MB)


Optimizer Adam, Learning Rate 0.001, Effective batch size 64


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


ValueError: no path specified

In [4]:
wandb.finish()