-
Notifications
You must be signed in to change notification settings - Fork 320
/
ssl_finetuner.py
92 lines (73 loc) · 3.06 KB
/
ssl_finetuner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pytorch_lightning as pl
import pytorch_lightning.metrics.functional as plm
from pl_bolts.models.self_supervised import SSLEvaluator
import torch
import torch.nn.functional as F
class SSLFineTuner(pl.LightningModule):
def __init__(self, backbone, in_features, num_classes, hidden_dim=1024):
"""
Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
with 1024 units
Example::
from pl_bolts.utils.self_supervised import SSLFineTuner
from pl_bolts.models.self_supervised import CPCV2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
CPCTrainTransformsCIFAR10
# pretrained model
backbone = CPCV2.load_from_checkpoint(PATH, strict=False)
# dataset + transforms
dm = CIFAR10DataModule(data_dir='.')
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()
# finetuner
finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)
# train
trainer = pl.Trainer()
trainer.fit(finetuner, dm)
# test
trainer.test(datamodule=dm)
Args:
backbone: a pretrained model
in_features: feature dim of backbone outputs
num_classes: classes of the dataset
hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
"""
super().__init__()
self.backbone = backbone
self.ft_network = SSLEvaluator(
n_input=in_features,
n_classes=num_classes,
p=0.2,
n_hidden=hidden_dim
)
def on_train_epoch_start(self) -> None:
self.backbone.eval()
def training_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.TrainResult(loss)
result.log('train_acc', acc, prog_bar=True)
return result
def validation_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.EvalResult(checkpoint_on=loss, early_stop_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True)
return result
def test_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.EvalResult()
result.log_dict({'test_acc': acc, 'test_loss': loss})
return result
def shared_step(self, batch):
x, y = batch
with torch.no_grad():
feats = self.backbone(x)
feats = feats.view(feats.size(0), -1)
logits = self.ft_network(feats)
loss = F.cross_entropy(logits, y)
acc = plm.accuracy(logits, y)
return loss, acc
def configure_optimizers(
self,
):
return torch.optim.Adam(self.ft_network.parameters(), lr=0.0002)