-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
Hello,
I was facing few issues while using the trainer.test() function, on debugging I found out that the problem was with the _load_model_state class method which is called by load_from_checkpoint.
Code For reference
@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
# add some back compatibility, the actual one shall be last
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
if hparam_key in checkpoint:
model_args.update(checkpoint[hparam_key])
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
init_args_name = inspect.signature(cls).parameters.keys()
if args_name == 'kwargs':
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
kwargs.update(**cls_kwargs)
elif args_name:
if args_name in init_args_name:
kwargs.update({args_name: model_args})
else:
args = (model_args, ) + args
# load the state_dict on the model automatically
model = cls(*args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])
# give model a chance to load something
model.on_load_checkpoint(checkpoint)
return model
Consider the case where the model has no arguments, which corresponds to LightModel.load_from_checkpoint('path'). Here, the else clause of the if-elif is being executed where the agrs variable is updated from an empty tuple to a tuple with an empty dictionary args = (model_args, ) + args (as model_args={}). Therefore, while unpacking the args and kwargs (model = cls(*args, **kwargs)), There is an extra argument being passed which raises a TypeError: __init__() takes 1 positional arguments but 2 were given. #2364
In some cases if the model has an argument and the user has forgotten to add it in the load_from_checkpoint, then an empty dictionary will be passed instead and it raises other errors depending on the code. For example, in the issue #2359 an empty dict is passed while loading the model and hence raises RuntimeError: Error(s) in loading state_dict for Model:.
I do not fully understand what is happening in the function. It would be great if someone can suggest changes to make in the comments so that I can start working after updating the changes in my forked repo.
Steps to reproduce
!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def test_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
mnist_model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=3)
trainer.fit(mnist_model, train_loader)
test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
trainer.test(test_dataloaders=test_loader)
Which returns:
TypeError Traceback (most recent call last)
<ipython-input-5-50449ee4f6cc> in <module>()
1 test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
----> 2 trainer.test(test_dataloaders=test_loader)
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path)
1168 if ckpt_path == 'best':
1169 ckpt_path = self.checkpoint_callback.best_model_path
-> 1170 model = self.get_model().load_from_checkpoint(ckpt_path)
1171
1172 self.testing = True
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
167 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
168
--> 169 model = cls._load_model_state(checkpoint, *args, **kwargs)
170 return model
171
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs)
201
202 # load the state_dict on the model automatically
--> 203 model = cls(*cls_args, **cls_kwargs)
204 model.load_state_dict(checkpoint['state_dict'])
205
TypeError: __init__() takes 1 positional argument but 2 were given
Expected behavior
Start testing