Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

position_idx missing in state_dict when loading from checkpoint #12

Closed
gabriele-bani opened this issue Sep 12, 2020 · 2 comments
Closed

Comments

@gabriele-bani
Copy link

gabriele-bani commented Sep 12, 2020

When usingExtractiveSummarizer.load_from_checkpointor ExtractiveSummarizer.load_weights to load most of the models, i find that the position_ids field is not saved in the checkpoint file, which causes an error. The only model that can be correctly loaded is distilbert-base-uncased-ext-sum

This is the error I get when running predictions_website.py and trying to use bert-base-uncased-ext-sum checkpoints.


Exception happened during processing of request from ('127.0.0.1', 34450)
Traceback (most recent call last):
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/socketserver.py", line 320, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/socketserver.py", line 351, in process_request
    self.finish_request(request, client_address)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/socketserver.py", line 364, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/socketserver.py", line 724, in __init__
    self.handle()
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/http/server.py", line 418, in handle
    self.handle_one_request()
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/http/server.py", line 406, in handle_one_request
    method()
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/site-packages/gradio/networking.py", line 158, in do_POST
    prediction, durations = interface.process(raw_input)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/site-packages/gradio/interface.py", line 220, in process
    prediction = predict_fn(*processed_input)
  File "/home/myuser/intrical/repos/Intrical-Transformers/predictions_website.py", line 11, in summarize_text
    summarizer = ExtractiveSummarizer.load_from_checkpoint(model_choice)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 192, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "/home/myuser/anaconda3/envs/transformersum/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ExtractiveSummarizer:
	Missing key(s) in state_dict: "word_embedding_model.embeddings.position_ids". 
----------------------------------------

How can I correctly load any type of model?

@HHousen
Copy link
Owner

HHousen commented Sep 13, 2020

@gabriele-bani Good catch. This is a problem with the update from 3.0.2 to 3.1.0 of huggingface/transformers. The problem is being discussed at huggingface/transformers#6882. While they work on a fix, you can install the previous version of huggingface/transformers by running pip install -U transformers==3.0.2. Let me know if this solves the problem for you.

@gabriele-bani
Copy link
Author

@HHousen Thank you for your answer! Indeed, 3.0.2 allows to load the checkpoints correctly, so I'll be using that for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants