In [1]:
import json
from transformers import AutoTokenizer

import json
import ast
import astunparse
from transformers import PreTrainedTokenizer
from torch.utils.data import Dataset
from copy import deepcopy
from typing import Dict, List

# text constants
FUNCTION_CALL_NAME = 'tool_call'
FUNCTION_CALL_PREFIX = '```python\n'
FUNCTION_CALL_POSTFIX = '\n```'
TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You have access to the following tools:\n'
CONVERSATOIN_KEY = 'conversations'
TOOL_DESC_KEY = 'tools'


def format_function_call(function_name: str, parameters: Dict[str, str]):
    function_name = ast.Name(id=function_name)
    keywords = [
        ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
        for arg_name, arg_value in parameters.items()
    ]
    func_call = ast.Call(func=function_name, args=[], keywords=keywords)
    return astunparse.unparse(func_call).strip()


def format_conversation(item, tokenizer, conversation_key: str, tool_key: str):
    conversations = deepcopy(item[conversation_key])

    # Note: `loss_mask` here means whether *the prediction* of the token should take loss
    tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0]

    def _update(_tokens: List[int], value: int = 1):
        value = int(value)
        tokens.extend(_tokens)
        loss_masks.extend([value] * len(_tokens))

    # insert system prompt for tools
    if tool_key in item:
        conversations.insert(0,
                             {
                                 "role": "system",
                                 "content": TOOL_DEFINITION_PREFIX + json.dumps(item[tool_key], indent=4,
                                                                                ensure_ascii=False)
                             }
                             )

    for idx, conv in enumerate(conversations):
        loss = conv.get("loss", True)
        if conv['role'] in {'system', 'user'}:
            loss = False
        if conv['role'] == 'tool':
            # function call python code
            value = FUNCTION_CALL_PREFIX + format_function_call(FUNCTION_CALL_NAME,
                                                                conv["parameters"]) + FUNCTION_CALL_POSTFIX
            text = tokenizer.build_single_message("assistant", conv["name"], value)
            _update(text, loss)

            # function call result
            value = conv.get('observation', None)
            if not isinstance(value, str):
                value = json.dumps(value, ensure_ascii=False)
            text = tokenizer.build_single_message("observation", "", value)
            _update(text, False)
        else:
            text = tokenizer.build_single_message(conv['role'], "", conv["content"])
            _update(text, loss)

    _update([tokenizer.eos_token_id], False)

    assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
    return tokens, loss_masks

In [2]:
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)


In [3]:
with open("./formatted_data/tool_alpaca.jsonl", "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f]

In [8]:
tokens, loss_masks = format_conversation(data[0], tokenizer, conversation_key=CONVERSATOIN_KEY, tool_key=TOOL_DESC_KEY)

In [7]:
print(tokenizer.decode(tokens))

[gMASK]sop<|system|> 
 Answer the following questions as best as you can. You have access to the following tools:
[
    "sendHttpRequest: Send an HTTP request with the specified method, headers, and data to the Httpbin API for testing purposes.\nParameters: {\"method\": \"Required. string. One of: [GET, POST, PUT, DELETE, HEAD, PATCH]. The HTTP method to use (GET, POST, PUT, DELETE, HEAD, or PATCH).\", \"url\": \"Required. string. The endpoint URL to send the request to.\", \"headers\": \"Object.  A key-value pair of headers to include in the request.\", \"data\": \"Object.  A key-value pair of data to include in the request body.\"}\nOutput: Successful response.\n - Format: application/json\n - Structure: Object{response: Object{status_code, headers: Object, body}}\ngetClientRequestData: Retrieve the client's request data, including headers, form data, uploaded files, and cookies.\nParameters: {\"url\": \"Required. string. The endpoint URL to send the request to.\"}\nOutput: Successfu

In [9]:
loss_masks

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
