-
Couldn't load subscription status.
- Fork 3.6k
Description
Describe the bug
If the dataset size is just one batch, then line 372:
self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval)
in trainer.py evaluates to 0 as nb_tng_batches is 1 and val_check_interval=0.98 by default.
When the trainer then gets to the validation step on line 852:
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
The error ZeroDivisionError: integer division or modulo by zero is raised.
To Reproduce
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as pl
class CoolModel(pl.LightningModule):
def __init__(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def configure_optimizers(self):
# REQUIRED
return [torch.optim.Adam(self.parameters(), lr=0.02)]
@pl.data_loader
def tng_dataloader(self):
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
subset = torch.utils.data.Subset(dataset, range(32)) # Dataset size = 1 batch.
return DataLoader(subset, batch_size=32)
from pytorch_lightning import Trainer
model = CoolModel()
trainer = Trainer()
trainer.fit(model)
Possible solutions
A fix could be to do like on line 364 by adding
self.val_check_batch = max(1, self.val_check_batch) after line 372.
Additional context
It may be an error that val_check_interval=0.98 by default.
A default value of 1.0 makes more sense, since it is more common to go through all training data before validating.