# Introduction

This note will create dataset.

# Login Huggingface

In [29]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HUGGINGFACE_W")

login(token = hf_token)


# Generate dataset 

In [30]:
import json
import random

def generate_dataset(num_samples):
    """Generates a dataset of database operations and corresponding SQL statements.

    Args:
        num_samples: The number of samples to generate.

    Returns:
        A list of dictionaries, where each dictionary contains an "input" (JSON string representing the database operation) and an "output" (SQL string).
    """
    dataset = []
    for _ in range(num_samples):
        # Randomly choose an action type
        action = random.choice(["query", "insert", "update", "delete"])

        # Generate a table name
        table = random.choice(["users", "products", "orders", "customers", "employees", "books", "sales"])

        input_data = {"action": action, "table": table}
        output_sql = ""

        if action == "query":
            # Randomly choose columns to query
            columns = random.sample(["*"] + get_table_columns(table), random.randint(1, len(get_table_columns(table)) + 1))
            input_data["columns"] = columns

            # Generate a WHERE clause
            where_clause = generate_where_clause(table)
            if where_clause:
                input_data["where"] = where_clause

            join_data = None  #  在这里初始化 join_data 为 None

            # Generate a JOIN clause (simplified for Kaggle example)
            if random.random() < 0.3 and table == "employees": # 30% chance to generate a JOIN, only for employees table in this example
                join_data = generate_join_clause(table)
                if join_data:
                    input_data["join"] = join_data

            output_sql = generate_select_sql(table, columns, where_clause, join_data)

        elif action == "insert":
            # Generate data for insertion
            data = generate_insert_data(table)
            input_data["data"] = data
            output_sql = generate_insert_sql(table, data)

        elif action == "update":
            # Generate data for update
            data = generate_update_data(table)
            input_data["data"] = data

            # Generate a WHERE clause
            where_clause = generate_where_clause(table)
            if where_clause:
                input_data["where"] = where_clause
            output_sql = generate_update_sql(table, data, where_clause)

        elif action == "delete":
            # Generate a WHERE clause
            where_clause = generate_where_clause(table)
            if where_clause:
                input_data["where"] = where_clause
            output_sql = generate_delete_sql(table, where_clause)

        dataset.append({"input": json.dumps(input_data), "output": output_sql})
    return dataset


def get_table_columns(table):
    # Simplified columns for Kaggle demo
    columns = {
        "users": ["id", "name", "email", "age"],
        "products": ["id", "name", "price", "category"],
        "orders": ["order_id", "status", "customer_id"],
        "customers": ["customer_id", "name", "country"],
        "employees": ["id", "name", "salary", "department_id"],
        "books": ["title", "author", "publication_year"],
        "sales": ["order_date", "amount"]
    }
    return columns.get(table, [])


def generate_where_clause(table):
    columns = get_table_columns(table)
    if not columns:
        return None

    column = random.choice(columns)
    operator = random.choice(["=", ">", "<", ">=", "<=", "!="])
    value = generate_value_for_column(column)  # Use the improved value generation

    return f"{column} {operator} {value}"

def generate_join_clause(table):
    if table == "employees":
        return {"table": "departments", "on": "employees.department_id = departments.id", "columns": ["department_name"]}
    return None


def generate_value_for_column(column):
    # More diverse value generation
    value_types = {
        "id": random.randint(1, 100),
        "age": random.randint(18, 65),
        "price": round(random.uniform(1.0, 100.0), 2),
        "salary": random.randint(30000, 100000),
        "order_id": random.randint(1000, 9999),
        "customer_id": random.randint(100, 999),
        "amount": round(random.uniform(10.0, 1000.0), 2),
        "name": random.choice(["Alice", "Bob", "Charlie", "David"]),
        "email": lambda: f"{random.choice(['test', 'example'])}@{random.choice(['gmail', 'yahoo'])}.com",
        "category": random.choice(["Electronics", "Books", "Clothing"]),
        "status": random.choice(["Pending", "Shipped", "Delivered"]),
        "country": random.choice(["USA", "Canada", "UK"]),
        "title": random.choice(["The Hitchhiker's Guide to the Galaxy", "Pride and Prejudice"]),
        "author": random.choice(["Douglas Adams", "Jane Austen"]),
        "publication_year": random.randint(1900, 2023),
        "order_date": lambda: f"'{random.randint(2020, 2023)}-{random.randint(1, 12):02}-{random.randint(1, 28):02}'",
        "department_id": random.randint(1, 5) # Example for department_id
    }

    if column in value_types:
        if callable(value_types[column]):
            return value_types[column]()
        else:
            return f"'{value_types[column]}'" if isinstance(value_types[column], str) else str(value_types[column])
    else:
      return f"'{random.choice(['value1', 'value2', 'value3'])}'"


def generate_insert_data(table):
    columns = get_table_columns(table)
    data = {}
    for column in columns:
        if column != "id" and column != "order_id" and column != "customer_id": # Assuming auto-increment for id and order_id and customer_id
            data[column] = generate_value_for_column(column)
    return data

def generate_update_data(table):
    columns = get_table_columns(table)
    data = {}
    for column in columns:
        if column != "id" and column != "order_id" and column != "customer_id": # Avoid updating IDs in this simplified example
            data[column] = generate_value_for_column(column)
    return data

def generate_select_sql(table, columns, where_clause, join_data):
    column_str = ", ".join(columns)
    sql = f"SELECT {column_str} FROM {table}"
    if join_data:
        sql += f" JOIN {join_data['table']} ON {join_data['on']}"
        if "columns" in join_data and join_data["columns"]:
            join_columns_str = ", ".join([f"{join_data['table']}.{col}" for col in join_data['columns']])
            if column_str == "*":
                column_str = f"{table}.*, {join_columns_str}"
            else:
                column_str += ", " + join_columns_str
            sql = f"SELECT {column_str} FROM {table} JOIN {join_data['table']} ON {join_data['on']}" # Reconstruct SELECT with correct columns

    if where_clause:
        sql += f" WHERE {where_clause}"
    return sql + ";"


def generate_insert_sql(table, data):
    columns = ", ".join(data.keys())
    values = ", ".join([str(v) if isinstance(v, (int, float)) else f"'{v}'" for v in data.values()])
    return f"INSERT INTO {table} ({columns}) VALUES ({values});"

def generate_update_sql(table, data, where_clause):
    set_clauses = []
    for k, v in data.items():
        if isinstance(v, (int, float)):
            set_clauses.append(f"{k} = {v}")
        else:
            set_clauses.append(f"{k} = '{v}'")

    set_clause = ", ".join(set_clauses)
    sql = f"UPDATE {table} SET {set_clause}"
    if where_clause:
        sql += f" WHERE {where_clause}"
    return sql + ";"

def generate_delete_sql(table, where_clause):
    sql = f"DELETE FROM {table}"
    if where_clause:
        sql += f" WHERE {where_clause}"
    return sql + ";"


# Generate the dataset and print it (or save to a file for Kaggle)
dataset = generate_dataset(20) # Increased to 20 for more data
print(json.dumps(dataset, indent=4))

[
    {
        "input": "{\"action\": \"query\", \"table\": \"employees\", \"columns\": [\"*\", \"name\", \"department_id\", \"salary\", \"id\"], \"where\": \"name >= 'David'\"}",
        "output": "SELECT *, name, department_id, salary, id FROM employees WHERE name >= 'David';"
    },
    {
        "input": "{\"action\": \"delete\", \"table\": \"customers\", \"where\": \"name > 'David'\"}",
        "output": "DELETE FROM customers WHERE name > 'David';"
    },
    {
        "input": "{\"action\": \"query\", \"table\": \"customers\", \"columns\": [\"customer_id\", \"*\", \"name\", \"country\"], \"where\": \"country = 'USA'\"}",
        "output": "SELECT customer_id, *, name, country FROM customers WHERE country = 'USA';"
    },
    {
        "input": "{\"action\": \"delete\", \"table\": \"employees\", \"where\": \"salary <= 74733\"}",
        "output": "DELETE FROM employees WHERE salary <= 74733;"
    },
    {
        "input": "{\"action\": \"update\", \"table\": \"products\", \"data