diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 5db0f8a318..0d2f7fd8a8 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -4,16 +4,15 @@ import torch from pytorch_lightning import LightningModule, Trainer, seed_everything +from torch import Tensor from torch.nn import functional as F from torch.optim import Adam from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate -from pl_bolts.models.self_supervised.byol.models import SiameseArm +from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from pl_bolts.utils.stability import under_review -@under_review() class BYOL(LightningModule): """PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL_)_ @@ -21,16 +20,20 @@ class BYOL(LightningModule): Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ Bilal Piot, Koray Kavukcuoglu, RĂ©mi Munos, Michal Valko. + Args: + learning_rate (float, optional): optimizer learning rate. Defaults to 0.2. + weight_decay (float, optional): optimizer weight decay. Defaults to 1.5e-6. + warmup_epochs (int, optional): number of epochs for scheduler warmup. Defaults to 10. + max_epochs (int, optional): maximum number of epochs for scheduler. Defaults to 1000. + base_encoder (Union[str, torch.nn.Module], optional): base encoder architecture. Defaults to "resnet50". + encoder_out_dim (int, optional): base encoder output dimension. Defaults to 2048. + projector_hidden_dim (int, optional): projector MLP hidden dimension. Defaults to 4096. + projector_out_dim (int, optional): projector MLP output dimension. Defaults to 256. + initial_tau (float, optional): initial value of target decay rate used. Defaults to 0.996. + Model implemented by: - `Annika Brundyn `_ - .. warning:: Work in progress. This implementation is still being verified. - - TODOs: - - verify on CIFAR-10 - - verify on STL-10 - - pre-train on imagenet - Example:: model = BYOL(num_classes=10) @@ -42,11 +45,6 @@ class BYOL(LightningModule): trainer = pl.Trainer() trainer.fit(model, datamodule=dm) - Train:: - - trainer = Trainer() - trainer.fit(model) - CLI command:: # cifar10 @@ -65,87 +63,82 @@ class BYOL(LightningModule): def __init__( self, - num_classes, learning_rate: float = 0.2, weight_decay: float = 1.5e-6, - input_height: int = 32, - batch_size: int = 32, - num_workers: int = 0, warmup_epochs: int = 10, max_epochs: int = 1000, base_encoder: Union[str, torch.nn.Module] = "resnet50", encoder_out_dim: int = 2048, - projector_hidden_size: int = 4096, + projector_hidden_dim: int = 4096, projector_out_dim: int = 256, - **kwargs - ): - """ - Args: - datamodule: The datamodule - learning_rate: the learning rate - weight_decay: optimizer weight decay - input_height: image input height - batch_size: the batch size - num_workers: number of workers - warmup_epochs: num of epochs for scheduler warm up - max_epochs: max epochs for scheduler - base_encoder: the base encoder module or resnet name - encoder_out_dim: output dimension of base_encoder - projector_hidden_size: hidden layer size of projector MLP - projector_out_dim: output size of projector MLP - """ + initial_tau: float = 0.996, + **kwargs: Any, + ) -> None: + super().__init__() self.save_hyperparameters(ignore="base_encoder") - self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim) + self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_dim, projector_out_dim) self.target_network = deepcopy(self.online_network) - self.weight_callback = BYOLMAWeightUpdate() + self.predictor = MLP(projector_out_dim, projector_hidden_dim, projector_out_dim) - def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None: - # Add callback for user automatically since it's key to BYOL weight update + self.weight_callback = BYOLMAWeightUpdate(initial_tau=initial_tau) + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: + """Add callback to perform exponential moving average weight update on target network.""" self.weight_callback.on_train_batch_end(self.trainer, self, outputs, batch, batch_idx) - def forward(self, x): - y, _, _ = self.online_network(x) - return y + def forward(self, x: Tensor) -> Tensor: + """Returns the encoded representation of a view. - def shared_step(self, batch, batch_idx): - imgs, y = batch - img_1, img_2 = imgs[:2] + Args: + x (Tensor): sample to be encoded + """ + return self.online_network.encode(x) - # Image 1 to image 2 loss - y1, z1, h1 = self.online_network(img_1) - with torch.no_grad(): - y2, z2, h2 = self.target_network(img_2) - loss_a = -2 * F.cosine_similarity(h1, z2).mean() + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Complete training loop.""" + return self._shared_step(batch, batch_idx, "train") - # Image 2 to image 1 loss - y1, z1, h1 = self.online_network(img_2) - with torch.no_grad(): - y2, z2, h2 = self.target_network(img_1) - # L2 normalize - loss_b = -2 * F.cosine_similarity(h1, z2).mean() + def validation_step(self, batch: Any, batch_idx: int) -> Tensor: + """Complete validation loop.""" + return self._shared_step(batch, batch_idx, "val") - # Final loss - total_loss = loss_a + loss_b + def _shared_step(self, batch: Any, batch_idx: int, step: str) -> Tensor: + """Shared evaluation step for training and validation loop.""" + imgs, _ = batch + img1, img2 = imgs[:2] - return loss_a, loss_b, total_loss + # Calculate similarity loss in each direction + loss_12 = self.calculate_loss(img1, img2) + loss_21 = self.calculate_loss(img2, img1) - def training_step(self, batch, batch_idx): - loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + # Calculate total loss + total_loss = loss_12 + loss_21 - # log results - self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss}) + # Log losses + if step == "train": + self.log_dict({"train_loss_12": loss_12, "train_loss_21": loss_21, "train_loss": total_loss}) + elif step == "val": + self.log_dict({"val_loss_12": loss_12, "val_loss_21": loss_21, "val_loss": total_loss}) + else: + raise ValueError(f"Step '{step}' is invalid. Must be 'train' or 'val'.") return total_loss - def validation_step(self, batch, batch_idx): - loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor: + """Calculates similarity loss between the online network prediction of target network projection. - # log results - self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "val_loss": total_loss}) - - return total_loss + Args: + v_online (Tensor): Online network view + v_target (Tensor): Target network view + """ + _, z1 = self.online_network(v_online) + h1 = self.predictor(z1) + with torch.no_grad(): + _, z2 = self.target_network(v_target) + loss = -2 * F.cosine_similarity(h1, z2).mean() + return loss def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) @@ -155,30 +148,23 @@ def configure_optimizers(self): return [optimizer], [scheduler] @staticmethod - def add_model_specific_args(parent_parser): + def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--online_ft", action="store_true", help="run online finetuner") - parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"]) - - (args, _) = parser.parse_known_args() + args = parser.parse_args([]) - # Data - parser.add_argument("--data_dir", type=str, default=".") - parser.add_argument("--num_workers", default=8, type=int) + if "max_epochs" in args: + parser.set_defaults(max_epochs=1000) + else: + parser.add_argument("--max_epochs", type=int, default=1000) - # optim - parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--learning_rate", type=float, default=1e-3) + parser.add_argument("--learning_rate", type=float, default=0.2) parser.add_argument("--weight_decay", type=float, default=1.5e-6) - parser.add_argument("--warmup_epochs", type=float, default=10) - - # Model + parser.add_argument("--warmup_epochs", type=int, default=10) parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet") return parser -@under_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule @@ -188,23 +174,19 @@ def cli_main(): parser = ArgumentParser() - # trainer args parser = Trainer.add_argparse_args(parser) - - # model args parser = BYOL.add_model_specific_args(parser) - args = parser.parse_args() + parser = CIFAR10DataModule.add_dataset_specific_args(parser) + parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"]) - # pick data - dm = None + args = parser.parse_args() - # init default datamodule + # Initialize datamodule if args.dataset == "cifar10": dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_classes = dm.num_classes - elif args.dataset == "stl10": dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed @@ -214,20 +196,24 @@ def cli_main(): dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes - elif args.dataset == "imagenet2012": dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.dims dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes + else: + raise ValueError( + f"{args.dataset} is not a valid dataset. Dataset must be 'cifar10', 'stl10', or 'imagenet2012'." + ) - model = BYOL(**args.__dict__) + # Initialize BYOL module + model = BYOL(**vars(args)) # finetune in real-time online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes) - trainer = Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) + trainer = Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, datamodule=dm) diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py index fe6c0b856b..7d6168cc34 100644 --- a/pl_bolts/models/self_supervised/byol/models.py +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -1,43 +1,78 @@ -from torch import nn +from typing import Tuple, Union + +from torch import Tensor, nn from pl_bolts.utils.self_supervised import torchvision_ssl_encoder -from pl_bolts.utils.stability import under_review -@under_review() class MLP(nn.Module): - def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): + """MLP architecture used as projectors in online and target networks and predictors in the online network. + + Args: + input_dim (int, optional): Input dimension. Defaults to 2048. + hidden_dim (int, optional): Hidden layer dimension. Defaults to 4096. + output_dim (int, optional): Output dimension. Defaults to 256. + + Note: + Default values for input, hidden, and output dimensions are based on values used in BYOL. + """ + + def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None: + super().__init__() - self.output_dim = output_dim - self.input_dim = input_dim + self.model = nn.Sequential( - nn.Linear(input_dim, hidden_size, bias=False), - nn.BatchNorm1d(hidden_size), + nn.Linear(input_dim, hidden_dim, bias=False), + nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), - nn.Linear(hidden_size, output_dim, bias=True), + nn.Linear(hidden_dim, output_dim, bias=True), ) - def forward(self, x): - x = self.model(x) - return x + def forward(self, x: Tensor) -> Tensor: + return self.model(x) -@under_review() class SiameseArm(nn.Module): - def __init__(self, encoder="resnet50", encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): + """SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single + class. + + Args: + encoder (Union[str, nn.Module], optional): Online and target network encoder architecture. + Defaults to "resnet50". + encoder_out_dim (int, optional): Output dimension of encoder. Defaults to 2048. + projector_hidden_dim (int, optional): Online and target network projector network hidden dimension. + Defaults to 4096. + projector_out_dim (int, optional): Online and target network projector network output dimension. + Defaults to 256. + """ + + def __init__( + self, + encoder: Union[str, nn.Module] = "resnet50", + encoder_out_dim: int = 2048, + projector_hidden_dim: int = 4096, + projector_out_dim: int = 256, + ) -> None: + super().__init__() if isinstance(encoder, str): - encoder = torchvision_ssl_encoder(encoder) - # Encoder - self.encoder = encoder - # Projector - self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim) - # Predictor - self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim) - - def forward(self, x): + self.encoder = torchvision_ssl_encoder(encoder) + else: + self.encoder = encoder + + self.projector = MLP(encoder_out_dim, projector_hidden_dim, projector_out_dim) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: y = self.encoder(x)[0] z = self.projector(y) - h = self.predictor(z) - return y, z, h + return y, z + + def encode(self, x: Tensor) -> Tensor: + """Returns the encoded representation of a view. This method does not calculate the projection as in the + forward method. + + Args: + x (Tensor): sample to be encoded + """ + return self.encoder(x)[0] diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 59f0ca0519..a94882110c 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -1,8 +1,9 @@ -from distutils.version import LooseVersion +import warnings import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.utilities.warnings import PossibleUserWarning from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised import AMDIM, BYOL, CPC_v2, Moco_v2, SimCLR, SimSiam, SwAV @@ -37,16 +38,26 @@ def test_cpcv2(tmpdir, datadir): trainer.fit(model, datamodule=datamodule) -# todo: some pickling issue with min config -@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.7.0"), reason="Pickling issue") -def test_byol(tmpdir, datadir): - datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) - datamodule.train_transforms = CPCTrainTransformsCIFAR10() - datamodule.val_transforms = CPCEvalTransformsCIFAR10() - - model = BYOL(data_dir=datadir, num_classes=datamodule) - trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - trainer.fit(model, datamodule=datamodule) +def test_byol(tmpdir, datadir, catch_warnings): + """Test BYOL on CIFAR-10.""" + warnings.filterwarnings( + "ignore", + message=".+does not have many workers which may be a bottleneck.+", + category=PossibleUserWarning, + ) + dm = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) + + model = BYOL(data_dir=datadir) + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + max_epochs=1, + accelerator="auto", + log_every_n_steps=1, + ) + trainer.fit(model, datamodule=dm) def test_amdim(tmpdir, datadir): diff --git a/tests/models/self_supervised/test_ssl_scripts.py b/tests/models/self_supervised/test_ssl_scripts.py index e58fad9765..f080903cdc 100644 --- a/tests/models/self_supervised/test_ssl_scripts.py +++ b/tests/models/self_supervised/test_ssl_scripts.py @@ -76,12 +76,10 @@ def test_cli_run_ssl_simclr(cli_args): cli_main() -# todo: seems to take too long -@pytest.mark.skip(reason="FIXME: seems to take too long") @pytest.mark.parametrize( "cli_args", [ - _DEFAULT_ARGS + " --online_ft" " --gpus 1", + _DEFAULT_ARGS + " --gpus 1", ], ) @pytest.mark.skipif(**_MARK_REQUIRE_GPU)