Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

size mismatch for loading model #3

Open
marcos452 opened this issue May 2, 2024 · 1 comment
Open

size mismatch for loading model #3

marcos452 opened this issue May 2, 2024 · 1 comment

Comments

@marcos452
Copy link

marcos452 commented May 2, 2024

Thanks for your great work.

I am trying to load large model ./icassp_sasb_ckpts/SpeechCLIP+/large/flickr/cascaded/model.ckpt by using example.py(However, it loads base model, there is no error). It occurs following error:

Using cache found in /home/marco/.cache/torch/hub/s3prl_cache/4a54d64fa42b41e39db994c958d8107d5785a100f38c6eba680b6a3cc79babb3
for https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt
WARNING:avssl.module.clip_official:Reduce text embedding to size of 8112
Traceback (most recent call last):
File "/home/marco/Documents/human-gesture-generation/Bechmark/SpeechCLIP_plus/example.py", line 10, in
model = avssl.model.KWClip_GeneralTransformer.load_from_checkpoint(model_fp)
File "/home/marco/.conda/envs/emagepy38/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/home/marco/.conda/envs/emagepy38/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 204, in _load_model_state
keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
File "/home/marco/.conda/envs/emagepy38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for KWClip_GeneralTransformer:
size mismatch for criterion.eye_mat: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for criterion.neg_eye_mat: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for criterion.eye_mat_fl: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]).

Any insights or suggestions you can provide would be greatly appreciated.

Thank you!

@ShampooWang
Copy link
Owner

ShampooWang commented May 4, 2024

Hi,

In avssl/module/losses.py on line 126, there is a variable called MAX_EYE, which must be manually modified if you load models of different sizes. For the base models, MAX_EYE=256, and for the large models, MAX_EYE=1024. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants