Skip to content

Commit

Permalink
Fix hparams.encoder forgotten rename in CPCv2 (#773)
Browse files Browse the repository at this point in the history
* Re-enable testing CPCv2
* Fix hparams forgotten rename in CPCv2 module

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Nov 29, 2021
1 parent 216fe24 commit 73945c6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(self, img_1):
Z = self.encoder(img_1)

# non cpc resnets return a list
if self.hparams.encoder != "cpc_encoder":
if self.hparams.encoder_name != "cpc_encoder":
Z = Z[0]

# (?) -> (b, -1, nb_feats, nb_feats)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ warn_no_return = "False"
# TODO: Fix typing for these modules
[[tool.mypy.overrides]]
module = [
"pl_bolts.callbacks.ssl_online",
"pl_bolts.datasets.*",
"pl_bolts.datamodules",
"pl_bolts.datamodules.experience_source",
Expand Down
9 changes: 6 additions & 3 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from tests import _MARK_REQUIRE_GPU


# todo: seems to be failing on GH Actions for min config
@pytest.mark.skipif(**_MARK_REQUIRE_GPU)
@pytest.mark.skip(reason="RuntimeError: Given groups=1, weight of size [256, 2048, 1, 1], expected input[2, 1, 32, 32]")
def test_cpcv2(tmpdir, datadir):
datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
datamodule.train_transforms = CPCTrainTransformsCIFAR10()
Expand All @@ -30,7 +28,12 @@ def test_cpcv2(tmpdir, datadir):
online_ft=True,
num_classes=datamodule.num_classes,
)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)

# FIXME: workaround for bug caused by
# https://github.com/PyTorchLightning/lightning-bolts/commit/2e903c333c37ea83394c7da2ce826de1b82fb356
model.datamodule = datamodule

trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, gpus=1 if torch.cuda.device_count() > 0 else 0)
trainer.fit(model, datamodule=datamodule)


Expand Down

0 comments on commit 73945c6

Please sign in to comment.