In [16]:
import lamini
import logging
import random
from typing import AsyncIterator, Iterator, Union
import sqlite3
import copy
from tqdm import tqdm

import pandas as pd
import jsonlines
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_node import GenerationNode
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_pipeline import GenerationPipeline
from util.get_schema import get_schema, get_schema_s
from util.make_llama_3_prompt import make_llama_3_prompt

logger = logging.getLogger(__name__)
engine = sqlite3.connect("./nba_roster.db")


class Args:
    def __init__(self,
                 max_examples=100,
                 sql_model_name="meta-llama/Meta-Llama-3-8B-Instruct",
                 gold_file_name="gold-test-set.jsonl",
                 training_file_name="generated_queries.jsonl",
                 num_to_generate=10):
        self.sql_model_name = sql_model_name
        self.max_examples = max_examples
        self.gold_file_name = gold_file_name
        self.training_file_name = training_file_name
        self.num_to_generate = num_to_generate

In [17]:
system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
system += (
    "Consider a table called 'nba_roster' with the following schema (columns)\n"
)
system += get_schema_s()
system += "Consider the following questions, and queries used to answer them:\n"


In [18]:
system

'You are an NBA analyst with 15 years of experience writing complex SQL queries.\nConsider a table called \'nba_roster\' with the following schema (columns)\n0|Team|TEXT eg. "Toronto Raptors"\n1|NAME|TEXT eg. "Otto Porter Jr."\n2|Jersey|TEXT eg. "0" and when null has a value "NA"\n3|POS|TEXT eg. "PF"\n4|AGE|INT eg. "22" in years\n5|HT|TEXT eg. `6\' 7"` or `6\' 10"` castable to int\n6|WT|TEXT eg. "232 lbs" \n7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"\n8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"\nConsider the following questions, and queries used to answer them:\n'

In [19]:
question = """What is the median weight in the NBA?"""
sql = "select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"

system += "Question: " + question + "\n"
system += "Query: " + sql + "\n"


In [20]:
print(system)

You are an NBA analyst with 15 years of experience writing complex SQL queries.
Consider a table called 'nba_roster' with the following schema (columns)
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"` castable to int
6|WT|TEXT eg. "232 lbs" 
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"
Consider the following questions, and queries used to answer them:
Question: What is the median weight in the NBA?
Query: select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;



In [21]:
user = "Write two queries that are similar but different to those above.\n"
user += "Format the queries as a JSON object, i.e.\n"
user += '{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.\n'


In [22]:
print(user)

Write two queries that are similar but different to those above.
Format the queries as a JSON object, i.e.
{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.



In [23]:
user += "First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;\n"

In [24]:
print(user)

Write two queries that are similar but different to those above.
Format the queries as a JSON object, i.e.
{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.
First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;



In [25]:
prompt = make_llama_3_prompt(user, system)

In [None]:
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()

True

In [27]:
llm = lamini.Lamini(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
result = llm.generate(prompt, output_type={ "explanation": "str", "sql_query_1" : "str", "sql_query_2": "str" }, max_new_tokens=200)
print(result)

{'explanation': 'I decided to write these new queries to provide more insights into the NBA data. The first query calculates the average height of players in the NBA, while the second query finds the team with the highest average salary. These queries are similar to the original query in that they involve extracting and manipulating', 'sql_query_1': "SELECT AVG(CAST(SUBSTR(HT, 1, INSTR(HT,'')-1) AS INTEGER)) AS average_height FROM nba_roster", 'sql_query_2': "SELECT Team, AVG(CAST(SUBSTR(SALARY, 2, LENGTH(SALARY)-2) AS INTEGER)) AS average_salary FROM nba_roster WHERE SALARY!= '--' GROUP BY Team ORDER BY average_salary DESC LIMIT 1"}


In [28]:
def check_sql_query(query):
    try:
        pd.read_sql(query, con=engine)
    except Exception as e:
        logger.debug(f"Error in SQL query: {e}")
        return False

    logger.info(f"SQL query {query} is valid")

    return True


In [29]:
check_sql_query(result["sql_query_1"])

True

In [30]:
check_sql_query(result["sql_query_2"])

True

In [31]:
# Wrap it all up together in a class

class ModelStage(GenerationNode):
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=300,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        prompt = self.add_template(prompt)

        results = super().generate(
            prompt,
            output_type={
                "explanation": "str",
                "sql_query_1": "str",
                "sql_query_2": "str",
            },
            *args,
            **kwargs,
        )

        return results

    async def add_template(self, prompts):
        async for prompt in prompts:
            new_prompt = make_llama_3_prompt(**self.make_prompt(prompt.data))
            yield PromptObject(prompt=new_prompt, data=prompt.data)

    async def process_results(self, results):
        async for result in results:
            if result is None:
                continue

            if result.response is None:
                continue

            logger.info("=====================================")
            logger.info(f"Generated query 1: {result.response['sql_query_1']}")
            logger.info(f"Generated query 2: {result.response['sql_query_2']}")
            logger.info("=====================================")

            if self.check_sql_query(result.response["sql_query_1"]):
                new_result = PromptObject(prompt="", data=copy.deepcopy(result.data))
                new_result.data.generated_sql_query = result.response["sql_query_1"]
                yield new_result

            if self.check_sql_query(result.response["sql_query_2"]):
                new_result = PromptObject(prompt="", data=copy.deepcopy(result.data))
                new_result.data.generated_sql_query = result.response["sql_query_2"]
                yield new_result

    def make_prompt(self, data):
        system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
        system += (
            "Consider a table called 'nba_roster' with the following schema (columns)\n"
        )
        system += get_schema()
        system += "Consider the following questions, and queries used to answer them:\n"
        for example in data.sample:
            system += "Question: " + example["question"] + "\n"
            system += "Query: " + example["sql"] + "\n"

        # Important: generate relevant queries to your reference data
        # Ideally, close to those that are failing so you can show the model examples of how to do it right!
        user = "Write two queries that are similar but different to those above.\n"
        user += "Format the queries as a JSON object, i.e.\n"
        user += '{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.\n'

        # Next, use Chain of Thought (CoT) and prompt-engineering to help with generating SQL queries
        user += "First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;\n"

        return {"system": system, "user": user}

    def check_sql_query(self, query):
        try:
            pd.read_sql(query, con=engine)
        except Exception as e:
            logger.debug(f"Error in SQL query: {e}")
            return False

        logger.info(f"SQL query {query} is valid")

        return True

In [32]:
system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
system += (
    "Consider a table called 'nba_roster' with the following schema (columns)\n"
)
system += get_schema() + "\n"
system += "Queries, and questions that they are used to answer:\n"

example_question = """What is the median weight in the NBA?"""
example_sql = "select CAST(SUBSTR(WT, 1, INSTR(WT,' ')) as INTEGER) as percentile from nba_roster order by percentile limit 1 offset (select count(*) from nba_roster)/2;"

system += "Question: " + example_question + "\n"
system += "Query: " + example_sql + "\n"

In [33]:
generated_sql = result["sql_query_2"]

In [34]:
user = "Now consider the following query.\n"
user += "Query: " + generated_sql + "\n"
user += "Write a question that this query could be used to answer.\n"


In [35]:
user += "Format your response as a JSON object, i.e.\n"
user += '{ "explanation": str, "question": str }.\n'

user += "First write an explanation in about 3-5 sentences, then write a one sentence question.\n"


In [36]:
prompt = make_llama_3_prompt(user, system)
result = llm.generate(prompt, output_type={ "explanation": "str", "question" : "str" }, max_new_tokens=200)
print(result)

{'explanation': 'This query calculates the average salary for each team in the NBA, excluding teams with unknown salaries. It does this by first removing the dollar sign and any leading or trailing characters from the salary string, then converting the remaining characters to an integer. The results are then grouped by team and ordered in descending order by average salary, with the team having the highest average salary returned first. The LIMIT 1 clause ensures that only the top team is returned', 'question': 'Which team has the highest average salary among all teams in the NBA'}


In [37]:
# Wrap it all up together in a class which generates a question
# given a query

class QuestionStage(GenerationNode):
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=150,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        results = super().generate(
            prompt,
            output_type={
                "explanation": "str",
                "question": "str",
            },
            *args,
            **kwargs,
        )
        return results

    def preprocess(self, obj: PromptObject):
        new_prompt = make_llama_3_prompt(**self.make_question_prompt(obj.data))
        obj.prompt = new_prompt

    def make_question_prompt(self, data):
        system = "You are an NBA analyst with 15 years of experience writing complex SQL queries.\n"
        system += (
            "Consider a table called 'nba_roster' with the following schema (columns)\n"
        )
        system += get_schema() + "\n"
        system += "Queries, and questions that they are used to answer:\n"
        for example in data.sample:
            system += "Query: " + example["sql"] + "\n"
            system += "Question: " + example["question"] + "\n"

        user = "Now consider the following query.\n"
        user += "Query: " + data.generated_sql_query + "\n"
        user += "Write a question that this query could be used to answer.\n"

        # Using Chain of Thought (CoT) again
        # This time you can do it programmatically with function calling, so you can easily extract a question out of the JSON object
        user += "Format your response as a JSON object, i.e.\n"
        user += '{ "explanation": str, "question": str }.\n'

        user += "First write an explanation in about 3-5 sentences, then write a one sentence question.\n"

        return {"system": system, "user": user}


In [38]:
class QueryGenPipeline(GenerationPipeline):
    def __init__(self):
        super().__init__()
        self.model_stage = ModelStage()
        self.question_stage = QuestionStage()

    def forward(self, x):
        x = self.model_stage(x)
        x = self.question_stage(x)
        return x

In [39]:
async def run_query_gen_pipeline(gold_queries):
    return QueryGenPipeline().call(gold_queries)

In [40]:
# Generate N samples, for every example in the gold dataset

all_examples = []

async def load_gold_queries(args):
    path = f"data/{args.gold_file_name}"

    with jsonlines.open(path) as reader:
        global all_examples

        all_examples = [obj for obj in reader]

    sample_count = args.num_to_generate
    sample_size = 3

    random.seed(42)

    for i in range(sample_count):
        example_sample = ExampleSample(random.sample(all_examples, sample_size), i)
        yield PromptObject(prompt="", data=example_sample)


class ExampleSample:
    def __init__(self, sample, index):
        self.sample = sample
        self.index = index

In [43]:
async def save_generation_results(results, args):
    path = f"data/{args.training_file_name}"

    pbar = tqdm(desc="Saving results", unit=" results")
    with jsonlines.open(path, "w") as writer:

        async for result in results:
            writer.write(
                {
                    "question": result.response["question"],
                    "sql": result.data.generated_sql_query,
                }
            )
            pbar.update()

        for example in all_examples:
            writer.write(example)
            pbar.update()

In [44]:
args = Args()
gold_queries = load_gold_queries(args)
results = await run_query_gen_pipeline(gold_queries)
await save_generation_results(results, args)



Exception: Must use Python 3.10 or greater for this feature