-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip Sanity Check if Validation Dataloader is None #15703
Labels
bug
Something isn't working
Comments
The following code reproduces the issues: import os
from typing import Optional
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BaseDataModule(LightningDataModule):
def __init__(
self,
):
super().__init__()
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None
def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
)
def val_dataloader(self):
if self.data_val:
return DataLoader(
dataset=self.data_val,
)
else:
return None
def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
)
class InheritedDataModule(BaseDataModule):
def setup(self, stage: Optional[str] = None):
if not self.data_train and not self.data_test:
self.data_train =RandomDataset(32, 64)
self.data_test = RandomDataset(32, 64)
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, InheritedDataModule())
trainer.test(model, InheritedDataModule())
if __name__ == "__main__":
run() Lightning 1.8.1 |
In recent versions of Lightning, this will be an error now:
To the best of my knowledge, it is intentional to avoid letting the user return an accidental None which could be hard to debug. The desired behavior that was asked here can still be achieved, by simply returning an empty iterable: def val_dataloader(self):
if self.data_val:
return DataLoader(
dataset=self.data_val,
)
else:
return [] # <--- return empty iterable here This will effectively skip any validation. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Bug description
I am making a generic base DataModule for all of my other DataModules to inherit from. The base datamodule will always define the train, test, and validation dataloader functions to abstract away the construction of those, but the actual datamodule may choose to not implement that particular split.
For
val_dataloader
in particular, if it is defined but returnsNone
, I would expect the sanity_check to be skipped, but instead it carries on and causes an error in the progress bar callback.How to reproduce the bug
Error messages and logs
Environment
No response
More info
No response
The text was updated successfully, but these errors were encountered: