In [None]:
import json
from tqdm import tqdm

In [None]:
INPUT_LOCATION="/content/"
OUTPUT_LOCATION="/content/"
TABLES_FILE = "tables.json"
OUTPUT_FILE = "cosql_with_intents.json"

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

MODEL_NAME = "microsoft/phi-1_5"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype="auto")

llm = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    do_sample=False,
    temperature=0.0,
    max_new_tokens=300
)

tokenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/736 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.84G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

Device set to use cuda:0


In [None]:
def generate_intents_old(table_name, columns, data_types):
    few_shot_prompt = """
Example 1:
Table: inventory
Columns:
- item_id (INTEGER)
- item_name (TEXT)
- quantity (INTEGER)
- restock_date (DATE)

Table Purpose:
Tracks the stock levels and restocking schedules of items in a warehouse.

Column Descriptions:
- item_id: Unique identifier for each inventory item.
- item_name: Name or label of the item.
- quantity: Number of units currently in stock.
- restock_date: Date when the item is expected to be restocked.

---

Example 2:
Table: employee_attendance
Columns:
- employee_id (INTEGER)
- date (DATE)
- status (TEXT)

Table Purpose:
Logs the daily attendance status of company employees.

Column Descriptions:
- employee_id: Unique identifier for the employee.
- date: The calendar date of the attendance record.
- status: Attendance status for the day (e.g., Present, Absent, Sick).

---
""".strip()

    column_lines = "\n".join(
        [f"- {col} ({dtype})" for col, dtype in zip(columns, data_types)]
    )

    prompt = (
        f"{few_shot_prompt}\n\n"
        f"Table: {table_name}\n"
        f"Columns:\n{column_lines}\n\n"
        f"Table Purpose:\n"
    )


    result = llm(prompt)[0]["generated_text"]
    generated = result[len(prompt):].strip()

    for stop_token in ["```", "2. Write a", "CREATE TABLE", "# Solution"]:
        if stop_token in generated:
            generated = generated.split(stop_token)[0].strip()
    table_intent = ""
    column_intents = {}

    if "Column Descriptions:" in generated:
        table_intent, column_block = generated.split("Column Descriptions:", 1)
        table_intent = table_intent.strip()
        for line in column_block.strip().splitlines():
            if line.startswith("-") and ":" in line:
                try:
                    col, intent = line[1:].split(":", 1)
                    column_intents[col.strip()] = intent.strip()
                except ValueError:
                    continue
    else:
        table_intent = generated
    return table_intent, column_intents

In [None]:
def generate_intents(table_name, columns, data_types):
    few_shot_prompt = """
You're a helpful assistant. You are given the name of a SQL table and a list of columns with their types.
Your task is to describe the likely purpose of this table in one sentence, and then describe what each listed column likely represents.

Only describe the columns listed. Do not add or describe any columns that are not explicitly given.

Example 1:
Table: inventory
Columns:
- item_id (INTEGER)
- item_name (TEXT)
- quantity (INTEGER)
- restock_date (DATE)

Table Purpose:
Tracks the stock levels and restocking schedules of items in a warehouse.

Column Descriptions:
- item_id: Unique identifier for each inventory item.
- item_name: Name or label of the item.
- quantity: Number of units currently in stock.
- restock_date: Date when the item is expected to be restocked.

---

Example 2:
Table: employee_attendance
Columns:
- employee_id (INTEGER)
- date (DATE)
- status (TEXT)

Table Purpose:
Logs the daily attendance status of company employees.

Column Descriptions:
- employee_id: Unique identifier for the employee.
- date: The calendar date of the attendance record.
- status: Attendance status for the day (e.g., Present, Absent, Sick).

---
""".strip()

    column_lines = "\n".join(
        [f"- {col} ({dtype})" for col, dtype in zip(columns, data_types)]
    )

    prompt = (
        f"{few_shot_prompt}\n\n"
        f"Table: {table_name}\n"
        f"Columns:\n{column_lines}\n\n"
        f"Table Purpose:\n"
    )

    result = llm(prompt, return_full_text=False)[0]["generated_text"]
    generated = result.strip()

    # Post-processing
    for stop_token in ["```", "2. Write a", "CREATE TABLE", "# Solution"]:
        if stop_token in generated:
            generated = generated.split(stop_token)[0].strip()

    table_intent = ""
    column_intents = {}

    expected_columns = set([col.lower() for col in columns])

    if "Column Descriptions:" in generated:
        table_intent, column_block = generated.split("Column Descriptions:", 1)
        table_intent = table_intent.strip()

        for line in column_block.strip().splitlines():
            if line.startswith("-") and ":" in line:
                try:
                    col, intent = line[1:].split(":", 1)
                    col = col.strip().lower()
                    if col in expected_columns:
                        column_intents[col] = intent.strip()
                except ValueError:
                    continue
    else:
        table_intent = generated

    return table_intent, column_intents


In [None]:
def load_tables(path):
    with open(path, "r") as f:
        return json.load(f)

def main():
    schemas = load_tables(INPUT_LOCATION + TABLES_FILE)
    output = []
    for schema in tqdm(schemas):
        db_id = schema.get('db_id', '')
        table_names = schema.get('table_names', [])
        column_names = schema.get('column_names', [])
        column_types = schema.get('column_types', [])

        table_index_to_name = {i: name for i, name in enumerate(table_names)}
        table_columns = {}
        for col_idx, (table_idx, col_name) in enumerate(column_names):
            if table_idx == -1:
                continue
            if table_idx not in table_columns:
                table_columns[table_idx] = []
            col_type = column_types[col_idx] if col_idx < len(column_types) else 'unknown'
            table_columns[table_idx].append((col_name, col_type))
        table_iter = tqdm(table_index_to_name.items())
        for table_idx, table_name in table_iter:
            cols = table_columns.get(table_idx, [])
            col_names = [col_name for col_name, _ in cols]
            col_datatypes = [col_type for _, col_type in cols]
            combined_columns = [f"{col_name} ({col_type})" for col_name, col_type in cols]
            table_intent, column_intents = generate_intents(table_name, col_names, col_datatypes)
            table_dict = {
                'db_id': db_id,
                'table_name': table_name,
                'columns': combined_columns,
                'table_intent': table_intent,
                'column_intents': column_intents
            }
            output.append(table_dict)

    with open(OUTPUT_LOCATION + OUTPUT_FILE, "w") as f:
        json.dump(output, f, indent=2)

    print(f"Saved enriched Spider metadata to {OUTPUT_FILE}")

In [None]:
main()

  0%|          | 0/178 [00:00<?, ?it/s]

 20%|██        | 1/5 [00:07<00:31,  7.79s/it][A
 40%|████      | 2/5 [00:14<00:21,  7.07s/it][A
 60%|██████    | 3/5 [00:20<00:13,  6.86s/it][A
 80%|████████  | 4/5 [00:27<00:06,  6.74s/it][A
100%|██████████| 5/5 [00:34<00:00,  6.81s/it]
  1%|          | 1/178 [00:34<1:40:27, 34.05s/it]
  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:06<00:06,  6.59s/it][A
100%|██████████| 2/2 [00:13<00:00,  6.61s/it]
  1%|          | 2/178 [00:47<1:03:56, 21.80s/it]
  0%|          | 0/11 [00:00<?, ?it/s][A
  9%|▉         | 1/11 [00:06<01:06,  6.62s/it][A
 18%|█▊        | 2/11 [00:13<00:59,  6.58s/it][A
 27%|██▋       | 3/11 [00:19<00:52,  6.60s/it][AYou seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset

 36%|███▋      | 4/11 [00:26<00:46,  6.60s/it][A
 45%|████▌     | 5/11 [00:32<00:39,  6.59s/it][A
 55%|█████▍    | 6/11 [00:39<00:33,  6.61s/it][A
 64%|██████▎   | 7/11 [00:46<0

Saved enriched Spider metadata to cosql_with_intents.json





In [None]:
from google.colab import files
files.download(OUTPUT_LOCATION + OUTPUT_FILE)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>