<a href="https://colab.research.google.com/github/YichengShen/cis5220-project/blob/main/gpt3.5_turbo_with_db.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-3.5-turbo

In [None]:
!pip install langchain
!pip install gpt_index

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import argparse
import json
import logging
import os
import re
import shutil

from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import BaseLanguageModel
from sqlalchemy import create_engine
from tqdm import tqdm

from gpt_index import GPTSQLStructStoreIndex, LLMPredictor, SQLDatabase, ServiceContext

In [None]:
os.environ["OPENAI_API_KEY"] = '' # You need to add an API key to run this code.

In [None]:
# Create data folder if not exist
!mkdir -p data

# Change this path to where you store spider.zip in your Drive
dataset_zip_path_in_drive = "/content/drive/Shareddrives/CIS 522/spider.zip"
dataset_zip_path_in_runtime = "/content/data/spider.zip"

shutil.copy(dataset_zip_path_in_drive, dataset_zip_path_in_runtime)

'/content/data/spider.zip'

In [None]:
!unzip -q -o /content/data/spider.zip -d /content/data/

In [None]:
logging.getLogger("root").setLevel(logging.WARNING)

In [None]:
_spaces = re.compile(r"\s+")
_newlines = re.compile(r"\n+")

In [None]:
table_paths = "/content/data/spider/tables.json"

if not isinstance(table_paths, list):
        table_paths = (table_paths, )

for i, TABLE_PATH in enumerate(table_paths):
    print(f"Loading data from {TABLE_PATH}")
    with open(TABLE_PATH) as inf:
        table_data= json.load(inf)

Loading data from /content/data/spider/tables.json


In [None]:
def format_dict(input_dict):
    formatted_value = []

    for i in range(len(input_dict['table_names'])):
        table_name = input_dict['table_names'][i]
        columns = [col[1].replace(" ", "_") for col in input_dict['column_names'] if col[0] == i]
        formatted_columns = ', '.join(columns)
        formatted_value.append(f"{table_name} : {formatted_columns}")

    formatted_value_str = " | ".join(formatted_value)
    return {input_dict['db_id']: formatted_value_str}

formatted_table_data = [format_dict(d) for d in table_data]
merged_formatted_table_data = {k: v for d in formatted_table_data for k, v in d.items()}

In [None]:
def _generate_sql(
    llama_index: GPTSQLStructStoreIndex,
    nl_query_text: str,
) -> str:
    """Generate SQL query for the given NL query text."""
    
    response = llama_index.query(nl_query_text)
    if (
        response.extra_info is None
        or "sql_query" not in response.extra_info
        or response.extra_info["sql_query"] is None
    ):
        raise RuntimeError("No SQL query generated.")
    query = response.extra_info["sql_query"]
    # Remove newlines and extra spaces.
    query = _newlines.sub(" ", query)
    query = _spaces.sub(" ", query)
    return query.strip()

In [None]:
def generate_sql(llama_indexes: dict, examples: list, output_file: str) -> None:
    """Generate SQL queries for the given examples and write them to the output file."""
    with open(output_file, "w") as f:
        for example in tqdm(examples, desc=f"Generating {output_file}"):
            db_name = example["db_id"]
            nl_query_text = example["question"] 
            # + " | " + merged_formatted_table_data[db_name]
            # print(nl_query_text)
            # break
            try:
                sql_query = _generate_sql(llama_indexes[db_name], nl_query_text)
            except Exception as e:
                print(
                    f"Failed to generate SQL query for question: "
                    f"{example['question']} on database: {example['db_id']}."
                )
                print(e)
                sql_query = "ERROR"
            f.write(sql_query + "\n")

In [None]:
# Define variables or use input prompts to get values for arguments
input_path = "/content/data/spider"
output_path = "/content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name"
model_choice = "gpt-3.5-turbo"  # Replace with desired model option

In [None]:
if not os.path.exists(output_path):
    os.makedirs(output_path)

In [None]:
# Load the Spider dataset from the input directory.
with open(os.path.join(input_path, "train_spider.json"), "r") as f:
    train_spider = json.load(f)
with open(os.path.join(input_path, "train_others.json"), "r") as f:
    train_others = json.load(f)
with open(os.path.join(input_path, "dev.json"), "r") as f:
    dev = json.load(f)

In [None]:
# Create all necessary SQL database objects.
databases = {}
for db in train_spider + train_others + dev:
    db_name = db["db_id"]
    if db_name in databases:
        continue
    db_path = os.path.join(input_path, "database", db_name, db_name + ".sqlite")
    engine = create_engine("sqlite:///" + db_path)
    databases[db_name] = (SQLDatabase(engine=engine), engine)

  self._metadata.reflect(
  self._metadata.reflect(
  self.metadata_obj.reflect()
  self.metadata_obj.reflect()
  self._metadata.reflect(
  self.metadata_obj.reflect()
  self._metadata.reflect(
  self.metadata_obj.reflect()
  self._metadata.reflect(
  self.metadata_obj.reflect()


In [None]:
# Create the LlamaIndexes for all databases.
if model_choice in ["gpt-3.5-turbo", "gpt-4"]:
    llm: BaseLanguageModel = ChatOpenAI(model=model_choice, temperature=0)
else:
    llm = OpenAI(model=model_choice, temperature=0)
llm_predictor = LLMPredictor(llm=llm)
llm_indexes = {}
for db_name, (db, engine) in databases.items():
    # Get the name of the first table in the database.
    # This is a hack to get a table name for the index, which can use any
    # table in the database.
    table_name = engine.execute(
        "select name from sqlite_master where type = 'table'"
    ).fetchone()[0]
    service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
    llm_indexes[db_name] = GPTSQLStructStoreIndex.from_documents(
        documents=[],
        service_context=service_context,
        sql_database=db,
        table_name=table_name,
    )

  for column in self._inspector.get_columns(table_name):
  for column in self._inspector.get_columns(table_name):
  for foreign_key in self._inspector.get_foreign_keys(table_name):
  for foreign_key in self._inspector.get_foreign_keys(table_name):
  for foreign_key in self._inspector.get_foreign_keys(table_name):


In [None]:
# llm_indexes
sample_dev = dev[:3]
sample_dev

query = [a["query"] for a in sample_dev]
query

['SELECT count(*) FROM singer',
 'SELECT count(*) FROM singer',
 'SELECT name ,  country ,  age FROM singer ORDER BY age DESC']

In [None]:
merged_formatted_table_data['perpetrator']

'perpetrator : perpetrator_id, people_id, date, year, location, country, killed, injured | people : people_id, name, height, weight, home_town'

In [None]:
# Generate SQL queries.
# generate_sql(
#     llama_indexes=llm_indexes,
#     examples=train_spider + train_others,
#     output_file=os.path.join(output_path, "train_pred.sql"),
# )
generate_sql(
    llama_indexes=llm_indexes,
    examples=dev,
    output_file=os.path.join(output_path, "dev_pred.sql"),
)

Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:   9%|▉         | 91/1034 [02:27<24:14,  1.54s/it]

Failed to generate SQL query for question: For each continent, list its id, name, and how many countries it has? on database: car_1.
(sqlite3.OperationalError) no such column: c.ContId
[SQL: SELECT c.ContId, c.CountryName, COUNT(*) AS num_countries
FROM countries c
JOIN continents con ON c.Continent = con.ContId
GROUP BY c.Continent
ORDER BY num_countries DESC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  10%|▉         | 100/1034 [02:41<25:42,  1.65s/it]

Failed to generate SQL query for question: Find the name of the makers that produced some cars in the year of 1970? on database: car_1.
(sqlite3.OperationalError) no such column: car_names.MakeId
[SQL: SELECT DISTINCT car_makers.Maker
FROM car_makers
JOIN cars_data ON car_makers.Id = car_names.MakeId
WHERE cars_data.Year = 1970
ORDER BY car_makers.Maker ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  10%|▉         | 101/1034 [02:43<26:51,  1.73s/it]

Failed to generate SQL query for question: What is the name of the different car makers who produced a car in 1970? on database: car_1.
(sqlite3.OperationalError) no such column: car_names.MakeId
[SQL: SELECT DISTINCT car_makers.Maker
FROM car_makers
JOIN cars_data ON car_makers.Id = car_names.MakeId
WHERE cars_data.Year = 1970
ORDER BY car_makers.Maker ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  10%|█         | 104/1034 [02:51<34:08,  2.20s/it]

Failed to generate SQL query for question: Which distinct car models are the produced after 1980? on database: car_1.
(sqlite3.OperationalError) ambiguous column name: Model
[SQL: SELECT DISTINCT Model
FROM car_names
JOIN model_list ON car_names.Model = model_list.Model
JOIN cars_data ON car_names.MakeId = cars_data.Id
WHERE Year > 1980
ORDER BY Model ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  17%|█▋        | 179/1034 [05:20<47:16,  3.32s/it]

Failed to generate SQL query for question: What are the ids and names of all countries that either have more than 3 car makers or produce fiat model ? on database: car_1.
(sqlite3.OperationalError) no such column: continents.Country
[SQL: SELECT CountryId, CountryName FROM countries
WHERE CountryId IN (
  SELECT Country FROM car_makers
  GROUP BY Country
  HAVING COUNT(DISTINCT Maker) > 3
)
OR CountryId IN (
  SELECT DISTINCT continents.Country FROM continents
  JOIN countries ON continents.ContId = countries.Continent
  JOIN car_makers ON countries.CountryId = car_makers.Country
  JOIN model_list ON car_makers.Id = model_list.Maker
  WHERE model_list.Model = 'fiat'
)
ORDER BY CountryName]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  21%|██        | 217/1034 [05:58<21:45,  1.60s/it]

Failed to generate SQL query for question: Count the number of United Airlines flights arriving in ASY Airport. on database: flight_2.
(sqlite3.OperationalError) no such column: airlines.Airline
[SQL: SELECT COUNT(*) FROM flights 
JOIN airports ON flights.DestAirport = airports.AirportCode 
WHERE airlines.Airline = 'United Airlines' AND airports.AirportCode = 'ASY']
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  23%|██▎       | 235/1034 [06:29<21:32,  1.62s/it]

Failed to generate SQL query for question: Which airlines have a flight with source airport AHD? on database: flight_2.
(sqlite3.OperationalError) ambiguous column name: Airline
[SQL: SELECT Airline, Abbreviation FROM airlines
JOIN flights ON airlines.uid = flights.Airline
WHERE flights.SourceAirport = 'AHD'
ORDER BY Airline ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  24%|██▎       | 244/1034 [06:53<32:29,  2.47s/it]

Failed to generate SQL query for question: Find all airlines that have fewer than 200 flights. on database: flight_2.
(sqlite3.OperationalError) ambiguous column name: Airline
[SQL: SELECT Airline, COUNT(*) as num_flights
FROM flights
JOIN airlines ON flights.Airline = airlines.uid
GROUP BY Airline
HAVING num_flights < 200
ORDER BY num_flights ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  24%|██▎       | 245/1034 [06:55<31:57,  2.43s/it]

Failed to generate SQL query for question: Which airlines have less than 200 flights? on database: flight_2.
(sqlite3.OperationalError) ambiguous column name: Airline
[SQL: SELECT Airline, COUNT(*) as num_flights
FROM flights
JOIN airlines ON flights.Airline = airlines.uid
GROUP BY Airline
HAVING num_flights < 200
ORDER BY num_flights ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  44%|████▍     | 456/1034 [11:25<12:02,  1.25s/it]

Failed to generate SQL query for question: List the first and last name of all players in the order of birth date. on database: wta_1.
(sqlite3.OperationalError) Could not decode to UTF-8 column 'last_name' with text 'Treyes Albarrac��N'
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  44%|████▍     | 457/1034 [11:26<11:32,  1.20s/it]

Failed to generate SQL query for question: What are the full names of all players, sorted by birth date? on database: wta_1.
(sqlite3.OperationalError) Could not decode to UTF-8 column 'full_name' with text 'Joselyn Margarita Treyes Albarrac��N'
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  50%|█████     | 517/1034 [12:54<13:14,  1.54s/it]

Failed to generate SQL query for question: What is the name and id of the department with the most number of degrees ? on database: student_transcripts_tracking.
(sqlite3.OperationalError) ambiguous column name: department_id
[SQL: SELECT department_id, department_name, COUNT(degree_program_id) AS num_degrees
FROM Degree_Programs
JOIN Departments ON Degree_Programs.department_id = Departments.department_id
GROUP BY department_id
ORDER BY num_degrees DESC
LIMIT 1;]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  52%|█████▏    | 537/1034 [13:25<18:00,  2.17s/it]

Failed to generate SQL query for question: What are the first, middle, and last names for everybody enrolled in a Bachelors program? on database: student_transcripts_tracking.
(sqlite3.OperationalError) no such column: degree_program_id
[SQL: SELECT first_name, middle_name, last_name 
FROM Students 
WHERE degree_program_id IN 
    (SELECT degree_program_id 
     FROM Degree_Programs 
     WHERE degree_summary_name LIKE '%Bachelor%')
ORDER BY last_name ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  56%|█████▌    | 575/1034 [14:36<12:20,  1.61s/it]

Failed to generate SQL query for question: What is the date and id of the transcript with the least number of results? on database: student_transcripts_tracking.
(sqlite3.OperationalError) ambiguous column name: transcript_id
[SQL: SELECT transcript_id, transcript_date, COUNT(*) as num_results
FROM Transcript_Contents
JOIN Transcripts ON Transcript_Contents.transcript_id = Transcripts.transcript_id
GROUP BY Transcript_Contents.transcript_id
ORDER BY num_results ASC
LIMIT 1;]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  85%|████████▌ | 881/1034 [21:32<03:17,  1.29s/it]

Failed to generate SQL query for question: Show me all grades that have at least 4 students. on database: network_1.
(sqlite3.OperationalError) no such column: student_id
[SQL: SELECT grade, COUNT(student_id) AS num_students
FROM Highschooler
GROUP BY grade
HAVING num_students >= 4
ORDER BY grade ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  87%|████████▋ | 899/1034 [22:04<03:57,  1.76s/it]

Failed to generate SQL query for question: Show the ids of high schoolers who have friends and are also liked by someone else. on database: network_1.
(sqlite3.OperationalError) near "Answer": syntax error
[SQL: SELECT DISTINCT f.student_id
FROM Friend f
JOIN Likes l ON f.student_id = l.liked_id
Answer: The query returns the IDs of high schoolers who have friends and are also liked by someone else.]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql:  87%|████████▋ | 900/1034 [22:06<04:20,  1.94s/it]

Failed to generate SQL query for question: What are the ids of students who both have friends and are liked? on database: network_1.
(sqlite3.OperationalError) 1st ORDER BY term does not match any column in the result set
[SQL: SELECT DISTINCT f.student_id 
FROM Friend f 
JOIN Likes l ON f.student_id = l.student_id 
INTERSECT 
SELECT DISTINCT f.friend_id 
FROM Friend f 
JOIN Likes l ON f.friend_id = l.liked_id
ORDER BY student_id ASC]
(Background on this error at: https://sqlalche.me/e/14/e3q8)


Generating /content/drive/Shareddrives/CIS 522/GPT3.5_pred_no_col_name/dev_pred.sql: 100%|██████████| 1034/1034 [25:26<00:00,  1.48s/it]


Evaluation is done in other notebooks.