Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple DataModules from Models - CPCV2 #386

Merged
merged 2 commits into from
Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 16 additions & 35 deletions pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,65 +33,52 @@ class CPCV2(pl.LightningModule):

def __init__(
self,
datamodule: Optional[pl.LightningDataModule] = None,
encoder_name: str = 'cpc_encoder',
patch_size: int = 8,
patch_overlap: int = 4,
online_ft: int = True,
online_ft: bool = True,
task: str = 'cpc',
num_workers: int = 4,
learning_rate: int = 1e-4,
data_dir: str = '',
batch_size: int = 32,
num_classes: int = 10,
learning_rate: float = 1e-4,
pretrained: Optional[str] = None,
**kwargs,
):
"""
Args:
datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder,
or a custon nn.Module encoder
patch_size: How big to make the image patches
patch_overlap: How much overlap should each patch have.
online_ft: Enable a 1024-unit MLP to fine-tune online
patch_overlap: How much overlap each patch should have
online_ft: If True, enables a 1024-unit MLP to fine-tune online
task: Which self-supervised task to use ('cpc', 'amdim', etc...)
num_workers: num dataloader worksers
learning_rate: what learning rate to use
data_dir: where to store data
batch_size: batch size
num_workers: number of dataloader workers
num_classes: number of classes
learning_rate: learning rate
pretrained: If true, will use the weights pretrained (using CPC) on Imagenet
"""

super().__init__()
self.save_hyperparameters()

# HACK - datamodule not pickleable so we remove it from hparams.
# TODO - remove datamodule from init. data should be decoupled from models.
del self.hparams['datamodule']

self.online_evaluator = self.hparams.online_ft

if pretrained:
self.hparams.dataset = pretrained
self.online_evaluator = True

assert datamodule
self.datamodule = datamodule

self.encoder = self.init_encoder()

# info nce loss
c, h = self.__compute_final_nb_c(self.hparams.patch_size)
self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1)

self.z_dim = c * h * h
self.num_classes = self.datamodule.num_classes
self.num_classes = num_classes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since self.num_classes was already defined, I'm leaving it as is, but is this variable really necessary? As far as I understand the paper, it uses only images without any labels...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this variable self.num_classes useful/necessary for downstream tasks...?


if pretrained:
self.load_pretrained(self.hparams.encoder_name)

print(self.hparams)

def load_pretrained(self, encoder_name):
available_weights = {'resnet18'}

Expand Down Expand Up @@ -212,19 +199,9 @@ def add_model_specific_args(parent_parser):
'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
]
parser.add_argument('--encoder', default='cpc_encoder', type=str, choices=possible_resnets)

# training params
parser.add_argument('--batch_size', type=int, default=128)

# cifar10: 1e-5, stl10: 3e-5, imagenet: 4e-4
parser.add_argument('--learning_rate', type=float, default=1e-5)

# data
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--data_dir', default='.', type=str)
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet')
parser.add_argument('--num_workers', default=8, type=int)

return parser


Expand All @@ -237,9 +214,13 @@ def cli_main():
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = CPCV2.add_model_specific_args(parser)
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--data_dir', default='.', type=str)
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--batch_size', type=int, default=128)

args = parser.parse_args()
args.online_ft = True

datamodule = None

Expand Down Expand Up @@ -276,9 +257,9 @@ def to_device(batch, device):
datamodule.val_transforms = CPCEvalTransformsImageNet128()
args.patch_size = 32

model = CPCV2(**vars(args), datamodule=datamodule)
model = CPCV2(**vars(args))
trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
trainer.fit(model)
trainer.fit(model, datamodule)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def test_cpcv2(tmpdir, datadir):
datamodule.train_transforms = CPCTrainTransformsCIFAR10()
datamodule.val_transforms = CPCEvalTransformsCIFAR10()

model = CPCV2(encoder='resnet18', data_dir=datadir, batch_size=2, online_ft=True, datamodule=datamodule)
model = CPCV2(encoder='resnet18', online_ft=True, num_classes=datamodule.num_classes)
trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
trainer.fit(model, datamodule)
loss = trainer.progress_bar_dict['val_nce']

assert float(loss) > 0
Expand Down