Skip to content

An Extra argument passed to the class, loaded from load_from_checkpoint. #2386

@nischal-sanil

Description

@nischal-sanil

🐛 Bug

Hello,
I was facing few issues while using the trainer.test() function, on debugging I found out that the problem was with the _load_model_state class method which is called by load_from_checkpoint.

Code For reference

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
    # pass in the values we saved automatically
    if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
        model_args = {}

        # add some back compatibility, the actual one shall be last
        for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
            if hparam_key in checkpoint:
                model_args.update(checkpoint[hparam_key])

        if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
            model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)

        args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
        init_args_name = inspect.signature(cls).parameters.keys()

        if args_name == 'kwargs':
            cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
            kwargs.update(**cls_kwargs)
        elif args_name:
            if args_name in init_args_name:
                kwargs.update({args_name: model_args})
        else:
            args = (model_args, ) + args

    # load the state_dict on the model automatically
    model = cls(*args, **kwargs)
    model.load_state_dict(checkpoint['state_dict'])

    # give model a chance to load something
    model.on_load_checkpoint(checkpoint)

    return model

Consider the case where the model has no arguments, which corresponds to LightModel.load_from_checkpoint('path'). Here, the else clause of the if-elif is being executed where the agrs variable is updated from an empty tuple to a tuple with an empty dictionary args = (model_args, ) + args (as model_args={}). Therefore, while unpacking the args and kwargs (model = cls(*args, **kwargs)), There is an extra argument being passed which raises a TypeError: __init__() takes 1 positional arguments but 2 were given. #2364

In some cases if the model has an argument and the user has forgotten to add it in the load_from_checkpoint, then an empty dictionary will be passed instead and it raises other errors depending on the code. For example, in the issue #2359 an empty dict is passed while loading the model and hence raises RuntimeError: Error(s) in loading state_dict for Model:.

I do not fully understand what is happening in the function. It would be great if someone can suggest changes to make in the comments so that I can start working after updating the changes in my forked repo.

Steps to reproduce

!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

mnist_model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=3)    
trainer.fit(mnist_model, train_loader)  

test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
trainer.test(test_dataloaders=test_loader)

Which returns:

TypeError                                 Traceback (most recent call last)

<ipython-input-5-50449ee4f6cc> in <module>()
      1 test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
----> 2 trainer.test(test_dataloaders=test_loader)

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path)
   1168             if ckpt_path == 'best':
   1169                 ckpt_path = self.checkpoint_callback.best_model_path
-> 1170             model = self.get_model().load_from_checkpoint(ckpt_path)
   1171 
   1172         self.testing = True

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
    167         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    168 
--> 169         model = cls._load_model_state(checkpoint, *args, **kwargs)
    170         return model
    171 

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs)
    201 
    202         # load the state_dict on the model automatically
--> 203         model = cls(*cls_args, **cls_kwargs)
    204         model.load_state_dict(checkpoint['state_dict'])
    205 

TypeError: __init__() takes 1 positional argument but 2 were given

Expected behavior

Start testing

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions