### 1. 导入模型

In [12]:
import importlib

# 选择模型 （llama3_8b, llama3_70b, qwen2_7b, qwen2_72b, gpt_4o_mini）
model = "llama3_8b"

module = importlib.import_module('core.models')
LLM = getattr(module, model)

### 2. 构建 输入 Prompt

In [2]:
from core.utils import *

# 选择数据库 （spider, bird）
dataset = 'spider'

prompt_list = prompt_construction(dataset)

print(len(prompt_list))
print(prompt_list[-1])

8034
A database 'real_estate_properties' has 5 tables named 'Ref_Feature_Types, Ref_Property_Types, Other_Available_Features, Properties, Other_Property_Features'.
Ref_Feature_Types table has columns: 'feature_type_code, feature_type_name'.
Ref_Property_Types table has columns: 'property_type_code, property_type_description'.
Other_Available_Features table has columns: 'feature_id, feature_type_code, feature_name, feature_description'.
Properties table has columns: 'property_id, property_type_code, date_on_market, date_sold, property_name, property_address, room_count, vendor_requested_price, buyer_offered_price, agreed_selling_price, apt_feature_1, apt_feature_2, apt_feature_3, fld_feature_1, fld_feature_2, fld_feature_3, hse_feature_1, hse_feature_2, hse_feature_3, oth_feature_1, oth_feature_2, oth_feature_3, shp_feature_1, shp_feature_2, shp_feature_3, other_property_details'.
Other_Property_Features table has columns: 'property_id, feature_id, property_feature_description'.

Gave m

### 3. 生成 SQL，存储

In [8]:
import time

save_path = f'./results/sql_gen_ori/{model}_{dataset}_ori.txt'

with open(save_path, 'a')as f:
    for index, prompt in enumerate(prompt_list):
        if index >= 0 : # 如果数据太多，可以设置 index， 分段生成，不必一次性生成
            print("-------------------------" + str(index) + ' sample_num')
            output = LLM(prompt)
            print(output)

            f.write(f'-------------------------{str(index)} sample_num\n')
            f.write(output + '\n')

            # time.sleep(0.5)           # 如果报错：访问频率过快，则强制休眠，模型较小时使用

-------------------------0 sample_num
```sql
SELECT COUNT(h.head_ID)
FROM head h
JOIN management m ON h.head_ID = m.head_ID
WHERE h.age > 56;
```
-------------------------1 sample_num
```sql
SELECT h.name, h.born_state, h.age
FROM head h
JOIN management m ON h.head_ID = m.head_ID
JOIN department d ON m.department_ID = d.Department_ID
ORDER BY h.age;
```
-------------------------2 sample_num
```sql
SELECT EXTRACT(YEAR FROM Creation) AS Creation_Year, Name, Budget_in_Billions
FROM department;
```
-------------------------3 sample_num
```sql
SELECT MAX(Budget_in_Billions) AS Max_Budget, MIN(Budget_in_Billions) AS Min_Budget
FROM department;
```
-------------------------4 sample_num
```sql
SELECT AVG(Num_Employees)
FROM department
WHERE Ranking BETWEEN 10 AND 15;
```
-------------------------5 sample_num
```sql
SELECT h.name
FROM head h
WHERE h.born_state <> 'California';
```
-------------------------6 sample_num
```sql
SELECT DISTINCT YEAR(d.Creation)
FROM department d
JOIN management m O

### 4. 输出清洗

In [3]:
from core.utils import *

# save_path = f'./results/sql_gen_ori/{model}_{dataset}_ori.txt'
sql_list = sql_match(save_path) 

# 存入文件
with open(f'./results/sql_gen_clean/{model}_{dataset}_clean.txt', 'w') as f:
    for sql in sql_list:
        f.write(sql + '\n')

print(len(sql_list))
sql_list[-1]

101


'SELECT student_id FROM Student_Course_Registrations UNION SELECT student_id FROM Student_Course_Attendance;'

### 5. 结果评估

In [8]:
from core.evaluate import *

pre_sql_path = './results/sql_gen_clean/llama3_70b_spider_clean.txt'
gold_sql_path = './dataset/data_instance/spider_all_gold.txt'
db_path = './dataset/spider/database'

res_list = ex_evaluation(pre_sql_path, gold_sql_path, db_path)

print_context = f"""pre_num: {len(res_list)};
ex: {res_list.count(1)/len(res_list):.4f}"""

print(print_context)
res_list

pre_num: 101;
ex: 0.7129


[0,
 0,
 -1,
 1,
 1,
 1,
 -1,
 1,
 -1,
 1,
 0,
 1,
 0,
 -1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 -1,
 1,
 -1,
 -1,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

### 6. 存入结果

In [11]:
import json

sql_incor_path = f"./results/sql_incor/{model}_{dataset}_incor.json"

with open(pre_sql_path, 'r') as f:
    pre_sql_list = [s.strip() for s in f.readlines()]

sql_incor_list = []
for index, _ in enumerate(res_list):
    if res_list[index] == 0:  # 只考虑执行结果不正确的，排除无法执行以及超时错误
        sql_incor_list.append(pre_sql_list[index])

data = {
    "res_list": res_list,
    "ex": res_list.count(1)/len(res_list),
    "sql_incor": sql_incor_list,
    "sql_gen_clean": pre_sql_list
}

with open(sql_incor_path, 'w')as f:
    json.dump(data, f, indent=4)