In [1]:
from ollama import chat
from ollama import ChatResponse
import torch
import json
import os

In [2]:
def load_json(filepath):
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON format in {filepath}")
        return None

In [1]:
def save_json(data, filepath):
    try:
        with open(filepath, 'w') as f:
            json.dump(data, f, indent=4)
        print(f"Data saved successfully to {filepath}")
    except Exception as e:
        print(f"Error saving data to {filepath}: {e}")

In [4]:
system_prompt_dict = {}

In [5]:
system_prompt_dict["SQL"] = """
You are an expert SQL analyst. Your task is to dissect a given SQL script and provide a structured JSON response detailing all of its components and overall functionality. You must strictly adhere to the following JSON format:
{
  "output": {
    "database_system": "<DATABASE_SYSTEM_IF_SPECIFIED_OR_GENERAL_SQL>",
    "components": [
      {
        "component_type": "<TYPE_OF_SQL_COMPONENT>",
        "component_name": "<NAME_OF_COMPONENT_IF_APPLICABLE>",
        "component_code": "<THE_ACTUAL_SQL_CODE_OF_THE_COMPONENT>",
        "component_description": "<DETAILED_DESCRIPTION_OF_COMPONENT_FUNCTIONALITY>"
      },
      { /* ... more components ... */ }
    ],
    "overall_description": "<DETAILED_SUMMARY/DESCRIPTION_OF_THE_ENTIRE_SQL_SCRIPT_FUNCTIONALITY>"
  }
}

**Instructions:**
1. Identify the SQL Variant: Determine if the script is specific to MySQL, PostgreSQL, SQL Server, Oracle, etc., or general SQL.
2. Component Types: Classify each component accurately, such as TABLE_DEFINITION, VIEW, STORED_PROCEDURE, FUNCTION, TRIGGER, INDEX, QUERY, etc.
3. Component Names: Provide the correct identifier for each component (table name, procedure name, etc.), or NULL if not applicable.
4. Component Code: Include the complete, unmodified SQL code for each component.
5. Component Descriptions: Provide a technical explanation of what each component does, including its role in data storage, manipulation, or retrieval.
6. Overall Description: Offer a detailed summary of the entire SQL script, its purpose, and how different components interact.
7. Strict JSON Output: Your ENTIRE response must be ONLY the valid JSON object. Do not include any explanations, introductions, or additional text outside the JSON structure.

Analyze the following SQL script properly and return ONLY the JSON response with no additional text:
"""

In [6]:
def string_to_json(input_string):
    try:
        data = json.loads(input_string)
        return data, None
    except json.JSONDecodeError as e:
        return None, str(e)

In [7]:
def make_data(data, system_prompt):
    output_data = {}

    for key, value in data.items():
        model_name = "qwen2.5-coder:32b"

        response: ChatResponse = chat(model=model_name, messages=[
            {
                'role': 'system',
                'content': system_prompt,
            },
            { 
                'role': 'user',
                'content': value,
            }
        ])

        output_data[key] = {
            "input" : value,
            "output" : response['message']['content']
        }

        print("Processed:", key)
            
    return output_data

In [8]:
data = load_json("sql_train.json")

num_keys = 25
if not isinstance(data, dict):
    raise TypeError("Input must be a dictionary.")

if not isinstance(num_keys, int) or num_keys < 0:
    raise ValueError("num_keys must be a non-negative integer.")

sliced_dict = {}
count = 0
for key, value in data.items():
    if count < num_keys:
        sliced_dict[key] = value
        count += 1
    else:
        break

del data
del count

In [9]:
print("Length of sliced dictionary:", len(sliced_dict))

Length of sliced dictionary: 25


In [10]:
output = make_data(sliced_dict, system_prompt_dict["SQL"])

Processed: sql_0
Processed: sql_1
Processed: sql_2
Processed: sql_3
Processed: sql_4
Processed: sql_5
Processed: sql_6
Processed: sql_7
Processed: sql_8
Processed: sql_9
Processed: sql_10
Processed: sql_11
Processed: sql_12
Processed: sql_13
Processed: sql_14
Processed: sql_15
Processed: sql_16
Processed: sql_17
Processed: sql_18
Processed: sql_19
Processed: sql_20
Processed: sql_21
Processed: sql_22
Processed: sql_23
Processed: sql_24


In [11]:
for k, v in output.items():
    print(f"Input: \n{v['input']}\n")
    print(f"Output: \n{v['output']}\n")
    print("-" * 80)

Input: 
<reponame>Dragontalker/MySQL-study-notes
#视图

/*
含义: 虚拟表, 和普通表一样使用
mysql5.1版本出现的新特性, 是通过表动态生成的数据

比如: 舞蹈班和普通班的对比

		创建语法的关键字	是否实际占用物理空间		使用
视图		create table	 只保存了sql逻辑		增删改查, 一般不能增删改
表		create view		   保准了数据ALTER	增删改查
*/

#案例: 查询姓张的学生名和专业名
SELECT stuname, major_name
FROM stuinfo AS s
INNER JOIN major AS m 
ON s.major_id = m.id
WHERE s.stuname LIKE '%张';

#创建视图
CREATE VIEW v1
AS
SELECT stuname, major_name
FROM stuinfo AS s
INNER JOIN major AS m
ON s.major_id = m.id;

#使用视图
SELECT * FROM v1 WHERE s.stuname LIKE '%张';

#一、创建视图
/*
语法:
create view 视图名
as
查询语句;
*/

#1. 查询姓名中包含a字符的员工名、部门名和工种新消息
#(1)创建
CREATE VIEW myv1
AS
SELECT last_name, department_name, job_title
FROM employees AS e
JOIN departments AS d
ON e.department_id = d.department_id
JOIN jobs AS j
ON j.job_id = e.job_id;

#(2)使用
SELECT *
FROM myv1
WHERE last_name LIKE '%a%';

#2. 查询各部门的平均工资级别
#(1)创建视图查看每个部门的平均工资
CREATE VIEW myv2
AS
SELECT AVG(salary) AS ag, department_id
FROM employees
GROUP BY department_id;

#(2)使用
SELECT 

In [12]:
import pickle
# Save the output to a pickle file
with open("data.pickle", "wb") as file:
    pickle.dump(output, file)

In [2]:
import pickle
# Load the output from the pickle file
with open("data.pickle", "rb") as file:
    loaded_output = pickle.load(file)
print("Loaded output from pickle file:")
for i, (k, v) in enumerate(loaded_output.items()):
    print(f"Input {i}: \n{v['input']}\n")
    print(f"Output {i}: \n{v['output']}\n")
    print("-" * 80)

Loaded output from pickle file:
Input 0: 
<reponame>Dragontalker/MySQL-study-notes
#视图

/*
含义: 虚拟表, 和普通表一样使用
mysql5.1版本出现的新特性, 是通过表动态生成的数据

比如: 舞蹈班和普通班的对比

		创建语法的关键字	是否实际占用物理空间		使用
视图		create table	 只保存了sql逻辑		增删改查, 一般不能增删改
表		create view		   保准了数据ALTER	增删改查
*/

#案例: 查询姓张的学生名和专业名
SELECT stuname, major_name
FROM stuinfo AS s
INNER JOIN major AS m 
ON s.major_id = m.id
WHERE s.stuname LIKE '%张';

#创建视图
CREATE VIEW v1
AS
SELECT stuname, major_name
FROM stuinfo AS s
INNER JOIN major AS m
ON s.major_id = m.id;

#使用视图
SELECT * FROM v1 WHERE s.stuname LIKE '%张';

#一、创建视图
/*
语法:
create view 视图名
as
查询语句;
*/

#1. 查询姓名中包含a字符的员工名、部门名和工种新消息
#(1)创建
CREATE VIEW myv1
AS
SELECT last_name, department_name, job_title
FROM employees AS e
JOIN departments AS d
ON e.department_id = d.department_id
JOIN jobs AS j
ON j.job_id = e.job_id;

#(2)使用
SELECT *
FROM myv1
WHERE last_name LIKE '%a%';

#2. 查询各部门的平均工资级别
#(1)创建视图查看每个部门的平均工资
CREATE VIEW myv2
AS
SELECT AVG(salary) AS ag, department_id
FROM employees
GROUP

In [3]:
loaded_output["sql_12"]["output"] = loaded_output["sql_12"]["output"][:loaded_output["sql_12"]["output"].index("```")]

In [4]:
loaded_output["sql_12"]["output"] = loaded_output["sql_12"]["output"] + "\n}"

In [5]:
# After manual verification

final_data = {}
remove_ind_num = [1, 2, 5, 6, 14, 20]
for i, (k, v) in enumerate(loaded_output.items()):
    if i not in remove_ind_num:
        print(f"Keeping index {i} in output.")
        final_data[k] = v

Keeping index 0 in output.
Keeping index 3 in output.
Keeping index 4 in output.
Keeping index 7 in output.
Keeping index 8 in output.
Keeping index 9 in output.
Keeping index 10 in output.
Keeping index 11 in output.
Keeping index 12 in output.
Keeping index 13 in output.
Keeping index 15 in output.
Keeping index 16 in output.
Keeping index 17 in output.
Keeping index 18 in output.
Keeping index 19 in output.
Keeping index 21 in output.
Keeping index 22 in output.
Keeping index 23 in output.
Keeping index 24 in output.


In [6]:
final_data.keys()

dict_keys(['sql_0', 'sql_3', 'sql_4', 'sql_7', 'sql_8', 'sql_9', 'sql_10', 'sql_11', 'sql_12', 'sql_13', 'sql_15', 'sql_16', 'sql_17', 'sql_18', 'sql_19', 'sql_21', 'sql_22', 'sql_23', 'sql_24'])

In [7]:
# Save the output to a JSON file
import json
output_json_path = "final_sql_data.json"
try:
    save_json(final_data, output_json_path)
except Exception as e:
    print(f"Error saving final data to JSON: {e}")

Data saved successfully to final_sql_data.json
