In [None]:
from zhipuai import ZhipuAI
Access_Token = 'Your Team Account'  # 比赛队伍Token，用于访问比赛数据库
MODEL_sql = "glm-4-plus"  #"glm-4-plus"#"glm-4-flashx"
MODEL_rag = "glm-4-plus"  #"glm-4-plus"#"glm-4-flashx"
client = ZhipuAI(api_key='Your ZhipuAI API_KEY')

# 1. 工具函数

## 1.1 LLM生成数据

In [None]:
from concurrent.futures import ThreadPoolExecutor
import queue


def create_chat_completion(messages, model):
    """
    Create a chat completion using the provided messages and model.
    
    Parameters:
        messages (list): A list of message dictionaries to pass to the model.
        model (str): The model name to use.
    
    Returns:
        response (dict): The response from the chat completion endpoint.
    """
    response = client.chat.completions.create(
        model=model,
        stream=False,
        messages=messages,
        temperature= 0.5
    )
    return response

def _threaded(func):
    """
    A function that adds threading capabilities to a function.
    The returned function will take two additional arguments: thread_id and result_queue.
    It will run the function and put the result in the result_queue as a tuple (thread_id, result).

    Args:
        func (Callable): The function to be wrapped.

    Returns:
        Callable: The wrapped function.
    """
    def wrapper(*args, thread_id, result_queue, **kwargs):
        try:
            result = func(*args, **kwargs)
            result_queue.put((thread_id, result))
        except Exception as e:
            print(f"Exception in thread with kwargs: {kwargs}\n{e}")
            result_queue.put((thread_id, None))
    return wrapper

def async_llm_chain_call(
    messages,
    model,
    sampling_count = 1,
) :

    call_list = []
    engine_id = 0
    for _ in range(sampling_count):
        call_list.append({
            'function': create_chat_completion,
            'kwargs': {
                'messages': messages,    
                'model': model,    
                }
        })
        engine_id += 1

    result_queue = queue.Queue()
    # with ThreadPoolExecutor(max_workers=len(call_list)) as executor:
    with ThreadPoolExecutor(max_workers=20) as executor:
        for idx, call in enumerate(call_list):
            func = _threaded(call['function'])
            kwargs = call['kwargs']
            executor.submit(func, thread_id=idx, result_queue=result_queue, **kwargs)

    results = []
    while not result_queue.empty():
        results.append(result_queue.get())

    # Sort results based on their thread IDs
    results = sorted(results, key=lambda x: x[0])
    sorted_results = [result[1] for result in results]
    if len(sorted_results) == 1:
        return sorted_results[0]

    _sorted_results = {}
    for i, result in results:
        # select_result, intro_str = to_select(result.choices[0].message.content)
        try:
            _sorted_results[i] = [result.choices[0].message.content]
        except:
            pass
    
    idx_ = get_best(_sorted_results, messages, model)

    return sorted_results[idx_]

def get_best(sorted_results, context, model):
    prompt = f"""Instructions:
--------------
根据对话上下文，请从下面的回复中选择一个最好的回复，只输出回复的标号
--------------
这是上下文：{str(context)}
---------------
这是回复：
--------------
{str(sorted_results)}
--------------

请仅回复标号："""
    messages = [{"role": "user", "content": prompt}]
    response = create_chat_completion(messages, model)
    answer = response.choices[0].message.content
    num = len(sorted_results)
    for i in range(num):
        if str(i) in answer:
            return i
    return 0

## 1.3 找到LLM回复中的json数据

In [None]:
import re
import json

def find_json(text):
    """
    Attempt to extract and parse a JSON object from the provided text.
    The function tries up to three attempts using two patterns:
      1. A Markdown code block with ```json ... ```
      2. A more general JSON-like pattern using { and }

    If successful, returns the parsed JSON data.
    If parsing fails after all attempts, returns the original text.
    
    Parameters:
        text (str): The input text from which to extract JSON.
    
    Returns:
        dict or str: Parsed JSON dictionary if successful, else the original text.
    """
    max_attempts = 3
    for attempt in range(1, max_attempts + 1):
        json_pattern = r"```json\n(.*?)\n```"
        match = re.search(json_pattern, text, re.DOTALL)
        if not match:
            json_pattern2 = r"({.*?})"
            match = re.search(json_pattern2, text, re.DOTALL)

        if match:
            json_string = match.group(1) if match.lastindex == 1 else match.group(0)
            # Remove Markdown formatting if present
            json_string = json_string.replace("```json\n", "").replace("\n```", "")
            try:
                data = json.loads(json_string)
                return data
            except json.JSONDecodeError as e:
                if attempt < max_attempts:
                    text = fix_json(json_string, e, model=MODEL_rag)
                    print(f"Attempt {attempt}: Failed to parse JSON, reason: {e}. Retrying...")
                    # 写入log文件
                    with open("log.txt", "a") as f:
                        f.write(f"Attempt {attempt}: Failed to parse JSON, reason: {e}. Retrying...\n")
                else:
                    print(f"All {max_attempts} attempts to parse JSON failed. Returning original text.")
                    # 写入log文件
                    with open("log.txt", "a") as f:
                        f.write(f"All {max_attempts} attempts to parse JSON failed. Returning original text.\n")
        else:
            if attempt < max_attempts:
                print(f"Attempt {attempt}: No JSON string found in the text. Retrying...")
                # 写入log文件
                with open("log.txt", "a") as f:
                    f.write(f"Attempt {attempt}: No JSON string found in the text. Retrying...\n")
            else:
                print("No matching JSON string found. Returning original text.")
                # 写入log文件
                with open("log.txt", "a") as f:
                    f.write("No matching JSON string found. Returning original text.\n")

        # If no match or no success in this attempt, return the original text
    return text

def fix_json(text, json_error, model):
    """
    修复JSON字符串，使其成为有效的JSON。
    """
    NAIVE_FIX = f"""Instructions:
--------------
请修复JSON字符串，使其成为有效的JSON。
--------------

下面是原始的JSON字符串：
--------------
{text}
--------------
下面是的错误信息：
--------------
{json_error}
--------------

请仅回复json，用```json ... ```包裹json字符串："""
    
    messages = [{"role": "user", "content": NAIVE_FIX}]
    response = create_chat_completion(messages, model)
    answer = response.choices[0].message.content
    return answer


## 1.4 读取题目文件相关函数

In [None]:
import jieba

def map_chinese_to_english_tables(chinese_names, english_names):
    """
    Map Chinese table names to their corresponding English table names.
    For each Chinese name, there is a matching English name 
    (case-insensitive comparison).
    
    Parameters:
        chinese_names (list): A list of Chinese table names.
        english_names (list): A list of English table names.
        
    Returns:
        name_map (dict): A dictionary mapping Chinese table names to English table names.
    """
    name_map = {}
    for cname in chinese_names:
        # Find the corresponding English name (case-insensitive match)
        english_match = [en for en in english_names if str(en).lower() == cname.lower()][0]
        name_map[cname] = english_match
    return name_map

def get_table_schema(table_shema, database_table_en, question=''):
    """
    Retrieve table schemas along with optional filtered field comments.
    If a question is provided, the comments will be filtered based on 
    question keywords.
    
    The function:
      1. Maps Chinese table names to English table names.
      2. For each table, retrieves its structure and finds associated comments.
      3. If a question is provided, filter the comments based on keywords extracted from the question.
    
    Parameters:
        question (str): The question text. If empty, no filtering is performed.
        table_shema (list): A list of dictionaries containing table schema information.
        
    Returns:
        table_maps (list): A list of dictionaries, each containing table schema information.
        {
            '数据表名': EnglishTableName,
            '数据表结构': TableStructure,
            '字段注释': FilteredComments (optional if question is provided)
        }
    """

    parsed_tables = table_shema

    # List of Chinese table names (keys)
    chinese_table_names = [i['table'] for i in parsed_tables]

    name_map = map_chinese_to_english_tables(chinese_table_names, database_table_en)


    table_maps = []
    for table in parsed_tables:
        
        # Filter comments based on question
        table_map = {
            '数据表名': name_map.get(table['table']),
            '数据表结构': get_simple_schema(table['schema'], question),
            # '字段注释': filtered_comments
        }

        table_maps.append(table_map)

    return table_maps


def get_simple_schema(table_schema, question):
    table_simple = []
    if question == "":
        for table in table_schema:
            data_example = table["数据示例"]
            # 数据示例取前50个字符
            data_example = str(data_example)[:50]
            table_simple.append({
                "列名": table["列名"],
                '中文描述': table["中文描述"],
                '数据示例': data_example,
            })
    else:
        for table in table_schema:
            data_example = table["数据示例"]
            # 数据示例取前50个字符
            data_example = str(data_example)[:50]
            comment = is_add_comment(question, table["注释"])
            if comment:
                table_simple.append({
                    "列名": table["列名"],
                    '中文描述': table["中文描述"],
                    '数据示例': data_example,
                    '注释': comment,
                })
            else:
                table_simple.append({
                    "列名": table["列名"],
                    '中文描述': table["中文描述"],
                    '数据示例': data_example,
                })
    
    return table_simple

def is_add_comment(question, comment):
    if comment and str(comment) != "nan":
        stopwords = ['？', '有', '的', '多少', '人', '（', '）']
        seg_list = list(jieba.cut(question, cut_all=False))
        filtered_seg_list = [word for word in seg_list if word not in stopwords]

        if any(keyword in comment for keyword in filtered_seg_list):
            return comment
    return None

def clean_text(text):
    """
    Remove any parenthetical segments (including Chinese parentheses) and trim whitespace.
    For example, "This is a sentence(remark)" -> "This is a sentence"
    
    Parameters:
        text (str): The text to clean.
        
    Returns:
        str: The cleaned text.
    """
    pattern = r'[\(（][^\)）]*[\)）]'  # Pattern to match parentheses and their contents
    cleaned_text = re.sub(pattern, '', text).strip()
    return cleaned_text

def find_dict_by_element(dict_list, target_element):
    """
    Given a list of dictionaries, return all dictionaries where  '列名中文描述' contains the target_element.
    Parameters:
        dict_list (list): A list of dictionaries, each expected to have '列名中文描述' key.
        target_element (str): The element to search for.
        
    Returns:
        list: A list of dictionaries that contain target_element in '列名中文描述'.
    """
    return [d for d in dict_list if target_element in d.get('列名中文描述', [])]


## 1.5 其他

In [None]:
def read_foreigns(foreigns_path):
    # 读取外键信息,txt文件
    with open(foreigns_path, 'r') as f:
        foreigns = f.readlines()
    foreigns = set([tuple(set(line.strip().split('='))) for line in foreigns])
    foreigns = [f"{foreign[0]}={foreign[1]}" for foreign in foreigns if len(foreign) == 2]

    return foreigns

def dict_to_sentence(data):
    """
    Convert a dictionary into a descriptive sentence by enumerating key-value pairs.
    For example: {"name": "John", "age": 30} -> "name 是 John, age 是 30"
    
    Parameters:
        data (dict): The dictionary to convert.
        
    Returns:
        str: A sentence describing the dictionary keys and values.
    """
    try:
        if not isinstance(data, dict):
            raise ValueError("Input is not a dictionary")

        return ", ".join(f"{key} 是 {value}" for key, value in data.items())
    except Exception as e:
        print(f"Error in dict_to_sentence: {e}")
        # 写入log文件
        with open("log.txt", "a") as f:
            f.write(f"Error in dict_to_sentence: {e}\n")
        return str(data)

def process_dict(d):
    """
    Recursively process a nested dictionary to produce a comma-separated description.
    For nested dictionaries, it processes them recursively and returns a descriptive string.
    
    For example:
        {
            "company": {
                "name": "ABC Corp",
                "location": "New York"
            },
            "year": 2021
        }
    might be processed into a string like:
        "company company 是 name 是 ABC Corp, location 是 New York, year 2021"
    
    Parameters:
        d (dict): A dictionary or another object to describe.
        
    Returns:
        str: A descriptive string.
    """

    def recursive_process(sub_dict):
        sentences = []
        for key, value in sub_dict.items():
            if isinstance(value, dict):
                # Process nested dictionary and wrap result in dict_to_sentence for formatting
                nested_result = recursive_process(value)
                sentences.append(dict_to_sentence({key: nested_result}))
            else:
                # Non-dict values are directly appended
                sentences.append(f"{key} {value}")
        return ", ".join(sentences)

    if not isinstance(d, dict):
        # If it's not a dictionary, just return its string representation
        return str(d)

    return recursive_process(d)

# 2. 预处理数据表结构
## 2.1 改进题目提供的数据表结构

In [None]:
import pandas as pd
import jieba
import re
import requests
import json
from collections import Counter
from tqdm import tqdm


def make_sql_schema(column_name, table_name):
    sql_text = f"""
    SELECT {column_name} FROM {table_name};
    """
    return sql_text

def select_data_schema(sql_text):
    """
    Sends the given SQL query to a specified endpoint and returns the JSON response.
    
    Parameters:
        sql_text (str): The SQL query to be executed.
        
    Returns:
        str: The JSON response from the API, formatted with indentation.
    """
    url = "https://comm.chatglm.cn/finglm2/api/query"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f'Bearer {Access_Token}'
    }
    data = {
        "sql": sql_text,  # e.g. SELECT * FROM constantdb.secumain LIMIT 10
        "limit": 100
    }
    response = requests.post(url, headers=headers, json=data)    
    response_json = response.json()
    if response.status_code == 200:
        examples = response_json.get('data')
        for example in examples:
            # 得到字典example的values
            value = list(example.values())[0]
            if value:
                return str(value)[:50]
        return None
    else:
        # 抛出异常，并打印错误信息
        print(response.status_code)
        print(sql_text)
        return None
        # raise Exception(response.status_code)

def get_value(col_name, table_name):
    """
    判断字段col_name的值是否是唯一的
    """
    sql_text = f"SELECT DISTINCT {col_name} FROM {table_name} ;"

    url = "https://comm.chatglm.cn/finglm2/api/query"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f'Bearer {Access_Token}'
    }
    data = {
        "sql": sql_text,  # e.g. SELECT * FROM constantdb.secumain LIMIT 10
        "limit": 100
    }
    response = requests.post(url, headers=headers, json=data)    
    response_json = response.json()
    if response.status_code == 200:
        examples = response_json.get('data')
        if len(examples) == 1:
            return False
        return examples
    else:
        # 抛出异常，并打印错误信息
        print(response.status_code)
        print(sql_text)
        return False

# 读取官方的数据表信息

df1 = pd.read_excel(r'../../assets/data_dictionary.xlsx', sheet_name='库表关系')
df2 = pd.read_excel(r'../../assets/data_dictionary.xlsx', sheet_name='表字段信息')


df1['库表名英文'] = df1['库名英文'] + '.' + df1['表英文']
df1['库表名中文'] = df1['库名中文'] + '.' + df1['表中文']

database_name = list(df1['库名中文'])
table_name = list(df1['表中文'])
table_description = list(df1['表描述'])
table_name_en = list(df1['表英文'])
database_table_ch = list(df1['库表名中文'])
database_table_en = list(df1['库表名英文'])
database_table_en_zs = {'库表名': database_table_en, 
                        '对应中文注释说明': table_name,
                        '对应中文表描述': table_description}
database_table_map = df1.set_index('库表名中文')['库表名英文'].to_dict()

##### 1.将数据信息转换为json格式，并加入每个字段得数据例子 #####
table_schema_list = []
for table in tqdm(df1['库表名英文']):
    table_schema = {}
    table_name_i = table
    table_schema["DB"] = table.split('.')[0]
    table_schema["table"] = table_name_i
    # 在df1中找到table对应的行
    table_name_anno = df1[df1['库表名英文'] == table]
    table_schema["description"] = table_name_anno['表描述'].values[0]
    # pandas查找对应的值
    col_name_anno = df2[df2['table_name'] == table.split('.')[-1]]
    col_name_anno = col_name_anno.copy()
    # 将col_name_anno每一行变成字典
    col_name_anno['schema'] = col_name_anno.apply(lambda row: {"列名": row['column_name'], 
                                                     "中文描述": row['column_description'],
                                                     "注释": row['注释'],
                                                     "数据示例": select_data_schema(make_sql_schema(row['column_name'], table_name_i))}, 
                                        axis=1)
    
    list_of_dicts = col_name_anno['schema'].tolist()
    unique_dicts = []
    seen = set()
    for d in list_of_dicts:
        # 将字典转换成元组
        d_tuple = tuple(sorted(d.items()))
        if d_tuple not in seen:
            unique_dicts.append(d)
            seen.add(d_tuple)

    # 删除table_schema['schema']中列名为PresiderOfficialPost的字典
    unique_dicts = [d for d in unique_dicts if d['列名'] != 'PresiderOfficialPost']
    # 删除table_schema['schema']中列名为nan的字典
    table_schema['schema'] = [d for d in unique_dicts if str(d['列名']) != 'nan']

    table_schema_list.append(table_schema)

# 将table_schema_list写入jsonl文件
with open('my_table_schema.jsonl', 'w', encoding='utf-8') as f:
    for item in table_schema_list:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')


##### 2. 将筛选字段值只有唯一一个的字段从shema中删除（这里写的冗余了，这步可以和上一步结合起来） ######
table_schema_list = []
# 读取jsonl文件
with open('my_table_schema.jsonl', 'r', encoding='utf-8') as file:
    content = [json.loads(line) for line in file]

for table_schema in tqdm(content):
    _table = {
    'DB': table_schema['DB'],
    'table': table_schema['table'],
    'description': table_schema['description'],
    'schema': []
    }
    table_name_i = table_schema['table']
    col_schema = table_schema['schema']
    for col_ in col_schema:
        if col_['列名'] == 'ID':
            continue
        _data = get_value(col_['列名'], table_name_i)
        if _data:
            _table['schema'].append(col_)
    table_schema_list.append(_table)
        

# 将table_schema_list写入jsonl文件
with open('my_table_schema.jsonl', 'w', encoding='utf-8') as f:
    for item in table_schema_list:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')


## 2.2 根据题目提供的数据表结构生成表之间的关系

In [None]:
def find_foreign(text):
    max_attempts = 3
    for attempt in range(1, max_attempts + 1):
        json_pattern = r"【FOREIGN KEY】/\n(.*?)\n/【FOREIGN KEY】"
        match = re.search(json_pattern, text, re.DOTALL)
        if not match:
            return None

        if match:
            json_string = match.group(1) if match.lastindex == 1 else match.group(0)
            # Remove Markdown formatting if present
            json_string = json_string.replace("```【FOREIGN KEY】/\n", "").replace("\n/【FOREIGN KEY】", "")
            return json_string


def get_foreign(description):
    prompt = '''
    你是一个乐于解答各种问题的助手，你的任务是为用户提供专业、准确、有见地的建议。
用户会给你一个数据库表中的FOREIGN KEY信息，你需要以下面的格式输出FOREIGN KEY：
【FOREIGN KEY】/
表名.字段名
/【FOREIGN KEY】

如果用户没有提供FOREIGN KEY信息，你需要输出：
【没有关联信息】

例子1：律师事务所企业编号(LawOfficeCode)：与机构基本资料表(LC_InstiArchive)中公司代码(CompanyCode)关联，得到预测机构的具体信息。
输出：
【FOREIGN KEY】/
LC_InstiArchive.CompanyCode
/【FOREIGN KEY】

例子2：3212-方案部分实施，3301-已注册未发行，3302-已发行有额度，3303-已发行无额度，3304-提前终止，3305-放弃，3399-其他。
输出：
【没有关联信息】
    '''

    messages = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': description}]
    aa = create_chat_completion(messages, model=MODEL_sql)
    bb = find_foreign(aa.choices[0].message.content)
    return bb

# 读my_table_schema.jsonl
with open(r'my_table_schema.jsonl', 'r', encoding='utf-8') as f:
    table_schema = [json.loads(line) for line in f]

df1 = pd.read_excel(r'../../assets/data_dictionary.xlsx', sheet_name='库表关系')
df1['库表名英文'] = df1['库名英文'] + '.' + df1['表英文']
database_table_en = list(df1['库表名英文'])


# List of Chinese table names (keys)
chinese_table_names = [i['table'] for i in table_schema]

name_map = map_chinese_to_english_tables(chinese_table_names, database_table_en)

db_table_col = []
# 循环遍历table_schema
for table in tqdm(table_schema):
    table_name = name_map.get(table['table'])
    for column in table["schema"]:
        column_name = column["列名"]
        db_table_col.append(f"{table_name}.{column_name}")

foreigns = []
# 循环遍历table_schema
for table in tqdm(table_schema):
    table_name = name_map.get(table['table'])
    for column in table["schema"]:
        column_name = column["列名"]
        description = column["注释"]
        if str(description) == "nan":
            continue
        foreign = get_foreign(description)
        if foreign:
            _database = [i for i in db_table_col if "."+foreign in i]
            if len(_database) == 1:
                _database = _database[0]
                foreigns.append(f"{table_name}.{column_name}={_database}")
            else:
                print(description)

# 保存foreigns
with open("foreigns.txt", "w") as f:
    for foreign in foreigns:
        f.write(foreign + "\n")


## 2.3 LLM生成关键词的别名

In [None]:
df1 = pd.read_excel(r'../../assets/data_dictionary.xlsx', sheet_name='库表关系')
df2 = pd.read_excel(r'../../assets/data_dictionary.xlsx', sheet_name='表字段信息')
df1['库表名英文'] = df1['库名英文'] + '.' + df1['表英文']
df1['库表名中文'] = df1['库名中文'] + '.' + df1['表中文']

database_name = list(df1['库名中文'])
table_name = list(df1['表中文'])
table_name_en = list(df1['表英文'])
database_table_ch = list(df1['库表名中文'])
database_table_en = list(df1['库表名英文'])
database_table_en_zs = {'库表名': database_table_en, '对应中文注释说明': table_name}
database_table_map = df1.set_index('库表名中文')['库表名英文'].to_dict()

database_L_zh = []
for i in table_name_en:
    df3 = df2[df2['table_name'] == i]
    name = df1[df1['表英文'] == i]['库表名英文'].iloc[0]
    column_name = list(df3['column_name'])
    column_name_zh = list(df3['column_description'])
    column_name_2 = list(df3['注释'].dropna())

    dict_2 = {'数据表名': name, '列名': column_name, '列名中文描述': column_name_zh, '注释': column_name_2}
    database_L_zh.append(dict_2)

L_num = []
for items in database_L_zh:
    L_num += items['列名中文描述']

# Get unique column descriptions
L_num_new = [item for item, count in Counter(L_num).items() if count == 1]

# Drop NaN if any
series_num = pd.Series(L_num_new)
L_num_new = list(series_num.dropna())

# Remove known irrelevant items
irrelevant_items = ['年度', '占比']
for irr in irrelevant_items:
    if irr in L_num_new:
        L_num_new.remove(irr)

items_another = {} 
sensitive_term = []
for item in tqdm(L_num_new):
    messages = [{'role': 'system', 'content': '你是一个乐于解答各种问题的助手，你的任务是为用户提供专业、准确、有见地的建议。'}, 
                {'role': 'user', 'content': item + """ 请把这个词在口头语的说法尽量多的写出来，可以写长、写短，用json格式输出，格式如下{"short":[], "medium":[], "long":[]}"""}]
    try:
        aa = create_chat_completion(messages, model=MODEL_sql)
    except Exception as e:
        print(e)
        print(item)
        sensitive_term.append(item)
        items_another[item] = [item]
        continue
    bb = find_json(aa.choices[0].message.content)
    
    if not isinstance(bb, dict):
        print(item)
        sensitive_term.append(item)
        items_another[item] = [item]
        continue

    bb_list = [item]
    # 遍历bb字典
    for key, value in bb.items():
        bb_list.extend(value)
    items_another[item] = list(set(bb_list))

# 加入一些模型敏感词
items_another["联系人电话"] = [
        "联系人电话",
        "联系人",
        "电话号码",
        "联系号码",
        "联系人的电话号码",
        "联系人的电话"
    ]

items_another["信息披露网址"] = [
        "信息披露网址",
        "信息披露网站"
    ]

items_another["董秘电话"] = [
        "董秘电话",
        "董秘",
        "董事会秘书",
        "董事会秘书电话",
        "董事会秘书的联系电话",
        "董事会秘书的联系电话"
    ]

items_another["预收账款/营业收入TTM(%)"] = [
        "预收账款/营业收入TTM(%)",
        "预收账款/营业收入TTM",
        "预收/营业"
    ]

items_another["信托公司持股比例(%)"] = [
        "信托公司持股比例(%)",
        "信托公司持股比例",
        "信托持股"
    ]

items_another["增发股份上市日期"] = [
        "增发股份上市日期",
        "增发股份上市日"
    ]

items_another["今开盘(元)"] = [
        "今开盘(元)",
        "今开",
        "今开盘",
        "开盘价",
        "开盘"
    ]

items_another["今开盘(元)"] = [
        "今开盘(元)",
        "今开",
        "今开盘",
        "开盘价",
        "开盘"
    ]

items_another["复牌时间"] = [
        "复牌时间",
        "复牌",
        "复牌日期",
        "复牌日",
        "复牌时刻"
    
    ]

items_another["公司类别描述"] = [
        "公司类别描述",
        "公司类别",
        "公司类别说明"
    ]

items_another["近三个月成交量(股)"] = [
        "近三个月成交量(股)",
        "近三个月成交",
        "近三个月成交股数",
        "近三个月成交数",
        "近三个月成交数量",
        "近三个月成交股",
        "近三个月成交量",
        "近三个月成交额"
    ]

items_another["开盘价"] = [
        "开盘价",
        "开盘"
    ]

items_another["基金代码"] = [
        "基金代码"
    ]

# 保存json
with open('items_another.json', 'w', encoding='utf-8') as f:
    json.dump(items_another, f, ensure_ascii=False, indent=4)

# 3. 推理
## 3.1 读取预处理中的所有数据

In [None]:
import pandas as pd
question_data_path = '../../assets/question.json'
df1 = pd.read_excel('../../assets/data_dictionary.xlsx', sheet_name='库表关系')
df2 = pd.read_excel('../../assets/data_dictionary.xlsx', sheet_name='表字段信息')
df1['库表名英文'] = df1['库名英文'] + '.' + df1['表英文']
df1['库表名中文'] = df1['库名中文'] + '.' + df1['表中文']

database_name = list(df1['库名中文'])
table_name = list(df1['表中文'])
table_name_en = list(df1['表英文'])
database_table_ch = list(df1['库表名中文'])
database_table_en = list(df1['库表名英文'])
database_table_en_zs = {'库表名': database_table_en, '对应中文注释说明': table_name}
database_table_map = df1.set_index('库表名中文')['库表名英文'].to_dict()

database_L = []
database_L_zh = []
for i in table_name_en:
    df3 = df2[df2['table_name'] == i]
    name = df1[df1['表英文'] == i]['库表名英文'].iloc[0]
    column_name = list(df3['column_name'])
    column_name_zh = list(df3['column_description'])
    column_name_2 = list(df3['注释'].dropna())

    dict_1 = {'数据表名': name, '列名': column_name, '注释': column_name_2}
    dict_2 = {'数据表名': name, '列名': column_name, '列名中文描述': column_name_zh, '注释': column_name_2}
    database_L.append(dict_1)
    database_L_zh.append(dict_2)

# 读取上面生成的表结构文件
file_path = 'my_table_schema.jsonl'
with open(file_path, 'r', encoding='utf-8') as file:
    content = [json.loads(line) for line in file]
input_text = content

# 读取上面生成的表联系文件
foreigns_path = 'foreigns.txt'
foreigns = read_foreigns(foreigns_path)

# 读取上面生成items_another.json文件
with open('items_another.json', 'r', encoding='utf-8') as f:
    items_another = json.load(f)

### 3.1.2  提取SQL，查询数据函数

In [None]:
def replace_date_with_day(sql):
    """
    This function replaces instances of exact date conditions in a SQL 
    statement from a format like:
        TradingDate = 'YYYY-MM-DD'
    to:
        date(TradingDate) = 'YYYY-MM-DD'
    
    Parameters:
        sql (str): The original SQL statement.
        
    Returns:
        str: The modified SQL statement, or the original if no match is found.
    """
    # Regex pattern to match patterns like: ColumnName = 'YYYY-MM-DD'
    pattern = r"([.\w]+)\s*=\s*'(\d{4}-\d{2}-\d{2})'"

    def replace_func(match):
        column_name = match.group(1)
        date_value = match.group(2)
        return f"date({column_name}) = '{date_value}'"

    new_sql = re.sub(pattern, replace_func, sql)

    # If no change was made, return the original SQL
    return new_sql if new_sql != sql else sql

def select_data(sql_text):
    """
    Sends the given SQL query to a specified endpoint and returns the JSON response.
    
    Parameters:
        sql_text (str): The SQL query to be executed.
        
    Returns:
        str: The JSON response from the API, formatted with indentation.
    """
    url = "https://comm.chatglm.cn/finglm2/api/query"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f'Bearer {Access_Token}'
    }
    data = {
        "sql": sql_text.replace("`", ""),  # e.g. SELECT * FROM constantdb.secumain LIMIT 10
        "limit": 15
    }
    response = requests.post(url, headers=headers, json=data)
    try:
        result = json.dumps(response.json(), indent=2, ensure_ascii=False)
        if "查询执行失败" in result:
            return result, "查询执行失败，请检查SQL语句是否正确"
        result_json = response.json()['data']
        if len(result_json) >= 15:
            return result, "由于性能问题，数据库最多只能返回15条数据，请用聚合函数count、sum、avg、max、min等来查询"
        return result, ""
    except json.decoder.JSONDecodeError:
        return str(response), "查询失败，请重新生成SQL语句"
        # raise Exception("API response is not in JSON format.")
    
def extract_sql(text):
    """
    Extracts an SQL statement from a block of text enclosed in triple backticks:
        ```sql
        SELECT ...
        ```
    
    Parameters:
        text (str): The full text containing an SQL statement.
        
    Returns:
        str: The extracted SQL statement, or a message if not found.
    """
    sql_pattern = re.compile(r'```sql(.*?)```', re.DOTALL)
    match = sql_pattern.search(text)
    if match:
        # Strip leading and trailing whitespace from the matched SQL
        return match.group(1).strip()
    else:
        return "No SQL statement found."

def to_select(text,):
    """
    High-level function that:
      1. Extracts SQL from the given text.
      2. Optimizes the extracted SQL by converting date columns to 'date(...)'.
      3. Executes the optimized SQL through select_data and returns the result.
    
    Parameters:
        text (str): The input text containing an SQL statement.
        
    Returns:
        str: The JSON response from the SQL query.
    """
    sql_statement = extract_sql(text)
    print('***********Extracted SQL****************')
    print(sql_statement)
    with open('log.txt', 'a', encoding='utf-8') as f:
        f.write(sql_statement + '\n')
    print('***********Extracted SQL****************')
    # 写入log文件
    with open('log.txt', 'a', encoding='utf-8') as f:
        f.write('***********Extracted SQL****************' + '\n')
    if 'No SQL statement found.' in sql_statement:
        return "未找到SQL语句，请重新生成sql。或者在提示我：<全部完成，答案如下>后，直接回答问题。", "", sql_statement
    optimized_sql = replace_date_with_day(sql_statement)
    result, intro_str = select_data(optimized_sql)
    return result, intro_str, optimized_sql

## 3.2 定义对话逻辑
### 3.2.1 问题的预处理

In [None]:
import requests
# 下面是对模型来说较难的查询关键词
examples = {
    "查询美股公司信息": "查询美股公司，要同时查询USStockDB.US_CompanyInfo 和 ConstantDB.US_SecuMain中的信息,不要遗漏",
    "查询港股公司信息": "查询港股公司，要同时查询HKStockDB.HK_StockArchives 和 ConstantDB.HK_SecuMain中的信息,不要遗漏",
    "查询的表名中有DailyQuote": "查询AStockMarketQuotesDB.QT_DailyQuote, USStockDB.US_DailyQuote中的特定公司股票时，不要使用其他字段筛选特定公司，示例sql语句:SELECT * FROM AStockDailyQuote WHERE  InnerCode = 1234;",
    "查询某行业市值": "查询某行业市值,示例sql语句:SELECT TotalMV, NegotiableMV, FreeFloatMV FROM AStockIndustryDB.LC_IndustryValuation WHERE date(TradingDay) = '2020-07-02' AND IndustryName = '风电零部件';",
    "进行加减乘除数学计算": "使用sql进行加减乘除数学计算。",
    "比例/百分比是多少": "查询比例/百分比是多少，要考虑加上百分号后再进行四舍五入。",
    "近一个月最高价": "查询近一个月最高价,你写的sql语句可以优先考虑表中已有字段HighPriceRM  近一月最高价(元)",
    "近一个月最低价": "查询近一月最低价(元),你写的sql语句直接调用已有字段LowPriceRM",
    "查询某行业数量": "查询某行业某年数量 示例sql语句:SELECT count(*) as 风电零部件_2021 FROM AStockIndustryDB.LC_ExgIndustry where ThirdIndustryName like '%风电零部件%' and year(InfoPublDate)=2021 and IfPerformed = 1;",
    "某股票/公司属于哪些行业/概念板块？": "查询某股票/公司属于哪些概念板块？ 示例sql语句:SELECT ConceptCode, ConceptName from AStockIndustryDB.LC_ConceptList WHERE ConceptCode IN (SELECT DISTINCT ConceptCode  FROM AStockIndustryDB.LC_COConcept WHERE InnerCode = 1167);",
    "某行业/概念板块有哪些股票/公司？": "查询某概念板块有哪些股票/公司？ 示例sql语句:SELECT InnerCode FROM AStockIndustryDB.LC_COConcept WHERE ConceptCode = 11100021;",
    """持有无限售流通A股数量""": """特别重要一定注意，查询最新更新XXXX年年度报告，机构持有无限售流通A股数量合计InstitutionsHoldProp最多公司代码，优先使用查询sql语句，SELECT *
                            FROM AStockShareholderDB.LC_StockHoldingSt
                            WHERE date(EndDate) = 'XXXX-12-31'
                              AND UpdateTime = (
                                SELECT MAX(UpdateTime)
                                FROM AStockShareholderDB.LC_StockHoldingSt
                                WHERE date(EndDate) = 'XXXX-12-31'
                              ) order by InstitutionsHoldings desc limit 1 ，XXXX代表问题查询年度，sql语句禁止出现group by InnerCode;

                              查询最新更新XXXX年年度报告,公司机构持有无限售流通A股比例合计InstitutionsHoldProp是多少,优先使用查询sql语句，SELECT InstitutionsHoldProp
                            FROM AStockShareholderDB.LC_StockHoldingSt
                            WHERE date(EndDate) = 'XXXX-12-31'
                              AND UpdateTime = (
                                SELECT MAX(UpdateTime)
                                FROM AStockShareholderDB.LC_StockHoldingSt
                                WHERE date(EndDate) = 'XXXX-12-31'
                              ) order by InstitutionsHoldings desc limit 1 ，XXXX代表问题查询年度，sql语句禁止出现group by InnerCode;""",
    "xxx指标 新高 最多的交易日": """
    xxx指标 新高 最多的交易日 要用AStockMarketQuotesDB.CS_StockPatterns现有字段，例子中IfHighestTVRMThree字段可以根据情况灵活调整
        查询成交量创近一季度新高的证券数量和交易日，示例sql语句:
            SELECT count(*) as num, TradingDay  FROM AStockMarketQuotesDB.CS_StockPatterns where  IfHighestTVRMThree=1 group by TradingDay ORDER BY num DESC limit 1;
        查询某日成交量创近一季度新高的证券，示例sql语句:
            SELECT InnerCode, TradingDay  FROM AStockMarketQuotesDB.CS_StockPatterns where  IfHighestTVRMThree=1 and date(TradingDay) = '2021-12-23';
""",
    "新高": """新高 要用AStockMarketQuotesDB.CS_StockPatterns现有字段
        查询今天是2021年01月01日，创近半年新高的股票有几只。示例sql语句:SELECT count(*)  FROM AStockMarketQuotesDB.CS_StockPatterns
                where  IfHighestHPriceRMSix=1 and date(TradingDay)='2021-01-01';
        判断某日 YY-MM-DD  InnerCode XXXXXX 是否创近一周的新高，查询结果1代表是,IfHighestHPriceRW字段可以根据情况灵活调整  SELECT   InnerCode,TradingDay,IfHighestHPriceRW  FROM AStockMarketQuotesDB.CS_StockPatterns
where  date(TradingDay)='2021-12-20' and InnerCode = '311490'""",
    "成交额": """查询这家公司一周内成交额是多少。示例sql语句:SELECT TurnoverValueRW AS TurnoverValueWan
FROM AStockMarketQuotesDB.QT_StockPerformance
WHERE InnerCode = 1289 AND date(TradingDay) = '2021-06-17';""",
    "半年度报告": """查询XXXX年半年度报告的条件为：year(EndDate) = XXXX and InfoSource='半年度报告'""",
    
}

def exec_sql_s(sql, limit = 10):
    """
    Execute a given SQL query on a remote endpoint and return the result.
    Uses 'Access_Token' for authorization and limits the result to 10 rows.

    Parameters:
        sql (str): The SQL query to be executed.

    Returns:
        list: The query result as a list of rows (dictionaries), or None if not found.
    """
    headers = {
        "Authorization": f'Bearer {Access_Token}',
        "Accept": "application/json"
    }
    url = "https://comm.chatglm.cn/finglm2/api/query"

    response = requests.post(url, headers=headers, json={
        "sql": sql.replace("`", ""),
        "limit": limit
    })
    response_json = response.json()

    # If there's no 'data' field, print the full response for debugging
    if 'data' not in response_json:
        print(response_json)
        # 写入log文件
        with open('log.txt', 'a', encoding='utf-8') as f:
            f.write(str(response_json) + '\n')
        if response_json['detail'] == 'Invalid authentication credentials':
            raise Exception('Invalid authentication credentials')

    # Return 'data' if present
    return response_json.get('data', None)

def process_question(question, model):
    """
    Given a question, run it through a prompt to perform Named Entity Recognition (NER),
    extract entities (公司名称，代码，基金名称，概念名称，人名), parse the assistant's JSON response,
    and process the items to retrieve relevant information from the database.

    Parameters:
        question (str): The user question.

    Returns:
        tuple: (res, tables) where
               res (str) - Processed result details as a string.
               tables (list) - List of tables involved in the final result.
    """
    prompt = '''
    你将会进行命名实体识别任务，并输出实体json，主要识别以下几种实体：
    公司名称，代码，基金名称，概念名称，人名。

    其中，公司名称可以是全称，简称，拼音缩写，代码包含股票代码和基金代码，基金名称包含债券型基金，
    以下是几个示例：
    user:唐山港集团股份有限公司是什么时间上市的（回答XXXX-XX-XX）
    当年一共上市了多少家企业？
    这些企业有多少是在北京注册的？
    assistant:```json
    [{"公司名称":"唐山港集团股份有限公司"}]
    ```
    user:JD的职工总数有多少人？
    该公司披露的硕士或研究生学历（及以上）的有多少人？
    20201月1日至年底退休了多少人？
    assistant:```json
    [{"公司名称":"JD"}]
    ```
    user:600872的全称、A股简称、法人、法律顾问、会计师事务所及董秘是？
    该公司实控人是否发生改变？如果发生变化，什么时候变成了谁？是哪国人？是否有永久境外居留权？（回答时间用XXXX-XX-XX）
    assistant:```json
    [{"代码":"600872"}]
    ```
    user:华夏鼎康债券A在2019年的分红次数是多少？每次分红的派现比例是多少？
    基于上述分红数据，在2019年最后一次分红时，如果一位投资者持有1000份该基金，税后可以获得多少分红收益？
    assistant:```json
    [{"基金名称":"华夏鼎康债券A"}]
    ```
    user:化工纳入过多少个子类概念？
    assistant:```json
    [{"概念名称":"化工"}]
    ```
    user:李一硕管理的基金中，规模最大的是哪一个？
    assistant:```json
    [{"人名":"李一硕"}]
    ```
    '''

    messages = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': question}]
    for _ in range(3):
        aa = create_chat_completion(messages, model)
        bb = find_json(aa.choices[0].message.content)
        # if bb is 
        if isinstance(bb, list):
            break
    return process_items(bb)

def process_items(item_list):
    """
    Given a list of items (dictionaries) from JSON extraction, attempt to process each based on its key:
    - If key is '基金名称' or '公司名称', use process_company_name.
    - If key is '代码', use process_code.
    - If key is '人名', use process_human.
    - If key is '概念名称', use process_concept.
    - Otherwise, print an unrecognized key message.

    Parameters:
        item_list (list): A list of dictionaries like [{"公司名称": "XX公司"}, {"代码":"600872"}].

    Returns:
        tuple: (res, tables)
               res (str): A formatted string showing what was found.
               tables (list): A list of table names where matches were found.
    """
    res_list = []
    lsh_res_list = []
    concept_list = []
    if isinstance(item_list, str):
        return item_list, [], []
    
    for item in item_list:
        try:
            key, value = list(item.items())[0]
        except:
            continue
        if key in ["基金名称", "公司名称"]:
            r, lsh_r = process_company_name(value)
            res_list.extend(r)
            lsh_res_list.extend(lsh_r)
        elif key == "代码":
            res_list.extend(process_code(value))
        elif key == "概念名称":
            concept_list.extend(process_concept(value))
        elif key == "人名":
            lsh_res_list.extend(process_human(value))

        else:
            print(f"无法识别的键：{key}")
            # 写入log文件
            with open("log.txt", "a", encoding='utf-8') as f:
                f.write(f"无法识别的键：{key}\n")

    if concept_list:
        concept_list.append("行业概念中，ConceptNames是SubclassName的子类，SubclassName是ClassName的子类；" + 
                            "SecondIndustryName是FirstIndustryName的子类，ThirdIndustryName是SecondIndustryName的子类，FourthIndustryName是ThirdIndustryName的子类；")

    # Filter out empty results
    res_list_ = []
    res_str = []
    for i in res_list:
        if i:
            if str(i) not in res_str:
                res_str.append(str(i))
                res_list_.append(i)

    res_list = res_list_
    res = ''
    tables = []
    for result_data, table_name in res_list:
        tables.append(table_name)
        res += f"预处理程序通过表格：{table_name} 查询到以下内容：\n {json.dumps(result_data, ensure_ascii=False, indent=1)} \n"
    
    for r in lsh_res_list:
        res += f"\n {r}"
    return res, tables, concept_list

def process_company_name(value):
    """
    Given a company name (or related keyword), search in three tables:
    ConstantDB.SecuMain, ConstantDB.HK_SecuMain, ConstantDB.US_SecuMain.

    Attempts to match various company-related fields (e.g., ChiName, EngName, etc.)
    and returns all matching results along with the table where they were found.

    Parameters:
        value (str): The company name or related string to match.

    Returns:
        list: A list of tuples (result, table) where result is the matched data and table is the table name.
              If no matches found, prints a message and returns an empty list.
    """
    res_lst = []
    lsh_res_lst = []
    tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
    columns_to_match = ['CompanyCode', 'SecuCode', 'ChiName', 'ChiNameAbbr',
                        'EngName', 'EngNameAbbr', 'SecuAbbr', 'ChiSpelling']
    columns_to_select = ['InnerCode', 'CompanyCode', 'SecuCode', 'ChiName', 'ChiNameAbbr',
                         'EngName', 'EngNameAbbr', 'SecuAbbr', 'ChiSpelling']

    # Escape single quotes to prevent SQL injection
    value = value.replace("'", "''")

    for table in tables:
        # For the US table, remove columns that may not be available
        local_match_cols = columns_to_match.copy()
        local_select_cols = columns_to_select.copy()
        if 'US' in table:
            if 'ChiNameAbbr' in local_match_cols:
                local_match_cols.remove('ChiNameAbbr')
            if 'ChiNameAbbr' in local_select_cols:
                local_select_cols.remove('ChiNameAbbr')
            if 'EngNameAbbr' in local_match_cols:
                local_match_cols.remove('EngNameAbbr')
            if 'EngNameAbbr' in local_select_cols:
                local_select_cols.remove('EngNameAbbr')

        # Build the WHERE clause with OR conditions for each column
        match_conditions = [f"{col} like '%{value}%'" for col in local_match_cols]
        where_clause = ' OR '.join(match_conditions)

        sql = f"""
        SELECT {', '.join(local_select_cols)}
        FROM {table}
        WHERE {where_clause}
        """
        result = exec_sql_s(sql)
        if result:
            res_lst.append((result, table))
    else:
        # The 'else' clause in a for loop runs only if no 'break' was encountered.
        # Here it just prints if no results were found.
        if not res_lst:
            print(f"未在任何表中找到公司名称为 {value} 的信息。")
            # 写入log文件
            with open('log.txt', 'a', encoding='utf-8') as f:
                f.write(f"未在任何表中找到公司名称为 {value} 的信息。\n")
            
    return res_lst, lsh_res_lst

def process_code(value):
    """
    Given a code (e.g., a stock code), search the three tables and return matches.

    Parameters:
        value (str): The code to search for.

    Returns:
        list: A list of tuples (result, table) if found, else empty.
    """
    res_lst = []
    tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
    columns_to_select = ['InnerCode', 'CompanyCode', 'SecuCode', 'ChiName', 'ChiNameAbbr',
                         'EngName', 'EngNameAbbr', 'SecuAbbr', 'ChiSpelling']

    value = value.replace("'", "''")  # Escape single quotes

    for table in tables:
        local_select_cols = columns_to_select.copy()
        if 'US' in table:
            if 'ChiNameAbbr' in local_select_cols:
                local_select_cols.remove('ChiNameAbbr')
            if 'EngNameAbbr' in local_select_cols:
                local_select_cols.remove('EngNameAbbr')

        sql = f"""
        SELECT {', '.join(local_select_cols)}
        FROM {table}
        WHERE SecuCode = '{value}'
        """
        result = exec_sql_s(sql)
        if result:
            res_lst.append((result, table))
    else:
        if not res_lst:
            print(f"未在任何表中找到代码为 {value} 的信息。")
            # 写入log文件
            with open('log.txt', 'a', encoding='utf-8') as f:
                f.write(f"未在任何表中找到代码为 {value} 的信息。\n")

    return res_lst

def process_concept(value):
    """
    Process the concept value and return the result list.
    """
    res_lst = ["进行where条件查询时，请充分考虑下面的条件："]
    res_lst = []
    tables = {'AStockIndustryDB.LC_ExgIndustry': ['Industry', 'FirstIndustryName', 'SecondIndustryName', 'ThirdIndustryName', 'FourthIndustryName'], 
              'AStockIndustryDB.LC_ExgIndChange': ['Industry', 'FirstIndustryName', 'SecondIndustryName', 'ThirdIndustryName', 'FourthIndustryName'], 
              'AStockIndustryDB.LC_IndustryValuation': ['IndustryName'],
              'AStockIndustryDB.LC_IndFinIndicators': ['IndustryName'], 
              'AStockIndustryDB.LC_ConceptList': ['ClassName', 'SubclassName', 'ConceptName']}

    value = value.replace("'", "''")  # Escape single quotes

    for table, cols in tables.items():
        match_conditions = [f"{col} = '{value}'" for col in cols]
        where_clause = ' OR '.join(match_conditions)
        sql = f"""
        SELECT {', '.join(cols)}
        FROM {table}
        WHERE {where_clause}
        """
        result = exec_sql_s(sql, limit = 1)
        if result:
            res_lst.append(f"{table}表中存在行业<{value}>，部分数据如下：{result} " )

    return res_lst

