## 1. load data

In [1]:
import json
with open('work_policy_data.json', 'r', encoding='utf-8') as json_file:
    raw_data = json.load(json_file)
    data = raw_data['data']

## 2. Generate Policy Question 

### 2.1 Split documents to text_chunks

In [2]:
from tqdm.auto import tqdm
from langchain.docstore.document import Document as LangchainDocument
from langchain.text_splitter import RecursiveCharacterTextSplitter

def split_doc(doc, chunk_size=300, chunk_overlap=30):

    langchain_docs = []

    # 转变为Langchain document 对象
    combined_content = f"{doc['metadata']['title']}\n{doc['content']}\n" + \
                    "\n".join([f"问：{qa['question']} 答：{qa['answer']}" for qa in doc['T_qa_pairs'] if doc['metadata']['qa_url'] != ''])
    langchain_doc = LangchainDocument(page_content=combined_content, metadata=doc['metadata'])
    langchain_docs.append(langchain_doc)

    # 初始化文本分割器
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        add_start_index=True,
        separators=["\n\n", "\n", "。", " ", ""],
    )


    # 处理文档 这个处理的话是把所有文档都汇集到一起了
    docs_processed = []

    for doc in langchain_docs:
        docs_processed += text_splitter.split_documents([doc])
    
    return docs_processed



### 2.2 load LLM API and prompts

In [3]:
import os
import time
import jwt
import requests

ZP_key = os.getenv("GLM_KEY")
def generate_token(apikey: str, exp_seconds: int):
    try:
        id, secret = apikey.split(".")
    except Exception as e:
        raise Exception("invalid apikey", e)

    payload = {
        "api_key": id,
        "exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
        "timestamp": int(round(time.time() * 1000)),
    }

    return jwt.encode(
        payload,
        secret,
        algorithm="HS256",
        headers={"alg": "HS256", "sign_type": "SIGN"},
    )

def ask_glm(content):
    url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
    headers = {
      'Content-Type': 'application/json',
      'Authorization': generate_token(ZP_key, 1000)
    }

    data = {
        "model": "glm-3-turbo",
        "messages": [{"role": "user", "content": content}]
    }

    response = requests.post(url, headers=headers, json=data)
    return response.json()

# GPT-4 生成的prompt 融合自己的
example_question = "本办法的施行时间和有效期限是多久？"
prompt = """
你的任务是根据上下文写出一个事实问题和答案。
你的事实陈述问题应该用来自上下文的具体、简洁的事实信息来回答。
你的事实陈述问题应该采用与用户在进行政策问答提出的问题相同的风格。
这意味着你的事实问题不能提及“根据段落”或“上下文”之类的内容。

提供你的答案如下：
输出:::
事实陈述问题-(你的事实陈述问题)
答案-(你对事实陈述问题的答案)

这是上下文。

上下文：{context}\n
输出:::
"""



### 2.3 Batch generate QA pairs

In [4]:
import random 
import time 
import pandas as pd

def generate_qa_couples(data, batch_size=5, chunksize=300, chunkoverlap=30, max_generate=15):
    """
    batch_size: the number of docs processed once
    chunk_size: the chunk of split doc
    chunk_overlap: the overlop of every split doc
    max_generate: the maximum qa couples of every policy file
    """
    iter_times = int(len(data)) / batch_size
    start_index = 0
    end_index = batch_size
    for batch in range(iter_times):
        if batch < iter_times:
            for index, doc in enumerate(data[start_index:end_index]):
                docs_processed = split_doc(doc, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
                N_Generations = min(int(len(docs_processed) * (1-chunk_overlap / chunk_size)), max_generate)
                llm_generate_questions(docs_processed, N_Generations , index, start_index)
            start_index += batch_size
            end_index += batch_size
        else:
            for index, doc in enumerate(data[end_index:]):
                docs_processed = split_doc(doc, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
                N_Generations = min(int(len(docs_processed) * (1-chunk_overlap / chunk_size)), max_generate, end_index)
                llm_generate_questions(docs_processed, N_Generations , index)



def llm_generate_questions(docs_processed, N_Generations, index, start_index):
    print(f"Generating {N_Generations} QA couples...")
    outputs = []
    for sampled_context in tqdm(random.sample(docs_processed, N_Generations)):
        time.sleep(random.random()*3)
        output_QA_couple = ask_glm(prompt.format(context=sampled_context.page_content))['choices'][0]['message']['content']
        try:
            question = output_QA_couple.split('事实陈述问题-')[-1].split('答案-')[0]
            answer = output_QA_couple.split('答案-')[-1]
            outputs.append(
                {
                    "context": sampled_context.page_content,
                    "question": question,
                    "answer": answer,
                    "source_url": sampled_context.metadata['url'],
                    "category": sampled_context.metadata['topic'],
                    "title": sampled_context.metadata['title'],
                    "org": sampled_context.metadata['source']
                }
            )
        except:
            continue
    
    df = pd.DataFrame(outputs)
    df.to_csv(f'./batchQA_output/{index+start_index}doc.csv', index=False, encoding='utf-8-sig')       