Skip to content

Latest commit

 

History

History
106 lines (72 loc) · 2.81 KB

test_set.rst

File metadata and controls

106 lines (72 loc) · 2.81 KB

Test set

Lightning forces the user to run the test set separately to make sure it isn't evaluated by mistake. Testing is performed using the trainer object's .test() method.

pytorch_lightning.trainer.Trainer.test


Test after fit

To run the test set after training completes, use this method.

# run full training
trainer.fit(model)

# (1) load the best checkpoint automatically (lightning tracks this for you)
trainer.test()

# (2) don't load a checkpoint, instead use the model with the latest weights
trainer.test(ckpt_path=None)

# (3) test using a specific checkpoint
trainer.test(ckpt_path="/path/to/my_checkpoint.ckpt")

# (4) test with an explicit model (will use this model and not load a checkpoint)
trainer.test(model)

Test multiple models

You can run the test set on multiple models using the same trainer instance.

model1 = LitModel()
model2 = GANModel()

trainer = Trainer()
trainer.test(model1)
trainer.test(model2)

Test pre-trained model

To run the test set on a pre-trained model, use this method.

model = MyLightningModule.load_from_checkpoint(
    checkpoint_path="/path/to/pytorch_checkpoint.ckpt",
    hparams_file="/path/to/test_tube/experiment/version/hparams.yaml",
    map_location=None,
)

# init trainer with whatever options
trainer = Trainer(...)

# test (pass in the model)
trainer.test(model)

In this case, the options you pass to trainer will be used when running the test set (ie: 16-bit, dp, ddp, etc...)


Test with additional data loaders

You can still run inference on a test set even if the test_dataloader method hasn't been defined within your lightning module <../common/lightning_module> instance. This would be the case when your test data is not available at the time your model was declared.

# setup your data loader
test_dataloader = DataLoader(...)

# test (pass in the loader)
trainer.test(dataloaders=test_dataloader)

You can either pass in a single dataloader or a list of them. This optional named parameter can be used in conjunction with any of the above use cases. Additionally, you can also pass in an datamodules <../extensions/datamodules> that have overridden the datamodule-test-dataloader-label method.

class MyDataModule(pl.LightningDataModule):
    ...

    def test_dataloader(self):
        return DataLoader(...)


# setup your datamodule
dm = MyDataModule(...)

# test (pass in datamodule)
trainer.test(datamodule=dm)