Skip to content

Predicting with custom dataloader overwrites "predict_dataloader()" method of module #8868

@OlfwayAdbayIgbay

Description

@OlfwayAdbayIgbay

🐛 Bug

When predicting using trainer.predict(module,DataLoader(my_dataset)), the trainer will overwrite the "predict_dataloader()" method of "module".

To Reproduce

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn import MSELoss, Linear, ReLU
import pytorch_lightning as pl
from pytorch_lightning.loggers import NeptuneLogger



class TestDataset(Dataset):
    def __init__(self, length = 100):
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        return torch.tensor(item).float(), torch.tensor(item**2).float()

tensor_dataset = TensorDataset(torch.tensor(torch.linspace(1,100)),
                               torch.tensor(torch.linspace(1,100**2)))

class Module(pl.LightningModule):
    def __init__(self):
        super(Module,self).__init__()
        self.criterion = MSELoss()
        self.dataset = TestDataset()
        self.layers = torch.nn.Sequential(Linear(1, 5), ReLU(),
                                          Linear(5, 5), ReLU(),
                                          Linear(5, 5), ReLU(),
                                          Linear(5, 1))

    def forward(self, x):
        return self.layers(self.layers(x.view(x.shape[0], -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        train_loss = self.criterion(self(x), y)
        self.log("train_loss", train_loss, on_step=False, on_epoch=True)
        return train_loss

    def predict_step(self, batch, batch_idx, dataloader_ix):
        x, y = batch
        train_loss = self.criterion(self(x), y)
        return train_loss

    def train_dataloader(self):
        return DataLoader(self.dataset,batch_size=10)

    def predict_dataloader(self):
        return DataLoader(self.dataset, batch_size=10)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

trainer = pl.Trainer( gpus=[2])
module = Module()
trainer.predict(module)
print(" ")
print("This should give TestDataset class:")
print(" ")
print(module.predict_dataloader().dataset)
trainer.predict(module, DataLoader(tensor_dataset))
print(" ")
print("This should give TestDataset class:")
print(module.predict_dataloader().dataset)

Expected behavior

The DataLoader returned from "predict_dataloader()" should be the DataLoader specified in the modules function, not the dataloader passed to trainer.predict() at some point in time before that.

Environment

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - available: True
    - version: 11.1
  • Packages:
    - numpy: 1.20.2
    - pyTorch_debug: False
    - pyTorch_version: 1.8.1
    - pytorch-lightning: 1.3.5
    - tqdm: 4.61.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.10
    - version: fix loading pkg while setup #71~20.04.1-Ubuntu SMP Thu Jul 15 17:46:08 UTC 2021

Additional context

I need to run trainer.predict() in two separate steps, because I need to do something with the returned data in between. I cant pass it as a list of dataloaders.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions