Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

e2e-funsd-best.pt Error(s) in loading state_dict #13

Closed
bilykigor opened this issue Jul 12, 2023 · 3 comments
Closed

e2e-funsd-best.pt Error(s) in loading state_dict #13

bilykigor opened this issue Jul 12, 2023 · 3 comments

Comments

@bilykigor
Copy link

`
import torch
from src.models.graphs import SetModel
from src.paths import CHECKPOINTS

sm = SetModel(name='e2e', device=device)
model = sm.get_model(4, 2, chunks, False) # 4 and 2 refers to nodes and edge classes, check paper for details!
model.load_state_dict(torch.load(CHECKPOINTS / 'e2e-funsd-best.pt', map_location=torch.device('cpu'))) # load pretrained model
model.eval() # set the model for inference only
`

MODEL

-> Using E2E
-> Total params: 7674914
-> Device: False


RuntimeError Traceback (most recent call last)
Cell In[19], line 7
5 sm = SetModel(name='e2e', device=device)
6 model = sm.get_model(4, 2, chunks, False) # 4 and 2 refers to nodes and edge classes, check paper for details!
----> 7 model.load_state_dict(torch.load(CHECKPOINTS / 'e2e-funsd-best.pt', map_location=torch.device('cpu'))) # load pretrained model
8 model.eval() # set the model for inference only

File /Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
1666 error_msgs.insert(
1667 0, 'Missing key(s) in state_dict: {}. '.format(
1668 ', '.join('"{}"'.format(k) for k in missing_keys)))
1670 if len(error_msgs) > 0:
-> 1671 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1672 self.class.name, "\n\t".join(error_msgs)))
1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for E2E:
Missing key(s) in state_dict: "projector.modalities.3.0.weight", "projector.modalities.3.0.bias", "projector.modalities.3.1.weight", "projector.modalities.3.1.bias", "projector.modalities.4.0.weight", "projector.modalities.4.0.bias", "projector.modalities.4.1.weight", "projector.modalities.4.1.bias", "projector.modalities.5.0.weight", "projector.modalities.5.0.bias", "projector.modalities.5.1.weight", "projector.modalities.5.1.bias".
size mismatch for projector.modalities.0.0.weight: copying a param with shape torch.Size([300, 4]) from checkpoint, the shape in current model is torch.Size([300, 0]).
size mismatch for projector.modalities.1.0.weight: copying a param with shape torch.Size([300, 300]) from checkpoint, the shape in current model is torch.Size([300, 0]).
size mismatch for projector.modalities.2.0.weight: copying a param with shape torch.Size([300, 1448]) from checkpoint, the shape in current model is torch.Size([300, 0]).
size mismatch for message_passing.linear.weight: copying a param with shape torch.Size([900, 1800]) from checkpoint, the shape in current model is torch.Size([1800, 3600]).
size mismatch for message_passing.linear.bias: copying a param with shape torch.Size([900]) from checkpoint, the shape in current model is torch.Size([1800]).
size mismatch for message_passing.lynorm.weight: copying a param with shape torch.Size([900]) from checkpoint, the shape in current model is torch.Size([1800]).
size mismatch for message_passing.lynorm.bias: copying a param with shape torch.Size([900]) from checkpoint, the shape in current model is torch.Size([1800]).
size mismatch for edge_pred.W1.weight: copying a param with shape torch.Size([300, 1814]) from checkpoint, the shape in current model is torch.Size([300, 3614]).
size mismatch for node_pred.0.weight: copying a param with shape torch.Size([4, 900]) from checkpoint, the shape in current model is torch.Size([4, 1800]).

@andreagemelli
Copy link
Owner

Hi @bilykigor ,
Di you check this issue #8 ? it seems similar to me.
In case it is not of any help, I will dig deeper.

@bilykigor
Copy link
Author

Thanks @andreagemelli
That was same reason.

Fixed it with next preprocessing.yaml

FEATURES:
add_embs: true
add_eweights: true
add_fudge: false
add_geom: true
add_hist: false
add_visual: true
num_polar_bins: 8
GRAPHS:
data_type: img
edge_type: fully
node_granularity: gt
LOADER:
src_data: FUNSD

@andreagemelli
Copy link
Owner

Glad to hear that!
If you want to get rid of histograms simply retrain the network. 🤗
A.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants