Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _reset_eval_dataloader() for IterableDataset #1560

Merged
merged 3 commits into from May 5, 2020

Conversation

ybrovman
Copy link
Contributor

@ybrovman ybrovman commented Apr 22, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

I encountered the following error TypeError: object of type 'MyIterableDataset' has no len() from line 190 in _reset_eval_dataloader() in data_loading.py file when using an IterableDataset for the validation dataset.

The if dl caused the issue. if dl is equivalent to bool(dl) = dataloader.__bool__, but there is no dataloader.__bool__ so bool() uses dataloader.__len__ > 0. But... dataloader.__len__ uses IterableDataset.__len__ for IterableDatasets for which __len__ is undefined.

Resolving the issue by comparing to None, if dl is not None, in _reset_eval_dataloader().

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

👍

@mergify mergify bot requested a review from a team April 22, 2020 15:27
pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team April 22, 2020 21:38
@pep8speaks
Copy link

pep8speaks commented Apr 23, 2020

Hello @ybrovman! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-05-05 15:48:37 UTC

@Borda Borda requested review from tullie and Borda April 23, 2020 11:31
@Borda Borda added the bug Something isn't working label Apr 23, 2020
@Borda Borda added this to the 0.7.4 milestone Apr 23, 2020
@mergify mergify bot requested a review from a team April 23, 2020 11:34
@mergify mergify bot requested a review from a team April 24, 2020 07:55
@Borda Borda modified the milestones: 0.7.4, 0.7.5 Apr 24, 2020
Copy link
Member

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, but I don't understand why the test test_inf_val_dataloader passes if this bug exists. How do I reproduce this actually? The description of this PR makes it look like this should fail with any IterableDataset, but I don't get this error.

@mergify mergify bot requested a review from a team April 25, 2020 23:21
@ybrovman
Copy link
Contributor Author

Good catch, but I don't understand why the test test_inf_val_dataloader passes if this bug exists. How do I reproduce this actually? The description of this PR makes it look like this should fail with any IterableDataset, but I don't get this error.

I am not too familiar with the testing details here, however, I think the issue might be with the CustomInfDataloader. Perhaps since if dl = bool(dl) and CustomInfDataloader.__bool__ or CustomInfDataloader.__len__ does not exist, bool(CustomInfDataloader) is always True, so the test_inf_val_dataloader passes.

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🤖

@mergify mergify bot requested a review from a team May 5, 2020 15:32
@Borda
Copy link
Member

Borda commented May 5, 2020

@ybrovman mind add a test for this iter dataset?

@Borda Borda added the ready PRs ready to be merged label May 5, 2020
@codecov
Copy link

codecov bot commented May 5, 2020

Codecov Report

Merging #1560 into master will not change coverage.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #1560   +/-   ##
======================================
  Coverage      88%     88%           
======================================
  Files          69      69           
  Lines        4151    4151           
======================================
  Hits         3661    3661           
  Misses        490     490           

Copy link
Member

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case is dl actually None? Have you considered removing the if statement completely?

@mergify mergify bot requested a review from a team May 5, 2020 17:22
@ybrovman
Copy link
Contributor Author

ybrovman commented May 5, 2020

In which case is dl actually None? Have you considered removing the if statement completely?

@awaelchli I did remove the if statement in the original commit, however, added back the None check after @tullie 's comment.

@tullie
Copy link
Contributor

tullie commented May 5, 2020

I tried to look at cases where it could be None and it's hard to track down exactly. However, i'm fairly sure it's only None if the val_dataloader or test_dataloader returns None.

@williamFalcon williamFalcon merged commit 35bbe17 into Lightning-AI:master May 5, 2020
@Borda
Copy link
Member

Borda commented May 5, 2020

I tried to look at cases where it could be None and it's hard to track down exactly. However, i'm fairly sure it's only None if the val_dataloader or test_dataloader returns None.

shall we raise a warning if any dataloader is none or just skip it...

@tullie
Copy link
Contributor

tullie commented May 5, 2020

Yeah ideally we should raise a warning but not the biggest issue.

@Borda
Copy link
Member

Borda commented May 5, 2020

Yeah ideally we should raise a warning but not the biggest issue.

so just a warning to logging or standard Runtime warning... and just once and then remove the None dataloader from the list

@ybrovman mind send a followup PR?

@ybrovman ybrovman mentioned this pull request May 5, 2020
5 tasks
@ybrovman
Copy link
Contributor Author

ybrovman commented May 5, 2020

@Borda I created PR #1745 to address your comment.

SiddhantRanade added a commit to SiddhantRanade/pytorch-lightning that referenced this pull request Aug 13, 2020
This function has the if statement `if (train_dataloader or val_dataloaders) and datamodule:`.


The issue is similar to that in Lightning-AI#1560. The problem is that the `if(dl)` translates to `if(bool(dl))`, but there's no dataloader.__bool__ so bool() uses dataloader.__len__ > 0. But... dataloader.__len__ uses IterableDataset.__len__ for IterableDatasets for which __len__ is undefined.

The fix is also the same, the `if dl` should be replaced by `if dl is not None`.
Borda pushed a commit that referenced this pull request Aug 13, 2020
…2957)

This function has the if statement `if (train_dataloader or val_dataloaders) and datamodule:`.


The issue is similar to that in #1560. The problem is that the `if(dl)` translates to `if(bool(dl))`, but there's no dataloader.__bool__ so bool() uses dataloader.__len__ > 0. But... dataloader.__len__ uses IterableDataset.__len__ for IterableDatasets for which __len__ is undefined.

The fix is also the same, the `if dl` should be replaced by `if dl is not None`.

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
ameliatqy pushed a commit to ameliatqy/pytorch-lightning that referenced this pull request Aug 17, 2020
…ightning-AI#2957)

This function has the if statement `if (train_dataloader or val_dataloaders) and datamodule:`.


The issue is similar to that in Lightning-AI#1560. The problem is that the `if(dl)` translates to `if(bool(dl))`, but there's no dataloader.__bool__ so bool() uses dataloader.__len__ > 0. But... dataloader.__len__ uses IterableDataset.__len__ for IterableDatasets for which __len__ is undefined.

The fix is also the same, the `if dl` should be replaced by `if dl is not None`.

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IterableDataset does not work in validation
7 participants