Skip to content

Commit

Permalink
test: replacing tmpdir with tmp_path in tests_pytorch/helpers (#…
Browse files Browse the repository at this point in the history
…19643)

* refactored tmpdir from the tests_pytorch/helpers dir
* Update tests/tests_pytorch/helpers/test_datasets.py

---------

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
fnhirwa and awaelchli committed Mar 16, 2024
1 parent b9edd18 commit 41868ca
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/tests_pytorch/helpers/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
("dataset_cls", "args"),
[(MNIST, {"root": _PATH_DATASETS}), (TrialMNIST, {"root": _PATH_DATASETS}), (AverageDataset, {})],
)
def test_pickling_dataset_mnist(tmpdir, dataset_cls, args):
def test_pickling_dataset_mnist(dataset_cls, args):
mnist = dataset_cls(**args)

mnist_pickled = pickle.dumps(mnist)
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
pytest.param(RegressDataModule, RegressionModel, marks=RunIf(sklearn=True, onnx=True)),
],
)
def test_models(tmpdir, data_class, model_class):
def test_models(tmp_path, data_class, model_class):
"""Test simple models."""
dm = data_class() if data_class else data_class
model = model_class()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)

trainer.fit(model, datamodule=dm)

Expand All @@ -48,4 +48,4 @@ def test_models(tmpdir, data_class, model_class):

model.to_torchscript()
if data_class:
model.to_onnx(os.path.join(tmpdir, "my-model.onnx"), input_sample=dm.sample)
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)

0 comments on commit 41868ca

Please sign in to comment.