# Preventing SQL Injection

Utilize Prompt Engineering to prevent unsafe SQL statements from generation or execution.

---
---

## Suggested SageMaker Environment
Sagemaker Image: sagemaker-distribution-cpu

Kernel: Python 3

Instance Type: ml.m5.large

---

## Contents

1. [Install Dependencies](#step-1-install-dependencies)
1. [Configure Athena and Bedrock Client](#step-2-configure-athena-and-bedrock-client)
1. [Prevent SQL Injection with DELETE request](#step-3-Prevent-SQL-Injection-with-DELETE-request)
1. [Prevent SQL Injection with UPDATE request](#step-4-Prevent-SQL-Injection-with-UPDATE-request)
1. [Apply Read Restrictions](#step-5-Apply-Read-Restrictions)

---

## Objective
This notebook will provide code snippets to apply Prompt Engineering to prevent unsafe SQL statements like INSERT, UPDATE, and DELETE from generation and execution. 

---

## The Approach to the Text-to-SQL Security

SQL Injection is an attack that targets the database layer of an application by injecting rogue SQL statements through user inputs or application parameters. These attacks can lead to unauthorized exposure, modification, or deletion of sensitive data stored in the database.

GenAI Security places a strong emphasis on identifying and preventing unsafe SQL statements, particularly those that involve potentially harmful actions such as 'insert', 'update' or 'delete.' The focus here is to create a robust framework by Prompt Engineering that actively identifies and restricts the execution of SQL statements that could compromise data integrity or privacy.

For a deeper dive into the challenges and approaches to prevent from Prompt Injection to SQL Injection attacks in text-to-SQL use cases, please read this paper: 

https://arxiv.org/pdf/2308.01990.pdf

# Step 1: Install Dependencies

Here we will install all the required dependencies to run this notebook. 

In [None]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3~=1.34"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install langchain --quiet

In [None]:
import json
import boto3
from botocore.config import Config
import sys
from langchain.prompts import PromptTemplate

sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

# Step 2: Configure Athena and Bedrock Client

### Replace those variables with your set up .

In [None]:

ATHENA_RESULTS_S3_LOCATION = "<workshop bucket name>" # available in cloudformation outputs
ATHENA_CATALOG_NAME = "<athena catalog name>" # available in cloudformation outputs
DB_NAME = "tpcds1"

In [None]:
retry_config = Config(retries = {'max_attempts': 100})
bedrock_region = athena_region = boto3.session.Session().region_name
session = boto3.Session(region_name=bedrock_region)
bedrock = session.client('bedrock-runtime', region_name=bedrock_region, config=retry_config)
accept = 'application/json'
content_type = 'application/json'

In [None]:
model_id='anthropic.claude-v2'

from libs.din_sql import din_sql_lib as dsl
din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)

In [None]:
din_sql.athena_connect(catalog_name=ATHENA_CATALOG_NAME, 
               db_name=DB_NAME, 
               s3_prefix=ATHENA_RESULTS_S3_LOCATION)

In [None]:
schema = {'customer': 'c_customer_sk|c_customer_id|c_current_cdemo_sk|c_current_hdemo_sk|c_current_addr_sk|c_first_shipto_date_sk|c_first_sales_date_sk|c_salutation|c_first_name|c_last_name|c_preferred_cust_flag|c_birth_day|c_birth_month|c_birth_year|c_birth_country|c_login|c_email_address|c_last_review_date_sk|'}

# Step 3: Prevent SQL injection with delete request

In [None]:
# A user request to delete tables, which is potentially cause damage to the database
query = "delete all the customers"

### Prompt without protection of SQL operations
In the initial version of prompt below, you’ll see that there are baseline protections in the instruction such as to use only the columns from the schema and never query for all the columns but relevant columns given the query. However, there is no protection of the unsafe operations such as "insert", "delete" or "udpate", which will alter the database data.

In [None]:
# Prompt without protection of SQL injection

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question}
        Instructions:
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
          
        <Question>"How many users do we have?"</Question>
        <SQL>SELECT SUM(users) FROM customers</SQL>

        <Question>"How many users do we have for Mobile?"</Question>
        <SQL>SELECT SUM(users) FROM customer WHERE source_medium='Mobile'</SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
#print(prompt_data)

In [None]:
# Formulate Bedrock Model Invoke Body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})

### Without protection, the LLM generates the delete sql, which will delete all the customer records if executed.

In [None]:
# Invoke model to generate response
response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']
print(sql)

### The prompt below shows how to avoid generating unsafe sql like "delete"
#### Adding this instruction and example to the prompt:


"Reject any question that ask for insert, update, and delete actions"

  "\<Question\>"Delete all the customers"\</Question\>"
  
  "\<SQL\>I don't have permission to generate or execute SQLs which can change data\</SQL\>"


In [None]:
# Prompt with protection of SQL injection

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question}
        Instructions:
           Reject any question that ask for insert, update, and delete actions  
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
          
        <Question>"How many users do we have?"</Question>
        <SQL>SELECT SUM(users) FROM customers</SQL>

        <Question>"How many users do we have for Mobile?"</Question>
        <SQL>SELECT SUM(users) FROM customer WHERE source_medium='Mobile'</SQL>

        <Question>"Delete all the customers"</Question>
        <SQL>I don't have permission to generate or execute SQLs which can change data</SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
#print(prompt_data)

In [None]:
# Formulate Bedrock Model Invoke Body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})

### Given the same user request, with improved prompt, the LLM generates different answer to avoid deleting all the customers from the table

In [None]:
# Invoke model to generate response

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']
print(sql)

# Step 4: Prevent SQL injection with update request

In [None]:
# A user query to massively update the table in a database
query = "update all the customers who were born after 2000 and set their first name to 'John'"

### Prompt without protection of SQL operations
In the initial version of prompt below, you’ll see that there are baseline protections in the instruction such as to use only the columns from the schema and never query for all the columns but relevant columns given the query. However, there is no protection of the unsafe operations such as "insert", "delete" or "udpate", which will alter the database data.

In [None]:
# Prompt without protection of SQL injection

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question}
        Instructions:
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
          
        <Question>"How many users do we have?"</Question>
        <SQL>SELECT SUM(users) FROM customers</SQL>

        <Question>"How many users do we have for Mobile?"</Question>
        <SQL>SELECT SUM(users) FROM customer WHERE source_medium='Mobile'</SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
#print(prompt_data)

In [None]:
# Formulate Bedrock Model Invoke Body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})

### Without operation protection, the LLM generates the update sql, which will update all the customers born before 2000 if executed.

In [None]:
# Invoke model to generate response

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']
print(sql)

### The prompt below shows how to avoid generating unsafe udpate sql in this example.
#### Adding this instruction and example to the prompt:


"Reject any question that ask for insert, update, and delete actions. Don't generate SQL statement."

  "\<Question\>Delete all the customers\</Question\>"
  
  "\<SQL\>I don't have permission to generate or execute SQLs which can change data\</SQL\>"



In [None]:
# Prompt with protection of SQL injection

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question}
        Instructions:
           Reject any question that ask for insert, update, and delete actions. Don't generate SQL statement.  
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
          
        <Question>"How many customers do we have?"</Question>
        <SQL>SELECT SUM(customers) FROM customers</SQL>

        <Question>"How many customers do we have for Mobile?"</Question>
        <SQL>SELECT SUM(customers) FROM customer WHERE source_medium='Mobile'</SQL>

        <Question>"Delete all the customers"</Question>
        <SQL>I don't have permission to generate or execute SQLs which can change data</SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
