Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer.test's dataloader argument can't replace pre-defined dataloader #1754

Closed
davinnovation opened this issue May 7, 2020 · 2 comments
Closed
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@davinnovation
Copy link
Contributor

🐛 Bug

Trainer.test function supports test_dataloader argument. But if test dataloader defined before in trainer module, changing test dataloader with giving an argument in Trainer.test isn't working.

To Reproduce

... do some train stuff with trainer

trainer.test(model, dataloader1) # run dataloader 1
trainer.test(model. dataloader2) # expect dataloader 2 for test, but dataloader 1 is called

Environment

Python 3.6
pytorch-lightning 0.7.5

Additional context

trainer.test calls run_evaluation. And overwriting Trainer.test_dataloader performed when only trainer.test_dataloaders in None

        # select dataloaders
        if test_mode:
            if self.test_dataloaders is None:
                self.reset_test_dataloader(model)

Of course, if test_dataloader argument given to Trainer.test, it did something with __attach_dataloaders but it overwrites only model's test_dataloader so it doesn't used in Trainer.run_evaluation

    def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
        # when dataloader is passed via fit, patch the train_dataloader
        # functions to overwrite with these implementations
        if train_dataloader is not None:
            model.train_dataloader = _PatchDataLoader(train_dataloader)

        if val_dataloaders is not None:
            model.val_dataloader = _PatchDataLoader(val_dataloaders)

        if test_dataloaders is not None:
            model.test_dataloader = _PatchDataLoader(test_dataloaders)

so Trainer.test code should be fixed from

        self.testing = True

        if test_dataloaders is not None:
            if model:
                self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
            else:
                self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

to

        self.testing = True

        if test_dataloaders is not None:
            self.test_dataloaders = None
            if model:
                self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
            else:
                self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
@davinnovation davinnovation added bug Something isn't working help wanted Open to be worked on labels May 7, 2020
@Borda
Copy link
Member

Borda commented May 11, 2020

mind send a fix, PR?

@davinnovation
Copy link
Contributor Author

finished by #1858

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants