# Advanced Prompting for Text-to-SQL: DIN-SQL
Use of advanced prompting techniques to convert a natural language question to SQL

---
---

## Suggested SageMaker Environment
Sagemaker Image: sagemaker-distribution-cpu

Kernel: Python 3

Instance Type: ml.m5.large

---

## Contents

1. [Install Dependencies](#step-1-install-dependencies)
1. [Set up Athena Connection](#step-2-set-up-connection-to-the-tpc-ds-data-set-in-athena)
1. [Schema Linking](#step-3-determine-schema-links)
1. [Classify Query Complexity](#step-4-classify-sql-complexity)
1. [Generate SQL Query](#step-5-generate-sql-query)
1. [Execute SQL Query](#step-6-execute-query)
1. [Validate Results](#step-7-validate-results)
1. [Self-Correction](#step-8-self-correction)
1. [Experiment](#step-9-experiment)
1. [Citation](#citation)

---

## Objective
This notebook will provide code snippets to assist with implementing one approach to converting a natural language question into a SQL query that would answer it.

---

## The Approach to the Text-to-SQL Problem
We'll implement the DIN-SQL prompting strategy to break a question down into smaller parts, get an understanding of the query complexity, and ultimately create a valid SQL statement. As shown below, this process consists of four main prompting steps:

1. Schema Linking
2. Classification and decomposition
3. SQL code generation
4. Self-correction

For a deeper dive into the methodology and findings about this approach, please read the full paper here: https://arxiv.org/pdf/2304.11015.pdf

![Alt text](content/din_sql_methodology.png)

### Tools
SQLAlchemy, Anthropic, Amazon Bedrock SDK (Boto3), PyAthena, Jinja2

---

### Step 1: Install Dependencies

Here we will install all the required dependencies to run this notebook. **You can ignore the following errors** that may arise due to dependency conflicts for libraries we won't be using in this module:
```
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dash 2.14.1 requires dash-core-components==2.0.0, which is not installed.
dash 2.14.1 requires dash-html-components==2.0.0, which is not installed.
dash 2.14.1 requires dash-table==5.0.0, which is not installed.
jupyter-ai 2.5.0 requires faiss-cpu, which is not installed.
amazon-sagemaker-jupyter-scheduler 3.0.4 requires pydantic==1.*, but you have pydantic 2.6.0 which is incompatible.
gluonts 0.13.7 requires pydantic~=1.7, but you have pydantic 2.6.0 which is incompatible.
jupyter-ai 2.5.0 requires pydantic~=1.0, but you have pydantic 2.6.0 which is incompatible.
jupyter-ai-magics 2.5.0 requires pydantic~=1.0, but you have pydantic 2.6.0 which is incompatible.
jupyter-scheduler 2.3.0 requires pydantic~=1.10, but you have pydantic 2.6.0 which is incompatible.
sparkmagic 0.21.0 requires pandas<2.0.0,>=0.17.1, but you have pandas 2.1.2 which is incompatible.
tensorflow 2.12.1 requires typing-extensions<4.6.0,>=3.6.6, but you have typing-extensions 4.9.0 which is incompatible.
```


In [None]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3~=1.34"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install "faiss-cpu" --quiet

Import the `din_sql` library to assist with using the prompts written in the paper. Note that we've leveraged Jinja for our prompt templating.

In [None]:
import boto3
import sys

sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

### Step 2: Set up connection to the TPC-DS data set in Athena

Initialize the following variables with details relative to your account, and how you setup the Athena data source connector for the TPC-DS dataset. You can find in these in the CloudFormation outputs.

In [None]:
ATHENA_RESULTS_S3_LOCATION = "<workshop bucket name>" # available in cloudformation outputs
ATHENA_CATALOG_NAME = "<athena catalog name>" # available in cloudformation outputs
DB_NAME = "tpcds1"

instantiate the `din_sql` class with the bedrock model of your choice. In this module, the prompts are tailored specifically to work well with ClaudeV2, so we'll be using that.

In [None]:
din_sql = dsl.DIN_SQL(bedrock_model_id='anthropic.claude-v2')

Create a connection to Athena using the information entered above. We'll use this connection to test our generated SQL. Its also used to augment prompts in DIN-SQL.

In [None]:
din_sql.athena_connect(catalog_name=ATHENA_CATALOG_NAME, 
               db_name=DB_NAME, 
               s3_prefix=ATHENA_RESULTS_S3_LOCATION)

### Step 3: Determine Schema Links 

The first step in the DIN-SQL process is to find out which foreign key relationships are required in order to answer the question. Let's take a look at how the prompt for this task is designed.

In [None]:
!head ../libs/din_sql/prompt_templates/schema_linking_prompt.txt.jinja

In [None]:
return_sql= din_sql.find_fields(db_name=DB_NAME)
print(return_sql)

If you take a look at the prompt template, you can see we're using some [Anthropic Prompting best practices](https://docs.anthropic.com/claude/docs/introduction-to-prompt-design) to improve results when working with Claude:
1. [Mark different parts of the prompt](https://docs.anthropic.com/claude/docs/constructing-a-prompt#mark-different-parts-of-the-prompt) using XML tags. In our example, we use xml tags and ```sql to organize our output
2. [We use many examples](https://docs.anthropic.com/claude/docs/constructing-a-prompt#examples-optional) This prompt technique uses a many-shot method by offering Claude a lot of examples.
3. [We ask Claude to think step-by-step](https://docs.anthropic.com/claude/docs/ask-claude-to-think-step-by-step)
4. We use [Roleplay Dialogue](https://docs.anthropic.com/claude/docs/roleplay-dialogue) to help Claude act the part of a relational database expert.

Lets see how our prompt will look by passing in a question and database name to the `schema_linking_prompt_maker` method. Note the use of tags.

In [None]:
question = "Which customer spent the most money in the web store?"

schema_links_prompt = din_sql.schema_linking_prompt_maker(question, DB_NAME)
print(schema_links_prompt)

Now that we have our schema link prompt, lets see what Claude comes up with for us. 

In [None]:
schema_links = din_sql.llm_generation(
                    schema_links_prompt,
                    stop_sequences=['</example>']
                    )
print(schema_links)

As you can see, Claude reasoned its way through identifying the foreign key relationships between tables. This is because we gave it a list of tables and their columns for Claude to inspect. Lets use those `<link>` tags to clean up our response, and store this list for our next step in the DIN-SQL method.

In [None]:
links = schema_links.split('<links>')[1].split('</links>')[0].replace('\n','')
links

### Step 4: Classify SQL Complexity

The next step in the process is to classify the complexity of the SQL that will be required to answer the question. Lets take a look at the prompt

In [None]:
!head ../libs/din_sql/prompt_templates/classification_prompt.txt.jinja

Here we're giving Claude a decision making framework for determining if the class of the query required to answer the question. This is done by offering simple if/then logic.

Feel free to take a closer look at how this prompt uses examples of each class to teach Claude how to make decisions. Once complete, go ahead and send your prompt to Claude to classify the complexity of this query.

In [None]:
classification = din_sql.llm_generation(
    din_sql.classification_prompt_maker(question, DB_NAME, links)
    )
print(classification)

You can see that Claude is taking advantage of the room we gave it think about the decision. Let's parse the result using the `<label>` tag and move onto SQL code generation.

In [None]:
predicted_class = classification.split("<label>")[1].split("</label>")[0].strip()

### Step 5: Generate SQL Query

With our question in hand, complexity of the requisite query classified, and our schema links identified, we are now ready to generate our SQL statement. Before we do that, lets look at the prompt. Since 'NON-NESTED' classes use the 'medium_prompt' template, we'll take a look at that.

In [None]:
!head ../libs/din_sql/prompt_templates/medium_prompt.txt.jinja

Because these types of SQL queries require a join, these prompts give many examples using a join so Claude understands how to use one. Let's send our prompt to Claude to see what it generates. Note we are levaraging our example end tag, that we used in our prompt, to act as our stop sequence so Claude will stop generating a response if its following the format we've instructed it to follow.

In [None]:
sql_qry = din_sql.llm_generation(
                    prompt=din_sql.medium_prompt_maker(
                        test_sample_text=question, 
                        database=DB_NAME, 
                        schema_links=links,
                        sql_tag_start='```sql',
                        sql_tag_end='```'),
                    stop_sequences=['</example>'])
print(f"{sql_qry}")

You can see now how Claude is following the instructions, thinking step-by-step, and enclosing the SQL statement in our chosen tags. Lets parse out the last query, as it will always be the last in the chain of thought process that should be the most accurate, per our instructions.

In [None]:
SQL = sql_qry.split('```sql')[-1].split('```')[0]
print(f"{SQL}")

### Step 6: Execute Query

Let's test our query to see if the results match what we would expect to see, and if it actually answers our question. We'll do this by returning our SQL Alchemy result set and using a Pandas Data Frame to interact with it.

In [None]:
import pandas as pd
result_set = din_sql.query(SQL)
pd.DataFrame(result_set)

### Step 7: Validate Results
Let's make sure this answer is correct by submitting a query we know will list the top 10 customers by web sales. 

In [None]:
validation_query = """
    SELECT "c"."c_customer_sk"
    , "c"."c_first_name"
    , "c"."c_last_name"
    , SUM("ws"."ws_ext_list_price") as total_sales
    FROM "customer" "c" 
    JOIN "web_sales" "ws" 
        ON "ws"."ws_bill_customer_sk" = "c"."c_customer_sk"   
    GROUP BY "c"."c_customer_sk"
    , "c"."c_first_name"
    , "c"."c_last_name"
    ORDER BY total_sales desc
    limit 10
"""
validation_set = din_sql.query(validation_query)
pd.DataFrame(validation_set)

Do you see the same Customer SK at the top of the list? Which fields did the generated query use, versus the corret one?
If the query threw an error, feel free to move onto Self Correction where we'll let the LLM correct the query.

### Step 8: Self Correction

This is the last step in the process. We make one last check of our SQL code by asking Claude to fix anything that is wrong with the code for the given SQL dialect. Let's take a look at those instructions now to see how that is done. 

In [None]:
!head ../libs/din_sql/prompt_templates/clean_query_prompt.txt.jinja

Now, we'll use this template to create a prompt for our query using the "presto" syntax, which is what Athena uses for querying its underlying data sources.

In [None]:
revised_sql = din_sql.debugger_generation(
            prompt=din_sql.debugger(question, DB_NAME, SQL, sql_dialect='presto')
            ).replace("\n", " ")
print(f"{revised_sql}")

With our revised SQL returned, lets parse it out of the response using our code tags

In [None]:
SQL = revised_sql.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")

### Step 9: Experiment

Are the results what you expected? If not, how can you improve the prompting to generalize better?

Below is another run through of the process end-to-end for a different question. 

In [None]:
question = 'What year had the highest catalog sales?'

#get schema links
schema_links_prompt = din_sql.schema_linking_prompt_maker(question, DB_NAME)
schema_links = din_sql.llm_generation(
                    schema_links_prompt
                    )
print(schema_links)
links = schema_links.split('<links>')[1].split('</links>')[0].replace('\n','')

# classify and decompose
classification = din_sql.llm_generation(
    din_sql.classification_prompt_maker(question, DB_NAME, links)
    )
print(classification)
predicted_class = classification.split("<label>")[1].split("</label>")[0].strip()

# generate SQL
sql_qry = din_sql.llm_generation(
                    prompt=din_sql.medium_prompt_maker(
                        test_sample_text=question, 
                        database=DB_NAME, 
                        schema_links=links,
                        sql_tag_start='```sql',
                        sql_tag_end='```'),
                    stop_sequences=['</example>'])
print(f"{sql_qry}")
SQL = revised_sql.split('```sql')[1].split('```')[0].strip()

In [None]:
pd.DataFrame(din_sql.query(SQL))

If the query threw an error, try self-correction once more.

In [None]:
#self correction
revised_sql = din_sql.debugger_generation(
            prompt=din_sql.debugger(question, DB_NAME, SQL, sql_dialect='presto')
            ).replace("\n", " ")
print(f"{revised_sql}")
SQL = revised_sql.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")

# see results
result_set = pd.DataFrame(din_sql.query(SQL))
result_set

### Citation
```
@article{pourreza2023din,
  title={DIN-SQL: Decomposed In-Context Learning of Text-to-SQL with Self-Correction},
  author={Pourreza, Mohammadreza and Rafiei, Davood},
  journal={arXiv preprint arXiv:2304.11015},
  year={2023}
}
Paper: https://arxiv.org/abs/2304.11015
Code: https://github.com/MohammadrezaPourreza/Few-shot-NL2SQL-with-prompting
```