Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Late update of Trainer current_epoch property for LightningDataModule #3974

Closed
nrupatunga opened this issue Oct 8, 2020 · 3 comments · Fixed by #3975
Closed

[Bug]: Late update of Trainer current_epoch property for LightningDataModule #3974

nrupatunga opened this issue Oct 8, 2020 · 3 comments · Fixed by #3975
Labels
bug Something isn't working help wanted Open to be worked on
Milestone

Comments

@nrupatunga
Copy link
Contributor

🐛 Bug

Late update of Trainer current_epoch property for LightningDataModule object.

To Reproduce

The below code reproduces the issue:

Please check for the print logs for the current_epoch number in train_dataloader.

Code sample

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import random_split, DataLoader

import pytorch_lightning as pl


class LitModel(pl.LightningModule):

    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # We take in input dimensions as parameters and use those to dynamically build model.
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes))

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        print('\n----------------------------')
        print(f'Current epoch: {self.trainer.current_epoch}')
        print('----------------------------')
        if self.trainer.current_epoch > 2:
            return DataLoader(self.mnist_train, batch_size=32)
        else:
            return DataLoader(self.mnist_train, batch_size=32)


# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1, reload_dataloaders_every_epoch=True)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)

logs

Note: current_epoch is 0 two time, which is indeed due to the late update of this property

----------------------------
Current epoch: 0
----------------------------
/home/nthere/2020/pytorch-lightning/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 0:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1700/1719 [00:10<00:00, 157.72it/s, loss=0.284, v_num=16]
----------------------------
Current epoch: 0
----------------------------
Epoch 1:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1700/1719 [00:10<00:00, 164.23it/s, loss=0.200, v_num=16]
----------------------------
Current epoch: 1
----------------------------
Epoch 2:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1700/1719 [00:10<00:00, 164.02it/s, loss=0.168, v_num=16]
----------------------------
Current epoch: 2
----------------------------
Epoch 3:  10%|████████████████▊                                                                                                                                                | 180/1719 [00:01<00:09, 163.17it/s, loss=0.185, v_num=16]
^C/home/nthere/2020/pytorch-lightning/pytorch_lightning/utilities/distributed.py:45: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...
  warnings.warn(*args, **kwargs)
Epoch 3:  10%|████████████████▊                                                                                                                                                | 180/1719 [00:01<00:10, 150.43it/s, loss=0.185, v_num=16]

Expected behavior

current_epoch should be updated and reflect the right epoch number for the LightningDataModule

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

* CUDA:
        - GPU:
                - GeForce GTX 960
        - available:         True
        - version:           10.1
* Packages:
        - numpy:             1.18.5
        - pyTorch_debug:     False
        - pyTorch_version:   1.5.0+cu101
        - pytorch-lightning: 0.10.0
        - tqdm:              4.47.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.3
        - version:           #113~16.04.1-Ubuntu SMP Fri Jul 10 04:37:08 UTC 2020

@nrupatunga nrupatunga added bug Something isn't working help wanted Open to be worked on labels Oct 8, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Oct 8, 2020

Hi! thanks for your contribution!, great first issue!

@ananthsub
Copy link
Contributor

ananthsub commented Oct 8, 2020

great catch @nrupatunga !

  • this bug happens specifically when using the flag reload_dataloaders_every_epoch
  • this is because reload_dataloaders_every_epoch happens before train_epoch_start
  • trainer.current_epoch is updated in train_epoch_start, so the epoch state when used inside of train_dataloader is stale

We could:

  • set trainer.current_epoch directly in the trainer at the beginning of the loop
  • Move reload_dataloaders_every_epoch into train_epoch_start
  • Or both?

What do you think @williamFalcon @awaelchli ?

@nrupatunga
Copy link
Contributor Author

Suggested changes in this PR
#3975

Please review let me know if there is a better way, Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants