From 8c559aba0b35867dfb0b79dc3ef6c572916e0f60 Mon Sep 17 00:00:00 2001 From: hirwa Date: Mon, 18 Mar 2024 16:07:08 +0530 Subject: [PATCH 1/2] move the pretrained model to the same device as model in unit test --- tests/tests_pytorch/models/test_restore.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index c7a28f25d938a..1f5d36117fc11 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -465,6 +465,7 @@ def test_load_model_from_checkpoint(tmp_path, model_template): # Ensure that model can be correctly restored from checkpoint pretrained_model = model_template.load_from_checkpoint(last_checkpoint) + pretrained_model.to(model.device) # Move pretrained_model to the same device as model # test that hparams loaded correctly for k, v in model.hparams.items(): From c1b8265a5d9d650f64ac0e7d024fc7c365e1e10d Mon Sep 17 00:00:00 2001 From: hirwa Date: Mon, 18 Mar 2024 21:13:55 +0530 Subject: [PATCH 2/2] added cpu as traner accelerator --- tests/tests_pytorch/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 1f5d36117fc11..211dc49d42eda 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -446,6 +446,7 @@ def test_load_model_from_checkpoint(tmp_path, model_template): "limit_test_batches": 2, "callbacks": [ModelCheckpoint(dirpath=tmp_path, monitor="val_loss", save_top_k=-1)], "default_root_dir": tmp_path, + "accelerator": "cpu", } # fit model @@ -465,7 +466,6 @@ def test_load_model_from_checkpoint(tmp_path, model_template): # Ensure that model can be correctly restored from checkpoint pretrained_model = model_template.load_from_checkpoint(last_checkpoint) - pretrained_model.to(model.device) # Move pretrained_model to the same device as model # test that hparams loaded correctly for k, v in model.hparams.items():