<a href="https://colab.research.google.com/github/LC1332/Luotuo-Chinese-LLM/blob/main/notebook/asyncAnswer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 批量回答的代码

代码最初由李鲁鲁开发

[骆驼项目主页](https://github.com/LC1332/Luotuo-Chinese-LLM)

如果你使用我们的代码获取了有用的数据，也欢迎分享给我们，或者告诉我们你公开后的github/huggingface链接

如果你使用我们的代码获取数据并发表了论文或者tech report，欢迎cite我们的github repo

## 安装环境

In [None]:
!pip install openai
!pip install aiofiles
!pip install tiktoken

In [2]:
import os
import json
import time
import openai
import asyncio
import aiohttp
import aiofiles
from functools import partial
from tqdm.asyncio import tqdm as tqdm
import tiktoken

enc = tiktoken.get_encoding("cl100k_base")
max_zh_en_ratio = 2.3

## 输入你的openAI API

In [3]:
# 在这里输入你的openAI API token

api_key = ["sk-DfFy"]


class KeyPool:
    def __init__(self, strings):
        self.pool = list(strings)
        self.last_used = {s: -1 for s in strings}

    def getKey(self):
        result = min(self.last_used, key=self.last_used.get)
        self.last_used[result] = int(time.time() * 1000)
        return result

pool = KeyPool(api_key)

## 指定工作目录



In [4]:
os.chdir("/content/")

## 获取需要翻译的样本

这里我们使用WizardLM的样本

In [5]:
!wget https://raw.githubusercontent.com/LC1332/WizardLM/main/data/WizardLM_train10k.jsonl -O WizardLM_train10k.jsonl

--2023-05-12 23:56:35--  https://raw.githubusercontent.com/LC1332/WizardLM/main/data/WizardLM_train10k.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24924875 (24M) [text/plain]
Saving to: ‘WizardLM_train10k.jsonl’


2023-05-12 23:56:36 (158 MB/s) - ‘WizardLM_train10k.jsonl’ saved [24924875/24924875]



给定一个jsonl文件 WizardLM_train10k.jsonl ，每行是一个json，注意文件中有中文，读入这个文件的前20行并保存到Wizard_demo.jsonl

In [6]:
import json

# 读取前20行并保存到文件
with open('WizardLM_train10k.jsonl','r', encoding='utf-8') as f:
    data = [json.loads(next(f)) for x in range(20)]
    with open('Wizard_demo.jsonl', 'w', encoding='utf-8') as f_out:
        for d in data:
            f_out.write(json.dumps(d, ensure_ascii=False) + '\n')

In [8]:
delay = 0.05 * 2

concurrency_limit = 16

input_file = "Wizard_demo.jsonl"

# 数据缓存目录
temp_path = "/content/tempAns"

# 数据输出目录
output_path = "/content/answer"

output_prefix = "WizardLM_Ans"

max_file_size = 1024**3

# 需要提问的字段
entries = ["instruction_zh"]

# 需要保存到的字段
save_entries = ["output_zh"]

# 需要参考长度的字段
ref_entries = ["output"]


os.system(f"mkdir -p {temp_path} {output_path}")

0

In [9]:
import re

async def getTranslation(item, entries: list = []):
    async def get(text, ans_en ):
        # text = text.replace("\n", " ")
        openai.api_key = pool.getKey()
        try:
            if ans_en == '':
              max_zh_len = 150
            else:
              en_token_len = float(len(enc.encode( ans_en )))
              max_zh_len = int( max_zh_en_ratio * en_token_len )
              max_zh_len = max(10, max_zh_len)

            ans_len = float(len(enc.encode( text )))

            messages =  [   {'role':'user', 'content':text}  ]

            if max_zh_len + ans_len + 100 > 4096:
              max_zh_len = 4096 - 100 - ans_len
              print('shorten answer len into ', max_zh_len, ' with ans len = ', ans_len )

            resp = await openai.ChatCompletion.acreate(
                model="gpt-3.5-turbo",
                messages=messages,
                temperature=0,
                max_tokens=max_zh_len
            )
            if "choices" in resp:
                result = resp['choices'][0]['message']['content']

                result = result.strip()

                return result
            else:
                raise Exception(f"Invalid API response: {resp}")
        except Exception as e:
            print(f"[Error] {e}")
            return None

    for i in range(len(entries)):
        entry = entries[i]
        save_entry = save_entries[i]
        ref_entry = ref_entries[i]

        ans = await get(item[entry], item[ref_entry])
        if ans is None:
            return None
        else:
            item[save_entry] = ans
    return item


async def process(id, item, semaphore):
    async with semaphore:
        file_name = f"{temp_path}/{output_prefix}_{id}.json"
        try:
            print('start ', id )
            it = await getTranslation(item, entries)
            if it is None:
                raise Exception(file_name)
            async with aiofiles.open(file_name, "w") as f:
                await f.write(json.dumps(it, ensure_ascii=False, indent=4))
            print('done ', id )
        except Exception as e:
            print(f"Error saving item: {e}")


async def main():
    try:
      with open(input_file, "r") as file:
          data = json.load(file)
    except json.JSONDecodeError:
      data = []
      with open(input_file, "r") as file:
          for line in file:
              entry = json.loads(line)
              data.append(entry)

    tasks = []

    semaphore = asyncio.Semaphore(concurrency_limit)

    skip_count = 0

    for id, item in enumerate(data):
        file_name = f"{temp_path}/{output_prefix}_{id}.json"
        if os.path.exists(file_name):
            skip_count = skip_count + 1
            continue
        tasks.append(asyncio.create_task(process(id, item, semaphore)))

    print('skip ', skip_count )
    print('rest ', len(tasks))

    async for task in tqdm(tasks, total=len(tasks), desc="Processing items"):
        await task
        time.sleep(delay)

由于网络问题或OpenAI的限制会导致获取数据失败，此时脚本会跳过这部分数据

重新运行下面的单元格即可补充获取失败的数据

In [10]:
await main()

Processing items:   0%|          | 0/20 [00:00<?, ?it/s]

start  0
start  1
start  2
start  3
start  4
start  5
start  6
start  7
start  8
start  9
start  10
start  11
start  12
start  13
start  14
start  15
done  12
start  16
done  11
start  17
done  13
start  18
done  2
start  19
done  14
done  17
done  6
done  3
done  19
done  8
done  9
done  1
done  10
done  16
done  18


Processing items:  10%|█         | 2/20 [00:43<06:30, 21.69s/it]

done  0


Processing items:  30%|███       | 6/20 [00:50<01:21,  5.84s/it]

done  4
done  15
done  7


Processing items:  35%|███▌      | 7/20 [00:57<01:19,  6.10s/it]

done  5


Processing items: 100%|██████████| 20/20 [00:58<00:00,  2.94s/it]


## 合并所有翻译数据

In [11]:
data = []
for filename in tqdm(os.listdir(temp_path)):
    if filename.startswith(output_prefix) and filename.endswith(".json"):
        with open(os.path.join(temp_path, filename), 'r', encoding='utf-8') as file:
            try:
                entry = json.load(file)
                data.append(entry)
            except json.JSONDecodeError:
                pass

100%|██████████| 20/20 [00:00<00:00, 4261.85it/s]


In [12]:
file_counter = 1
current_file_size = 0
output_file = f"{output_path}/{output_prefix}_{file_counter}.jsonl"

with open(output_file, 'w', encoding='utf-8') as out:
    for item in tqdm(data):
        item_json = json.dumps(item, ensure_ascii=False)
        item_size = len(item_json.encode('utf-8'))
        out.write(item_json + "\n")
        current_file_size += item_size
        if current_file_size > max_file_size:
            file_counter += 1
            output_file = f"{output_path}/{output_prefix}_{file_counter}.jsonl"
            out = open(output_file, 'w', encoding='utf-8')
            current_file_size = 0

100%|██████████| 20/20 [00:00<00:00, 11826.60it/s]


In [13]:
print(output_file)

/content/answer/WizardLM_Ans_1.jsonl
