This bolts module houses a collection of all self-supervised learning models.
Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.
Self-supervised models are trained with unlabeled datasets
Here are some use cases for the self-supervised package.
The models in this module are trained unsupervised and thus can capture better image representations (features).
In this example, we'll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.
from pl_bolts.models.self_supervised import SimCLR
# load resnet50 pretrained using SimCLR on imagenet weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr_resnet50 = simclr.encoder simclr_resnet50.eval()
This means you can now extract image representations that were pretrained via unsupervised learning.
Example:
my_dataset = SomeDataset()
for batch in my_dataset:
x, y = batch
out = simclr_resnet50(x)
These models are perfect for training from scratch when you have a huge set of unlabeled images
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform
train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)
trainer.fit(
model,
DataLoader(train_dataset),
DataLoader(val_dataset),
)
Mix and match any part, or subclass to create your own new method
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask
amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPC_v2(contrastive_task=amdim_task)
Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.
In this section, we list Lightning implementations of popular contrastive learning approaches.
pl_bolts.models.self_supervised.AMDIM
pl_bolts.models.self_supervised.BYOL
PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding
Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).
Model implemented by:
To Train:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc import (
CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()
# model
model = CPC_v2()
# fit
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
To finetune:
python cpc_finetuner.py
--ckpt_path path/to/checkpoint.ckpt
--dataset cifar10
--gpus 1
CPCv2 does not report baselines on CIFAR-10 and STL-10 datasets. Results in table are reported from the YADIM paper.
CPCv2 implementation resultsDataset | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
CIFAR-10 | 84.52 | CPCresnet101 | Adam | 64 | 1000 (upto 24 hours) | 1 V100 (32GB) | 4e-5 |
STL-10 | 78.36 | CPCresnet101 | Adam | 144 | 1000 (upto 72 hours) | 4 V100 (32GB) | 1e-4 |
ImageNet | 54.82 | CPCresnet101 | Adam | 3072 | 1000 (upto 21 days) | 64 V100 (32GB) | 4e-5 |
CIFAR-10 pretrained model:
from pl_bolts.models.self_supervised import CPC_v2
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt'
cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)
cpc_v2.freeze()
Pre-training:
Fine-tuning:
STL-10 pretrained model:
from pl_bolts.models.self_supervised import CPC_v2
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt'
cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)
cpc_v2.freeze()
Pre-training:
Fine-tuning:
pl_bolts.models.self_supervised.CPC_v2
pl_bolts.models.self_supervised.Moco_v2
PyTorch Lightning implementation of SimCLR
Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.
Model implemented by:
To Train:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10')
# fit
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
Implementation | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
Original | ~94.00 | resnet50 | LARS | 2048 | 800 | TPUs | 1.0/1.5 |
Ours | 88.50 | resnet50 | LARS | 2048 | 800 (4 hours) | 8 V100 (16GB) | 1.5 |
CIFAR-10 pretrained model:
from pl_bolts.models.self_supervised import SimCLR
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-sgd/simclr-cifar10-sgd.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr.freeze()
Pre-training:
Fine-tuning (Single layer MLP, 1024 hidden units):
To reproduce:
# pretrain
python simclr_module.py
--gpus 8
--dataset cifar10
--batch_size 256
-- num_workers 16
--optimizer sgd
--learning_rate 1.5
--exclude_bn_bias
--max_epochs 800
--online_ft
# finetune
python simclr_finetuner.py
--gpus 4
--ckpt_path path/to/simclr/ckpt
--dataset cifar10
--batch_size 64
--num_workers 8
--learning_rate 0.3
--num_epochs 100
Implementation | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
Original | ~69.3 | resnet50 | LARS | 4096 | 800 | TPUs | 4.8 |
Ours | 68.4 | resnet50 | LARS | 4096 | 800 | 64 V100 (16GB) | 4.8 |
Imagenet pretrained model:
from pl_bolts.models.self_supervised import SimCLR
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr.freeze()
To reproduce:
# pretrain
python simclr_module.py
--dataset imagenet
--data_path path/to/imagenet
# finetune
python simclr_finetuner.py
--gpus 8
--ckpt_path path/to/simclr/ckpt
--dataset imagenet
--data_dir path/to/imagenet/dataset
--batch_size 256
--num_workers 16
--learning_rate 0.8
--nesterov True
--num_epochs 90
pl_bolts.models.self_supervised.SimCLR
PyTorch Lightning implementation of SwAV Adapted from the official implementation
Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.
Implementation adapted by:
To Train:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVTrainDataTransform, SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization
# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed
dm.train_transforms = SwAVTrainDataTransform(
normalize=stl10_normalization()
)
dm.val_transforms = SwAVEvalDataTransform(
normalize=stl10_normalization()
)
# model
model = SwAV(
gpus=1,
num_samples=dm.num_unlabeled_samples,
dataset='stl10',
batch_size=batch_size
)
# fit
trainer = pl.Trainer(precision=16)
trainer.fit(model)
We have included an option to directly load ImageNet weights provided by FAIR into bolts.
You can load the pretrained model using:
ImageNet pretrained model:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=True)
swav.freeze()
The original paper does not provide baselines on STL10.
STL-10 implementation resultsImplementation | test acc | Encoder | Optimizer | Batch | Queue used | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|---|
Ours | 86.72 | SwAV resnet50 | LARS | 128 | No | 100 (~9 hr) | 1 V100 (16GB) | 1e-3 |
STL-10 pretrained model:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)
swav.freeze()
Pre-training:
Fine-tuning (Single layer MLP, 1024 hidden units):
To reproduce:
# pretrain
python swav_module.py
--online_ft
--gpus 1
--batch_size 128
--learning_rate 1e-3
--gaussian_blur
--queue_length 0
--jitter_strength 1.
--nmb_prototypes 512
# finetune
python swav_finetuner.py
--gpus 8
--ckpt_path path/to/simclr/ckpt
--dataset imagenet
--data_dir path/to/imagenet/dataset
--batch_size 256
--num_workers 16
--learning_rate 0.8
--nesterov True
--num_epochs 90
Implementation | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
Original | 75.3 | resnet50 | LARS | 4096 | 800 | 64 V100s | 4.8 |
Ours | 74 | resnet50 | LARS | 4096 | 800 | 64 V100 (16GB) | 4.8 |
Imagenet pretrained model:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)
swav.freeze()
pl_bolts.models.self_supervised.SwAV
pl_bolts.models.self_supervised.SimSiam