<a href="https://colab.research.google.com/github/LC1332/Luotuo-Chinese-LLM/blob/main/notebook/improvedTranslation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 一个升级后的批量翻译代码

这个代码最初由黄泓森进行开发，由李鲁鲁转到colab并进行了更改

[骆驼项目主页](https://github.com/LC1332/Luotuo-Chinese-LLM)

如果你使用我们的代码获取了有用的数据，也欢迎分享给我们，或者告诉我们你公开后的github/huggingface链接

如果你使用我们的代码获取数据并发表了论文或者tech report，欢迎cite我们的github repo

## 安装环境

In [None]:
!pip install openai
!pip install aiofiles

In [5]:
import os
import json
import time
import openai
import asyncio
import aiohttp
import aiofiles
from functools import partial
from tqdm.asyncio import tqdm as tqdm

## 输入你的openAI API

In [None]:
#!/usr/bin/env python3



os.chdir("/data/workspace/coco2017")

class KeyPool:
    def __init__(self, strings):
        self.pool = list(strings)
        self.last_used = {s: -1 for s in strings}

    def getKey(self):
        result = min(self.last_used, key=self.last_used.get)
        self.last_used[result] = int(time.time() * 1000)
        return result


delay = 0.05

concurrency_limit = 64

input_file = "./captions/val2017.json"

output_path = "./embeddings/val/"

output_prefix = f"{output_path}val2017_embed_"

entries = ["caption", "caption_zh"]

api_key = ["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"]

pool = KeyPool(api_key)

# openai.api_type = "openai"
# openai.api_base = "https://api.openai-proxy.com/"
openai.proxy = "http://127.0.0.1:1450"

async def getEmbedding(item, entries: list = []):
    async def get(text):
        text = text.replace("\n", " ")
        openai.api_key = pool.getKey()
        try:
            resp = await openai.Embedding.acreate(
                model="text-embedding-ada-002", input=[text]
            )
            if "data" in resp:
                return resp["data"][0]["embedding"]
            else:
                raise Exception(f"Invalid API response: {resp}")
        except Exception as e:
            print(f"[Error] {e}")
            return None

    for entry in entries:
        embedding = await get(item[entry])
        if embedding is None:
            return None
        else:
            item[f"{entry}_embedding"] = embedding

    return item


async def getTranslation(item, entries: list = []):
    async def get(text):
        text = text.replace("\n", " ")
        openai.api_key = pool.getKey()
        try:
            resp = await openai.ChatCompletion.acreate(
                model="gpt-3.5-turbo",
                messages=[
                    {
                        "role": "system",
                        "content": "这是一个能够将文本翻译成中文的AI助手。请将引号中的文本翻译成简体中文。",
                    },
                    {"role": "user", "content": f'"""\n{text}\n"""'},
                ],
                temperature=0.3,
                max_tokens=100,
                top_p=1.0,
                frequency_penalty=0.0,
                presence_penalty=0.0,
            )
            if "choices" in resp:
                return resp['choices'][0]['message']['content']
            else:
                raise Exception(f"Invalid API response: {resp}")
        except Exception as e:
            print(f"[Error] {e}")
            return None

    for entry in entries:
        trans = await get(item[entry])
        if trans is None:
            return None
        else:
            item[f"{entry}_zh"] = trans

    return item


async def process(id, item, semaphore):
    async with semaphore:
        file_name = f"{output_prefix}{id}.json"
        try:
            tr = await getTranslation(item, ["caption"])
            if tr is None:
                raise Exception(file_name)
            it = await getEmbedding(tr, ["caption", "caption_zh"])
            if it is None:
                raise Exception(file_name)
            async with aiofiles.open(file_name, "w") as f:
                await f.write(json.dumps(it, ensure_ascii=False, indent=4))
        except Exception as e:
            print(f"Error saving item: {e}")


async def main():
    with open(input_file, "r") as file:
        data = json.load(file)

    tasks = []

    semaphore = asyncio.Semaphore(concurrency_limit)

    for id, item in enumerate(data):
        if os.path.exists(f"{output_prefix}{id}.json"):
            continue
        tasks.append(asyncio.create_task(process(id, item, semaphore)))

    async for task in tqdm(tasks, total=len(tasks), desc="Processing items"):
        await task
        time.sleep(delay)


if __name__ == "__main__":
    asyncio.run(main())
