Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

make DataModule compatible with Python dataclass #8272

Closed
awaelchli opened this issue Jul 4, 2021 · 6 comments 路 Fixed by #9039
Closed

make DataModule compatible with Python dataclass #8272

awaelchli opened this issue Jul 4, 2021 · 6 comments 路 Fixed by #9039
Assignees
Labels
data handling Generic data-related topic feature Is an improvement or enhancement good first issue Good for newcomers help wanted Open to be worked on
Milestone

Comments

@awaelchli
Copy link
Member

awaelchli commented Jul 4, 2021

馃殌 Feature

Support the following:

@dataclass
class MyDataModule(LightningDataModule):
    pass

Motivation

To reduce boilerplate code is at the core of philosophy in Lightning. It should be compatible with dataclasses.

Code sample

Here is an example. It currently does not work as we have some internal attributes that don't play well with the dataclass.

import os
from dataclasses import dataclass

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


@dataclass
class BoringDataModule(LightningDataModule):

    batch_size: int = 2

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.fit(model, datamodule=BoringDataModule())


if __name__ == '__main__':
    run()

Alternatives

#3792 introduces save_hyperparameters() for the datamodule. However, I believe the dataclass approach here is not in conflict with that because both could be useful at the same time.

@awaelchli awaelchli added feature Is an improvement or enhancement help wanted Open to be worked on good first issue Good for newcomers labels Jul 4, 2021
@awaelchli awaelchli added this to the v1.5 milestone Jul 4, 2021
@awaelchli awaelchli added the data handling Generic data-related topic label Jul 4, 2021
@QueshAnmak
Copy link

Hi, I would like to work on this issue, could you please assign it to me?

@awaelchli
Copy link
Member Author

Hey! Yes, you can take it if you want. I haven't really had the time to look why it is not working. I don't know how difficult it would be.
Give it a try. We appreciate the help!

@QueshAnmak
Copy link

Hey, Is there any slack or discord server I could join to ask my queries?

@awaelchli
Copy link
Member Author

Yes, feel free to join the PyTorch Lightning slack: https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ

@QueshAnmak
Copy link

Screenshot 2021-07-10 025106
I set init to False (which by default is set to True). This seems to fix the issue. I think the "@Property" decorator on the attributes might be the cause of the issue.
Screenshot 2021-07-10 025441
How should I proceed?

@awaelchli
Copy link
Member Author

Oh yes, but actually we want dataclass to generate an init for us. Otherwise we don't get any great value of a dataclass here.
But actually, all we have to do is this:

    def __post_init__(self):
        super().__init__()

Because the init generated by dataclass does not call super(), and that's required by the datamodule.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic feature Is an improvement or enhancement good first issue Good for newcomers help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants