Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing key(s) in state_dict when testing using predict_downstream_condition.py #17

Open
wangpichao opened this issue Mar 14, 2023 · 7 comments

Comments

@wangpichao
Copy link

python predict_downstream_condition.py --ckpt_path model_name_roberta-base_taskname_qqp_lr_3e-05_seed_42_numsteps_2000_sample_Categorical_schedule_mutual_hybridlambda_0.0003_wordfreqlambda_0.0_fromscratch_False_timestep_none_ckpts/best(38899).th
using standard schedule with num_steps: 2000.
Traceback (most recent call last):
File "predict_downstream_condition.py", line 101, in
model.load_state_dict(ckpt['model'])
File "/opt/conda/envs/diff/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1672, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for RobertaForMaskedLM:
Missing key(s) in state_dict: "roberta.embeddings.position_ids", "roberta.embeddings.word_embeddings.weight", "roberta.embeddings.position_embeddings.weight", "roberta.embeddings.token_type_embeddings.weight", "roberta.embeddings.LayerNorm.weight", "roberta.embeddings.LayerNorm.bias", "roberta.encoder.layer.0.attention.self.query.weight", "roberta.encoder.layer.0.attention.self.query.bias".........................

@Hzfinfdu
Copy link
Owner

Hi,

Did you train the model with DDP? If so, the state dict keys may be different.

@bansky-cl
Copy link

bansky-cl commented May 29, 2023

i met this problem too.
after run DDP_main_conditional.sh got the ckpt (xxx .th) , when i run predict_downstream_condition.py, it occurs error the same.

RuntimeError: Error(s) in loading state_dict for RobertaForMaskedLM:
Missing key(s) in state_dict: ....sd1
Unexpected key(s) in state_dict: .....sd2

sd2 keys is module.roberta.... while sd1 is roberta.... , it seems different.

it runs when i change model.load_state_dict(ckpt['model']) to model.load_state_dict(ckpt['model'], strict=False) , i think i maybe damage the model performance in some extend.

i also want to know how long you train on what device. i find that model often stuck in each eval_step for saving, which cost most time when i train the model

@xiang-xiang-zhu
Copy link

same problem

1 similar comment
@xiang-xiang-zhu
Copy link

same problem

@Hyunseung-Kim
Copy link

Hyunseung-Kim commented Dec 29, 2023

Thank you for your great work.

I have the same problem.

As explained in this github, I executed "run.sh" and it executed "DDP_main.py".
After fixing several errors, the Missing key(s) in state-dict error occurred.
I cannot make sure that my modification to solve several errors made this issue.

However, I'm confused that @Hzfinfdu said, state dict keys of DDP may be different.
Then, which file should be executed to reproduce the DiffusionBERT?

Thank you.

Hi,

Did you train the model with DDP? If so, the state dict keys may be different.

@xiang-xiang-zhu
Copy link

I found that this was because the dictionary of the checkpoint saved after training had a few extra keys, so removing them was fine.
In predict_downstratem_task.py at line 101:

model.load_state_dict(ckpt['model'])

change to

ckpt_model = ckpt['model']
new_ckpt = {}
for key, value in ckpt_model.items():
    new_ckpt[key[7:]] = value
model.load_state_dict(new_ckpt)

I don't know if this is correct, but the program does run correctly and outputs the results correctly

@Hyunseung-Kim
Copy link

Hyunseung-Kim commented Jan 2, 2024

Thank you for your reply. I solved the issue with your suggestion.
However, did you solve the issue #25 ? It seems to be the final error to solve to use predict.py

or could you share the modified codes? It would be really helpful for me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants