diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py deleted file mode 100644 index b0c7fb2290..0000000000 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional, Tuple - -from torch import Tensor, nn - -from pl_bolts.utils.self_supervised import torchvision_ssl_encoder -from pl_bolts.utils.stability import under_review - - -@under_review() -class MLP(nn.Module): - def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: - super().__init__() - self.output_dim = output_dim - self.input_dim = input_dim - self.model = nn.Sequential( - nn.Linear(input_dim, hidden_size, bias=False), - nn.BatchNorm1d(hidden_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_size, output_dim, bias=True), - ) - - def forward(self, x: Tensor) -> Tensor: - x = self.model(x) - return x - - -@under_review() -class SiameseArm(nn.Module): - def __init__( - self, - encoder: Optional[nn.Module] = None, - input_dim: int = 2048, - hidden_size: int = 4096, - output_dim: int = 256, - ) -> None: - super().__init__() - - if encoder is None: - encoder = torchvision_ssl_encoder("resnet50") - # Encoder - self.encoder = encoder - # Projector - self.projector = MLP(input_dim, hidden_size, output_dim) - # Predictor - self.predictor = MLP(output_dim, hidden_size, output_dim) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - y = self.encoder(x)[0] - z = self.projector(y) - h = self.predictor(z) - return y, z, h diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 7554f0cf80..4ea1afbecc 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -1,38 +1,39 @@ from argparse import ArgumentParser +from copy import deepcopy +from typing import Any, Dict, List, Union import torch +import torch.nn as nn +import torch.nn.functional as F from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from torch.nn import functional as F - -from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 -from pl_bolts.models.self_supervised.simsiam.models import SiameseArm -from pl_bolts.optimizers.lars import LARS -from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay -from pl_bolts.transforms.dataset_normalizations import ( - cifar10_normalization, - imagenet_normalization, - stl10_normalization, -) -from pl_bolts.utils.stability import under_review - - -@under_review() +from torch import Tensor + +from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + + class SimSiam(LightningModule): - """PyTorch Lightning implementation of Exploring Simple Siamese Representation Learning (SimSiam_) + """PyTorch Lightning implementation of Exploring Simple Siamese Representation Learning (SimSiam_)_ Paper authors: Xinlei Chen, Kaiming He. + Args: + learning_rate (float, optional): optimizer leaning rate. Defaults to 0.05. + weight_decay (float, optional): optimizer weight decay. Defaults to 1e-4. + momentum (float, optional): optimizer momentum. Defaults to 0.9. + warmup_epochs (int, optional): number of epochs for scheduler warmup. Defaults to 10. + max_epochs (int, optional): maximum number of epochs for scheduler. Defaults to 100. + base_encoder (Union[str, nn.Module], optional): base encoder architecture. Defaults to "resnet50". + encoder_out_dim (int, optional): base encoder output dimension. Defaults to 2048. + projector_hidden_dim (int, optional): projector MLP hidden dimension. Defaults to 2048. + projector_out_dim (int, optional): project MLP output dimension. Defaults to 2048. + predictor_hidden_dim (int, optional): predictor MLP hidden dimension. Defaults to 512. + exclude_bn_bias (bool, optional): option to exclude batchnorm and bias terms from weight decay. + Defaults to False. + Model implemented by: - `Zvi Lapp `_ - .. warning:: Work in progress. This implementation is still being verified. - - TODOs: - - verify on CIFAR-10 - - verify on STL-10 - - pre-train on imagenet - Example:: model = SimSiam() @@ -44,11 +45,6 @@ class SimSiam(LightningModule): trainer = Trainer() trainer.fit(model, datamodule=dm) - Train:: - - trainer = Trainer() - trainer.fit(model) - CLI command:: # cifar10 @@ -58,7 +54,6 @@ class SimSiam(LightningModule): python simsiam_module.py --gpus 8 --dataset imagenet2012 - --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 @@ -67,128 +62,104 @@ class SimSiam(LightningModule): def __init__( self, - gpus: int, - num_samples: int, - batch_size: int, - dataset: str, - num_nodes: int = 1, - arch: str = "resnet50", - hidden_mlp: int = 2048, - feat_dim: int = 128, + learning_rate: float = 0.05, + weight_decay: float = 1e-4, + momentum: float = 0.9, warmup_epochs: int = 10, max_epochs: int = 100, - temperature: float = 0.1, - first_conv: bool = True, - maxpool1: bool = True, - optimizer: str = "adam", + base_encoder: Union[str, nn.Module] = "resnet50", + encoder_out_dim: int = 2048, + projector_hidden_dim: int = 2048, + projector_out_dim: int = 2048, + predictor_hidden_dim: int = 512, exclude_bn_bias: bool = False, - start_lr: float = 0.0, - learning_rate: float = 1e-3, - final_lr: float = 0.0, - weight_decay: float = 1e-6, - **kwargs - ): - """ - Args: - datamodule: The datamodule - learning_rate: the learning rate - weight_decay: optimizer weight decay - input_height: image input height - batch_size: the batch size - num_workers: number of workers - warmup_epochs: num of epochs for scheduler warm up - max_epochs: max epochs for scheduler - """ + **kwargs, + ) -> None: + super().__init__() - self.save_hyperparameters() - - self.gpus = gpus - self.num_nodes = num_nodes - self.arch = arch - self.dataset = dataset - self.num_samples = num_samples - self.batch_size = batch_size - - self.hidden_mlp = hidden_mlp - self.feat_dim = feat_dim - self.first_conv = first_conv - self.maxpool1 = maxpool1 - - self.optim = optimizer - self.exclude_bn_bias = exclude_bn_bias - self.weight_decay = weight_decay - self.temperature = temperature - - self.start_lr = start_lr - self.final_lr = final_lr - self.learning_rate = learning_rate - self.warmup_epochs = warmup_epochs - self.max_epochs = max_epochs - - self.init_model() - - # compute iters per epoch - nb_gpus = len(self.gpus) if isinstance(gpus, (list, tuple)) else self.gpus - assert isinstance(nb_gpus, int) - global_batch_size = self.num_nodes * nb_gpus * self.batch_size if nb_gpus > 0 else self.batch_size - self.train_iters_per_epoch = self.num_samples // global_batch_size - - def init_model(self): - if self.arch == "resnet18": - backbone = resnet18 - elif self.arch == "resnet50": - backbone = resnet50 - - encoder = backbone(first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False) - self.online_network = SiameseArm( - encoder, input_dim=self.hidden_mlp, hidden_size=self.hidden_mlp, output_dim=self.feat_dim - ) + self.save_hyperparameters(ignore="base_encoder") - def forward(self, x): - y, _, _ = self.online_network(x) - return y + self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_dim, projector_out_dim) + self.target_network = deepcopy(self.online_network) + self.predictor = MLP(projector_out_dim, predictor_hidden_dim, projector_out_dim) - def cosine_similarity(self, a, b): - b = b.detach() # stop gradient of backbone + projection mlp - a = F.normalize(a, dim=-1) - b = F.normalize(b, dim=-1) - sim = -1 * (a * b).sum(-1).mean() - return sim + def forward(self, x: Tensor) -> Tensor: + """Returns encoded representation of a view.""" + return self.online_network.encode(x) - def training_step(self, batch, batch_idx): - (img_1, img_2, _), y = batch + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Complete training loop.""" + return self._shared_step(batch, batch_idx, "train") - # Image 1 to image 2 loss - _, z1, h1 = self.online_network(img_1) - _, z2, h2 = self.online_network(img_2) - loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 + def validation_step(self, batch: Any, batch_idx: int) -> Tensor: + """Complete validation loop.""" + return self._shared_step(batch, batch_idx, "val") - # log results - self.log_dict({"train_loss": loss}) + def _shared_step(self, batch: Any, batch_idx: int, step: str) -> Tensor: + """Shared evaluation step for training and validation loops.""" + imgs, _ = batch + img1, img2 = imgs[:2] - return loss + # Calculate similarity loss in each direction + loss_12 = self.calculate_loss(img1, img2) + loss_21 = self.calculate_loss(img2, img1) + + # Calculate total loss + total_loss = loss_12 + loss_21 - def validation_step(self, batch, batch_idx): - (img_1, img_2, _), y = batch + # Log loss + if step == "train": + self.log_dict({"train_loss_12": loss_12, "train_loss_21": loss_21, "train_loss": total_loss}) + elif step == "val": + self.log_dict({"val_loss_12": loss_12, "val_loss_21": loss_21, "val_loss": total_loss}) + else: + raise ValueError(f"Step '{step}' is invalid. Must be 'train' or 'val'.") - # Image 1 to image 2 loss - _, z1, h1 = self.online_network(img_1) - _, z2, h2 = self.online_network(img_2) - loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 + return total_loss - # log results - self.log_dict({"val_loss": loss}) + def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor: + """Calculates similarity loss between the online network prediction of target network projection. + Args: + v_online (Tensor): Online network view + v_target (Tensor): Target network view + """ + _, z1 = self.online_network(v_online) + h1 = self.predictor(z1) + with torch.no_grad(): + _, z2 = self.target_network(v_target) + loss = -0.5 * F.cosine_similarity(h1, z2).mean() return loss - def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")): + def configure_optimizers(self): + """Configure optimizer and learning rate scheduler.""" + if self.hparams.exclude_bn_bias: + params = self.exclude_from_weight_decay(self.named_parameters(), weight_decay=self.hparams.weight_decay) + else: + params = self.parameters() + + optimizer = torch.optim.SGD( + params, + lr=self.hparams.learning_rate, + momentum=self.hparams.momentum, + weight_decay=self.hparams.weight_decay, + ) + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs + ) + + return [optimizer], [scheduler] + + @staticmethod + def exclude_from_weight_decay(named_params, weight_decay, skip_list=("bias", "bn")) -> List[Dict]: + """Exclude parameters from weight decay.""" params = [] excluded_params = [] for name, param in named_params: if not param.requires_grad: continue - elif any(layer_name in name for layer_name in skip_list): + elif param.ndim == 1 or name in skip_list: excluded_params.append(param) else: params.append(param) @@ -198,73 +169,25 @@ def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", " {"params": excluded_params, "weight_decay": 0.0}, ] - def configure_optimizers(self): - if self.exclude_bn_bias: - params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) - else: - params = self.parameters() - - if self.optim == "lars": - optimizer = LARS( - params, - lr=self.learning_rate, - momentum=0.9, - weight_decay=self.weight_decay, - trust_coefficient=0.001, - ) - elif self.optim == "adam": - optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) - - warmup_steps = self.train_iters_per_epoch * self.warmup_epochs - total_steps = self.train_iters_per_epoch * self.max_epochs - - scheduler = { - "scheduler": torch.optim.lr_scheduler.LambdaLR( - optimizer, - linear_warmup_decay(warmup_steps, total_steps, cosine=True), - ), - "interval": "step", - "frequency": 1, - } - - return [optimizer], [scheduler] - @staticmethod - def add_model_specific_args(parent_parser): + def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: parser = ArgumentParser(parents=[parent_parser], add_help=False) - # model params - parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") - # specify flags to store false - parser.add_argument("--first_conv", action="store_false") - parser.add_argument("--maxpool1", action="store_false") - parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") - parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") - parser.add_argument("--online_ft", action="store_true") - parser.add_argument("--fp32", action="store_true") - - # transform params - parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") - parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") - parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10") - parser.add_argument("--data_dir", type=str, default=".", help="path to download data") - - # training params - parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") - parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") - parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") - parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") - parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") + args = parser.parse_args([]) - parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") - parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") - parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") - parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") - parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") + if "max_epochs" in args: + parser.set_defaults(max_epochs=100) + else: + parser.add_argument("--max_epochs", type=int, default=100) + + parser.add_argument("--learning_rate", default=0.05, type=float, help="base learning rate") + parser.add_argument("--weight_decay", default=1e-4, type=float, help="weight decay") + parser.add_argument("--momentum", default=0.9, type=float, help="momentum") + parser.add_argument("--base_encoder", default="resnet50", type=str, help="encoder backbone") + parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") return parser -@under_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule @@ -274,120 +197,46 @@ def cli_main(): parser = ArgumentParser() - # trainer args parser = Trainer.add_argparse_args(parser) - - # model args parser = SimSiam.add_model_specific_args(parser) - args = parser.parse_args() + parser = CIFAR10DataModule.add_dataset_specific_args(parser) + parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"]) - # pick data - dm = None - - # init datamodule - if args.dataset == "stl10": - dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + args = parser.parse_args() + # Initialize datamodule + if args.dataset == "cifar10": + dm = CIFAR10DataModule.from_argparse_args(args) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) + args.num_classes = dm.num_classes + elif args.dataset == "stl10": + dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed - args.num_samples = dm.num_unlabeled_samples - - args.maxpool1 = False - args.first_conv = True - args.input_height = dm.dims[-1] - - normalization = stl10_normalization() - - args.gaussian_blur = True - args.jitter_strength = 1.0 - elif args.dataset == "cifar10": - val_split = 5000 - if args.num_nodes * args.gpus * args.batch_size > val_split: - val_split = args.num_nodes * args.gpus * args.batch_size - - dm = CIFAR10DataModule( - data_dir=args.data_dir, - batch_size=args.batch_size, - num_workers=args.num_workers, - val_split=val_split, - ) - - args.num_samples = dm.num_samples - - args.maxpool1 = False - args.first_conv = False - args.input_height = dm.dims[-1] - args.temperature = 0.5 - normalization = cifar10_normalization() - - args.gaussian_blur = False - args.jitter_strength = 0.5 - elif args.dataset == "imagenet": - args.maxpool1 = True - args.first_conv = True - normalization = imagenet_normalization() - - args.gaussian_blur = True - args.jitter_strength = 1.0 - - args.batch_size = 64 - args.num_nodes = 8 - args.gpus = 8 # per-node - args.max_epochs = 800 - - args.optimizer = "lars" - args.lars_wrapper = True - args.learning_rate = 4.8 - args.final_lr = 0.0048 - args.start_lr = 0.3 - args.online_ft = True - - dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) - - args.num_samples = dm.num_samples - args.input_height = dm.dims[-1] + (c, h, w) = dm.dims + dm.train_transforms = SimCLRTrainDataTransform(h) + dm.val_transforms = SimCLREvalDataTransform(h) + args.num_classes = dm.num_classes + elif args.dataset == "imagenet2012": + dm = ImagenetDataModule.from_argparse_args(args, image_size=196) + (c, h, w) = dm.dims + dm.train_transforms = SimCLRTrainDataTransform(h) + dm.val_transforms = SimCLREvalDataTransform(h) + args.num_classes = dm.num_classes else: - raise NotImplementedError("other datasets have not been implemented till now") - - dm.train_transforms = SimCLRTrainDataTransform( - input_height=args.input_height, - gaussian_blur=args.gaussian_blur, - jitter_strength=args.jitter_strength, - normalize=normalization, - ) - - dm.val_transforms = SimCLREvalDataTransform( - input_height=args.input_height, - gaussian_blur=args.gaussian_blur, - jitter_strength=args.jitter_strength, - normalize=normalization, - ) - - model = SimSiam(**args.__dict__) - - # finetune in real-time - online_evaluator = None - if args.online_ft: - # online eval - online_evaluator = SSLOnlineEvaluator( - drop_p=0.0, - hidden_dim=None, - z_dim=args.hidden_mlp, - num_classes=dm.num_classes, - dataset=args.dataset, + raise ValueError( + f"{args.dataset} is not a valid dataset. Dataset must be 'cifar10', 'stl10', or 'imagenet2012'." ) - lr_monitor = LearningRateMonitor(logging_interval="step") - model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss") - callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] - callbacks.append(lr_monitor) + # Initialize SimSiam module + model = SimSiam(**vars(args)) + + # Finetune in real-time + online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes) - trainer = Trainer.from_argparse_args( - args, - sync_batchnorm=True if args.gpus > 1 else False, - callbacks=callbacks, - ) + trainer = Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, datamodule=dm) diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index a94882110c..bd8957928d 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -118,11 +118,23 @@ def test_swav(tmpdir, datadir, batch_size=2): trainer.fit(model, datamodule=datamodule) -def test_simsiam(tmpdir, datadir): - datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) - datamodule.train_transforms = SimCLRTrainDataTransform(32) - datamodule.val_transforms = SimCLREvalDataTransform(32) +def test_simsiam(tmpdir, datadir, catch_warnings): + """Test SimSiam on CIFAR-10.""" + warnings.filterwarnings( + "ignore", + message=".+does not have many workers which may be a bottleneck.+", + category=PossibleUserWarning, + ) + dm = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) - model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset="cifar10") - trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) - trainer.fit(model, datamodule=datamodule) + model = SimSiam() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + max_epochs=1, + accelerator="auto", + log_every_n_steps=1, + ) + trainer.fit(model, datamodule=dm) diff --git a/tests/models/self_supervised/test_ssl_scripts.py b/tests/models/self_supervised/test_ssl_scripts.py index f080903cdc..eb70e65482 100644 --- a/tests/models/self_supervised/test_ssl_scripts.py +++ b/tests/models/self_supervised/test_ssl_scripts.py @@ -118,7 +118,7 @@ def test_cli_run_ssl_swav(cli_args): @pytest.mark.parametrize( "cli_args", [ - _DEFAULT_ARGS + " --dataset cifar10" " --online_ft" " --gpus 1" " --fp32", + _DEFAULT_ARGS + " --gpus 1", ], ) @pytest.mark.skipif(**_MARK_REQUIRE_GPU)