Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CombinedLoader: NoneType object is not iterable #16912

Closed
awaelchli opened this issue Mar 1, 2023 · 4 comments · Fixed by #17007
Closed

CombinedLoader: NoneType object is not iterable #16912

awaelchli opened this issue Mar 1, 2023 · 4 comments · Fixed by #17007
Assignees
Labels
data handling Generic data-related topic feature Is an improvement or enhancement
Milestone

Comments

@awaelchli
Copy link
Member

awaelchli commented Mar 1, 2023

Bug description

When the LM.val_dataloader method returns None, two unexpected behaviors take place:

  1. There is a warning that shouldn't show:
UserWarning: Total length of `NoneType` across ranks is zero. Please make sure this was your intention.
  1. There is an error, because the loop tries to run validation with None as the dataloader:
TypeError: 'NoneType' object is not iterable

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return None

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=2,
        enable_model_summary=False,
    )
    trainer.fit(model)


if __name__ == "__main__":
    run()

Error messages and logs

  File "/Users/adrian/repositories/lightning/examples/pl_bug_report/bug_report_model.py", line 62, in <module>
    run()
  File "/Users/adrian/repositories/lightning/examples/pl_bug_report/bug_report_model.py", line 58, in run
    trainer.fit(model)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/trainer.py", line 517, in fit
    call._call_and_handle_interrupt(
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/trainer.py", line 556, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/trainer.py", line 928, in _run
    results = self._run_stage()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/trainer.py", line 967, in _run_stage
    self._run_train()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/trainer.py", line 988, in _run_train
    self.fit_loop.run()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/fit_loop.py", line 192, in run
    self.advance()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/fit_loop.py", line 365, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/training_epoch_loop.py", line 134, in run
    self.on_advance_end()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/training_epoch_loop.py", line 248, in on_advance_end
    self.val_loop.run()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/utilities.py", line 167, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/evaluation_loop.py", line 93, in run
    self.reset()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/evaluation_loop.py", line 184, in reset
    iter(data_fetcher)  # creates the iterator inside the fetcher
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/fetchers.py", line 104, in __iter__
    super().__iter__()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/loops/fetchers.py", line 54, in __iter__
    self.dataloader_iter = iter(self.dataloader)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/utilities/combined_loader.py", line 242, in __iter__
    iter(iterator)
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/utilities/combined_loader.py", line 121, in __iter__
    super().__iter__()
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/utilities/combined_loader.py", line 35, in __iter__
    self.iterators = [iter(iterable) for iterable in self.iterables]
  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/utilities/combined_loader.py", line 35, in <listcomp>
    self.iterators = [iter(iterable) for iterable in self.iterables]
TypeError: 'NoneType' object is not iterable

Environment

Current environment
#- PyTorch Lightning Version (e.g., 1.5.0): 2.0.0rc0
#- PyTorch Version (e.g., 2.0): 1.13.1
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): MacOS

More info

No response

cc @justusschock @awaelchli @Borda

@awaelchli awaelchli added bug Something isn't working needs triage Waiting to be triaged by maintainers data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Mar 1, 2023
@awaelchli awaelchli added this to the v1.9.x milestone Mar 1, 2023
@awaelchli awaelchli added the breaking change Includes a breaking change label Mar 1, 2023
@awaelchli
Copy link
Member Author

@carmocca Do you agree that this was an unintended change from 1.9.x?
I believe the use case is valid: You may want to conditionally return a val loader from a DataModule or LightningModule depending on availability of validation data or not.

@carmocca carmocca modified the milestones: v1.9.x, 2.0 Mar 1, 2023
@carmocca
Copy link
Member

carmocca commented Mar 1, 2023

This breaking change was intended: #16800 (comment)
It is only in master

Skipping None could be error prone and lead to silent bugs. If you had no data, then you'd return an empty iterable

@awaelchli
Copy link
Member Author

And what is the argument against converting the original warning to an error? Don't you think it is strange to get the warning "Total length of NoneType across ranks is zero". It doesn't mention that his is about a dataloader, so it is unclear to the user what should be done.

@carmocca
Copy link
Member

carmocca commented Mar 1, 2023

Yes. I'll add error checks in the loops for this

@carmocca carmocca self-assigned this Mar 1, 2023
@carmocca carmocca added feature Is an improvement or enhancement and removed bug Something isn't working breaking change Includes a breaking change labels Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants