In [5]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [6]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import re

import spacy
import numpy as np
from collections import defaultdict, Counter

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

# For reproducibility
import random
import os



In [7]:
schema = '''
Table Name, Field Name, Is Primary Key, Is Foreign Key, Type
AIRCRAFT, AIRCRAFT_CODE, n, y, varchar(3)
AIRCRAFT, AIRCRAFT_DESCRIPTION, n, n, varchar(50)
AIRCRAFT, MANUFACTURER, n, n, varchar(30)
AIRCRAFT, BASIC_TYPE, n, n, varchar(30)
AIRCRAFT, ENGINES, n, n, int(11)
AIRCRAFT, PROPULSION, n, n, varchar(10)
AIRCRAFT, WIDE_BODY, n, n, varchar(3)
AIRCRAFT, WING_SPAN, n, n, int(11)
AIRCRAFT, LENGTH, n, n, int(11)
AIRCRAFT, WEIGHT, n, n, int(11)
AIRCRAFT, CAPACITY, n, n, int(11)
AIRCRAFT, PAY_LOAD, n, n, int(11)
AIRCRAFT, CRUISING_SPEED, n, n, int(11)
AIRCRAFT, RANGE_MILES, n, n, int(11)
AIRCRAFT, PRESSURIZED, n, n, varchar(3)
-, -, -, -, -
AIRLINE, AIRLINE_CODE, n, y, varchar(2)
AIRLINE, AIRLINE_NAME, n, n, text
AIRLINE, NOTE, n, n, text
-, -, -, -, -
AIRPORT, AIRPORT_CODE, n, y, varchar(3)
AIRPORT, AIRPORT_NAME, n, n, text
AIRPORT, AIRPORT_LOCATION, n, n, text
AIRPORT, STATE_CODE, n, n, varchar(2)
AIRPORT, COUNTRY_NAME, n, n, varchar(6)
AIRPORT, TIME_ZONE_CODE, n, n, varchar(3)
AIRPORT, MINIMUM_CONNECT_TIME, n, n, int(11)
-, -, -, -, -
AIRPORT_SERVICE, CITY_CODE, n, y, varchar(4)
AIRPORT_SERVICE, AIRPORT_CODE, n, y, varchar(3)
AIRPORT_SERVICE, MILES_DISTANT, n, n, int(11)
AIRPORT_SERVICE, DIRECTION, n, n, varchar(2)
AIRPORT_SERVICE, MINUTES_DISTANT, n, n, int(11)
-, -, -, -, -
CITY, CITY_CODE, n, y, varchar(4)
CITY, CITY_NAME, n, n, varchar(18)
CITY, STATE_CODE, n, y, varchar(2)
CITY, COUNTRY_NAME, n, n, varchar(6)
CITY, TIME_ZONE_CODE, n, n, varchar(3)
-, -, -, -, -
CLASS_OF_SERVICE, BOOKING_CLASS, y, y, varchar(2)
CLASS_OF_SERVICE, RANK, n, n, int(11)
CLASS_OF_SERVICE, CLASS_DESCRIPTION, n, n, text
-, -, -, -, -
CODE_DESCRIPTION, CODE, y, n, varchar(4)
CODE_DESCRIPTION, DESCRIPTION, n, n, text
-, -, -, -, -
COMPARTMENT_CLASS, COMPARTMENT, n, n, varchar(5)
COMPARTMENT_CLASS, CLASS_TYPE, n, n, varchar(8)
-, -, -, -, -
DATE_DAY, MONTH_NUMBER, n, n, int(11)
DATE_DAY, DAY_NUMBER, n, n, int(11)
DATE_DAY, YEAR, n, n, int(11)
DATE_DAY, DAY_NAME, n, y, varchar(10)
-, -, -, -, -
DAYS, DAYS_CODE, n, y, varchar(20)
DAYS, DAY_NAME, n, y, varchar(10)
-, -, -, -, -
DUAL_CARRIER, MAIN_AIRLINE, n, n, varchar(2)
DUAL_CARRIER, LOW_FLIGHT_NUMBER, n, n, int(11)
DUAL_CARRIER, HIGH_FLIGHT_NUMBER, n, n, int(11)
DUAL_CARRIER, DUAL_AIRLINE, n, n, varchar(2)
DUAL_CARRIER, SERVICE_NAME, n, n, text
-, -, -, -, -
EQUIPMENT_SEQUENCE, AIRCRAFT_CODE_SEQUENCE, n, y, varchar(12)
EQUIPMENT_SEQUENCE, AIRCRAFT_CODE, n, y, varchar(3)
-, -, -, -, -
FARE, FARE_ID, y, y, int(11)
FARE, FROM_AIRPORT, n, y, varchar(3)
FARE, TO_AIRPORT, n, y, varchar(3)
FARE, FARE_BASIS_CODE, n, y, text
FARE, FARE_AIRLINE, n, n, text
FARE, RESTRICTION_CODE, n, y, text
FARE, ONE_DIRECTION_COST, n, n, int(11)
FARE, ROUND_TRIP_COST, n, n, int(11)
FARE, ROUND_TRIP_REQUIRED, n, n, varchar(3)
-, -, -, -, -
FARE_BASIS, FARE_BASIS_CODE, n, y, text
FARE_BASIS, BOOKING_CLASS, n, y, text
FARE_BASIS, CLASS_TYPE, n, n, text
FARE_BASIS, PREMIUM, n, n, text
FARE_BASIS, ECONOMY, n, n, text
FARE_BASIS, DISCOUNTED, n, n, text
FARE_BASIS, NIGHT, n, n, text
FARE_BASIS, SEASON, n, n, text
FARE_BASIS, BASIS_DAYS, n, y, text
-, -, -, -, -
FLIGHT, AIRCRAFT_CODE_SEQUENCE, n, y, text
FLIGHT, AIRLINE_CODE, n, y, varchar(3)
FLIGHT, AIRLINE_FLIGHT, n, n, text
FLIGHT, ARRIVAL_TIME, n, n, int(11)
FLIGHT, CONNECTIONS, n, n, int(11)
FLIGHT, DEPARTURE_TIME, n, n, int(11)
FLIGHT, DUAL_CARRIER, n, n, text
FLIGHT, FLIGHT_DAYS, n, y, text
FLIGHT, FLIGHT_ID, y, y, int(11)
FLIGHT, FLIGHT_NUMBER, n, n, int(11)
FLIGHT, FROM_AIRPORT, n, y, varchar(3)
FLIGHT, MEAL_CODE, n, y, text
FLIGHT, STOPS, n, n, int(11)
FLIGHT, TIME_ELAPSED, n, n, int(11)
FLIGHT, TO_AIRPORT, n, y, varchar(3)
-, -, -, -, -
FLIGHT_FARE, FLIGHT_ID, n, y, int(11)
FLIGHT_FARE, FARE_ID, n, y, int(11)
-, -, -, -, -
FLIGHT_LEG, FLIGHT_ID, n, y, int(11)
FLIGHT_LEG, LEG_NUMBER, n, n, int(11)
FLIGHT_LEG, LEG_FLIGHT, n, y, int(11)
-, -, -, -, -
FLIGHT_STOP, FLIGHT_ID, n, y, int(11)
FLIGHT_STOP, STOP_NUMBER, n, n, int(11)
FLIGHT_STOP, STOP_DAYS, n, n, text
FLIGHT_STOP, STOP_AIRPORT, n, y, text
FLIGHT_STOP, ARRIVAL_TIME, n, n, int(11)
FLIGHT_STOP, ARRIVAL_AIRLINE, n, n, text
FLIGHT_STOP, ARRIVAL_FLIGHT_NUMBER, n, n, int(11)
FLIGHT_STOP, DEPARTURE_TIME, n, n, int(11)
FLIGHT_STOP, DEPARTURE_AIRLINE, n, n, text
FLIGHT_STOP, DEPARTURE_FLIGHT_NUMBER, n, n, int(11)
FLIGHT_STOP, STOP_TIME, n, n, int(11)
-, -, -, -, -
FOOD_SERVICE, MEAL_CODE, n, y, text
FOOD_SERVICE, MEAL_NUMBER, n, n, int(11)
FOOD_SERVICE, COMPARTMENT, n, n, text
FOOD_SERVICE, MEAL_DESCRIPTION, n, n, varchar(10)
-, -, -, -, -
GROUND_SERVICE, CITY_CODE, n, y, text
GROUND_SERVICE, AIRPORT_CODE, n, y, text
GROUND_SERVICE, TRANSPORT_TYPE, n, n, text
GROUND_SERVICE, GROUND_FARE, n, n, int(11)
-, -, -, -, -
MONTH, MONTH_NUMBER, n, n, int(11)
MONTH, MONTH_NAME, n, n, text
-, -, -, -, -
RESTRICTION, RESTRICTION_CODE, n, y, text
RESTRICTION, ADVANCE_PURCHASE, n, n, int(11)
RESTRICTION, STOPOVERS, n, n, text
RESTRICTION, SATURDAY_STAY_REQUIRED, n, n, text
RESTRICTION, MINIMUM_STAY, n, n, int(11)
RESTRICTION, MAXIMUM_STAY, n, n, int(11)
RESTRICTION, APPLICATION, n, n, text
RESTRICTION, NO_DISCOUNTS, n, n, text
-, -, -, -, -
STATE, STATE_CODE, n, y, text
STATE, STATE_NAME, n, n, text
STATE, COUNTRY_NAME, n, n, text
-, -, -, -, -
TIME_INTERVAL, PERIOD, n, n, text
TIME_INTERVAL, BEGIN_TIME, n, n, int(11)
TIME_INTERVAL, END_TIME, n, n, int(11)
-, -, -, -, -
TIME_ZONE, TIME_ZONE_CODE, n, n, text
TIME_ZONE, TIME_ZONE_NAME, n, n, text
TIME_ZONE, HOURS_FROM_GMT, n, n, int(11)
'''

In [8]:
# Load the ATIS JSON data

nlp = spacy.load("en_core_web_sm")

with open("/content/drive/My Drive/atis.json", "r") as f:
    raw_data = json.load(f)

# Structures to store splits and template mappings
question_split = {"train": [], "dev": [], "test": []}
query_split = {"train": [], "dev": [], "test": []}
sql_templates = {}
template_id_counter = 0

# Process each item (SQL group)
for item in raw_data:
    # Select the shortest SQL template for this group
    sqls = item["sql"]
    shortest_sql = sorted(sqls, key=lambda x: (len(x), x))[0]

    # Assign a unique template ID if new
    if shortest_sql not in sql_templates:
        sql_templates[shortest_sql] = template_id_counter
        template_id_counter += 1
    template_id = sql_templates[shortest_sql]

    # Process each sentence (question) in this group
    for sent in item["sentences"]:
        raw_text = sent["text"]  # contains placeholders like city_name0
        variables = sent["variables"]

        # Replace placeholders in the text for the model input
        text = raw_text
        for var_name, var_value in variables.items():
            text = text.replace(var_name, var_value)

        # Tokenize the input (real question with values)
        tokens = [t.text for t in nlp(text)]

        # Generate tag sequence by aligning with raw_text
        raw_tokens = [t.text for t in nlp(raw_text)]
        tags = []
        for tok in raw_tokens:
            if tok in variables:
                tags.append(tok)  # tag is the variable name
            else:
                tags.append("O")

        entry = {
            "text": text,                        # real input used for tokenization and modeling
            "template_id": template_id,
            "template_sql": shortest_sql,
            "variables": variables,
            "raw_text": raw_text,               # for debugging only
            "tokens": tokens,                   # used for model input
            "tags": tags                        # ground truth tag for each token
        }

        # Add to splits
        question_split[sent["question-split"]].append(entry)
        query_split[item["query-split"]].append(entry)

print(sql_templates)

{'SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT AS AIRPORTalias0 , FLIGHT AS FLIGHTalias0 WHERE AIRPORTalias0.AIRPORT_CODE = "airport_code0" AND FLIGHTalias0.TO_AIRPORT = AIRPORTalias0.AIRPORT_CODE ;': 0, 'SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS CITYalias0 , CITY AS CITYalias1 , DATE_DAY AS DATE_DAYalias0 , DAYS AS DAYSalias0 , FLIGHT AS FLIGHTalias0 WHERE ( CITYalias0.CITY_CODE = AIRPORT_SERVICEalias0.CITY_CODE AND CITYalias0.CITY_NAME = "city_name1" AND CITYalias1.CITY_CODE = AIRPORT_SERVICEalias1.CITY_CODE AND CITYalias1.CITY_NAME = "city_name0" AND FLIGHTalias0.FROM_AIRPORT = AIRPORT_SERVICEalias0.AIRPORT_CODE AND FLIGHTalias0.TO_AIRPORT = AIRPORT_SERVICEalias1.AIRPORT_CODE ) AND DATE_DAYalias0.DAY_NUMBER = day_number0 AND DATE_DAYalias0.MONTH_NUMBER = month_number0 AND DATE_DAYalias0.YEAR = year0 AND DAYSalias0.DAY_NAME = DATE_DAYalias0.DAY_NAME AND FLIGHTalias0.FLIGHT_DAYS = DA

In [None]:
from openai import OpenAI
import time

client = OpenAI(
  base_url="https://openrouter.ai/api/v1",
  api_key="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",
)

train_data = question_split["train"]
test_data = question_split["test"]

print(len(query_split["test"]))

def get_example_string(example):
  sql_template = substitute_variables(example["template_sql"], example["variables"])

  return f'''
    Example:
      Question: {example["text"]}
      Answer: {example["template_sql"]}
  '''

def get_model_response(question, example_questions):
    example_questions_str = '\n'.join(f"- {q}" for q in example_questions)

    return client.chat.completions.create(
        model="meta-llama/llama-3.2-3b-instruct:free",
        messages=
          [
              {
                  "role": "system",
                  "content": "I would like you to generate the most efficacious SQL query based on my natural language input. Try your best to replicate the style of the examples provided (if any are provided). Please do not return anything but a single line SQL query. You are going to receive the schema for the database, then some examples, then finally the question. Ensure that you only return a single SQL query in a single line and nothing else."
              },

              {
                  "role": "user",
                  "content":
                      f"""
                          The question you need to answer is: {question}

                          The schema for the database is: {schema}

                          The example questions are: {example_questions_str}
                      """
              }
          ]
    )

def substitute_variables(raw_text, variables):
    pattern = re.compile(r'\b(' + '|'.join(map(re.escape, variables.keys())) + r')\b')
    return pattern.sub(lambda match: variables.get(match.group(0), match.group(0)), raw_text)

def run_model(num_examples, train_data, test_data):
  total = 0
  correct = 0
  for i in range(len(test_data)):
      random_examples = random.sample(train_data, 40)
      model_response = get_model_response(test_data[i]["text"], [get_example_string(example) for example in random_examples]).choices[0].message.content
      actual = substitute_variables(test_data[i]["template_sql"], test_data[i]["variables"])

      print(model_response)
      print(actual)

      if actual == model_response:
        correct += 1
      total += 1

      print('--- PROGRESS ---')
      print(total/len(test_data) * 100)
      time.sleep(3)

  print("--- CORRECT ---")
  print(correct)

  print("--- ACCURACY ---")
  print(f"{correct / total:.5f}")

run_model(0, query_split["train"], query_split["test"])
run_model(5, query_split["train"], query_split["test"])
run_model(40, query_split["train"], query_split["test"])

run_model(0, question_split["train"], question_split["test"])
run_model(5, question_split["train"], question_split["test"])
run_model(40, question_split["train"], question_split["test"])

SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS CITYalias0 , CITY AS CITYalias1 , FLIGHT AS FLIGHTalias0 WHERE CITYalias0.CITY_CODE = AIRPORT_SERVICEalias0.CITY_CODE AND CITYalias0.CITY_NAME = "MKE" AND CITYalias1.CITY_CODE = AIRPORT_SERVICEalias1.CITY_CODE AND CITYalias1.CITY_NAME = "various"
SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT AS AIRPORTalias0 , FLIGHT AS FLIGHTalias0 WHERE AIRPORTalias0.AIRPORT_CODE = "MKE" AND FLIGHTalias0.TO_AIRPORT = AIRPORTalias0.AIRPORT_CODE ;
--- PROGRESS ---
0.22371364653243847
SELECT DISTINCT FLIGHT.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS CITYalias0 , CITY AS CITYalias1 , FLIGHT AS FLIGHTalias0 WHERE ( CITYalias0.CITY_CODE = AIRPORT_SERVICEalias0.CITY_CODE AND CITYalias0.CITY_NAME = "city_name1" AND CITYalias1.CITY_CODE = AIRPORT_SERVICEalias1.CITY_CODE AND CITYalias1.CITY_NAME = "city_

KeyboardInterrupt: 