In [None]:
!pip install transformers

In [None]:
concert_singer_schema = '| concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer_in_concert : concert_id, singer_id'
sales_data_schema = '| sales_data | sales: transaction_id, volume, product_id, product_price_at_the_moment_of_transaction_per_unit, datetime, customer_id | customers: customer_id, name, location | products: id, name, current_price, amount_in_stock'
warehousing_schema = '| warehousing | warehouses : warehouse_id, location, max_storage_volume, has_temperature_control | packages : id, warehouse_id, package_volume, requires_temperature_control, stored_until_date, owner_id | owners : id, name'

from enum import Enum

class query_difficulty(Enum):
    HARD = 'hard'
    MEDIUM = 'medium'
    EASY = 'easy'

single_question_queries = [
    (f'How many singers do we have?', concert_singer_schema, query_difficulty.EASY),
    (f'What is the stadium with the highest capacity?', concert_singer_schema, query_difficulty.EASY),
    (f'Which stadium can hold the most people?', concert_singer_schema, query_difficulty.EASY),
    (f'What was the place of the latest concert?', concert_singer_schema, query_difficulty.MEDIUM),
    (f'Where was the latest concert of Taylor Swift?', concert_singer_schema, query_difficulty.MEDIUM),
    (f'How much money did we made in total?', sales_data_schema, query_difficulty.EASY),
    (f'What amount of "Lego set Technic 0254" we sold in total?', sales_data_schema, query_difficulty.MEDIUM),
    (f'Which client paid us the highest amount of money?', sales_data_schema, query_difficulty.MEDIUM),
    (f'What was the most expensive thing we ever sold', sales_data_schema, query_difficulty.MEDIUM),
    (f'What was the most expensive thing we sold in the year 2020?', sales_data_schema, query_difficulty.MEDIUM),
    (f'What is the most spacious warehouse?', warehousing_schema, query_difficulty.EASY),
    (f'What is the total capacity of our warehouses?', warehousing_schema, query_difficulty.EASY),
    (f'Which packages require temperature control?', warehousing_schema, query_difficulty.EASY),
    (f'Which packages require thermostats?', warehousing_schema, query_difficulty.EASY),
    (f'Which packages are stored in the warehouses equipped with temperature control?', warehousing_schema, query_difficulty.MEDIUM),
    (f'Which packages are stored in the warehouses with thermostats?', warehousing_schema, query_difficulty.MEDIUM),
    (f'Who owns the most packages stored in our warehouses?', warehousing_schema, query_difficulty.MEDIUM),
    (f'Who owns the most packages?', warehousing_schema, query_difficulty.MEDIUM),
    (f'Which packages are stored in violation of temperature conditions?', warehousing_schema, query_difficulty.MEDIUM),
]

In [None]:
from transformers import pipeline

In [None]:
pipe = pipeline(model="tscholak/cxmefzzi")

In [None]:
from tqdm import tqdm
drive_root = '/content/drive/MyDrive/Colab/vp'


In [None]:
with open(f'{drive_root}/text_to_sql_report.md', 'w') as report_file:
    report_file.write(f"## Model: T5-3B, finetuned on spider\n\n")
    results = []
    for question, schema, difficulty in tqdm(single_question_queries):
        result = pipe(f"{question} {schema}")[0]['generated_text']
        results.append(result)

        report_file.write(f"Question [{str(difficulty.value)}]: {question}\n\n")
        report_file.write(f"Schema: `{schema}`\n\n")
        report_file.write(f"Response: `{result}`\n\n")
        report_file.write(f"Is correct?: \n\n")
        report_file.write(f"---\n\n")
        print(question, result)

In [None]:
TODO: benchmark gpt

In [None]:
!pip install openai

In [None]:
import openai
from time import sleep

def query_openai(query: str, req_type: str, max_try: int = 3, **kwargs):
    for _ in range(max_try):
        try:
            if req_type == 'search':
                response = openai.Embedding.create(input=query, **kwargs)
                return response['data'][0]['embedding'], response["usage"]["total_tokens"]

            if req_type == 'completion':
                response = openai.Completion.create(prompt=query, **kwargs)
                return response["choices"][0]["text"], response["usage"]["total_tokens"]

            if req_type == "edit":
                response = openai.Edit.create(input=query, **kwargs)
                return response["choices"][0]["text"], response["usage"]["total_tokens"]

            raise ValueError(f'Invalid request type: {req_type}')

        except Exception as e:
            print("Error in request: ", e)
            sleep(3)

    return None, None


In [None]:
from getpass import getpass
openai.api_key = getpass(prompt="OpenAI API key: ")

In [None]:
with open(f'{drive_root}/text_to_sql_GPT_report.md', 'w') as report_file:
    report_file.write(f"## Model: code-davinchi-002\n\n")
    results = []
    for question, schema, difficulty in tqdm(single_question_queries):
        stripped_schema = ("|".join(schema.split("|")[2:])).lstrip()
        
        prompt = f'''We have a database with the following schema:

{stripped_schema}

As a senior analyst, given the above schema, write a detailed and correct SQLite sql query to answer the analytical question:

"{question}"

```SQL'''

        response, tokens = query_openai(
            prompt, "completion", max_try=2, model="code-davinci-002",
            temperature=0.0, stop="```", max_tokens=256
        )

        # print(response)
        # break
        result = response.strip('\n')
        results.append(result)

        report_file.write(f"Question [{str(difficulty.value)}]: {question}\n\n")
        report_file.write(f"Schema: `{schema}`\n\n")
        report_file.write(f"Response: `{result}`\n\n")
        report_file.write(f"Is correct?: \n\n")
        report_file.write(f"---\n\n")
        print(question, result)

In [None]:
# GPT results temp
# ['SELECT COUNT(*) FROM singer;',
#  'SELECT name, capacity\nFROM stadium\nWHERE capacity = (SELECT MAX(capacity) FROM stadium)',
#  'SELECT name, capacity\nFROM stadium\nORDER BY capacity DESC\nLIMIT 1',
#  'SELECT stadium.name\nFROM stadium\nINNER JOIN concert ON stadium.stadium_id = concert.stadium_id\nWHERE concert.year = (SELECT MAX(year) FROM concert)',
#  "SELECT s.name, s.location, c.year\nFROM stadium s\nJOIN concert c ON s.stadium_id = c.stadium_id\nJOIN singer_in_concert sc ON c.concert_id = sc.concert_id\nJOIN singer si ON sc.singer_id = si.singer_id\nWHERE si.name = 'Taylor Swift'\nORDER BY c.year DESC\nLIMIT 1",
#  'SELECT SUM(volume * product_price_at_the_moment_of_transaction_per_unit) AS total_revenue\nFROM sales;',
#  'SELECT SUM(volume)\nFROM sales\nWHERE product_id = (SELECT id FROM products WHERE name = "Lego set Technic 0254")',
#  'SELECT customers.name, SUM(sales.volume * sales.product_price_at_the_moment_of_transaction_per_unit) AS total_amount\nFROM sales\nJOIN customers ON customers.customer_id = sales.customer_id\nGROUP BY customers.name\nORDER BY total_amount DESC\nLIMIT 1;',
#  'SELECT\n  product_id,\n  MAX(product_price_at_the_moment_of_transaction_per_unit) AS max_price\nFROM sales\nGROUP BY product_id\nORDER BY max_price DESC\nLIMIT 1;',
#  "SELECT product_id, MAX(product_price_at_the_moment_of_transaction_per_unit)\nFROM sales\nWHERE datetime BETWEEN '2020-01-01' AND '2020-12-31'\nGROUP BY product_id",
#  'SELECT warehouses.warehouse_id, warehouses.location, warehouses.max_storage_volume, warehouses.has_temperature_control, SUM(packages.package_volume) AS total_volume\nFROM warehouses\nLEFT JOIN packages ON warehouses.warehouse_id = packages.warehouse_id\nGROUP BY warehouses.warehouse_id\nORDER BY total_volume DESC\nLIMIT 1',
#  'SELECT SUM(max_storage_volume) FROM warehouses;',
#  'SELECT packages.id, packages.warehouse_id, packages.package_volume, packages.requires_temperature_control, packages.stored_until_date, packages.owner_id\nFROM packages\nINNER JOIN warehouses ON packages.warehouse_id = warehouses.warehouse_id\nWHERE packages.requires_temperature_control = 1 AND warehouses.has_temperature_control = 1',
#  'SELECT * FROM packages WHERE requires_temperature_control = 1;',
#  'SELECT packages.id, packages.warehouse_id, packages.package_volume, packages.requires_temperature_control, packages.stored_until_date, packages.owner_id\nFROM packages\nINNER JOIN warehouses ON packages.warehouse_id = warehouses.warehouse_id\nWHERE warehouses.has_temperature_control = 1',
#  'SELECT packages.id, packages.warehouse_id, packages.package_volume, packages.requires_temperature_control, packages.stored_until_date, packages.owner_id\nFROM packages\nINNER JOIN warehouses ON packages.warehouse_id = warehouses.warehouse_id\nWHERE warehouses.has_temperature_control = 1',
#  'SELECT owners.name, COUNT(packages.id) AS package_count\nFROM packages\nJOIN owners ON packages.owner_id = owners.id\nGROUP BY owners.name\nORDER BY package_count DESC\nLIMIT 1;']