### Ask a Question

**Suggested SageMaker JupterLab Notebook Environment set up is as follows:**

Sagemaker Image: sagemaker-distribution-cpu

Kernel: Python 3

Instance Type: ml.m5.large

![Alt text](content/rag.png)

# Dependencies installations

Here we will install all the required dependencies to run this notebook. 

In [None]:
!python -m ensurepip --upgrade
!pip install langchain --quiet
!pip install jq --quiet
!pip install faiss-cpu --quiet

**Restart your Kernel before proceeding**

In [None]:
import json
import boto3
from botocore.config import Config
import sys


sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

### Replace those variables with your set up .

In [None]:

ATHENA_RESULTS_S3_LOCATION = "<workshop bucket name>" # available in cloudformation outputs
ATHENA_CATALOG_NAME = "<athena catalog name>" # available in cloudformation outputs
DB_NAME = "tpcds1"
DB_FAISS_PATH = '../vectorstore/db_faiss'
retry_config = Config(retries = {'max_attempts': 100})
bedrock_region = boto3.session.Session().region_name

In [None]:
session = boto3.Session(region_name=bedrock_region)
bedrock = session.client('bedrock-runtime', region_name=bedrock_region, config=retry_config)

In [None]:
from langchain.embeddings import BedrockEmbeddings
bedrock_embeddings = BedrockEmbeddings(client=bedrock)

In [None]:
from langchain.vectorstores import FAISS
DB_FAISS_PATH = './vectorstore/db_faiss'
question_db = FAISS.load_local(DB_FAISS_PATH, bedrock_embeddings)

In [None]:
query = "Find the top 10 customer name by total dollars spent"

### Get the tables and column information using both similarity and Keyword serach .

In [None]:
import json
schema =  {}

results_with_scores = question_db.similarity_search_with_score(query)
for doc, score in results_with_scores:
    print(doc.metadata['question'])
    schema[doc.metadata['tableName']] = doc.metadata['tableSchema']
  


In [None]:
from libs.din_sql import din_sql_lib as dsl
din_sql = dsl.DIN_SQL(bedrock_model_id='anthropic.claude-v2')

In [None]:
din_sql.athena_connect(catalog_name=ATHENA_CATALOG_NAME, 
               db_name=DB_NAME, 
               s3_prefix=ATHENA_RESULTS_S3_LOCATION)

In [None]:
from functools import reduce

list_tables = din_sql.find_tables(DB_NAME)
list_words = query.split(" ")

intersection = reduce(lambda acc, x: acc + [x] if x in list_words and x not in acc else acc, list_tables, [])
for table in  intersection :
   if table in schema:
      print("exists")
   else:
      schema_name = din_sql.get_schema(DB_NAME,table)
      schema[table] = schema_name


In [None]:
schema

In [None]:
from langchain.prompts import PromptTemplate

prompt_template = PromptTemplate.from_template(
    """\n\nHuman:
        Read database schema {schema} which contains a json list of table names and their pipe-delimited schemas.
        Use the schema, first create a syntactically correct awsathena query to answer the question {input_question}
        Instructions:
           Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
           Pay attention to use only the column names that you can see in the schema description. 
           Be careful to not query for columns that do not exist. 
           Pay attention to which column is in which table. 
           Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:
           Return the sql query inside the <SQL></SQL> tab.
          
        <Question>"How many users do we have?"</Question>
        <SQL>SELECT SUM(users) FROM customers</SQL>

        <Question>"How many users do we have for Mobile?"</Question>
        <SQL>SELECT SUM(users) FROM customer WHERE source_medium='Mobile'</SQL>
          
        <Question>{input_question}</Question>
        \n\n Assistant: """
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
print(prompt_data)

In [None]:
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})
model_id = 'anthropic.claude-v2' # change this to use a different version from the model provider
accept = 'application/json'
content_type = 'application/json'

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']


In [None]:
print(sql)