Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataLoader with multiple inputs caused errors #75

Closed
Runda-Xu opened this issue Jun 1, 2021 · 2 comments · Fixed by #76
Closed

DataLoader with multiple inputs caused errors #75

Runda-Xu opened this issue Jun 1, 2021 · 2 comments · Fixed by #76
Labels
new feature Feature request to work on

Comments

@Runda-Xu
Copy link

Runda-Xu commented Jun 1, 2021

I have a train_loader with multiple inputs:

train_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(input1, input2, input3, label) ,batch_size=batch_size, shuffle=True)

These inputs are arrays with different shapes, so it's hard to concatenate them into a single tensor. They work well with normal pytorch code. However, Ensemble-Pytorch can not deal with it.
Hope Ensemble-Pytorch can support DataLoader with multiple inputs in the future.
Thank you.

@xuyxu
Copy link
Member

xuyxu commented Jun 1, 2021

Hi @Runda-Xu, thanks for your suggestions! I will take a look at this feature request, and get back to you soon.

@xuyxu xuyxu added the new feature Feature request to work on label Jun 1, 2021
@xuyxu
Copy link
Member

xuyxu commented Jun 1, 2021

I am wondering that whether the code snippet below meets your requirement, where we have created a dataloader with three input tensors input_1, input_2, input_3, and it is passed into a model whose forward method takes three inputs accordingly. To make the entire workflow run as expected, we can pass multiple inputs into the model in the form of non-keyword arguments (i.e., *data). There should be no problem as long as the order of these inputs are the same between arguments in forward and arguments when creating the dataloader.

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

class TOY(nn.Module):
    def __init__(self):
        super(TOY, self).__init__()
        return

    def forward(self, input_1, input_2, input_3):
        return input_1

nb_samples = 100
input_1 = torch.randn(nb_samples, 10)
input_2 = torch.randn(nb_samples, 5)
input_3 = torch.randn(nb_samples, 7)
target = torch.empty(nb_samples, dtype=torch.long).random_(10)

dataset = TensorDataset(input_1, input_2, input_3, target)
loader = DataLoader(dataset, batch_size=2)

model = TOY()

for batch_idx, elem in enumerate(loader):
    data, target = elem[:-1], elem[-1]
    print(model(*data))
 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new feature Feature request to work on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants