Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hparams.encoder forgotten rename in CPCv2 #773

Merged
merged 10 commits into from
Nov 26, 2021
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(self, img_1):
Z = self.encoder(img_1)

# non cpc resnets return a list
if self.hparams.encoder != "cpc_encoder":
if self.hparams.encoder_name != "cpc_encoder":
Z = Z[0]

# (?) -> (b, -1, nb_feats, nb_feats)
Expand Down
9 changes: 6 additions & 3 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from tests import _MARK_REQUIRE_GPU


# todo: seems to be failing on GH Actions for min config
@pytest.mark.skipif(**_MARK_REQUIRE_GPU)
@pytest.mark.skip(reason="RuntimeError: Given groups=1, weight of size [256, 2048, 1, 1], expected input[2, 1, 32, 32]")
def test_cpcv2(tmpdir, datadir):
datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
datamodule.train_transforms = CPCTrainTransformsCIFAR10()
Expand All @@ -30,7 +28,12 @@ def test_cpcv2(tmpdir, datadir):
online_ft=True,
num_classes=datamodule.num_classes,
)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)

# FIXME: workaround for bug caused by
# https://github.com/PyTorchLightning/lightning-bolts/commit/2e903c333c37ea83394c7da2ce826de1b82fb356
model.datamodule = datamodule

trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, gpus=1 if torch.cuda.device_count() > 0 else 0)
trainer.fit(model, datamodule=datamodule)


Expand Down