#print(prompt_data)

In [None]:
# Formulate Bedrock Model Invoke Body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})

### Given the same user request, with the improved prompt, the LLM does not generate the SQL statement and provides a warning message for SQL statement that can modify database data.

In [None]:
# Invoke model to generate response

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']
print(sql)

# Step 5: Apply Read Restrictions

### This user query is potentially to retrieve information that the user does not have permission for and comprimise data privacy. 

In [None]:
# A query to retrieve all the customer information that the customer should not have access to.
query = "give me customer information for customers who were born before 1930"

In [None]:
# Prompt without read restrictions

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question} 
        Instructions:
           Reject any question that ask for insert, update, and delete actions  
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
        
        <Question>"How many customers do we have?"</Question>
        <SQL>SELECT SUM(customers) FROM customers</SQL>

        <Question>"How many customers do we have for Mobile?"</Question>
        <SQL>SELECT SUM(customers) FROM customer WHERE source_medium='Mobile'</SQL>

        <Question>"Delete all the customers"</Question>
        <SQL>I don't have permission to generate or execute SQLs which can change data</SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
#print(prompt_data)

In [None]:
# Formulate Bedrock Model Invoke Body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})

In [None]:
# Invoke model to generate response

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']
print(sql)

### Without read restriction in prompt, the LLM generates sql statement to retrieve all the customer information, which the user should not have access to.

In [None]:
results = din_sql.query(sql.split('<SQL>')[1].split('</SQL>')[0])
results

### The prompt below shows how to apply read restrictions
#### Adding this instruction and example to the prompt and provide customer_id at runtime.


"The question will be asked by a customer with a customer_id. The query should only return results for the customer_id of the customer asking the question as to protect the privacy of other customers. 
For example, a customer with customer_id='A' can not see the information of customer with customer_id='B'. The customer_id of the customer asking the question is: {customer_id}"

"\<Question\>Give me customer information for Mobile\</Question\>"

"\<SQL\>SELECT * FROM customer WHERE source_medium='Mobile' and customer_id = {customer_id} \</SQL\>"

In [None]:
# Prompt with protection of SQL injection

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question} 
        Instructions:
           The question will be asked by a customer with a customer_id. The query should only return results for the customer_id of the customer asking the question as to protect the privacy of other customers. 
           For example, a customer with customer_id='A' can not see the information of customer with customer_id='B'. The customer_id of the customer asking the question is: {customer_id}
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
        
        <Question>"How many customers do we have?"</Question>
        <SQL>SELECT SUM(customers) FROM customers</SQL>

        <Question>"How many customers do we have for Mobile?"</Question>
        <SQL>SELECT SUM(customers) FROM customer WHERE source_medium='Mobile'</SQL>
        
        <Question>"Give me customer information for Mobile"</Question>
        <SQL>SELECT * FROM customer WHERE source_medium='Mobile' and customer_id = {customer_id} </SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query, customer_id = 'AAAAAAAABMLCAAAA')
#print(prompt_data)

In [None]:
# Formulate Bedrock Model Invoke Body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})

### Given the same user request, with the improved prompt, the LLM generates different answer to apply read restrictions
In this case, it only returned the customer information for the specified customer_id

In [None]:
# Invoke model to generate response

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']
print(sql)

In [None]:
results = din_sql.query(sql.split('<SQL>')[1].split('</SQL>')[0])
results