# 1 导包

In [1]:
import os
from typing import List, Dict, Optional

from dbgpt_hub.data_process.data_utils import extract_sql_prompt_dataset
from dbgpt_hub.llm_base.chat_model import ChatModel
from dbgpt_hub.dataset_util.excel_processor import ExcelProcessor



# 2 准备预测table数据

In [11]:
TABLE_PROMPT = '''I want you to act as a SQL terminal in front of an example database, \
         you need only to return the table name to me.Below is an instruction that describes a task, \
         Write a response that appropriately completes the request.\n"
        "##Instruction:\n{}\n###Input:\n{}\n\n###Response:'''

In [12]:
def prepare_table_data(input: Optional[str]) -> str:
    db_id = "tp_mis"
    all_tables = ExcelProcessor().read_excel('/root/Neuvix/DB-GPT-HUB/dbgpt_hub/data/tp_mis/all_table.xlsx')
    tables = ', '.join(str(item[1]) for item in all_tables)
    tables_description = ', '.join(str(item[2]) for item in all_tables)
    instruction = '''I want you to act as a SQL terminal in front of an example database, you need only to return the table name to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. {} contains tables such as {}. The descriptions of these tables are {}'''.format(db_id, tables,tables_description)
    return TABLE_PROMPT.format(instruction,input)

# 3 加载模型

In [4]:
args = {
            "model_name_or_path": "/root/Neuvix/hugging-face/models--defog--llama-3-sqlcoder-8b",
            "template": "llama2",
            "finetuning_type": "lora",
            "checkpoint_dir": "/root/Neuvix/DB-GPT-HUB/dbgpt_hub/output/adapter/llama-3-sqlcoder-lora",
            # "predict_file_path": "/data1/zhengnanyan/NL2SQL/DB-GPT-Hub/dbgpt_hub/data/example_text2sql_dev.json",
            # "predict_out_dir": "/data1/zhengnanyan/NL2SQL/DB-GPT-Hub/dbgpt_hub/output/",
            # "predicted_out_filename": "pred_tp_mis.sql",
        }

model = ChatModel(args)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

06/21/2024 09:57:51 - INFO - dbgpt_hub.llm_base.adapter - Fine-tuning method: LoRA
06/21/2024 09:58:10 - INFO - dbgpt_hub.llm_base.adapter - Merged 1 model checkpoint(s).
06/21/2024 09:58:10 - INFO - dbgpt_hub.llm_base.adapter - Loaded fine-tuned model from checkpoint(s): /root/Neuvix/DB-GPT-HUB/dbgpt_hub/output/adapter/llama-3-sqlcoder-lora
06/21/2024 09:58:28 - INFO - dbgpt_hub.llm_base.load_tokenizer - trainable params: 0 || all params: 8030261248 || trainable%: 0.0000
06/21/2024 09:58:43 - INFO - dbgpt_hub.data_process.data_utils - Add pad token: <|eot_id|>


# 4 推理table

In [17]:
query = prepare_table_data('有多少司机')
response, _ = model.chat(query=query)
response

'change_driver_basic_info'

In [21]:
pred_tables = response.split(",")
pred_tables

['change_driver_basic_info']

# 5 获取table的schema

In [28]:
from dbgpt_hub.dataset_util.excel_processor import ExcelProcessor
excel_processor = ExcelProcessor()

DATA_PATH = "dbgpt_hub/data"
DB_ID = "tp_mis"
SCHEMA = "schema"

def generate_sql_table_meta_data() -> Optional[Dict]:
        # all_tables = excel_processor.read_excel(os.path.join(DATA_PATH, DB_ID, "all_table.xlsx"))
        all_tables = excel_processor.read_excel("/root/Neuvix/DB-GPT-HUB/dbgpt_hub/data/tp_mis/all_table.xlsx")
        tables = [str(item[1]) for item in all_tables]
        all_table_dict = {}
        for item in tables:  # 一个item是一个表
            table = item
            # schema_file = os.path.join(DATA_PATH, DB_ID, SCHEMA, table + ".xlsx")
            schema_file = os.path.join("/root/Neuvix/DB-GPT-HUB/dbgpt_hub/data/tp_mis/schema", table + ".xlsx")
            meta_data = excel_processor.read_excel(schema_file)
            columns = []
            column_comments = []
            primary_key = None
            for md in meta_data:
                columns.append(md[1])
                column_comments.append(md[3])
                if md[4]==1:
                    primary_key=md[1]
            new_dict = {
                 "columns":columns,
                 "column_comments":column_comments,
                 "primary_key":primary_key
            }
            all_table_dict[table] = new_dict
        return all_table_dict

# 6 推理SQL

In [29]:
SQL_PROMPT = '''I want you to act as a SQL terminal in front of an example database, \
         you need only to return the sql command to me.Below is an instruction that describes a task, \
         Write a response that appropriately completes the request.\n"
        "##Instruction:\n{}\n###Input:\n{}\n\n###Response:'''

In [30]:
def prepare_sql_data(input: Optional[str], pred_tables:Optional[List]) -> str:
    all_table_dict = generate_sql_table_meta_data()
    instruction_sql = (DB_ID + " contains tables such as " + ", ".join(pred_tables) + ". ")
    for name in pred_tables:
        columns = all_table_dict[name]["columns"]
        instruction_sql += ("Table " + name + " has columns such as " + ", ".join(columns) + ". ")
        comments = all_table_dict[name]["column_comments"]
        instruction_sql += "The comments of columns are " + ", ".join(comments) + ". "
        primary_key = all_table_dict[name]["primary_key"]
        instruction_sql += (primary_key + " is the primary key."+ "\n")
    return SQL_PROMPT.format(instruction_sql,input)

query = prepare_sql_data('有多少司机',pred_tables)
query

'I want you to act as a SQL terminal in front of an example database,          you need only to return the sql command to me.Below is an instruction that describes a task,          Write a response that appropriately completes the request.\n"\n        "##Instruction:\ntp_mis contains tables such as change_driver_basic_info. Table change_driver_basic_info has columns such as id, driver_no, driver_name, class_group. The comments of columns are id, 工号, 司机姓名, 班组. id is the primary key.\n\n###Input:\n有多少司机\n\n###Response:'

In [31]:
response, _ = model.chat(query=query)
response

'SELECT COUNT(driver_name) FROM tp_mis.change_driver_basic_info  ; '

# 00 以下为无用代码

In [1]:


def prepare_dataset(
    predict_file_path: Optional[str] = None,
) -> List[Dict]:
    with open(predict_file_path, "r") as fp:
        data = json.load(fp)
    predict_data = [extract_sql_prompt_dataset(item) for item in data]
    return predict_data


def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs):
    res = []
    # test
    # for item in predict_data[:20]:
    for item in tqdm(predict_data, desc="Inference Progress", unit="item"):
        print(f"item[input] \n{item['input']}")
        response, _ = model.chat(query=item["input"], history=[], **input_kwargs)
        print("response: ",response)
        res.append(response)
    return res

  from .autonotebook import tqdm as notebook_tqdm


[2024-06-20 11:48:01,922] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)




In [3]:
args = {
            "model_name_or_path": "/root/Neuvix/hugging-face/models--defog--llama-3-sqlcoder-8b",
            "template": "llama2",
            "finetuning_type": "lora",
            "checkpoint_dir": "/root/Neuvix/DB-GPT-HUB/dbgpt_hub/output/adapter/llama-3-sqlcoder-lora",
            # "predict_file_path": "/data1/zhengnanyan/NL2SQL/DB-GPT-Hub/dbgpt_hub/data/example_text2sql_dev.json",
            # "predict_out_dir": "/data1/zhengnanyan/NL2SQL/DB-GPT-Hub/dbgpt_hub/output/",
            # "predicted_out_filename": "pred_tp_mis.sql",
        }

model = ChatModel(args)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:17<00:00,  4.40s/it]

06/20/2024 11:50:10 - INFO - dbgpt_hub.llm_base.adapter - Fine-tuning method: LoRA





06/20/2024 11:50:27 - INFO - dbgpt_hub.llm_base.adapter - Merged 1 model checkpoint(s).
06/20/2024 11:50:27 - INFO - dbgpt_hub.llm_base.adapter - Loaded fine-tuned model from checkpoint(s): /root/Neuvix/DB-GPT-HUB/dbgpt_hub/output/adapter/llama-3-sqlcoder-lora
06/20/2024 11:50:43 - INFO - dbgpt_hub.llm_base.load_tokenizer - trainable params: 0 || all params: 8030261248 || trainable%: 0.0000
06/20/2024 11:50:54 - INFO - dbgpt_hub.data_process.data_utils - Add pad token: <|eot_id|>


In [4]:
dev_file_path = "/root/Neuvix/DB-GPT-HUB/dbgpt_hub/data/prompt_dev.json"

In [6]:
predict_data = prepare_dataset(dev_file_path)
predict_data

[{'input': 'I want you to act as a SQL terminal in front of an example database,          you need only to return the sql command to me.Below is an instruction that describes a task,          Write a response that appropriately completes the request.\n##Instruction:\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ntp_mis contains tables such as change_shipping_basic_work_data. Table change_shipping_basic_work_data has columns such as id, vss_code, call_id, ship_owner, line_name, service_in, service_out, ship_length, ship_width, vssc_name, vssl_name, ship_state, berth_code, berth_direction, head_band, caudal_band, plan_berthing_date_time, act_berthind_date_time, plan_start_work_date_time, act_start_work_date_time, plan_finish_work_date_time, act_finish_work_date_time, import_start_date_time, impor

In [7]:
result = inference(model, predict_data)

Inference Progress:   0%|          | 0/2 [00:00<?, ?item/s]

item[input] 
I want you to act as a SQL terminal in front of an example database,          you need only to return the sql command to me.Below is an instruction that describes a task,          Write a response that appropriately completes the request.
##Instruction:
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
tp_mis contains tables such as change_shipping_basic_work_data. Table change_shipping_basic_work_data has columns such as id, vss_code, call_id, ship_owner, line_name, service_in, service_out, ship_length, ship_width, vssc_name, vssl_name, ship_state, berth_code, berth_direction, head_band, caudal_band, plan_berthing_date_time, act_berthind_date_time, plan_start_work_date_time, act_start_work_date_time, plan_finish_work_date_time, act_finish_work_date_time, import_start_date_time, import_fi

Inference Progress:  50%|█████     | 1/2 [00:02<00:02,  2.77s/item]

response:  SELECT vssc_name FROM change_shipping_basic_work_data WHERE ship_state = "计划进港"
item[input] 
I want you to act as a SQL terminal in front of an example database,          you need only to return the sql command to me.Below is an instruction that describes a task,          Write a response that appropriately completes the request.
##Instruction:
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
tp_mis contains tables such as change_shipping_basic_work_data. Table change_shipping_basic_work_data has columns such as id, vss_code, call_id, ship_owner, line_name, service_in, service_out, ship_length, ship_width, vssc_name, vssl_name, ship_state, berth_code, berth_direction, head_band, caudal_band, plan_berthing_date_time, act_berthind_date_time, plan_start_work_date_time, act_start_work_date_tim

Inference Progress: 100%|██████████| 2/2 [00:03<00:00,  1.78s/item]

response:  SELECT vssc_name FROM change_shipping_basic_work_data WHERE ship_state="待进港"





In [9]:
prompt_template='''I want you to act as a SQL terminal in front of an example database,          
you need only to return the sql command to me.Below is an instruction that describes a task,          
Write a response that appropriately completes the request.\n##Instruction:\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ntp_mix contains tables such as change_ship_archives_basic_info. Table change_ship_archives_basic_info has columns such as id, ship_code, chi_and_eng_ship_name, nationality, ship_call_number, ship_imo_number, vessel_class, box_total, total_tonnage_of_ships, net_tonnage_of_ships, ships_deadweight_tonnage, ship_length, bow_position, chimney_position, ship_type_depth, ship_width, ship_height, ship_speed, date_of_construction, ship_type, Liner_identification, group_code, mmsi_code, bei_number, number_of_hatches, number_of_bow_shells, number_of_hatch_covers, empty_draft_tonnage, Heavy_load_draft_tonnage, number_of_deck_starting_layer, beihao_arrangement_order, slot_way, stowage_scheme, belonging_to_the_dock, create_date_time, del_flag, chi_ship_name, eng_ship_name. The comments of columns are 主键ID, 船舶代码, 中英文船名, 国籍, 船舶呼号, 船舶IMO号, 船舶等级？, 总箱量, 船舶总吨位, 船舶净吨位, 船舶载重吨, 船舶长度, 船头位置, 烟囱位置, 船舶型深, 船舶宽度, 船舶高度, 船舶速度, 建造日期, 船舶类型, 班轮标识, 集团代码, MMSI代码, 贝数量, 舱口数, 船头贝数, 舱盖数, 空载吃水吨位, 重载吃水吨位, 甲板起始层序号, 贝号排列顺序, slot_way?, stowage_scheme?, 所属码头, 记录创建时间, 删除标记。0-未删除；1-已删除, 中文船名, 英文船名. id is the primary key.\n\n\n###Input:\n###Input:\n{}\n\n###Response:\n\n###Response:'''
query = prompt_template.format("显示船舶csc18的国籍")
response, _ = model.chat(query=query)
response

'SELECT chi_ship_name,nationality FROM tp_mix.change_ship_archives_basic_info where chi_ship_name="csc18";'