## 加载模型

In [1]:
import os
import ast
import json
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from client import HFClient
MAX_LENGTH = 4096
TRUNCATE_LENGTH = 1024

MODEL_PATH = TOKENIZER_PATH = "THUDM/chatglm3-6b"
client = HFClient(MODEL_PATH, TOKENIZER_PATH, None)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

## Step1, 提问，生成工具调用

In [2]:
from tqdm import tqdm
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
tools = get_tools()

query = '北京天气怎么样？'
history: list[Conversation] = []
history.append(Conversation(Role.USER, query))

out_text = ''
for response in client.generate_stream(
                system=None,
                tools=tools,
                history=history,
                do_sample=True,
                max_length=MAX_LENGTH,
                temperature=1.0,
                top_p=0.95,
                stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
            ):
    out_text += response.token.text
print(out_text)           

[registered tool] {'description': 'Generates a random number x, s.t. range[0] <= x < range[1]',
 'name': 'random_number_generator',
 'params': [{'description': 'The random seed used by the generator',
             'name': 'seed',
             'required': True,
             'type': 'int'},
            {'description': 'The range of the generated numbers',
             'name': 'range',
             'required': True,
             'type': 'tuple[int, int]'}]}
[registered tool] {'description': 'Get the current weather for `city_name`',
 'name': 'get_weather',
 'params': [{'description': 'The name of the city to be queried',
             'name': 'city_name',
             'required': True,
             'type': 'str'}]}
get_weather
 ```python
tool_call(city_name='北京')
```<|observation|>


## Step2, 解析生成的代码

In [3]:
tool, *output_text = out_text.strip().split('\n')
output_text = '\n'.join(output_text)
history.append(Conversation(Role.TOOL,postprocess_text(output_text),tool,))
history

[Conversation(role=<Role.USER: 2>, content='北京天气怎么样？', tool=None, image=None),
 Conversation(role=<Role.TOOL: 4>, content="```python\ntool_call(city_name='北京')\n```", tool='get_weather', image=None)]

## Step3, 执行 Function

In [4]:
import re

def extract_code(text: str) -> str:
    pattern = r'```([^\n]*)\n(.*?)```'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches[-1][1]

def tool_call(*args, **kwargs) -> dict:
    print("=== Tool call:")
    print(args)
    print(kwargs)
    return kwargs

code = extract_code(output_text)
print(code)

args = eval(code, {'tool_call': tool_call}, {})

from tool_registry import dispatch_tool
observation = dispatch_tool(tool, args)
print(observation)

history.append(Conversation(Role.OBSERVATION, observation))

tool_call(city_name='北京')

=== Tool call:
()
{'city_name': '北京'}
{'lives': {'province': '北京', 'city': '北京市', 'temperature': '0', 'humidity': '51', 'weather': '多云', 'winddirection': '东', 'windpower': '≤3', 'reporttime': '2024-02-24 09:38:27'}}


## Step4, 根据用户问题及API结果，返回最终回答

In [5]:
out_text = ''
for response in client.generate_stream(
                system=None,
                tools=tools,
                history=history,
                do_sample=True,
                max_length=MAX_LENGTH,
                temperature=1.0,
                top_p=0.95,
                stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
            ):
    out_text += response.token.text
print(out_text)   


 北京目前的天气情况为多云，气温为0，相对湿度为51%。这个信息是在2024年2月24日9点38分27秒时更新的。