def process_human(value):
    """
    Process the human value and return the result list.
    """
    res_lst = ["进行where条件查询时，请充分考虑下面的条件："]
    res_lst = []
    tables = {
            'AStockBasicInfoDB.LC_StockArchives': ['GeneralManager', 'LegalConsultant'], 
            'AStockShareholderDB.LC_SHTypeClassifi': ['SHName', 'SHCode'], 
              'AStockShareholderDB.LC_MainSHListNew': ['SHList', 'GDID'], 
              'AStockShareholderDB.LC_Mshareholder': ['MSHName', 'GDID'],
              'AStockShareholderDB.LC_ActualController': ['ControllerName'], 
              'AStockShareholderDB.LC_ShareTransfer': ['TransfererName', 'ReceiverName'], 
              'AStockShareholderDB.LC_ShareFP': ['FPSHName', 'ReceiverName'], 
              'AStockShareholderDB.LC_ShareFPSta': ['FPSHName'], 
              'AStockShareholderDB.LC_LegalDistribution': ['InvestorName', 'StandardInvestorName', 'StandardAquirerName'], 
              'AStockShareholderDB.LC_NationalStockHoldSt': ['SHName'], 
              'CreditDB.LC_ViolatiParty': ['PartyName'],
              'AStockEventsDB.LC_InvestorDetail': ['PersonalName'],
              'AStockShareholderDB.LC_TransferPlan': ['SHName'],
              'PublicFundDB.MF_FundArchives': ['Manager', 'InvestAdvisorCode', 'TrusteeCode'],
              'PublicFundDB.MF_InvestAdvisorOutline': ['InvestAdvisorName', 'InvestAdvisorAbbrName', 'LegalRepr', 'GeneralManager'],
              'InstitutionDB.LC_InstiArchive': ['InvestAdvisorName', 'TrusteeName', 'LegalPersonRepr', 'GeneralManager', 'OtherManager', 'Contactman', ],
              }

    value = value.replace("'", "''")  # Escape single quotes

    for table, cols in tables.items():
        match_conditions = [f"{col} like '%{value}%'" for col in cols]
        where_clause = ' OR '.join(match_conditions)
        sql = f"""
        SELECT {', '.join(cols)}
        FROM {table}
        WHERE {where_clause}
        """
        result = exec_sql_s(sql)
        if result:
            res_lst.append(f"{table}表中存在人名<{value}>，部分数据如下：{result} " )

    return res_lst

def checkStockMarket(question, model):
    """
    判断股票市场
    """
    markets = []
    prompt = (
        "请判断要回答'<<question>>'，需要查询<A股、港股、美股>中的哪几个股票市场，"
        "请使用json回答\n"
        "格式如下：\n"
        "{'原因': '为什么要选择这几个股票市场', '选择的股票市场': [一个list]}"
    ).replace("<<question>>", question)

    messages = [{"role": "user", "content": prompt}]
    response = create_chat_completion(messages, model)
    answer = response.choices[0].message.content
    try:
        StockMarket = find_json(answer)['选择的股票市场']
        if StockMarket == []:
            return ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
        
        if 'A股' in StockMarket:
            markets.append('ConstantDB.SecuMain')
        if '港股' in StockMarket:
            markets.append('ConstantDB.HK_SecuMain')
        if '美股' in StockMarket:
            markets.append('ConstantDB.US_SecuMain')
    except:
        return ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
    return markets

def find_example(question):
    """
    找到与问题相关的示例问题
    """
    prompt = f'''
    用户问题：{question}
    你是一个金融领域的专家，请根据用户问题，从以下的示例问题中找到与用户问题相关的示例问题：
    {str([example for example in examples.keys()])}

    请输出与用户问题相关的示例问题，多个问题之间用逗号隔开，例如：key1,key2,key3
    '''
    messages_rag = [{"role": "user", "content": prompt}]
    response = async_llm_chain_call(messages_rag, MODEL_rag, sampling_count=1)
    last_response = response.choices[0].message.content

    result = []
    for key in examples.keys():
        if key in last_response:
            result.append(examples[key])
    return "\n".join(result)+"\n"+ "查询和时间相关的信息，如果没有返回数据，要考虑其他的时间相关列，比如InfoPublDate、EndDate、TradingDay等。"

def to_get_question_columns(question, database_L_zh):
    """
    Given a question (string) and a database_L_zh (list of dicts),
    find 列名 that correspond to 列名中文描述 mentioned in the question. 
    
    If any matching columns are found, return a message instructing the user to 
    use these column names directly for data querying. If none are found, return an empty string.
    
    Parameters:
        question (str): The input question text.
        
    Returns:
        str: A message with identified column names or an empty string if none found.
    """
    L_num = []
    for items in database_L_zh:
        L_num += items['列名中文描述']

    # Get unique column descriptions
    L_num_new = [item for item, count in Counter(L_num).items() if count == 1]

    # Drop NaN if any
    series_num = pd.Series(L_num_new)
    L_num_new = list(series_num.dropna())

    # Remove known irrelevant items
    irrelevant_items = ['年度', '占比']
    for irr in irrelevant_items:
        if irr in L_num_new:
            L_num_new.remove(irr)

    matched_columns = []
    for col_descs in L_num_new:
        col_desc_another = items_another[col_descs]
        for col_desc in col_desc_another:
            # Check if the column description or its cleaned version appears in the question
            if col_desc in question or clean_text(col_desc) in question:
                L_dict = find_dict_by_element(database_L_zh, col_descs)
                if not L_dict:
                    break
                # Create a mapping from Chinese description to English column name
                dict_zip = dict(zip(L_dict[0]['列名中文描述'], L_dict[0]['列名']))
                column_name = dict_zip[col_descs]
                data_table = L_dict[0]['数据表名']

                matched_columns.append({
                    '数据库表': data_table,
                    '列名': column_name,
                    '列名中文含义': col_descs
                })
                break

    if matched_columns:
        return f"已获得一部分数据库列名{matched_columns}，请充分利用获得的列名直接查询数据。"
    else:
        return ''
    

def get_foreigns(foreigns, content):
    """
    处理外键关系
    """
    fliter = []
    for f in foreigns:
        [f1, f2] = f.split('=')
        f1_table = ".".join(f1.strip().split('.')[0: 2])
        f2_table = ".".join(f2.strip().split('.')[0: 2])
        if f1_table in content and f2_table in content:
            fliter.append(f)
    return fliter

def table_intro(LL):
    """
    对选择的表进行重点提示
    """
    table_names = [i['数据表名'] for i in LL]
    is_H = 'HKStock' in str(table_names)
    is_A = 'AStock' in str(table_names)
    is_US = 'USStock' in str(table_names)
    if is_H + is_A + is_US > 1:
        return '港股、A股（AStock开头的表）、美股信息不通用，请分别查询，不要把不同地方上市的股票公司进行join。'
    return ''

class Example:
    def __init__(self, question, database_L_zh):
        self.question = question
        self.reference = find_example(question)
        self.question_columns = to_get_question_columns(question, database_L_zh)

    def to_string(self, get_reference = True, get_col_shema = True):
        _string = "根据参考仔细检查上一步生成的sql，如果sql有问题，则返回修改后的sql；如果没有问题则按要求继续进行下一步。\n"
        if get_col_shema:
            _string = self.question_columns
        else:
            _string += ""

        if get_reference:
            _string += ">>查询参考："
            _string += self.reference
        else:
            _string += ""
        return _string
    

### 3.2.2 对话生成sql和答案

In [None]:
def run_conversation(question, question_id, database_L_zh):

    # 1. 选择与问题相关的库表
    # 1.1 对数据库进行划分，区分港股、A股、美股

    lc = {'库表名': ['AStockBasicInfoDB.LC_StockArchives',
'AStockBasicInfoDB.LC_NameChange',
'AStockBasicInfoDB.LC_Business',
'AStockIndustryDB.LC_ExgIndustry',
'AStockIndustryDB.LC_ExgIndChange',
'AStockIndustryDB.LC_IndustryValuation',
'AStockIndustryDB.LC_IndFinIndicators',
'AStockIndustryDB.LC_COConcept',
'AStockIndustryDB.LC_ConceptList',
'AStockOperationsDB.LC_SuppCustDetail',
'AStockShareholderDB.LC_SHTypeClassifi',
'AStockShareholderDB.LC_MainSHListNew',
'AStockShareholderDB.LC_SHNumber',
'AStockShareholderDB.LC_Mshareholder',
'AStockShareholderDB.LC_ActualController',
'AStockShareholderDB.LC_ShareStru',
'AStockShareholderDB.LC_StockHoldingSt',
'AStockShareholderDB.LC_ShareTransfer',
'AStockShareholderDB.LC_ShareFP',
'AStockShareholderDB.LC_ShareFPSta',
'AStockShareholderDB.LC_Buyback',
'AStockShareholderDB.LC_BuybackAttach',
'AStockShareholderDB.LC_LegalDistribution',
'AStockShareholderDB.LC_NationalStockHoldSt',
'AStockShareholderDB.CS_ForeignHoldingSt',
'AStockFinanceDB.LC_AShareSeasonedNewIssue',
'AStockFinanceDB.LC_ASharePlacement',
'AStockFinanceDB.LC_Dividend',
'AStockFinanceDB.LC_CapitalInvest',
'AStockMarketQuotesDB.CS_StockCapFlowIndex',
'AStockMarketQuotesDB.CS_TurnoverVolTecIndex',
'AStockMarketQuotesDB.CS_StockPatterns',
'AStockMarketQuotesDB.QT_DailyQuote',
'AStockMarketQuotesDB.QT_StockPerformance',
'AStockMarketQuotesDB.LC_SuspendResumption',
'AStockFinanceDB.LC_BalanceSheetAll',
'AStockFinanceDB.LC_IncomeStatementAll',
'AStockFinanceDB.LC_CashFlowStatementAll',
'AStockFinanceDB.LC_IntAssetsDetail',
'AStockFinanceDB.LC_MainOperIncome',
'AStockFinanceDB.LC_OperatingStatus',
'AStockFinanceDB.LC_AuditOpinion',
'AStockOperationsDB.LC_Staff',
'AStockOperationsDB.LC_RewardStat',
'AStockEventsDB.LC_Warrant',
'AStockEventsDB.LC_Credit',
'AStockEventsDB.LC_SuitArbitration',
'AStockEventsDB.LC_EntrustInv',
'AStockEventsDB.LC_Regroup',
'AStockEventsDB.LC_MajorContract',
'AStockEventsDB.LC_InvestorRa',
'AStockEventsDB.LC_InvestorDetail',
'AStockShareholderDB.LC_ESOP',
'AStockShareholderDB.LC_ESOPSummary',
'AStockShareholderDB.LC_TransferPlan',
'AStockShareholderDB.LC_SMAttendInfo',
'ConstantDB.SecuMain',
],
'对应中文注释说明': ['公司概况',
'公司名称更改状况',
'公司经营范围与行业变更',
'公司行业划分表',
'公司行业变更表',
'行业估值指标',
'行业财务指标表：各行业的成长能力、偿债能力、盈利能力和现金获取能力等',
'概念所属公司表：记录A股上市公司所属概念信息',
'概念板块常量表：记录A股热点概念板块信息',
'公司供应商与客户',
'股东类型分类表',
'股东名单(新)',
'股东户数',
'大股东介绍',
'公司实际控制人',
'公司股本结构变动',
'股东持股统计',
'股东股权变动',
'股东股权冻结和质押',
'股东股权冻结和质押统计',
'股份回购',
'股份回购关联表',
'法人配售与战略投资者',
'A股国家队持股统计',
'外资持股统计',
'A股增发',
'A股配股',
'公司分红',
'资金投向说明',
'境内股票交易资金流向指标',
'境内股票成交量技术指标',
'股票技术形态表',
'日行情表',
'股票行情表现(新)',
'停牌复牌表',
'资产负债表_新会计准则',
'利润分配表_新会计准则',
'现金流量表_新会计准则',
'公司研发投入与产出',
'公司主营业务构成',
'公司经营情况述评',
'公司历年审计意见',
'公司职工构成',
'公司管理层报酬统计',
'公司担保明细',
'公司借贷明细',
'公司诉讼仲裁明细',
'重大事项委托理财',
'公司资产重组明细',
'公司重大经营合同明细',
'投资者关系活动',
'投资者关系活动调研明细',
'员工持股计划',
'员工持股计划概况',
'股东增减持计划表',
'股东大会出席信息',
'证券主表,包含字段InnerCode,CompanyCode,SecuCode,ChiName,ChiNameAbbr 代表中文名称缩写,EngName,EngNameAbbr,SecuAbbr 代表 证券简称,ListedDate',
]}

    other = {'库表名': [
'PublicFundDB.MF_FundArchives',
'PublicFundDB.MF_FundProdName',
'PublicFundDB.MF_InvestAdvisorOutline',
'PublicFundDB.MF_Dividend',
'CreditDB.LC_ViolatiParty',
'IndexDB.LC_IndexBasicInfo',
'IndexDB.LC_IndexComponent',
'InstitutionDB.LC_InstiArchive',
'ConstantDB.CT_SystemConst',
'ConstantDB.QT_TradingDayNew',
'ConstantDB.LC_AreaCode',
'InstitutionDB.PS_EventStru',
'InstitutionDB.PS_NewsSecurity'],
'对应中文注释说明': [
'公募基金概况',
'公募基金产品名称',
'公募基金管理人概况',
'公募基金分红',
'违规当事人处罚',
'指数基本情况',
'指数成份',
'机构基本资料',
'系统常量表',
'交易日表(新)',
'国家城市代码表',
'事件体系指引表',
'证券舆情表']}

    us = {'库表名': [
'USStockDB.US_CompanyInfo',
'USStockDB.US_DailyQuote',
'ConstantDB.US_SecuMain',
],
'对应中文注释说明': [
'美股公司概况',
'美股日行情',
'美股证券主表',
]}

    hk = {'库表名': [
'HKStockDB.HK_EmployeeChange',
'HKStockDB.HK_StockArchives',
'HKStockDB.CS_HKStockPerformance',
'ConstantDB.HK_SecuMain',
],
'对应中文注释说明': [
'港股公司员工数量变动表',
'港股公司概况',
'港股行情表现',
'港股证券主表，包含字段InnerCode,CompanyCode,SecuCode,ChiName,ChiNameAbbr 代表中文名称缩写,EngName,EngNameAbbr,SecuAbbr 代表 证券简称,ListedDate',
]}

    content_p_1 = """我有如下数据库表<<table_schema>>
我想回答问题
"<<question>>"

请从上面数据库表中筛选出要回答问题，需要哪些数据库表，记得提示我：<需要查询的数据库表>,格式如下：
**逐步分析选择什么表**：为什么选择这些表
**选择的数据库表**：所选择的表。
不要输出其他内容。"""

    # 1.2 对问题进行预处理，提取问题中的关键词
    res, tables, concept_list = process_question(question, MODEL_rag)

    # 1.3 根据问题中的关键词，从数据库表中选择相关的表
    table_schema ={'库表名': [], '对应中文注释说明': []}
    if tables == []:
        StockMarket = checkStockMarket(question, MODEL_rag)
    else:
        StockMarket = tables
    if 'ConstantDB.HK_SecuMain' in StockMarket:
        table_schema['库表名'].extend(hk['库表名'])
        table_schema['对应中文注释说明'].extend(hk['对应中文注释说明'])
    if 'ConstantDB.US_SecuMain' in StockMarket:
        table_schema['库表名'].extend(us['库表名'])
        table_schema['对应中文注释说明'].extend(us['对应中文注释说明'])
    if 'ConstantDB.SecuMain' in StockMarket:
        table_schema['库表名'].extend(lc['库表名'])
        table_schema['对应中文注释说明'].extend(lc['对应中文注释说明'])
    
    table_schema['库表名'].extend(other['库表名'])
    table_schema['对应中文注释说明'].extend(other['对应中文注释说明'])

    # 1.4 通过LLM进一步选择和问题相关的表
    content_p = content_p_1.replace('<<question>>', str(question))\
        .replace('<<fact_1>>', str((res, tables)))\
        .replace('<<table_schema>>', str(table_schema))
    ref = Example(question=question, database_L_zh = database_L_zh)
    content_p = content_p + ref.to_string()
    if concept_list:
        content_p = content_p + "\n行业提示：" + str(concept_list)

    messages_rag = []
    messages_rag.append({"role": "user", "content": "您好阿"})
    messages_rag.append({"role": "user", "content": content_p})
    response = async_llm_chain_call(messages_rag, MODEL_sql, sampling_count=1)
    table_maps = get_table_schema(question = question, database_table_en = database_table_en, table_shema = input_text)
    LL1 = [i for i in table_maps if i.get('数据表名') in response.choices[0].message.content + ref.to_string(get_col_shema=False) + str(concept_list) + str(StockMarket)]
    if 'ConstantDB.SecuMain' in StockMarket:
        LL = LL1
    else:
        LL = [i for i in LL1 if i.get('数据表名') in str(table_schema)]
    foreigns_choice = get_foreigns(foreigns, response.choices[0].message.content + str(StockMarket))
    
    # 2. 生成sql和答案
    content_p_2 = """
请写sql帮我查询问题。
问题：<<question>>
已查询获得的事实：<<fact_1>>
表结构：<<list>>
表之间的关联信息如下：<<foreigns>>
表结构中列名可以引用使用,表结构中数据示例只是参考不能引用。
我们现在开始查询当前问题，请你分步写出查询sql语句，我把查询结果告诉你，你再告诉我下一步，
注意如果我返回的结果为空或者错误影响下一步调用，请重新告诉我sql语句。
写sql时，请告诉我<这是第几步，这步做了什么事情>
等你全部回答完成，不需要进行下一步调用时，记得提示我：<全部完成，答案如下>,将答案总结以json格式给我，只需要总结当前问题。
查询技巧:sql查询年度时优先使用year()函数。sql查询语句不需要注释，不然会报错。sql中日期条件格式应参考这样date(TradingDay) = 'YYYY-MM-DD'。尽量利用表格中已有的字段。
"""
    content_p_2 = content_p_2.replace('<question>', question)\
        .replace('<list>', str(LL) + table_intro(LL))\
        .replace('<foreigns>', str(foreigns_choice))\
        .replace('<fact_1>', str((res, tables)))
    if concept_list:
        content_p_2 = content_p_2 + "\n行业提示：" + str(concept_list)

    messages_sql = []
    messages_sql.append({"role": "system", "content": content_p_2})  
    messages_sql.append({"role": "user", "content": f"下面开始解决问题：{question}"})
    messages_sql.append({"role": "user", "content": ref.to_string()})  
    ###开始对话  
    last_answer = run_conversation_until_complete(messages_sql, max_rounds=9, dialog_id = question_id, 
                                                o_q = question, 
                                                database_L_zh = database_L_zh,
                                                LL = LL +  [i for i in table_maps if i.get('数据表名') in str(tables)])
        
    return str(last_answer)

def find_error_col(text, LL):
    """
     处理sql中的Unknown column错误
    """
    # 查找 "Unknown column '" 和 "' in"之间的字符串
    match = re.search(r"Unknown column '([^']*)' in", text)
    col_name = match.group(1)
    if "." in col_name:
        col_name = col_name.split(".")[-1]
    right_table_name = []
    for table in LL:
        table_shema = table['数据表结构']
        for column in table_shema:
            if column['列名'] == col_name:
                right_table_name.append(table['数据表名'])
                break
    if len(right_table_name) == 0:
        return ""
    else:
        return f"{col_name} 在 {str(right_table_name)} 表中"

def run_conversation_until_complete(messages, dialog_id, o_q, LL, database_L_zh, max_rounds=6):
    """
    Test function to run a conversation loop until the assistant indicates completion.
    """
    def summarize(messages):
        """
        总结对话内容，生成最终答案
        """
        messages_a = [{"role": "user", "content": f"根据对话：{str(messages)}，回答问题<{o_q}>，如果答案涉及百分比，加上'%'。只输出答案，不要回答其他内容"}]
        response = async_llm_chain_call(messages_a, MODEL_sql, sampling_count=1)
        last_response = response.choices[0].message.content
        return last_response
    
    last_response = None  # 用于存储最后一次对话的响应
    round_count = 0  # 对话轮数计数器
    # 1. 生成第一次回复
    response = async_llm_chain_call(messages, MODEL_sql, sampling_count=1)
    pre_sql = ""
    while True:
        # 2. 通过不断对话修改sql，直到LLM回复"完成"
        del messages[-1]
        question = response.choices[0].message.content
        # 2.1 根据LLM的回复从关键例子中筛选出与问题相关的例子
        ref = Example(question=question, database_L_zh = database_L_zh)
        # 2.2 筛选出LLM回复中的sql，并执行sql，得到结果
        select_result, intro_str, sql = to_select(question)
        if round_count >= max_rounds:
            # 如果对话轮数超过最大值，则总结对话内容，生成最终答案
            messages.append({"role": "assistant", "content": question})
            messages.append({"role": "user", "content": str(select_result)})
            last_response = summarize(messages)  # 存储最后一次响应  
            break  # 如果对话轮数超过最大值，则退出循环
        # 2.3 处理LLM回复的sql，如果有错误，加入错误信息
        messages.append({"role": "assistant", "content": question})
        if "Unknown column" in select_result:
            col_position = find_error_col(select_result, LL)
            messages.append({"role": "user", "content": col_position})
        if pre_sql != sql:
            pre_sql = sql
            messages.append({"role": "user", "content": str(select_result) + intro_str})
        else:
            # 如果检测到回答相同，则停止循环
            last_response = summarize(messages)  # 存储最后一次响应  
            break  
        
        # 2.4 在对话中加入筛选出的与问题相关的例子
        messages.append({"role": "user", "content": ref.to_string()})
        # 2.5 根据追加的信息，调用LLM生成新的回复
        response = async_llm_chain_call(messages, MODEL_sql, sampling_count=1)

        last_response = response.choices[0].message.content  # 存储最后一次响应       
        if "全部完成" in response.choices[0].message.content:
            del messages[-1]
            messages.append({"role": "assistant", "content": last_response})
            last_response = summarize(messages)
            break
        round_count += 1
    os.makedirs("dialog",exist_ok=True)
    with open(f"dialog/{dialog_id}.json", "w", encoding="utf-8") as f:
        json.dump(messages, f, ensure_ascii=False, indent=4)
    parsed_data = find_json(last_response)
    final_string = process_dict(parsed_data)
    return final_string  # 返回最后一次对话的内容



### 3.2.3 定义获取问题答案的函数

In [None]:
import os
from tqdm import tqdm
import time
from collections import Counter
import requests
import pickle

def get_answer(question, question_id, database_L_zh):
    """
    Attempt to answer the given question by interacting with the 
    conversation model. If an error occurs, return a default error message.
    
    Parameters:
        question (str): The question that needs an answer.
        
    Returns:
        str: The answer string or an error message if an exception occurs.
    """
    try:
        print(f"Attempting to answer the question: {question}")
        # 写入log文件
        with open("log.txt", "a", encoding="utf-8") as f:
            f.write(f"Attempting to answer the question: {question}\n")
        last_answer = run_conversation(question, question_id, database_L_zh)
        return last_answer
    except Exception as e:
        print(f"Error occurred while executing get_answer: {e}")
        # 写入log文件
        with open("log.txt", "a", encoding="utf-8") as f:
            f.write(f"Error occurred while executing get_answer: {e}\n")
        return "An error occurred while retrieving the answer."
    
def question_rew(context_text, original_question):
    """
    Rewrite the given question to be clearer and more specific based on the provided context,
    without altering the original meaning or omitting any information.
    
    Parameters:
        context_text (str): The context text that the question is based on.
        original_question (str): The question to be rewritten.
        
    Returns:
        str: The rewritten question.
    """
    prompt = (
        "请将多回合对话中的用户问题重新表述，使重写的问题可以充分表达用户的信息需求，而不需要上下文。"
        "要求不改变原意，不要遗漏信息，要包含内容中与问题有关的重要信息（比如代号、公司、股票、职位等），特别是时间。\n"
        "我将给你举个多回合对话的例子，每个回合都包含一个问题和重写。重写部分使用json格式，其中解释了为什么要这样重写：\n"
        "Example : 之前的对话: '问题：最新更新的2021年度报告中，机构持有无限售流通A股数量合计最多的公司简称是？  回答：公司简称 帝尔激光',"
        "帮我重写这个问题'在这份报告中，该公司机构持有无限售流通A股比例合计是多少，保留2位小数？' \n"
        """assistant:```json
        {"原因": "根据之前的问答，这份报告是指 最新更新的2021年度报告，该公司机构是指 帝尔激光 。所以在用户问题中替换这2个意义不明确的词语，并保留用户的信息需求，即最新更新的2021年度报告中,公司简称 帝尔激光 持有无限售流通A股比例合计是多少，保留2位小数？",
         "重写": "最新更新的2021年度报告中,公司简称 帝尔激光 持有无限售流通A股比例合计是多少，保留2位小数？"}
         ```\n"""
         "现在，开始你的任务:\n"
        f"之前的对话: '{context_text}',帮我重写这个问题'{original_question}' \n assistant:"
    )

    messages = [{"role": "user", "content": prompt}]
    response = async_llm_chain_call(messages, MODEL_sql, sampling_count=1)
    answer = response.choices[0].message.content
    try:
        rewrite = find_json(answer)["重写"]
        return rewrite
    except:
        prompt = f"下面的内容是之前的问答'{context_text}'，之前的问答只能作为参考，不一定正确。现在要回答：'{original_question}' "
        return prompt
    # prompt = f"下面的内容是之前的问答'{context_text}'，之前的问答只能作为参考，不一定正确。现在要回答：'{original_question}' "
    # return prompt


def main_answer(q_json_list, start_index=0, end_index=None):
    """
    Process a portion of a list of JSON objects, each containing a 'tid' and 'team' 
    where 'team' is a list of questions.
    
    For each JSON object in the specified range:
      1. Extract all questions from 'team'.
      2. If no previous Q&A history, use the question directly. Otherwise, 
         rewrite the question based on all previously answered content.
      3. Get the answer using get_answer and store it.
      4. Update the original structure with the answers.
    
    Parameters:
        q_json_list (list): List of data objects, each containing keys 'tid' and 'team'.
        start_index (int): The starting index of the list subset to process.
        end_index (int): The ending index (non-inclusive) of the list subset. 
                         If None, process until the end of q_json_list.
                         
    Returns:
        list: A list of processed dictionaries with updated answers.
    """
    if end_index is None or end_index > len(q_json_list):
        end_index = len(q_json_list)
    os.makedirs("result", exist_ok=True)
    # 获取目录下的文件名
    file_list = os.listdir("result")
    data_list_result = []
    for i in tqdm(range(start_index, end_index), desc="Processing JSON data in range"):
        item = q_json_list[i]
        # if item['tid'] != "tttt----4":
        #     continue
        if item['tid']+".pkl" in file_list:
            tid = item['tid']
            with open(f"result/{tid}.pkl", "rb") as file:
                updated_data = pickle.load(file)
            data_list_result.append(updated_data)
            continue
        start_time = time.time()

        # Extract questions
        questions_list = [(member["id"], member["question"]) for member in item["team"]]
        answers_dict = {}
        all_previous = ''

        # Iterate over all questions in the current item
        for question_id, question_text in questions_list:
            if all_previous == '':
                rewritten_question = question_text
            else:
                rewritten_question = question_rew(all_previous, question_text)

            answer = get_answer(rewritten_question, question_id, database_L_zh)
            print(f'----------answer:{answer}')
            # 写入log文件
            with open('log.txt', 'a', encoding='utf-8') as f:
                f.write(f'----------answer:{answer}\n')
            answers_dict[question_id] = answer
            all_previous += "问题：" + question_text + "回答：" + answer + "\n"

        # Update original item with answers
        for member in item["team"]:
            member["answer"] = answers_dict.get(member["id"], "无答案")

        updated_data = {"tid": item["tid"], "team": item["team"]}
        tid = item["tid"]
        with open(f"result/{tid}.pkl", "wb") as file:
            pickle.dump(updated_data, file)

        data_list_result.append(updated_data)

        elapsed_time = time.time() - start_time
        print(f"Completed processing JSON index {i} in {elapsed_time:.2f} seconds")
        # 写入log文件
        with open('log.txt', 'a', encoding='utf-8') as f:
            f.write(f"Completed processing JSON index {i} in {elapsed_time:.2f} seconds\n")
        
    return data_list_result



In [None]:
# Load input data
with open(question_data_path, 'r', encoding='utf-8') as file:
    q_json_list = json.load(file)

# Users can specify a range to process the corresponding subset of data
# For example, from index 0 to 100 (excluding 100), processing the first 100 JSON entries
start_idx = 0
end_idx = 101  # Specify processing data in the range 0-101

results = main_answer(q_json_list, start_index=start_idx, end_index=end_idx)

# Write the processing results to a file
with open('result.json', 'w', encoding='utf-8') as json_file:
    json.dump(results, json_file, ensure_ascii=False, indent=4)