Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several datamodules ignoring batch_size #334

Closed
hecoding opened this issue Nov 4, 2020 · 4 comments 路 Fixed by #331 or #344
Closed

Several datamodules ignoring batch_size #334

hecoding opened this issue Nov 4, 2020 · 4 comments 路 Fixed by #331 or #344
Assignees
Labels
help wanted Extra attention is needed
Milestone

Comments

@hecoding
Copy link
Contributor

hecoding commented Nov 4, 2020

馃悰 Bug

Realized MNISTDataModule was ignoring batch_size parameter. I found a closed issue (#171) referring to that without a fix.

While fixing it myself (PR #331), I found more datamodules had this problem too - MNISTDataModule, BinaryMNISTDataModule, FashionMNISTDataModule, SklearnDataModule, SSLImagenetDataModule.

I could take care of that. My question is, is this signature from MNIST in use anymore?:

def train_dataloader(self, batch_size=64, transforms=None):

That is basically where the bug comes from. Other datamodules working fine, like CIFAR10DataModule, are simply:

def train_dataloader(self):

Same thing goes for val_dataloader and test_dataloader.

To Reproduce

Steps to reproduce the behavior:

  1. Set up MNISTDataModule to use a batch size anything else than 32.
  2. Run an experiment
  3. Check the batch sizes are 32 regardless

Code sample

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

batch_size = 64


class SampleModel(LightningModule):
    def configure_optimizers(self):
        pass

    def training_step(self, batch, batch_idx):
        x, y = batch
        assert x.shape[0] == batch_size
        print('fine')


model = SampleModel()
dm = MNISTDataModule(data_dir='~/Datasets/', batch_size=batch_size)

trainer = Trainer()
trainer.fit(model, dm)

Expected behavior

Batch sizes of 64.
Printing fine instead of assert error.

Environment

  • PyTorch Version (e.g., 1.0): 1.7
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.8
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information: pytorch-lightning-bolts 0.2.5

Additional context

@hecoding hecoding added the help wanted Extra attention is needed label Nov 4, 2020
@github-actions
Copy link

github-actions bot commented Nov 4, 2020

Hi! thanks for your contribution!, great first issue!

@ananyahjha93
Copy link
Contributor

@hecoding send in the PR.

the signature is def train_dataloader(self):. The datamodule class init should take the batch size params and not the train_dataloader.

@hecoding
Copy link
Contributor Author

hecoding commented Nov 5, 2020

Cool, I fixed the signatures too. Please @ananyahjha93 have a look at the PR here #331
Should I assign somebody for review myself? I'm a bit lost on that.

@hecoding
Copy link
Contributor Author

hecoding commented Nov 5, 2020

I can fix the rest of the data modules listed too. Let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
3 participants