Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overfit batches parameter gives a validation batch #15021

Open
HekpoMaH opened this issue Oct 6, 2022 · 1 comment
Open

Overfit batches parameter gives a validation batch #15021

HekpoMaH opened this issue Oct 6, 2022 · 1 comment
Labels
bug Something isn't working help wanted Open to be worked on trainer: fit

Comments

@HekpoMaH
Copy link

HekpoMaH commented Oct 6, 2022

Bug description

When overfitting on a single batch and defining dataloaders in class, the batch provided to the validation step is different from the batch on the training step. I was told in the slack community that this is NOT the intended behaviour.

How to reproduce the bug

import pytorch_lightning as pl
import torch_geometric
import torch

dataset = [torch_geometric.data.Data(x=torch.tensor([i])) for i in range(10)]
val_dataset = [torch_geometric.data.Data(x=torch.tensor([j])) for j in range(10,20)]
class LitModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.tensor([0.]))
    def train_dataloader(self):
        return torch_geometric.loader.DataLoader(dataset, batch_size=2)
    def val_dataloader(self):
        return torch_geometric.loader.DataLoader(val_dataset, batch_size=2)

    def training_step(self, batch, batch_idx):
        print('train', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def validation_step(self, batch, batch_idx):
        print('val', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                               lr=.0001)
        return optimizer

litmod = LitModule()
trainer = pl.Trainer(
    overfit_batches=1,
    accelerator='cuda',
    max_epochs=20,
    check_val_every_n_epoch=10,
)
trainer.fit(litmod)
print(litmod)

The val batch is the [10,11] tensor, the train batch is the [0,1] tensor
image


### Environment

  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3080 Laptop GPU
    • available: True
    • version: 11.6
  • Lightning:
    • pytorch-lightning: 1.7.7
    • torch: 1.12.1+cu116
    • torch-cluster: 1.6.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torch-sparse: 0.6.15
    • torch-spline-conv: 1.2.1
    • torchaudio: 0.12.1+cu116
    • torchmetrics: 0.10.0
    • torchvision: 0.13.1+cu116
  • Packages:
    • absl-py: 1.2.0
    • aiohttp: 3.8.3
    • aiosignal: 1.2.0
    • anndata: 0.8.0
    • astroid: 2.12.10
    • astunparse: 1.6.3
    • async-timeout: 4.0.2
    • attrs: 22.1.0
    • blinker: 1.4
    • brotlipy: 0.7.0
    • cachetools: 5.2.0
    • certifi: 2022.9.24
    • cffi: 1.15.1
    • charset-normalizer: 2.1.1
    • chex: 0.1.5
    • click: 8.0.4
    • colorama: 0.4.5
    • contourpy: 1.0.5
    • cryptography: 37.0.2
    • cycler: 0.11.0
    • dill: 0.3.5.1
    • distlib: 0.3.6
    • dm-clrs: 1.0.0
    • dm-haiku: 0.0.8
    • dm-tree: 0.1.7
    • etils: 0.8.0
    • filelock: 3.8.0
    • flatbuffers: 1.12
    • fonttools: 4.37.3
    • frozenlist: 1.3.1
    • fsspec: 2022.8.2
    • gast: 0.4.0
    • google-auth: 2.12.0
    • google-auth-oauthlib: 0.4.6
    • google-pasta: 0.2.0
    • googleapis-common-protos: 1.56.4
    • grpcio: 1.49.1
    • h5py: 3.7.0
    • idna: 3.4
    • importlib-metadata: 4.11.4
    • importlib-resources: 5.9.0
    • isort: 5.10.1
    • jax: 0.3.21
    • jaxlib: 0.3.20
    • jinja2: 3.1.2
    • jmp: 0.0.2
    • joblib: 1.2.0
    • jsonschema: 4.16.0
    • keras: 2.9.0
    • keras-preprocessing: 1.1.2
    • kiwisolver: 1.4.4
    • lazy-object-proxy: 1.7.1
    • libclang: 14.0.6
    • llvmlite: 0.39.1
    • markdown: 3.4.1
    • markupsafe: 2.1.1
    • matplotlib: 3.6.0
    • mccabe: 0.7.0
    • mkl-fft: 1.3.1
    • mkl-random: 1.2.2
    • mkl-service: 2.4.0
    • msgpack: 1.0.4
    • multidict: 6.0.2
    • natsort: 8.2.0
    • networkx: 2.8.6
    • numba: 0.56.2
    • numexpr: 2.8.3
    • numpy: 1.23.3
    • oauthlib: 3.2.1
    • opt-einsum: 3.3.0
    • optax: 0.1.3
    • packaging: 21.3
    • pandas: 1.5.0
    • patsy: 0.5.2
    • pillow: 9.2.0
    • pip: 22.1.2
    • platformdirs: 2.5.2
    • promise: 2.3
    • protobuf: 3.19.6
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycparser: 2.21
    • pydeprecate: 0.3.2
    • pyjwt: 2.5.0
    • pylint: 2.15.3
    • pynndescent: 0.5.7
    • pyopenssl: 22.0.0
    • pyparsing: 3.0.9
    • pyrsistent: 0.18.1
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • pytorch-lightning: 1.7.7
    • pytz: 2022.2.1
    • pyu2f: 0.1.5
    • pyyaml: 6.0
    • ray: 2.0.0
    • requests: 2.28.1
    • requests-oauthlib: 1.3.1
    • rsa: 4.9
    • scanpy: 1.9.1
    • scikit-learn: 1.1.2
    • scikit-misc: 0.1.4
    • scipy: 1.9.1
    • seaborn: 0.12.0
    • session-info: 1.0.0
    • setuptools: 65.4.1
    • six: 1.16.0
    • statsmodels: 0.13.2
    • stdlib-list: 0.8.0
    • tables: 3.7.0
    • tabulate: 0.8.10
    • tensorboard: 2.9.1
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.1
    • tensorboardx: 2.5.1
    • tensorflow: 2.9.1
    • tensorflow-estimator: 2.9.0
    • tensorflow-io-gcs-filesystem: 0.27.0
    • tensorflow-metadata: 1.10.0
    • termcolor: 2.0.1
    • tfds-nightly: 4.5.2.dev202204190046
    • threadpoolctl: 3.1.0
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.11.5
    • toolz: 0.12.0
    • torch: 1.12.1+cu116
    • torch-cluster: 1.6.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torch-sparse: 0.6.15
    • torch-spline-conv: 1.2.1
    • torchaudio: 0.12.1+cu116
    • torchmetrics: 0.10.0
    • torchvision: 0.13.1+cu116
    • tqdm: 4.64.1
    • typing-extensions: 4.3.0
    • umap-learn: 0.5.3
    • urllib3: 1.26.12
    • virtualenv: 20.16.5
    • werkzeug: 2.2.2
    • wheel: 0.37.1
    • wrapt: 1.14.1
    • yapf: 0.32.0
    • yarl: 1.8.1
    • zipp: 3.8.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.4
    • version: #202203181321-Ubuntu SMP PREEMPT Fri Mar 18 13:28:32 UTC 2022


### More info

_No response_

cc @justusschock @awaelchli
@HekpoMaH HekpoMaH added the needs triage Waiting to be triaged by maintainers label Oct 6, 2022
@awaelchli awaelchli added bug Something isn't working trainer: fit and removed needs triage Waiting to be triaged by maintainers labels Oct 9, 2022
@awaelchli awaelchli added this to the pl:1.7.x milestone Oct 9, 2022
@carmocca carmocca modified the milestones: pl:1.7.x, v1.8.x Oct 13, 2022
@Borda Borda modified the milestones: v1.8.x, v1.9 Jan 6, 2023
@Borda Borda modified the milestones: v1.9, v1.9.x Jan 16, 2023
@awaelchli awaelchli added the help wanted Open to be worked on label Dec 31, 2023
@awaelchli awaelchli removed this from the v1.9.x milestone Dec 31, 2023
@israfelsr
Copy link

I had the same problem. I was going crazy because in the documentation they supposed to be the same 😅.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on trainer: fit
Projects
No open projects
Status: No status
Development

No branches or pull requests

5 participants