Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named converted to regular tuples when sent to the gpu. #1588

Closed
nathanbreitsch opened this issue Apr 24, 2020 · 2 comments 路 Fixed by #1589
Closed

Named converted to regular tuples when sent to the gpu. #1588

nathanbreitsch opened this issue Apr 24, 2020 · 2 comments 路 Fixed by #1589
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@nathanbreitsch
Copy link
Contributor

馃悰 Bug

Named tuples returned from Dataset get converted to regular tuples when sent to the gpu.
This happens because isinstance(instance_of_a_named_tuple, tuple) evaluates to True in distrib_parts.py
https://github.com/PyTorchLightning/pytorch-lightning/blob/67d5f4dc392250d23bfeb11aba45e919a99ff1c0/pytorch_lightning/trainer/distrib_parts.py#L463

To Reproduce

import pytorch_lightning as pl
from collections import namedtuple
import torch
import numpy

NamedTupleDemoInput = namedtuple('DemoInput', ['x1', 'x2', 'y'])

class NamedTupleDemoDataset:
    def __len__(self):
        return 30000

    def __getitem__(self, index):
        x1 = numpy.random.uniform(0, 100)
        x2 = numpy.random.uniform(0, 100)
        y = 2*x1 + 3*x2 + numpy.random.normal(0, 0.05)
        return NamedTupleDemoInput(x1, x2, y)

class WeightedSum(torch.nn.Module):
    def __init__(self):
        super(WeightedSum, self).__init__()
        self.a = torch.nn.Parameter(torch.zeros(1))
        self.b = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x1, x2):
        return self.a * x1 + self.b * x2

class NamedTupleDemo(pl.LightningModule):

    def __init__(self):
        super(NamedTupleDemo, self).__init__()
        self.model = WeightedSum()

    def forward(self, x1, x2):
        return self.model(x1, x2)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(NamedTupleDemoDataset(), batch_size=128)

    def training_step(self, batch, batch_index):
        yhat = self.forward(batch.x1, batch.x2)
        return {'loss': torch.nn.functional.mse_loss(batch.y, yhat)}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)

if __name__ == '__main__':
    module = NamedTupleDemo()
    pl.Trainer(max_epochs=20, gpus=1).fit(module)
    print(f'a={float(module.model.a)} b={float(module.model.b)}')
Traceback (most recent call last):
  File "demo.py", line 48, in <module>
    pl.Trainer(max_epochs=20, gpus=1).fit(module)
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 749, in fit
    self.single_gpu_train(model)
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/distrib_parts.py", line 491, in single_gpu_train
    self.run_pretrain_routine(model)
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 910, in run_pretrain_routine
    self.train()
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 384, in train
    self.run_training_epoch()
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 456, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 597, in optimizer_closure
    output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens)
  File "/home/n/repos/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 770, in training_forward
    output = self.model.training_step(*args)
  File "demo.py", line 40, in training_step
    yhat = self.forward(batch.x1, batch.x2)
AttributeError: 'tuple' object has no attribute 'x1'

Expected behavior

Namedtuples returned from the dataset should be keep their original fields.

Environment

  • CUDA:
    - GPU:
    - GeForce RTX 2080 Ti
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.18.3
    - pyTorch_debug: False
    - pyTorch_version: 1.5.0
    - pytorch-lightning: 0.7.4rc5
    - tensorboard: 2.2.1
    - tqdm: 4.45.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor:
    - python: 3.8.2
    - version: Proposal for help聽#1 SMP PREEMPT Sun, 05 Apr 2020 05:13:14 +0000
@nathanbreitsch nathanbreitsch added bug Something isn't working help wanted Open to be worked on labels Apr 24, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@Vozf
Copy link
Contributor

Vozf commented Mar 1, 2021

I am having similar troubles with multiGpu setup? Is that fixed for multiple gpus in the pr? If not I believe this should be reopened.
In my case everything works fine for single gpu but with 2 gpus I get the error
AttributeError: 'tuple' object has no attribute 'image'
But it shouldn't be tuple on the error line, there should be namedtuple

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants