Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip Sanity Check if Validation Dataloader is None #15703

Closed
CCInc opened this issue Nov 16, 2022 · 2 comments
Closed

Skip Sanity Check if Validation Dataloader is None #15703

CCInc opened this issue Nov 16, 2022 · 2 comments
Labels
bug Something isn't working

Comments

@CCInc
Copy link

CCInc commented Nov 16, 2022

Bug description

I am making a generic base DataModule for all of my other DataModules to inherit from. The base datamodule will always define the train, test, and validation dataloader functions to abstract away the construction of those, but the actual datamodule may choose to not implement that particular split.

For val_dataloader in particular, if it is defined but returns None, I would expect the sanity_check to be skipped, but instead it carries on and causes an error in the progress bar callback.

How to reproduce the bug

class BaseDataModule(LightningDataModule):
    def __init__(
        self,
        config: DataModuleConfig = DataModuleConfig(),
    ):
        super().__init__()

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
        )

    def val_dataloader(self):
        if self.data_val:
            return DataLoader(
                dataset=self.data_val,
            )
        else:
            return None

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
        )

class InheritedDataModule(BaseDataModule):
    def setup(self, stage: Optional[str] = None):
        if not self.data_train and not self.data_test:
            self.data_train = MyDataset()
            self.data_test = MyDataset()

Error messages and logs


# Error messages and logs here please
/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:366: UserWarning: One of given dataloaders is
None and it will be skipped.
  rank_zero_warn("One of given dataloaders is None and it will be skipped.")

[2022-11-16 11:30:38,255][src.utils.utils][ERROR] - 
Traceback (most recent call last):
  File "/home/chris/lightning-hydra-template/src/utils/utils.py", line 40, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
  File "/home/chris/lightning-hydra-template/src/train.py", line 98, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 582, in fit
    call._call_and_handle_interrupt(
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 624, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1140, in _run_stage
    self._run_train()
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1153, in _run_train
    self._run_sanity_check()
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1227, in _run_sanity_check
    self._call_callback_hooks("on_sanity_check_end")
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1343, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/chris/miniconda3/envs/pyl/lib/python3.9/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", line 358, in on_sanity_check_end
    assert self.val_sanity_progress_bar_id is not None
AssertionError

Environment

No response

More info

No response

@CCInc CCInc added the needs triage Waiting to be triaged by maintainers label Nov 16, 2022
@awaelchli
Copy link
Member

The following code reproduces the issues:

import os
from typing import Optional

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BaseDataModule(LightningDataModule):
    def __init__(
        self,
    ):
        super().__init__()

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
        )

    def val_dataloader(self):
        if self.data_val:
            return DataLoader(
                dataset=self.data_val,
            )
        else:
            return None

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
        )


class InheritedDataModule(BaseDataModule):
    def setup(self, stage: Optional[str] = None):
        if not self.data_train and not self.data_test:
            self.data_train =RandomDataset(32, 64)
            self.data_test = RandomDataset(32, 64)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, InheritedDataModule())
    trainer.test(model, InheritedDataModule())


if __name__ == "__main__":
    run()

Lightning 1.8.1

@awaelchli awaelchli added bug Something isn't working progress bar: rich and removed needs triage Waiting to be triaged by maintainers labels Nov 19, 2022
@awaelchli awaelchli added this to the v1.8.x milestone Nov 19, 2022
@Borda Borda modified the milestones: v1.8.x, v1.9 Jan 6, 2023
@Borda Borda modified the milestones: v1.9, v1.9.x Jan 16, 2023
@awaelchli
Copy link
Member

In recent versions of Lightning, this will be an error now:

TypeError: An invalid dataloader was returned from `InheritedDataModule.val_dataloader()`. Found None.

To the best of my knowledge, it is intentional to avoid letting the user return an accidental None which could be hard to debug.

The desired behavior that was asked here can still be achieved, by simply returning an empty iterable:

def val_dataloader(self):
        if self.data_val:
            return DataLoader(
                dataset=self.data_val,
            )
        else:
            return []  # <---   return empty iterable here

This will effectively skip any validation.

@awaelchli awaelchli removed this from the v1.9.x milestone Dec 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants