Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] internlm2 不能使用llama.cpp量化转换 #612

Closed
gaord opened this issue Jan 18, 2024 · 10 comments
Closed

[Bug] internlm2 不能使用llama.cpp量化转换 #612

gaord opened this issue Jan 18, 2024 · 10 comments
Assignees

Comments

@gaord
Copy link

gaord commented Jan 18, 2024

Describe the bug

使用最新的llama.cpp代码(b1874),转化模型时报错:
python3 convert.py ../internlm2-chat-20b --outtype f16
/Users/pom/AIGC/llama.cpp-3/gguf-py
Loading model file ../internlm2-chat-20b/pytorch_model-00001-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00001-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00002-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00003-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00004-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00005-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00006-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00007-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00008-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00009-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00010-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00011-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00012-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00013-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00014-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00015-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00016-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00017-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00018-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00019-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00020-of-00021.bin
Loading model file ../internlm2-chat-20b/pytorch_model-00021-of-00021.bin
Traceback (most recent call last):
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 1658, in
main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv
^^^^^^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 1577, in main
model_plus = load_some_model(args.model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 1354, in load_some_model
model_plus = merge_multifile_models(models_plus)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 782, in merge_multifile_models
model = merge_sharded([mp.model for mp in models_plus])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 761, in merge_sharded
return {name: convert(name) for name in names}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 761, in
return {name: convert(name) for name in names}
^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 736, in convert
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pom/AIGC/llama.cpp-3/convert.py", line 736, in
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
~~~~~^^^^^^
KeyError: 'model.tok_embeddings.weight'

Environment

Mac m2 ultra
pytorch-lightning 2.1.0
torch 2.1.2
torchaudio 2.1.0
torchmetrics 1.2.0
torchvision 0.16.0

Other information

No response

@lvhan028
Copy link
Collaborator

The architecture of InternLM2 is different from InternLM. The former adopts GQA and has no attention bias.
Unlike other GQA models, it packed q, k, v weights into one tensor.
I think llama.cpp hasn't supported InternLM2 yet. May open a feature request in llama.cpp.

Our project LMDeploy has supported InternLM2, including 200K context length inference and 4bit inference. The offline inference is pretty simple. You may give it a try.

import lmdeploy
pipe = lmdeploy.pipeline("internlm/internlm-chat-7b")
response = pipe(["Hi, pls intro yourself", "Shanghai is"])
print(response)

@gaord
Copy link
Author

gaord commented Jan 19, 2024

thanks for the explanation. Will look for fixing in llama.cpp.
I have tried LMDeploy, however it didn't support Mac environment and would not support it as plan in near future neither.
users with Mac devices and mps computing power would be pleased if it could support it. I think, like fastchat.
Thank you for the great models and things around provided.

@wyklq
Copy link

wyklq commented Jan 19, 2024

@gaord I fixed the first convert error by simply skipping the non-existent lazyTensors in convert.py
def convert(name: str) -> LazyTensor:
- lazy_tensors: list[LazyTensor] = [model[name] for model in models]
+ lazy_tensors: list[LazyTensor] = []
+ for model in models:
+ try:
+ lazy_tensors.append(model[name])
+ except KeyError:
+ pass #skip the model that does not have the name
Then I meet the second error of not supporting "dynamic" type of rope scaling type, I fake it with YARN type but rope_finetuned as "False" as Transformer's document suggests so. But I have no idea if using YARN type can actually work.

--- a/convert.py
+++ b/convert.py
@@ -276,6 +276,10 @@ class Params:
rope_scaling_type = gguf.RopeScalingType.YARN
n_orig_ctx = rope_scaling["original_max_position_embeddings"]
rope_finetuned = rope_scaling["finetuned"]
+ elif typ == "dynamic":
+ rope_scaling_type = gguf.RopeScalingType.YARN
+ n_orig_ctx = config.get("max_position_embeddings")
+ rope_finetuned = False

After the above two fixes, I suppose I hit the architecture not supporting issue.
params = Params(n_vocab=92544, n_embd=4096, n_layer=32, n_ctx=32768, n_ff=14336, n_head=32, n_head_kv=8, f_norm_eps=1e-05, n_experts=None, n_experts_used=None, rope_scaling_type=<RopeScalingType.YARN: 'yarn'>, f_rope_freq_base=1000000, f_rope_scale=1.0, n_orig_ctx=32768, rope_finetuned=False, ftype=None, path_model=PosixPath('../internlm2-chat-7b'))
Found vocab files: {'tokenizer.model': PosixPath('../internlm2-chat-7b/tokenizer.model'), 'vocab.json': None, 'tokenizer.json': None}
Loading vocab file '../internlm2-chat-7b/tokenizer.model', type 'spm'
Traceback (most recent call last):
File "/home/y20wu/ClueAI/llama.cpp/convert.py", line 1670, in
main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv
File "/home/y20wu/ClueAI/llama.cpp/convert.py", line 1647, in main
model = convert_model_names(model, params)
File "/home/y20wu/ClueAI/llama.cpp/convert.py", line 1293, in convert_model_names
raise Exception(f"Unexpected tensor name: {name}")
Exception: Unexpected tensor name: model.tok_embeddings.weight

@gaoyang07
Copy link
Collaborator

@gaord I fixed the first convert error by simply skipping the non-existent lazyTensors in convert.py def convert(name: str) -> LazyTensor: - lazy_tensors: list[LazyTensor] = [model[name] for model in models] + lazy_tensors: list[LazyTensor] = [] + for model in models: + try: + lazy_tensors.append(model[name]) + except KeyError: + pass #skip the model that does not have the name Then I meet the second error of not supporting "dynamic" type of rope scaling type, I fake it with YARN type but rope_finetuned as "False" as Transformer's document suggests so. But I have no idea if using YARN type can actually work.

--- a/convert.py +++ b/convert.py @@ -276,6 +276,10 @@ class Params: rope_scaling_type = gguf.RopeScalingType.YARN n_orig_ctx = rope_scaling["original_max_position_embeddings"] rope_finetuned = rope_scaling["finetuned"] + elif typ == "dynamic": + rope_scaling_type = gguf.RopeScalingType.YARN + n_orig_ctx = config.get("max_position_embeddings") + rope_finetuned = False

After the above two fixes, I suppose I hit the architecture not supporting issue. params = Params(n_vocab=92544, n_embd=4096, n_layer=32, n_ctx=32768, n_ff=14336, n_head=32, n_head_kv=8, f_norm_eps=1e-05, n_experts=None, n_experts_used=None, rope_scaling_type=<RopeScalingType.YARN: 'yarn'>, f_rope_freq_base=1000000, f_rope_scale=1.0, n_orig_ctx=32768, rope_finetuned=False, ftype=None, path_model=PosixPath('../internlm2-chat-7b')) Found vocab files: {'tokenizer.model': PosixPath('../internlm2-chat-7b/tokenizer.model'), 'vocab.json': None, 'tokenizer.json': None} Loading vocab file '../internlm2-chat-7b/tokenizer.model', type 'spm' Traceback (most recent call last): File "/home/y20wu/ClueAI/llama.cpp/convert.py", line 1670, in main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv File "/home/y20wu/ClueAI/llama.cpp/convert.py", line 1647, in main model = convert_model_names(model, params) File "/home/y20wu/ClueAI/llama.cpp/convert.py", line 1293, in convert_model_names raise Exception(f"Unexpected tensor name: {name}") Exception: Unexpected tensor name: model.tok_embeddings.weight

Use tools/convert2llama.py to convert your keys into Llama format

@limoncc
Copy link

limoncc commented Jan 19, 2024

At today ,the InternLM release a tool that you can covert the model to llama architecture.
https://github.com/InternLM/InternLM/tree/main/tools
I successfully converted the model, but the llama.cpp don't support rope scaling type : dynamic

x python convert.py ./build/models/internlm/target
Loading model file build/models/internlm/target/pytorch_model-00001-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00001-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00002-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00003-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00004-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00005-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00006-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00007-of-00008.bin
Loading model file build/models/internlm/target/pytorch_model-00008-of-00008.bin
Traceback (most recent call last):
  File "/Users/xiaobai/dev/llama.cpp/convert.py", line 1295, in <module>
    main()
  File "/Users/xiaobai/dev/llama.cpp/convert.py", line 1234, in main
    params = Params.load(model_plus)
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xiaobai/dev/llama.cpp/convert.py", line 318, in load
    params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xiaobai/dev/llama.cpp/convert.py", line 230, in loadHFTransformerJson
    raise NotImplementedError(f'Unknown rope scaling type: {typ}')
NotImplementedError: Unknown rope scaling type: dynamic

@gaord
Copy link
Author

gaord commented Jan 19, 2024

discussion is here on the other side

@yhcc
Copy link
Collaborator

yhcc commented Jan 20, 2024

You can try manually converting the configuration (i.e., the rope_scaling parameter in config.json) to null. After the conversion, another issue you will encounter is that llama.cpp does not support the \u0000 token, which causes the code at this location https://github.com/ggerganov/llama.cpp/blob/77bc1bbd05f0c31cb45773eb5eb59b9ff2b07e1b/llama.cpp#L3005 to fail the assert.

There is a temporary solution:

  1. First, convert to the hf llama format through https://github.com/InternLM/InternLM/tree/main/tools.
  2. Open the config.json file and change rope_scaling to null.
  3. Use the tokenizer-related files in the zip file internlm2_llamacpp_tokenizer_fix.zip to replace the contents in the folder [The reason this works is that we used an emoji symbol to replace the \u0000 token].
  4. Then use llama.cpp for the conversion.

To completely solve this issue, it will be necessary to update the code in llama.cpp. We will try to initiate a pull request to fully fix this issue later on.


可以尝试手动将配置(即config.json中的rope_scaling这个参数)转换为null。转换之后还会遇到另一个问题是,llama.cpp不支持\u0000这个token,会导致 https://github.com/ggerganov/llama.cpp/blob/77bc1bbd05f0c31cb45773eb5eb59b9ff2b07e1b/llama.cpp#L3005 这个位置的代码无法通过assert。

有个临时解决方案是,

  1. 首先通过 https://github.com/InternLM/InternLM/tree/main/tools 转换为hf llama格式,
  2. 打开其中 config.json 文件,修改 rope_scaling 为 null
  3. 使用压缩包 internlm2_llamacpp_tokenizer_fix.zip
    中的 tokenizer 相关文件替换 folder 下的内容 【这里生效的原因是由于我们使用了一个emoji符号替代了\u0000这个token】
  4. 再使用 llama.cpp 进行转换

为了彻底解决这个问题,应该需要更新以下llama.cpp中的代码,我们之后会尝试发起pr来彻底修复一下这个问题

@gaord
Copy link
Author

gaord commented Jan 22, 2024

as workaround hinted above, just uploaded the converted gruff file (internlm2-chat-20b-no-ropescaling-q5_0.gguf) for a fresh try:
https://www.modelscope.cn/models/ruidong/internLM-20b-chat-gguf/files

Copy link

This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 7 days if the stale label is not removed or if there is no further response.

@github-actions github-actions bot added the Stale label Jan 30, 2024
@ZwwWayne ZwwWayne assigned yhcc and unassigned lvhan028 Jan 30, 2024
@github-actions github-actions bot removed the Stale label Jan 31, 2024
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Feb 2, 2024

Resolved in #627 after converting the model to llama format.

@ZwwWayne ZwwWayne closed this as completed Feb 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants