/
byol_module.py
246 lines (187 loc) · 8.23 KB
/
byol_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
from argparse import ArgumentParser
from copy import deepcopy
from typing import Any
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch.optim import Adam
from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
class BYOL(pl.LightningModule):
def __init__(self,
num_classes,
learning_rate: float = 0.2,
weight_decay: float = 1.5e-6,
input_height: int = 32,
batch_size: int = 32,
num_workers: int = 0,
warmup_epochs: int = 10,
max_epochs: int = 1000,
**kwargs):
"""
PyTorch Lightning implementation of `Bootstrap Your Own Latent (BYOL)
<https://arxiv.org/pdf/2006.07733.pdf>`_
Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \
Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \
Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.
Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
.. warning:: Work in progress. This implementation is still being verified.
TODOs:
- verify on CIFAR-10
- verify on STL-10
- pre-train on imagenet
Example::
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import BYOL
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
# model
model = BYOL(num_classes=10)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
trainer = pl.Trainer()
trainer.fit(model, dm)
Train::
trainer = Trainer()
trainer.fit(model)
CLI command::
# cifar10
python byol_module.py --gpus 1
# imagenet
python byol_module.py
--gpus 8
--dataset imagenet2012
--data_dir /path/to/imagenet/
--meta_dir /path/to/folder/with/meta.bin/
--batch_size 32
Args:
datamodule: The datamodule
learning_rate: the learning rate
weight_decay: optimizer weight decay
input_height: image input height
batch_size: the batch size
num_workers: number of workers
warmup_epochs: num of epochs for scheduler warm up
max_epochs: max epochs for scheduler
"""
super().__init__()
self.save_hyperparameters()
self.online_network = SiameseArm()
self.target_network = deepcopy(self.online_network)
self.weight_callback = BYOLMAWeightUpdate()
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
# Add callback for user automatically since it's key to BYOL weight update
self.weight_callback.on_train_batch_end(self.trainer, self, batch, batch_idx, dataloader_idx)
def forward(self, x):
y, _, _ = self.online_network(x)
return y
def cosine_similarity(self, a, b):
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
sim = (a * b).sum(-1).mean()
return sim
def shared_step(self, batch, batch_idx):
(img_1, img_2), y = batch
# Image 1 to image 2 loss
y1, z1, h1 = self.online_network(img_1)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_2)
loss_a = - 2 * self.cosine_similarity(h1, z2)
# Image 2 to image 1 loss
y1, z1, h1 = self.online_network(img_2)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_1)
# L2 normalize
loss_b = - 2 * self.cosine_similarity(h1, z2)
# Final loss
total_loss = loss_a + loss_b
return loss_a, loss_b, total_loss
def training_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
# log results
self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
return total_loss
def validation_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
# log results
self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
return total_loss
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
optimizer = LARSWrapper(optimizer)
scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.hparams.warmup_epochs,
max_epochs=self.hparams.max_epochs
)
return [optimizer], [scheduler]
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--online_ft', action='store_true', help='run online finetuner')
parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10')
(args, _) = parser.parse_known_args()
# Data
parser.add_argument('--data_dir', type=str, default='.')
parser.add_argument('--num_workers', default=0, type=int)
# optim
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1.5e-6)
parser.add_argument('--warmup_epochs', type=float, default=10)
# Model
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet')
return parser
def cli_main():
from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform
seed_everything(1234)
parser = ArgumentParser()
# trainer args
parser = pl.Trainer.add_argparse_args(parser)
# model args
parser = BYOL.add_model_specific_args(parser)
args = parser.parse_args()
# pick data
dm = None
# init default datamodule
if args.dataset == 'cifar10':
dm = CIFAR10DataModule.from_argparse_args(args)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
args.num_classes = dm.num_classes
elif args.dataset == 'stl10':
dm = STL10DataModule.from_argparse_args(args)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed
(c, h, w) = dm.size()
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes
elif args.dataset == 'imagenet2012':
dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
(c, h, w) = dm.size()
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes
model = BYOL(**args.__dict__)
def to_device(batch, device):
(x1, x2), y = batch
x1 = x1.to(device)
y = y.to(device)
return x1, y
# finetune in real-time
online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
online_eval.to_device = to_device
trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval])
trainer.fit(model, dm)
if __name__ == '__main__':
cli_main()