Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_from_checkpoint: TypeError: __init__() missing 1 required positional argument #2909

Closed
siahuat0727 opened this issue Aug 11, 2020 · 16 comments · Fixed by #2911
Closed
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@siahuat0727
Copy link
Contributor

siahuat0727 commented Aug 11, 2020

❓ Questions and Help

What is your question?

load_from_checkpoint: TypeError: init() missing 1 required positional argument

I have read the issues before, but the things different is my LightningModule is inherited from my self-defined LightningModule.

How to solve this problem or what is the best practice better suited to my needs?

Code

To reproduce the error:

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from argparse import Namespace

class _LitModel(pl.LightningModule):

    def __init__(self, hparams):
        super().__init__()
        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, hparams.classes)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class LitModel(_LitModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--classes', type=int, default=10)
parser.add_argument('--checkpoint', type=str, default=None)
hparams = parser.parse_args()

mnist_train = MNIST(os.getcwd(), train=True, download=True,
                    transform=transforms.ToTensor())
mnist_train = DataLoader(mnist_train, num_workers=1)
mnist_val = MNIST(os.getcwd(), train=False, download=False,
                  transform=transforms.ToTensor())
mnist_val = DataLoader(mnist_val, num_workers=1)

# A bit weird here. I just want to show `load_from_checkpoint` will fail.
if hparams.checkpoint is None:
    model = LitModel(hparams)
else:
    model = LitModel.load_from_checkpoint(hparams.checkpoint)

trainer = Trainer(max_epochs=2, limit_train_batches=2,
                  limit_val_batches=2, progress_bar_refresh_rate=0)
trainer.fit(model, mnist_train, mnist_val)

Error msg

Traceback (most recent call last):
  File "main.py", line 64, in <module>
    model = LitModel.load_from_checkpoint(hparams.checkpoint)
  File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 138, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 174, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
  File "main.py", line 46, in __init__
    super().__init__(*args, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'hparams'

How to run to get the error

$ python3 main.py 
$ python3 main.py --checkpoint lightning_logs/version_0/checkpoints/epoch\=1.ckpt

What's your environment?

  • OS: Linux
  • Packaging: pip
  • Version 0.9.0rc12
@siahuat0727 siahuat0727 added the question Further information is requested label Aug 11, 2020
@awaelchli
Copy link
Member

awaelchli commented Aug 11, 2020

Did you try to call self.save_hyperparameters() in _LitModel?
Because it looks like hparams were not saved to checkpoint.

@siahuat0727
Copy link
Contributor Author

@awaelchli
Hihi, the result is the same.
It works if I directly use _LitModel instead of LitModel. So I think that's sth about inheritance.

@siahuat0727
Copy link
Contributor Author

https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html

Anything assigned to self.hparams will also be saved automatically.

@awaelchli awaelchli self-assigned this Aug 11, 2020
@awaelchli awaelchli added the bug Something isn't working label Aug 11, 2020
@awaelchli
Copy link
Member

@siahuat0727 I can confirm this is a bug. I fixed it and reduced your example to a minimal test case, so it won't break in the future. Thanks for providing a easy to reproduce script!

@siahuat0727
Copy link
Contributor Author

Great job. Thanks!!

@shentianxiao
Copy link

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)

@ghost
Copy link

ghost commented Nov 12, 2020

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)

Bump

@2dot71mily
Copy link

bump for version 1.0.6 as well

@stathius
Copy link

same problem here on 1.0.4

@stathius
Copy link

Apparently the problem is that checkpoint['hparams_name'] is empty. Maybe the problem is in the saving of the module when it is inherited?

@stathius
Copy link

What solved it for me is that instead of passing the hparams, you can pass them as kwargs. So in your class use:

class my_pl_module(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()

Still a bug though cause the hparams method is not yet deprecated.

@awaelchli
Copy link
Member

@stathius yes, the "old" hparams method is not yet deprecated but it simply has conceptual flaws in terms of typing, that cannot be fixed as in a "bugfix". The solution we came up with here is to simply decouple two things:

  1. Saving hyperparameters into the checkpoint
  2. making hyperparameters accessible through a convenient self.hparams "namespace".

And the code you posted is exactly doing that, and this is the recommended way today.

@pietz
Copy link

pietz commented Oct 26, 2021

what can i do if i already trained my models without calling self.save_hyperparameters() explicitely?

@awaelchli
Copy link
Member

@pietz in this case you can instantiate your model normally, model = YourModel(...) and then load the state dict from the checkpoint:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"]

@pietz
Copy link

pietz commented Oct 26, 2021

@awaelchli Ah, thank you. Looking back I should have been able to figure this one out myself :)

@TangJiakai
Copy link

@pietz in this case you can instantiate your model normally, model = YourModel(...) and then load the state dict from the checkpoint:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"]

but we sometimes have to load some other params like optimizer params, we have to use several load_* function. It is not good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants