In [None]:
tokenizer_path = 'Qwen/Qwen2.5-7B-Instruct'
# Should be the path to the output of second_pass_gen
rendered_prompts_path = './rendered_prompts.json'

In [1]:
from transformers import Qwen2TokenizerFast
import regex
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from datasets import Dataset
import copy
import torch

In [2]:
formatted = json.load(open(rendered_prompts_path, "r"))

In [3]:
tokenizer = Qwen2TokenizerFast.from_pretrained(tokenizer_path)

In [4]:
max_seq_len = 2048

In [5]:
splitted_by = "<|im_end|>\n"
splitted = (formatted[0] + "\n").split(splitted_by)

In [6]:
splitted

['<|im_start|>system\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{"type": "function", "function": {"name": "RetrievalAction", "description": "The action of retrievaling from the knowledge base.", "parameters": {"properties": {"reasoning": {"description": "The reasong why you are doing the following query.", "type": "string"}, "queries": {"description": "The list of queries to execute. Each should be a question.", "items": {"type": "string"}, "type": "array"}}, "required": ["reasoning", "queries"], "type": "object"}}}\n{"type": "function", "function": {"name": "AskAction", "description": "The action of asking the patient for extra information.", "parameters": {"properties": {"reasoning": {"description": "The reason why you are asking the following question.", "type": "string"}, "disease": {"description": "你提问的疾病对象", "items": {"type": "string"}, "type": "array"}, "s

In [7]:
def tokenize_one(example: str):
    if not example.endswith("\n"):
        example = example + "\n"
    splitted_by = "<|im_end|>\n"
    splitted = example.split(splitted_by)
    input_ids = []
    labels = []
    attention_mask = []
    ret = []
    for idx,each in enumerate(splitted):
        if each.startswith("<|im_start|>system") or each.startswith("<|im_start|>user"):
            full = each + splitted_by
            this_input_ids = tokenizer.encode(full)
            labels.extend([-100] * len(tokenizer.encode(full)))
            input_ids.extend(this_input_ids)
            attention_mask.extend([1] * len(this_input_ids))
            continue
        if each.startswith("<|im_start|>assistant"):
            full = each
            if idx != len(splitted) - 1:
                full = each + splitted_by
            this_input_ids = tokenizer.encode("\n" + full)
            input_ids_cp = copy.deepcopy(input_ids)[:-1]
            labels_cp = copy.deepcopy(labels)[:-1]
            attention_mask_cp = copy.deepcopy(attention_mask)[:-1]
            input_ids_cp.extend(this_input_ids)
            labels_cp.extend(this_input_ids)
            attention_mask_cp.extend([1] * len(this_input_ids))
            input_ids_cp.extend([tokenizer.pad_token_id])
            labels_cp.extend([tokenizer.pad_token_id])
            attention_mask_cp.extend([1])
            non_pad_len = len(input_ids_cp)
            if non_pad_len > max_seq_len:
                print("OUT OF LENGTH")
                continue
            
            ret.append({
                "input_ids": input_ids_cp + [tokenizer.pad_token_id] * (max_seq_len - non_pad_len),
                "labels": labels_cp + [-100] * (max_seq_len - non_pad_len),
                "attention_mask": attention_mask_cp + [0] * (max_seq_len - non_pad_len),
            })
            
            # ret.append({
            #     "input_ids": torch.tensor(input_ids_cp + [tokenizer.pad_token_id] * (max_seq_len - non_pad_len), dtype=torch.long),
            #     "labels": torch.tensor(labels_cp + [-100] * (max_seq_len - non_pad_len), dtype=torch.long),
            #     "attention_mask": torch.tensor(attention_mask_cp + [0] * (max_seq_len - non_pad_len), dtype=torch.long),
            # })

            input_ids.extend(this_input_ids)
            labels.extend([-100] * len(this_input_ids))
            attention_mask.extend([1] * len(this_input_ids))
            continue
        if each == '':
            continue
        raise Exception("SHOULD NOT BE HERE")
    return ret

In [8]:
r = []

In [9]:
with ProcessPoolExecutor(max_workers=16) as executor:
    futures = []
    for each in tqdm(formatted):
        futures.append(executor.submit(tokenize_one, each))
    for future in tqdm(futures):
        r.extend(future.result())

100%|██████████| 65017/65017 [00:03<00:00, 18330.29it/s]
 87%|████████▋ | 56386/65017 [01:10<00:10, 829.21it/s] 

OUT OF LENGTH


100%|██████████| 65017/65017 [01:39<00:00, 652.08it/s] 


In [10]:
ds = Dataset.from_list(r)

In [11]:
ds.save_to_disk("./train")

Saving the dataset (0/9 shards):   0%|          | 0/151203 [00:00<?, ? examples/s]

In [12]:
print(tokenizer.decode(
    [e for e in r[9]['labels'] if e != -100]
))


<|im_start|>assistant
<tool_call>
{"name": "RetrievalAction", "arguments": {"reasoning": "患者已经经历过冷冻治疗，但表示疼痛且影响行动，因此希望了解其他治疗方法。同时，患者提到跖疣复发且加重，询问原因有助于理解病情。", "queries": ["跖疣的其他治疗方法有哪些？", "跖疣复发加重的原因是什么？"]}}
</tool_call><|im_end|>
<|endoftext|>


In [13]:
len(r)

151203