This bolts module houses a collection of all self-supervised learning models.
Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.
Self-supervised models are trained with unlabeled datasets
Note
We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
Here are some use cases for the self-supervised package.
The models in this module are trained unsupervised and thus can capture better image representations (features).
In this example, we'll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.
.. testcode:: from pl_bolts.models.self_supervised import SimCLR # load resnet50 pretrained using SimCLR on imagenet weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) simclr_resnet50 = simclr.encoder simclr_resnet50.eval()
This means you can now extract image representations that were pretrained via unsupervised learning.
Example:
my_dataset = SomeDataset() for batch in my_dataset: x, y = batch out = simclr_resnet50(x)
These models are perfect for training from scratch when you have a huge set of unlabeled images
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform
train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)
trainer.fit(
model,
DataLoader(train_dataset),
DataLoader(val_dataset),
)
Mix and match any part, or subclass to create your own new method
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask
amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPC_v2(contrastive_task=amdim_task)
Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.
In this section, we list Lightning implementations of popular contrastive learning approaches.
.. autoclass:: pl_bolts.models.self_supervised.AMDIM :noindex:
.. autoclass:: pl_bolts.models.self_supervised.BYOL :noindex:
PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding
Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).
Model implemented by:
To Train:
import pytorch_lightning as pl from pl_bolts.models.self_supervised import CPC_v2 from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.cpc import ( CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10) # data dm = CIFAR10DataModule(num_workers=0) dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() # model model = CPC_v2() # fit trainer = pl.Trainer() trainer.fit(model, datamodule=dm)
To finetune:
python cpc_finetuner.py --ckpt_path path/to/checkpoint.ckpt --dataset cifar10 --gpus 1
CPCv2 does not report baselines on CIFAR-10 and STL-10 datasets. Results in table are reported from the YADIM paper.
CPCv2 implementation resultsDataset | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
CIFAR-10 | 84.52 | CPCresnet101 | Adam | 64 | 1000 (upto 24 hours) | 1 V100 (32GB) | 4e-5 |
STL-10 | 78.36 | CPCresnet101 | Adam | 144 | 1000 (upto 72 hours) | 4 V100 (32GB) | 1e-4 |
ImageNet | 54.82 | CPCresnet101 | Adam | 3072 | 1000 (upto 21 days) | 64 V100 (32GB) | 4e-5 |
CIFAR-10 pretrained model:
from pl_bolts.models.self_supervised import CPC_v2 weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt' cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False) cpc_v2.freeze()
Pre-training:
Fine-tuning:
STL-10 pretrained model:
from pl_bolts.models.self_supervised import CPC_v2 weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt' cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False) cpc_v2.freeze()
Pre-training:
Fine-tuning:
.. autoclass:: pl_bolts.models.self_supervised.CPC_v2 :noindex:
.. autoclass:: pl_bolts.models.self_supervised.Moco_v2 :noindex:
PyTorch Lightning implementation of SimCLR
Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.
Model implemented by:
To Train:
import pytorch_lightning as pl from pl_bolts.models.self_supervised import SimCLR from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr.transforms import ( SimCLREvalDataTransform, SimCLRTrainDataTransform) # data dm = CIFAR10DataModule(num_workers=0) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) # model model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10') # fit trainer = pl.Trainer() trainer.fit(model, datamodule=dm)Cifar-10 implementation results
Implementation | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
Original | ~94.00 | resnet50 | LARS | 2048 | 800 | TPUs | 1.0/1.5 |
Ours | 88.50 | resnet50 | LARS | 2048 | 800 (4 hours) | 8 V100 (16GB) | 1.5 |
CIFAR-10 pretrained model:
from pl_bolts.models.self_supervised import SimCLR weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-sgd/simclr-cifar10-sgd.ckpt' simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) simclr.freeze()
Pre-training:
Fine-tuning (Single layer MLP, 1024 hidden units):
To reproduce:
# pretrain python simclr_module.py --gpus 8 --dataset cifar10 --batch_size 256 --num_workers 16 --optimizer sgd --learning_rate 1.5 --exclude_bn_bias --max_epochs 800 --online_ft # finetune python simclr_finetuner.py --gpus 4 --ckpt_path path/to/simclr/ckpt --dataset cifar10 --batch_size 64 --num_workers 8 --learning_rate 0.3 --num_epochs 100Cifar-10 implementation results
Implementation | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
Original | ~69.3 | resnet50 | LARS | 4096 | 800 | TPUs | 4.8 |
Ours | 68.4 | resnet50 | LARS | 4096 | 800 | 64 V100 (16GB) | 4.8 |
Imagenet pretrained model:
from pl_bolts.models.self_supervised import SimCLR weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) simclr.freeze()
To reproduce:
# pretrain python simclr_module.py --dataset imagenet --data_path path/to/imagenet # finetune python simclr_finetuner.py --gpus 8 --ckpt_path path/to/simclr/ckpt --dataset imagenet --data_dir path/to/imagenet/dataset --batch_size 256 --num_workers 16 --learning_rate 0.8 --nesterov True --num_epochs 90
.. autoclass:: pl_bolts.models.self_supervised.SimCLR :noindex:
PyTorch Lightning implementation of SwAV Adapted from the official implementation
Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.
Implementation adapted by:
To Train:
import pytorch_lightning as pl from pl_bolts.models.self_supervised import SwAV from pl_bolts.datamodules import STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import ( SwAVTrainDataTransform, SwAVEvalDataTransform ) from pl_bolts.transforms.dataset_normalizations import stl10_normalization # data batch_size = 128 dm = STL10DataModule(data_dir='.', batch_size=batch_size) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed dm.train_transforms = SwAVTrainDataTransform( normalize=stl10_normalization() ) dm.val_transforms = SwAVEvalDataTransform( normalize=stl10_normalization() ) # model model = SwAV( gpus=1, num_samples=dm.num_unlabeled_samples, dataset='stl10', batch_size=batch_size ) # fit trainer = pl.Trainer(precision=16) trainer.fit(model)
We have included an option to directly load ImageNet weights provided by FAIR into bolts.
You can load the pretrained model using:
ImageNet pretrained model:
from pl_bolts.models.self_supervised import SwAV weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar' swav = SwAV.load_from_checkpoint(weight_path, strict=True) swav.freeze()
The original paper does not provide baselines on STL10.
STL-10 implementation resultsImplementation | test acc | Encoder | Optimizer | Batch | Queue used | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|---|
Ours | 86.72 | SwAV resnet50 | LARS | 128 | No | 100 (~9 hr) | 1 V100 (16GB) | 1e-3 |
STL-10 pretrained model:
from pl_bolts.models.self_supervised import SwAV weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar' swav = SwAV.load_from_checkpoint(weight_path, strict=False) swav.freeze()
Pre-training:
Fine-tuning (Single layer MLP, 1024 hidden units):
To reproduce:
# pretrain python swav_module.py --online_ft --gpus 1 --batch_size 128 --learning_rate 1e-3 --gaussian_blur --queue_length 0 --jitter_strength 1. --nmb_prototypes 512 # finetune python swav_finetuner.py --gpus 8 --ckpt_path path/to/simclr/ckpt --dataset imagenet --data_dir path/to/imagenet/dataset --batch_size 256 --num_workers 16 --learning_rate 0.8 --nesterov True --num_epochs 90Cifar-10 implementation results
Implementation | test acc | Encoder | Optimizer | Batch | Epochs | Hardware | LR |
---|---|---|---|---|---|---|---|
Original | 75.3 | resnet50 | LARS | 4096 | 800 | 64 V100s | 4.8 |
Ours | 74 | resnet50 | LARS | 4096 | 800 | 64 V100 (16GB) | 4.8 |
Imagenet pretrained model:
from pl_bolts.models.self_supervised import SwAV weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt' swav = SwAV.load_from_checkpoint(weight_path, strict=False) swav.freeze()
.. autoclass:: pl_bolts.models.self_supervised.SwAV :noindex:
.. autoclass:: pl_bolts.models.self_supervised.SimSiam :noindex: