Skip to content

Commit

Permalink
mlx inference
Browse files Browse the repository at this point in the history
  • Loading branch information
zRzRzRzRzRzRzR committed Mar 26, 2024
1 parent a1013b1 commit 9e14386
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
*.pyc
finetune/output/*
wip.*
.idea
venv
.venv
.env
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,12 @@ python demo/vllm_based_demo.py --model_path <vllmcpm_repo_path>
python demo/hf_based_demo.py --model_path <hf_repo_path>
```
#### 使用如下命令启动基于 Mac mlx 加速框架推理
你需要安装 `mlx_lm` 库,并且,你需要下载对应的转换后的专用模型权重[MiniCPM-2B-sft-bf16-llama-format-mlx](https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx),然后运行以下命令:
```shell
python -m mlx_lm.generate --model mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx --prompt "hello, tell me a joke." --trust-remote-code
```
<p id="6"></p>
Expand Down
42 changes: 42 additions & 0 deletions demo/mlx_based_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
使用 MLX 快速推理 MiniCPM
如果你使用 Mac 设备进行推理,可以直接使用MLX进行推理。
由于 MiniCPM 暂时不支持 mlx 格式转换。您可以下载由 MLX 社群转换好的模型 [MiniCPM-2B-sft-bf16-llama-format-mlx](https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx)。
并安装对应的依赖包
```bash
pip install mlx-lm
```
这是一个简单的推理代码,使用 Mac 设备推理 MiniCPM-2
```python
python -m mlx_lm.generate --model mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx --prompt "hello, tell me a joke." --trust-remote-code
```
"""

from mlx_lm import load, generate
from jinja2 import Template

def chat_with_model():
model, tokenizer = load("mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx")
print("Model loaded. Start chatting! (Type 'quit' to stop)")

messages = []
chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}")

while True:
user_input = input("You: ")
if user_input.lower() == 'quit':
break
messages.append({"role": "user", "content": user_input})
response = generate(model, tokenizer, prompt=chat_template.render(messages=messages), verbose=True)
print("Model:", response)
messages.append({"role": "ai", "content": response})


chat_with_model()
11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
transformers>=4.38.2
torch>=2.0.0
triton>=2.2.0
httpx>=0.27.0
gradio>=4.21.0
flash_attn>=2.4.1
accelerate>=0.28.0
sentence_transformers>=2.6.0
sse_starlette>=2.0.0
tiktoken>=0.6.0
mlx_lm>=0.5.0

0 comments on commit 9e14386

Please sign in to comment.