Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: support PEFT/LoRA with added tokens #1828

Merged
merged 6 commits into from May 19, 2024

Conversation

mapmeld
Copy link
Contributor

@mapmeld mapmeld commented May 12, 2024

PEFT supports adding tokens/embeddings: https://github.com/huggingface/peft/blob/main/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb

When loading a pretrained model, tokenizer, and PEFT including new tokens, the model fails because the model and adapter have a different number of embeddings:

! lm_eval --model hf \
    --model_args pretrained=gradientai/Llama-3-8B-Instruct-262k,
       tokenizer=monsoon-nlp/llama3-biotokenpretrain-kaniwa,
       load_in_4bit=True,
       peft=monsoon-nlp/llama3-biotokenpretrain-kaniwa \
    --tasks hellaswag \
    --device cuda:0 \
    --batch_size 1
  File "/content/lm-evaluation-harness/lm_eval/models/huggingface.py", line 204, in __init__
    self._create_model(
  File "/content/lm-evaluation-harness/lm_eval/models/huggingface.py", line 582, in _create_model
    self._model = PeftModel.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 356, in from_pretrained
    model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 730, in load_adapter
    load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
  File "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py", line 249, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.modules_to_save.default.weight: copying a param with shape torch.Size([128260, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).
	size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([128260, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]).

In this case we want to run self._model.resize_token_embeddings(len(self.tokenizer))

The core fix is:

if peft:
  if self._model.config.vocab_size != len(self.tokenizer):
    # resize model for LoRAs with added tokens
    self._model.resize_token_embeddings(len(self.tokenizer))
  self._model = PeftModel.from_pretrained( ...)

The other change is moving initialization of self.tokenizer earlier in the script, so we have access to the tokenizer at this point in time.

CoLab proof of concept:
https://colab.research.google.com/drive/12NaNeDHRCMVhyIHL4_Z6AZQq6q_E2cP8

Copy link
Contributor

@haileyschoelkopf haileyschoelkopf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for the fix @mapmeld !

Added a log message for when this resizing happens, but otherwise LGTM!

lm_eval/models/huggingface.py Show resolved Hide resolved
Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
@haileyschoelkopf haileyschoelkopf merged commit 86319a9 into EleutherAI:main May 19, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants