-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
When predicting using trainer.predict(module,DataLoader(my_dataset)), the trainer will overwrite the "predict_dataloader()" method of "module".
To Reproduce
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn import MSELoss, Linear, ReLU
import pytorch_lightning as pl
from pytorch_lightning.loggers import NeptuneLogger
class TestDataset(Dataset):
def __init__(self, length = 100):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, item):
return torch.tensor(item).float(), torch.tensor(item**2).float()
tensor_dataset = TensorDataset(torch.tensor(torch.linspace(1,100)),
torch.tensor(torch.linspace(1,100**2)))
class Module(pl.LightningModule):
def __init__(self):
super(Module,self).__init__()
self.criterion = MSELoss()
self.dataset = TestDataset()
self.layers = torch.nn.Sequential(Linear(1, 5), ReLU(),
Linear(5, 5), ReLU(),
Linear(5, 5), ReLU(),
Linear(5, 1))
def forward(self, x):
return self.layers(self.layers(x.view(x.shape[0], -1)))
def training_step(self, batch, batch_idx):
x, y = batch
train_loss = self.criterion(self(x), y)
self.log("train_loss", train_loss, on_step=False, on_epoch=True)
return train_loss
def predict_step(self, batch, batch_idx, dataloader_ix):
x, y = batch
train_loss = self.criterion(self(x), y)
return train_loss
def train_dataloader(self):
return DataLoader(self.dataset,batch_size=10)
def predict_dataloader(self):
return DataLoader(self.dataset, batch_size=10)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
trainer = pl.Trainer( gpus=[2])
module = Module()
trainer.predict(module)
print(" ")
print("This should give TestDataset class:")
print(" ")
print(module.predict_dataloader().dataset)
trainer.predict(module, DataLoader(tensor_dataset))
print(" ")
print("This should give TestDataset class:")
print(module.predict_dataloader().dataset)
Expected behavior
The DataLoader returned from "predict_dataloader()" should be the DataLoader specified in the modules function, not the dataloader passed to trainer.predict() at some point in time before that.
Environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- available: True
- version: 11.1 - Packages:
- numpy: 1.20.2
- pyTorch_debug: False
- pyTorch_version: 1.8.1
- pytorch-lightning: 1.3.5
- tqdm: 4.61.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- version: fix loading pkg while setup #71~20.04.1-Ubuntu SMP Thu Jul 15 17:46:08 UTC 2021
Additional context
I need to run trainer.predict() in two separate steps, because I need to do something with the returned data in between. I cant pass it as a list of dataloaders.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on