In [None]:
%pip install -U opensearch-py
%pip install -U boto3
%pip install -U retrying
%pip install -U jq
%pip install -U langchain
from IPython.core.display import HTML
import warnings
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
warnings.filterwarnings('ignore')
print("Restarted")

In [None]:
import logging
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
import time
import json
import jq
from aws import Aws
from langchain.vectorstores import OpenSearchVectorSearch
from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain_community.chat_models import BedrockChat
from langchain_core.documents.base import Document
import pandas
from urllib.parse import urlparse

In [None]:
logger = logging.getLogger("Main")
logger.setLevel(logging.DEBUG)
logger.propagate = False
logger.addHandler(logging.StreamHandler())
logger.info("Logger initialized")

In [None]:
aoss_node="mey2dta082iatb0w5x43.us-east-1.aoss.amazonaws.com"
index_name = "snowflake"
database_name="snowflake"
catalog_name="snowflake"
service="aoss"
glue_databucket_name="text-to-sql-with-athena-and-sn-assetbucket1d025086-uvkbiuwwsgqb"
query_result_folder="athena-workgroup"
athena_work_group="snowflake-workgroup"
aws=Aws()

In [None]:
sts_client = boto3.client('sts')
response = aws.sts.get_caller_identity()
logger.info(json.dumps(response))
boto3_session = boto3.session.Session()
region = boto3_session.region_name
account_id=aws.sts.get_caller_identity().get('Account')
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, service)

In [None]:

response = aws.glue.get_tables(
    CatalogId=account_id,
    DatabaseName=database_name,
)

program=jq.compile(".TableList[] |  {Catalog: .DatabaseName, Database: (.StorageDescriptor.Location | split(\".\") | .[1]), Table: (.StorageDescriptor.Location | split(\".\") | .[2]), Columns: (.StorageDescriptor.Columns | map({Name:.Name , Type: .Type, Comment: .Comment})) }")
snowflake_tables=program.input(json.loads(json.dumps(response, indent=4, sort_keys=True, default=str))).all()
print(snowflake_tables)

In [None]:
docs=[]
for table in snowflake_tables:
    docs.append(Document(json.dumps(table)))
logger.info(docs)

In [None]:

embeddings = BedrockEmbeddings(
            client=aws.bedrock, 
            model_id="amazon.titan-embed-text-v1"
        )
# Build the OpenSearch client
oss_client = OpenSearch(
    hosts=[{'host': aoss_node, 'port': 443}],
    http_auth=auth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=300
)
if oss_client.indices.exists(index=index_name):
    logger.info(f"Dropping existing index '{index_name}'")
    response = oss_client.indices.delete(
        index = index_name
    )
    logger.info(json.dumps(response))
    time.sleep(10)
    
# Create index
logger.info(f"Creating index '{index_name}'")
vector_search=OpenSearchVectorSearch(opensearch_url=f"https://{aoss_node}", 
                                     index_name=index_name,
                                     embedding_function=embeddings,
                                     http_auth=auth,
                                    engine="faiss")
time.sleep(10)
vector_search.client=oss_client
logger.info(f"Indexing documents...")
response=vector_search.add_documents(documents=docs)
time.sleep(60) #it can take up to a minute for the documents to finish indexing
logger.info(response)
logger.info("Done")

In [None]:
inference_modifier = {
                "max_tokens": 3000,
                "temperature": 0,
                "top_k": 20,
                "top_p": 1,
                "stop_sequences": ["\n\nHuman:"],
            }
llm = BedrockChat(model_id = "anthropic.claude-3-haiku-20240307-v1:0",
                            client = aws.bedrock, 
                            model_kwargs = inference_modifier 
                            ) 


def wait_for_result(execution_id):
    logger.info(f"Getting status of query: {execution_id}")
   
    results = aws.athena.get_query_execution(QueryExecutionId=execution_id)
    #logger.info(f"Results: {results}")
    status=results['QueryExecution']['Status']['State']
    logger.info(f"Status {status}")
    if status in ['SUCCEEDED','FAILED','CANCELLED']:
        return results
    else:
        time.sleep(5) 
        return wait_for_result(execution_id)

def syntax_checker(query_string):
        # print("Inside execute query", query_string)
        query_config = {
            "OutputLocation": f"s3://{glue_databucket_name}/{query_result_folder}",
            "EncryptionConfiguration": {
                'EncryptionOption': 'SSE_S3',

            }
        }
        query_execution_context = {
            "Catalog": catalog_name,
        }
        query_string="Explain  "+query_string
        logger.info(f"Executing: {query_string}")
        try:
            logger.info("I am checking the syntax")
            query_execution = aws.athena.start_query_execution(
                QueryString=query_string,
                ResultConfiguration=query_config,
                QueryExecutionContext=query_execution_context,
                WorkGroup=athena_work_group
            )
            execution_id = query_execution["QueryExecutionId"]
            
            results = wait_for_result(execution_id)
            # print(f"results: {results}")
            status=results['QueryExecution']['Status']
            logger.info("Status: ",status)
            if status['State']=='SUCCEEDED':
                return "Passed"
            else:  
                print(results['QueryExecution']['Status']['StateChangeReason'])
                errmsg=results['QueryExecution']['Status']['StateChangeReason']
                return errmsg
            # return results
        except Exception as e:
            print("Error in exception")
            msg = str(e)
            print(msg)

def get_results(query_string):
    query_config = {
        "OutputLocation": f"s3://{glue_databucket_name}/{query_result_folder}",
        "EncryptionConfiguration": {
            'EncryptionOption': 'SSE_S3',
            
        }
    }
    query_execution_context = {
        "Catalog": catalog_name,
    }
    query_execution = aws.athena.start_query_execution(
        QueryString=query_string,
        ResultConfiguration=query_config,
        QueryExecutionContext=query_execution_context,
        WorkGroup=athena_work_group
    )
    execution_id = query_execution["QueryExecutionId"]
    results = wait_for_result(execution_id)
    return results

def generate_sql(prompt, max_attempt=4) ->str:
            """
            Generate and Validate SQL query.

            Args:
            - prompt (str): Prompt is user input and metadata from Rag to generating SQL.
            - max_attempt (int): Maximum number of attempts correct the syntax SQL.

            Returns:
            - string: Sql query is returned .
            """
            attempt = 0
            error_messages = []
            prompts = [prompt]
            sql_query=""   
            while attempt < max_attempt:
                logger.info(f'Sql Generation attempt Count: {attempt+1}')
                try:
                    logger.info(f'Attempt #{attempt+1} to generate the sql')
                    generated_sql = llm.invoke(prompt)
                    logger.info(generated_sql)
                    query_str = generated_sql.content.split("```")[1]
                    query_str = " ".join(query_str.split("\n")).strip()
                    logger.info(query_str)
                    sql_query = query_str[3:] if query_str.startswith("sql") else query_str
                    # return sql_query
                    syntaxcheckmsg=syntax_checker(sql_query)
                    if syntaxcheckmsg=='Passed':
                        logger.info(f'syntax checked for query passed in attempt number :{attempt+1}')
                        return sql_query
                    else:
                        prompt = f"""{prompt}
                        This is syntax error: {syntaxcheckmsg}. 
                        To correct this, please generate an alternative SQL query which will correct the syntax error.
                        The updated query should take care of all the syntax issues encountered.
                        Follow the instructions mentioned above to remediate the error. 
                        Update the below SQL query to resolve the issue:
                        {sql_query}
                        Make sure the updated SQL query aligns with the requirements provided in the initial question."""
                        prompts.append(prompt)
                        attempt += 1
                except Exception as e:
                    logger.error(f'FAILED: {e}')
                    msg = str(e)
                    error_messages.append(msg)
                    attempt += 1
            return sql_query

In [None]:


user_query='show me all non-adult movie titles from 1980'
logger.info(f'Searching metadata from vector store')

# vector_search_match=rqst.getEmbeddding(user_query)
vector_search_match = vector_search.similarity_search(user_query)
logger.info(vector_search_match)

if len(vector_search_match)>0: 
    page_contents=[]
    for x in vector_search_match:
        page_contents.append(x.page_content)
    
    details="It is important that the SQL query complies with Athena syntax. During join if column name are same please use alias ex llm.customer_id in select statement. It is also important to respect the type of columns: if a column is string, the value should be enclosed in quotes. If you are writing CTEs then include all the required columns. Please print the resulting SQL query in a sql code markdown block."
    final_question = "\n\nHuman:"+details +". The following json document represents the metadata for the tables in the database:  "+ ", ".join(page_contents)+". Generate SQL to select "+ user_query+ "\n\nAssistant:"
    logger.info(final_question)
    query_string = generate_sql(final_question)
    results=get_results(query_string)
    if results['QueryExecution']['Status']['State']=="FAILED":
        logger.error(results)
    output_location=results['QueryExecution']['ResultConfiguration']['OutputLocation']
else:
    logger.error("No vector search match")

In [None]:
url=urlparse(output_location)
obj=aws.s3.get_object(Bucket=url.netloc, Key=url.path.lstrip('/')) 

results_df = pandas.read_csv(obj['Body'])
display(results_df)