Skip to content

Commit

Permalink
BYOL implementation (#144)
Browse files Browse the repository at this point in the history
* byol wip

* add blank lines

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* verify implementation

* add l2 normalization
  • Loading branch information
annikabrundyn committed Aug 5, 2020
1 parent f409d5e commit 0a4f0ae
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ PyTorch-Lightning-Bolts documentation

callbacks
info_callbacks
self_supervised_callbacks
variational_callbacks
vision_callbacks

Expand Down
15 changes: 15 additions & 0 deletions docs/source/self_supervised_callbacks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.. role:: hidden
:class: hidden-section

Self-supervised Callbacks
=========================
Useful callbacks for self-supervised learning models

---------------

BYOLMAWeightUpdate
------------------
The exponential moving average weight-update rule from Bring Your Own Latent Space (BYOL).

.. autoclass:: pl_bolts.callbacks.self_supervised.BYOLMAWeightUpdate
:noindex:
6 changes: 6 additions & 0 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ AMDIM
.. autoclass:: pl_bolts.models.self_supervised.AMDIM
:noindex:

BYOL
^^^^

.. autoclass:: pl_bolts.models.self_supervised.BYOL
:noindex:

CPC (V2)
^^^^^^^^

Expand Down
62 changes: 62 additions & 0 deletions pl_bolts/callbacks/self_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import math
import pytorch_lightning as pl


class BYOLMAWeightUpdate(pl.Callback):

def __init__(self, initial_tau=0.996):
"""
Weight update rule from BYOL.
Your model should have a:
- self.online_network.
- self.target_network.
Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.
.. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step
Example::
from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate
# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...
# make sure to set max_steps in Trainer
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()], max_steps=1000)
Args:
initial_tau: starting tau. Auto-updates with every training step
"""
super().__init__()
self.initial_tau = initial_tau
self.current_tau = initial_tau

def on_batch_end(self, trainer, pl_module):

if pl_module.training:
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network

# update weights
self.update_weights(online_net, target_net)

# update tau after
self.current_tau = self.update_tau(pl_module, trainer)

def update_tau(self, pl_module, trainer):
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / trainer.max_steps) + 1) / 2
return tau

def update_weights(self, online_net, target_net):
# apply MA weight update
for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()):
if 'weight' in name:
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
1 change: 1 addition & 0 deletions pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM
from pl_bolts.models.self_supervised.byol.byol_module import BYOL
from pl_bolts.models.self_supervised.cpc.cpc_module import CPCV2
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2
Expand Down
Empty file.
219 changes: 219 additions & 0 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from copy import deepcopy
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.optimizers.layer_adaptive_scaling import LARS
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate


class BYOL(pl.LightningModule):
def __init__(self,
datamodule: pl.LightningDataModule = None,
data_dir: str = './',
learning_rate: float = 0.00006,
weight_decay: float = 0.0005,
input_height: int = 32,
batch_size: int = 32,
num_workers: int = 4,
optimizer: str = 'lars',
lr_sched_step: float = 30.0,
lr_sched_gamma: float = 0.5,
lars_momentum: float = 0.9,
lars_eta: float = 0.001,
loss_temperature: float = 0.5,
**kwargs):
"""
PyTorch Lightning implementation of `Bring Your Own Latent Space (BYOL)
<https://arxiv.org/pdf/2006.07733.pdf.>`_
Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \
Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \
Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.
Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
.. warning:: Work in progress. This implementation is still being verified.
TODOs:
- add cosine scheduler
- verify on CIFAR-10
- verify on STL-10
- pre-train on imagenet
Example:
>>> from pl_bolts.models.self_supervised import BYOL
...
>>> model = BYOL()
Train::
trainer = Trainer()
trainer.fit(model)
CLI command::
# cifar10
python byol_module.py --gpus 1
# imagenet
python byol_module.py
--gpus 8
--dataset imagenet2012
--data_dir /path/to/imagenet/
--meta_dir /path/to/folder/with/meta.bin/
--batch_size 32
Args:
datamodule: The datamodule
data_dir: directory to store data
learning_rate: the learning rate
weight_decay: optimizer weight decay
input_height: image input height
batch_size: the batch size
num_workers: number of workers
optimizer: optimizer name
lr_sched_step: step for learning rate scheduler
lr_sched_gamma: gamma for learning rate scheduler
lars_momentum: the mom param for lars optimizer
lars_eta: for lars optimizer
loss_temperature: float = 0.
"""
super().__init__()
self.save_hyperparameters()

# init default datamodule
if datamodule is None:
datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size)
datamodule.train_transforms = SimCLRTrainDataTransform(input_height)
datamodule.val_transforms = SimCLREvalDataTransform(input_height)

self.datamodule = datamodule

self.online_network = SiameseArm()
self.target_network = deepcopy(self.online_network)

self.weight_callback = BYOLMAWeightUpdate()

def on_batch_end(self):
# Add callback for user automatically since it's key to BYOL weight update
self.weight_callback.on_batch_end(self.trainer, self)

def forward(self, x):
y, _, _ = self.online_network(x)
return y

def shared_step(self, batch, batch_idx):
(img_1, img_2), y = batch

# Image 1 to image 2 loss
y1, z1, h1 = self.online_network(img_1)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_2)
# L2 normalize
h1_norm = F.normalize(h1, p=2, dim=1)
z2_norm = F.normalize(z2, p=2, dim=1)
loss_a = F.mse_loss(h1_norm, z2_norm)

# Image 2 to image 1 loss
y1, z1, h1 = self.online_network(img_2)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_1)
# L2 normalize
h1_norm = F.normalize(h1, p=2, dim=1)
z2_norm = F.normalize(z2, p=2, dim=1)
loss_b = F.mse_loss(h1_norm, z2_norm)

# Final loss
total_loss = loss_a + loss_b

return loss_a, loss_b, total_loss

def training_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

# log results
result = pl.TrainResult(minimize=total_loss)
result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})

return result

def validation_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

# log results
result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss)
result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})

return result

def configure_optimizers(self):
optimizer = LARS(self.parameters(), lr=self.hparams.learning_rate)
# TODO: add scheduler - cosine decay
return optimizer

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--online_ft', action='store_true', help='run online finetuner')
parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10')

(args, _) = parser.parse_known_args()
# Data
parser.add_argument('--data_dir', type=str, default='.')

# Training
parser.add_argument('--optimizer', choices=['adam', 'lars'], default='lars')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--learning_rate', type=float, default=1.0)
parser.add_argument('--lars_momentum', type=float, default=0.9)
parser.add_argument('--lars_eta', type=float, default=0.001)
parser.add_argument('--lr_sched_step', type=float, default=30, help='lr scheduler step')
parser.add_argument('--lr_sched_gamma', type=float, default=0.5, help='lr scheduler step')
parser.add_argument('--weight_decay', type=float, default=1e-4)
# Model
parser.add_argument('--loss_temperature', type=float, default=0.5)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet')

return parser


if __name__ == '__main__':
from argparse import ArgumentParser

parser = ArgumentParser()

# trainer args
parser = pl.Trainer.add_argparse_args(parser)

# model args
parser = BYOL.add_model_specific_args(parser)
args = parser.parse_args()

# pick data
datamodule = None
if args.dataset == 'stl10':
datamodule = STL10DataModule.from_argparse_args(args)
datamodule.train_dataloader = datamodule.train_dataloader_mixed
datamodule.val_dataloader = datamodule.val_dataloader_mixed

(c, h, w) = datamodule.size()
datamodule.train_transforms = SimCLRTrainDataTransform(h)
datamodule.val_transforms = SimCLREvalDataTransform(h)

elif args.dataset == 'imagenet2012':
datamodule = ImagenetDataModule.from_argparse_args(args, image_size=196)
(c, h, w) = datamodule.size()
datamodule.train_transforms = SimCLRTrainDataTransform(h)
datamodule.val_transforms = SimCLREvalDataTransform(h)

model = BYOL(**args.__dict__, datamodule=datamodule)

trainer = pl.Trainer.from_argparse_args(args, max_steps=10000)
trainer.fit(model)
39 changes: 39 additions & 0 deletions pl_bolts/models/self_supervised/byol/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from torch import nn
from pl_bolts.utils.self_supervised import torchvision_ssl_encoder


class MLP(nn.Module):
def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256):
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, output_dim, bias=True))

def forward(self, x):
x = self.model(x)
return x


class SiameseArm(nn.Module):
def __init__(self, encoder=None):
super().__init__()

if encoder is None:
encoder = torchvision_ssl_encoder('resnet50')
# Encoder
self.encoder = encoder
# Projector
self.projector = MLP()
# Predictor
self.predictor = MLP(input_dim=256)

def forward(self, x):
y = self.encoder(x)[0]
y = y.view(y.size(0), -1)
z = self.projector(y)
h = self.predictor(z)
return y, z, h
31 changes: 31 additions & 0 deletions tests/callbacks/test_param_update_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from torch import nn
from copy import deepcopy
from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate


def test_byol_ma_weight_update_callback(tmpdir):
a = nn.Linear(100, 10)
b = deepcopy(a)
a_original = deepcopy(a)
b_original = deepcopy(b)

# make sure a params and b params are the same
assert torch.equal(next(iter(a.parameters()))[0], next(iter(b.parameters()))[0])

# fake weight update
opt = torch.optim.SGD(a.parameters(), lr=0.1)
y = a(torch.randn(3, 100))
loss = y.sum()
loss.backward()
opt.step()
opt.zero_grad()

# make sure a did in fact update
assert not torch.equal(next(iter(a_original.parameters()))[0], next(iter(a.parameters()))[0])

# do update via callback
cb = BYOLMAWeightUpdate(0.8)
cb.update_weights(a, b)

assert not torch.equal(next(iter(b_original.parameters()))[0], next(iter(b.parameters()))[0])

0 comments on commit 0a4f0ae

Please sign in to comment.