In [5]:
#数据加载
import pandas as pd
import json
import csv
import os

def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def load_csv(file_path):
    with open(file_path, 'r') as f:
        reader = csv.reader(f)
        return list(reader)

def process_keywords(keywords_csv):
    keywords = []
    for word in keywords_csv:
        keywords.append(word[0])
        lower = word[0].lower()
        keywords.append(lower)
    return keywords

def process_content_data(content_data):
    content_data_dict = {}
    for data in content_data:
        content_data_dict[data['db_id']] = data
    return content_data_dict

if __name__ == "__main__":
    DATASET='DuSQL'
    # DATASET='NL2SQL'
    # DATASET='CSpider'
    datas_dev = load_json(os.path.join(DATASET, 'dev.json'))
    datas_train = load_json(os.path.join(DATASET, 'train.json'))
    datas = datas_train

    print(len(datas))

    keywords_csv = load_csv(os.path.join(DATASET, 'keywords.csv'))
    keywords = process_keywords(keywords_csv)

    content_data = load_json(os.path.join(DATASET, 'db_content.json'))
    content_data_dict = process_content_data(content_data)


22521


In [6]:
# 数据qa pair生成，DATASET='DuSQL，nl2sql,Cspider'
import unicodedata
import re
import string
import json
import random

instruct_pairs = []

def is_punctuation(char):
    return char in string.punctuation

def is_contain_chinese(str):
    return re.search('[\u4e00-\u9fa5]', str) is not None

def is_alphanumeric(string):
    return any(char.isalpha() for char in string) and any(char.isdigit() for char in string)

def query2keys(query, keywords):
    keys_list = []
    words = query.split(' ')
    for word in words:
        if word in keywords or word == "" or is_punctuation(word[0]):
            continue
        if not is_contain_chinese(word) and is_alphanumeric(word):
            keys_list.append(word)
        else:
            keys_list.append(word)
    return keys_list

def clean_keys_list(keys_list):
    new_keys_list = []
    for key in keys_list:
        if key[0] == 'T' and key[1].isdigit() and len(key) > 2:
            item = key.split('.')[-1]
            new_keys_list.append(item)
    return new_keys_list

def generate_prompt(question, table_list, pro):
    prompt = f"""你是一个自然语言到SQL转换专家，你的任务是将金融领域问题，转换成对应的SQL查询：生成结果只含SQL语句。\n问题: {question}\n"""
    if len(table_list) > 0:
        if pro > 0.1:
            prompt += "查询需要用到的数据库以及对应的字段如下："
            for idx, table in enumerate(table_list.keys()):
                prompt += f"""表{idx+1}: {table}, 可用字段: {table_list[table]}"""
        else:
            prompt += "查询需要用到的数据库表格名称如下："
            for idx, table in enumerate(table_list.keys()):
                prompt += f"""表{idx+1}: {table},"""
    prompt += "SQL查询:"
    return prompt

def main():
    count = 0
    with open(os.path.join(DATASET, 'sql_chtglm_train.json'), 'w') as f:
        for data in datas[:]:
            question = data['question']
            query = data['query']
            keys_list = query2keys(query, keywords)
            keys_list = clean_keys_list(keys_list)
            keys_list = list(set(keys_list))
            db_id = data['db_id']
            db = content_data_dict[db_id]

            table_names = [table for table in db['tables'].keys() if table in keys_list]
            table_list = {table_name: [key for key in keys_list if key in db['tables'][table_name]['header']] for table_name in table_names}

            pro = random.random()
            if pro <= 0.1:
                count += 1

            prompt = generate_prompt(question, table_list, pro)
            answer = query

            json.dump({"query": prompt, "answer": answer}, f, ensure_ascii=False)
            f.write('\n')
    print(count)

if __name__ == "__main__":
    main()


2329


In [8]:
#合并各个文件夹中的数据
import os
import json
datas_train=[]
datas_dev=[]
for name in ['CSpider','DuSQL','NL2SQL']:
    with open(os.path.join(name, 'sql_chtglm_train.json'), 'r') as f:
        datas_train += f.readlines()
    with open(os.path.join(name, 'sql_chtglm_dev.json'), 'r') as f:
        datas_dev += f.readlines()
with open('sql_chtglm_train_final.json', 'w') as f:
    f.writelines(datas_train)
with open('sql_chtglm_dev_final.json', 'w') as f:
    f.writelines(datas_dev)
        
