In [None]:
!pip install -r requirements.txt
!pip install ipywidgets
# !pip install sagemaker -U

In [None]:
import re
import pandas as pd
from io import StringIO
import json
import time
import boto3
# import sentencepiece
import pandas as pd
from anthropic import Anthropic
CLAUDE = Anthropic()
import multiprocessing
import subprocess
import shutil
import os
import codecs
import uuid
from transformers import LlamaTokenizer
import tiktoken
from transformers import AutoTokenizer
REDSHIFT=boto3.client('redshift-data')
S3=boto3.client('s3')
from botocore.config import Config
import ipywidgets as widgets
from IPython.display import display

config = Config(
    read_timeout=120,
    retries = dict(
        max_attempts = 4
    )
)
BEDROCK=boto3.client(service_name='bedrock-runtime',region_name='us-east-1',config=config)
MIXTRAL_ENDPOINT="mixtral"

### Deploy Mixtral 8x7B Instruct to SageMaker Endpoint

In [None]:
# Note this requires an ml.g5.48xlarge instance.
model_id = "huggingface-llm-mixtral-8x7b-instruct"
from sagemaker.jumpstart.model import JumpStartModel
model = JumpStartModel(model_id=model_id)
predictor = model.deploy(endpoint_name=MIXTRAL_ENDPOINT)

## REDSHIFT

#### Change parameters below to those of your redshift provisioned cluster

In [None]:
redshift_client = boto3.client('redshift-data')
CLUSTER_IDENTIFIER = 'redshift-cluster-1'
DATABASE = 'dev'
DB_USER = 'awsuser' 

In [None]:
redshift_client = boto3.client('redshift-data')
CLUSTER_IDENTIFIER = 'redshift-cluster-1'
DATABASE = 'dev'
DB_USER = 'awsuser' 

In [None]:
def token_counter(path):
    tokenizer = LlamaTokenizer.from_pretrained(path)
    return tokenizer
def mixtral_counter(path):
    tokenizer = AutoTokenizer.from_pretrained(path)
    return tokenizer

In [None]:
def query_llm(prompts,tokens):   
    """
    Function to prompt the model to generate SQL statements from natural language
    """
    import boto3 #remove
    import json #remove
    payload = {
        "inputs":prompts,
        "parameters": {"max_new_tokens": tokens, 
                       # "top_p": params['top_p'], 
                       "temperature": 0.1,
                       "return_full_text": False,}
    }
    llama=boto3.client("sagemaker-runtime")
    output=llama.invoke_endpoint(Body=json.dumps(payload), EndpointName=MIXTRAL_ENDPOINT,ContentType="application/json")
    answer=json.loads(output['Body'].read().decode())[0]['generated_text']  
    return answer

In [None]:
def qna_llm(prompts,params):
    """
    Function to prompt the model to generate natural language answers from sql results
    """   
    if 'mixtral' in params['model_id'].lower():        
        import boto3
        import json
        payload = {
            "inputs":prompts,
            "parameters": {"max_new_tokens": params['text-token'], 
                           # "top_p": params['top_p'], 
                           "temperature": params['temp'],
                           "return_full_text": False,}
        }
        llama=boto3.client("sagemaker-runtime")
        output=llama.invoke_endpoint(Body=json.dumps(payload), EndpointName=MIXTRAL_ENDPOINT,ContentType="application/json")
        answer=json.loads(output['Body'].read().decode())[0]['generated_text']   
    return answer

In [None]:
def chunk_csv_rows(csv_rows, max_token_per_chunk):
    """
    Chunk CSV rows based on the maximum token count per chunk.
    Args:
        csv_rows (list): List of CSV rows.
        max_token_per_chunk (int, optional): Maximum token count per chunk.
    Returns:
        list: List of chunks containing CSV rows.
    Raises:
        ValueError: If a single CSV row exceeds the specified max_token_per_chunk.
    """
    header = csv_rows[0]  # Assuming the first row is the header
    csv_rows = csv_rows[1:]  # Remove the header from the list
    current_chunk = []
    current_token_count = 0
    chunks = []
    header_token=len(mixtral_counter("mistralai/Mixtral-8x7B-v0.1").encode(header))
    for row in csv_rows:
        token = len(mixtral_counter("mistralai/Mixtral-8x7B-v0.1").encode(row))
        if current_token_count + token+header_token <= max_token_per_chunk:
            current_chunk.append(row)
            current_token_count += token
        else:
            if not current_chunk:
                raise ValueError("A single CSV row exceeds the specified max_token_per_chunk.")
            header_and_chunk=[header]+current_chunk
            chunks.append("\n".join([x for x in header_and_chunk]))
            current_chunk = [row]
            current_token_count = token

    if current_chunk:
        last_chunk_and_header=[header]+current_chunk
        chunks.append("\n".join([x for x in last_chunk_and_header]))
    return chunks

In [None]:
def get_tables_redshift(cluster_identifier, database, db_user, schema):
    """
    Get a list of table names in a specified schema from an Amazon Redshift cluster.
    Args:
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database containing the tables.
        db_user (str): The username used to authenticate with the Redshift cluster.
        schema (str): The schema pattern to filter tables.
    Returns:
        list: A list of table names in the specified schema.
    """
    tables_ls = REDSHIFT.list_tables(
    ClusterIdentifier=cluster_identifier,
    Database=database,
    DbUser=db_user,
    SchemaPattern=schema
    )
    return [x['name'] for x in  tables_ls['Tables']]

In [None]:
def get_db_redshift(cluster_identifier, database, db_user):
    """
    Get a list of databases from an Amazon Redshift cluster.
    Args:
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database containing the tables.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        list: A list of databases in the Redshift cluster.
    """
    db_ls = REDSHIFT.list_databases(
    ClusterIdentifier=cluster_identifier,
    Database=database,
    DbUser=db_user
    )
    return db_ls['Databases']

In [None]:
def get_schema_redshift(cluster_identifier, database, db_user):
    """
    Get a list of schemas from an Amazon Redshift cluster.
    Args:
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database containing the schemas.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        list: A list of schemas in the Redshift cluster.
    """
    schema_ls = REDSHIFT.list_schemas(
    ClusterIdentifier=cluster_identifier,
    Database=database,
    DbUser=db_user
    )
    return schema_ls['Schemas']

In [None]:
def execute_query_with_pagination( sql_query, cluster_identifier, database, db_user):
    """
    Execute multiple SQL queries in Amazon Redshift with pagination support.
    Args:
        sql_query1 (str): The first SQL query to execute.
        sql_query2 (str): The second SQL query to execute.
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        list: A list of results from executing the SQL queries.
    """
    results_list=[]
    response_b = REDSHIFT.batch_execute_statement(
        ClusterIdentifier=cluster_identifier,
        Database=database,
        DbUser=db_user,
        Sqls=sql_query
    )   
    describe_b=REDSHIFT.describe_statement(
         Id=response_b['Id'],
    )       
    status=describe_b['Status']
    while status != "FINISHED":
        time.sleep(1)
        describe_b=REDSHIFT.describe_statement(
                         Id=response_b['Id'],
                    ) 
        status=describe_b['Status']
    max_attempts = 5 
    attempts = 0
    while attempts < max_attempts:
        try:
            for ids in describe_b['SubStatements']:
                result_b = REDSHIFT.get_statement_result(Id=ids['Id'])                
                results_list.append(get_redshift_table_result(result_b))
            break
        except REDSHIFT.exceptions.ResourceNotFoundException as e:
            attempts += 1
            time.sleep(2)
    return results_list

In [None]:
def get_redshift_table_result(response):
    """
    Extracts result data from a Redshift query response and returns it as a CSV string.
    Args:
        response (dict): The response object from a Redshift query.
    Returns:
        str: A CSV string containing the result data.
    """
    columns = [c['name'] for c in response['ColumnMetadata']] 
    data = []
    for r in response['Records']:
        row = []
        for col in r:
            row.append(list(col.values())[0])  
        data.append(row)
    df = pd.DataFrame(data, columns=columns)    
    return df.to_csv(index=False)

In [None]:
def execute_query_redshift(sql_query, cluster_identifier, database, db_user):
    """
    Execute a SQL query on an Amazon Redshift cluster.
    Args:
        sql_query (str): The SQL query to execute.
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
    Returns:
        dict: The response object from executing the SQL query.
    """
    response = REDSHIFT.execute_statement(
        ClusterIdentifier=cluster_identifier,
        Database=database,
        DbUser=db_user,
        Sql=sql_query
    )
    return response

In [None]:
def single_execute_query(sql_query, cluster_identifier, database, db_user,question):
    """
    Execute a single SQL query on an Amazon Redshift cluster and process the result.

    Args:
        sql_query (str): The SQL query to execute.
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
        question (str): A descriptive label or question associated with the query.

    Returns:
        pandas.DataFrame: DataFrame containing the processed result of the SQL query.

    """
    result_sets = []
    response = execute_query_redshift(sql_query, cluster_identifier, database, db_user)
    df=redshift_querys(sql_query,response,question,params,cluster_identifier, database, db_user,question)    
    return df

In [None]:
def llm_debugger(question, statement, error, params): 
    """
    Generate debugging guidance and expected SQL correction for a PostgreSQL error.
    Args:
        question (str): The user's question or intent.
        statement (str): The SQL statement that caused the error.
        error (str): The error message encountered.
        params (dict): Additional parameters including schema, sample data, and length.
    Returns:
        str: Formatted debugging guidance and expected SQL correction.
    """
    prompts=f'''<s><<SYS>>[INST]
You are a PostgreSQL developer who is an expert at debugging errors.  

Here are the schema definition of table(s):
{params['schema']}
#############################
Here are example records for each table:
{params['sample']}
#############################
Here is the sql statement that threw the error below:
{statement}
#############################
Here is the error to debug:
{error}
#############################
Here is the intent of the user:
{params['prompt']}
<</SYS>>
First understand the error and think about how you can fix the error.
Use the provided schema and sample row to guide your thought process for a solution.
Do all this thinking inside <thinking></thinking> XML tags.This is a space for you to write down relevant content and will not be shown to the user.

Once your are done debugging, provide the the correct SQL statement without any additional text.
When generating the correct SQL statement:
1. Pay attention to the schema and table name and use them correctly in your generated sql. 
2. Never query for all columns from a table unless the question says so. You must query only the columns that are needed to answer the question.
3. Wrap each column name in double quotes (") to denote them as delimited identifiers. Do not use backslash (\) to escape underscores (_) in column names. 

Format your response as:
<sql> Correct SQL Statement </sql>[/INST]'''

    
#     prompts=f'''<s> [INST] You are a PostgreSQL developer who is an expert at debugging errors.
# Here are the schema definition of table(s):
# {params['schema']}
# #############################
# Here are example records for each table:
# {params['sample']}
# #############################
# Here is the sql statement that threw the error below:
# {statement}
# #############################
# Here is the error to debug:
# {error}
# #############################
# Here is the intent of the user:
# {params['prompt']} 
# First understand the error and think about how you can fix the error.
# Use the provided schema and sample row to guide your thought process for a solution.
# Do all this thinking inside <thinking></thinking> XML tags.This is a space for you to write down relevant content and will not be shown to the user.
# Once your are done debugging, provide the the correct SQL statement without any additional text.
# When generating the correct SQL statement:
# 1. Pay attention to the database schema and table name and use them correctly in your response. 
# 2. Never query for all columns from a table unless the question says so. You must query only the columns that are needed to answer the question.
# 3. Wrap all column name(s) in double quotes (") to denote them as delimited identifiers.  
# 4. DO NOT escape underscores (_) in column name(s). Just wrap them in double quotes (").
# 5. SQL engine is Amazon Redshift database.

# Format your response as:
# <sql> Correct SQL Statement </sql> [/INST] '''
    answer=query_llm(prompts,round(params['sql-len']))
    answer = answer.replace("\\","")
    return answer

In [None]:
def redshift_querys(q_s,response,prompt,params,cluster_identifier, database, db_user,question): 
    """
    Execute a Redshift query, handle errors, debug SQL, and return the result.

    Args:
        q_s (str): The SQL statement to execute or debug.
        response (dict): The response object from executing the SQL statement.
        prompt (str): The user's question or intent.
        params (dict): Additional parameters including schema, sample data, and length.
        cluster_identifier (str): The identifier of the Redshift cluster.
        database (str): The name of the database.
        db_user (str): The username used to authenticate with the Redshift cluster.
        question (str): A descriptive label or question associated with the query.

    Returns:
        pandas.DataFrame or str: DataFrame containing the query result, or debugging failure message with no result.

    """
    max_execution=5
    attempt_number=0
    debug_count=max_execution
    try:
        statement_result = REDSHIFT.get_statement_result(
            Id=response['Id'],

        )
    except REDSHIFT.exceptions.ResourceNotFoundException as err:  
        # print(err)
        describe_statement=REDSHIFT.describe_statement(
             Id=response['Id'],
        )
        query_state=describe_statement['Status']  
        while query_state in ['SUBMITTED','PICKED','STARTED']:
            # print(query_state)
            time.sleep(1)
            describe_statement=REDSHIFT.describe_statement(
                 Id=response['Id'],
            )
            query_state=describe_statement['Status']
        while (max_execution > 0 and query_state == "FAILED"):
            max_execution = max_execution - 1
            attempt_number = 5 - max_execution
            print("- - - - - - - - - - - - - -\n")
            print(f"\nDEBUG TRIAL {attempt_number}")
            bad_sql=describe_statement['QueryString']
            print(f"\nBAD SQL:\n{bad_sql}")                
            error=describe_statement['Error']
            print(f"ERROR:{error}")
            print("\nDEBUGGING...")
            cql=llm_debugger(prompt, bad_sql, error, params)            
            idx1 = cql.index('<sql>')
            idx2 = cql.index('</sql>')
            q_s=cql[idx1 + len('<sql>') + 1: idx2]
            print(f"\nDEBUGGED SQL {q_s}")
            response = execute_query_redshift(q_s, cluster_identifier, database, db_user)
            describe_statement=REDSHIFT.describe_statement(
                                 Id=response['Id'],
                            )
            query_state=describe_statement['Status']
            # print(f"\n{query_state}")
            while query_state in ['SUBMITTED','PICKED','STARTED']:
                time.sleep(2)
                # print(f"\n{query_state}")
                describe_statement=REDSHIFT.describe_statement(
                                 Id=response['Id'],
                            )
                query_state=describe_statement['Status']
            if query_state == "FINISHED":                
                break 
        
        if max_execution == 0 and query_state == "FAILED":
            print(f"DEBUGGING FAILED IN {str(debug_count)} ATTEMPTS")
        else:           
            max_attempts = 5
            attempts = 0
            while attempts < max_attempts:
                try:
                    time.sleep(1)
                    # print(response['Id'])
                    statement_result = REDSHIFT.get_statement_result(
                        Id=response['Id']
                    )
                    break

                except REDSHIFT.exceptions.ResourceNotFoundException as e:
                    attempts += 1
                    time.sleep(5)
    if max_execution == 0 and query_state == "FAILED":
        df=f"DEBUGGING FAILED IN {str(debug_count)} ATTEMPTS. NO RESULT AVAILABLE"
    else:
        df=get_redshift_table_result(statement_result)
    return df, q_s

In [None]:
def redshift_qna(params):
    """
    Execute a Q&A process for generating SQL queries based on user questions.
    Args:
        params (dict): A dictionary containing parameters including table name, database name, prompt, etc.
    Returns:
        tuple: A tuple containing the response, generated SQL statement, and query output.
    """
    # sql1=f"SELECT * FROM information_schema.columns WHERE table_name='{params['table']}' AND table_schema='{params['db']}'"
    # sql2=f"SELECT * from dev.{params['db']}.{params['table']} LIMIT 10"
    sql1=f"SELECT table_catalog,table_schema,table_name,column_name,ordinal_position,is_nullable,data_type FROM information_schema.columns WHERE table_schema='{params['db']}'"
    sql2=[]
    for table in params['tables']:
        sql2.append(f"SELECT * from dev.{params['db']}.{table} LIMIT 3")
    sqls=[sql1]+sql2
    #print(sqls)
    question=params['prompt']
    results=execute_query_with_pagination(sqls, CLUSTER_IDENTIFIER, DATABASE, DB_USER)    
    
    col_names=results[0].split('\n')[0]
    observations="\n".join(sorted(results[0].split('\n')[1:])).strip()
    params['schema']=f"{col_names}\n{observations}"
    params['sample']=''
    for examples in results[1:]:
        params['sample']+=f"{examples}\n\n"
    # params['schema']=schema
    # params['sample']=schema_example
    
    prompts=f"""<s><<SYS>>[INST]
You are an expert PostgreSQL developer. Your job is to provide a syntactically correct PostgreSQL query given a user question.
Here are the schema definition of table(s):
########
{params['schema']}
########

Here are example records for each table:
##########
{params['sample']}
###########
<</SYS>>
Here are some instructions when generating SQL statements:
1. Determine the necessary table(s) and schema needed for an accurate query.
2. Limit your queries to only the required columns to prevent unnecessary data retrieval and improve query performance.
3. For clarity and to prevent potential conflicts, always include the schema name when referencing table names in your SQL queries.
4. When working with Amazon Redshift table and column names containing underscores, do not use the backslash escape character (\). Instead, use double quotes ("") to enclose the names in your queries.
5. Do not mention 'dev' or 'public' in the queries.
In your response, provide a single SQL statement to answer the question, avoid additional text that would cause failure during executing the sql. 
Format your response as:
<sql>
generated SQL statement 
</sql>

Question: {question}[/INST]"""

#     prompts=f"""<s> [INST] You are an expert PostgreSQL developer. Your job is to provide a syntactically correct PostgreSQL query for Amazon Redshift Database.
# Here are the schema definition of table(s):
# {params['schema']}

# Here are example records for each table:
# {params['sample']}

# Here are some instructions when generating SQL statements:
# 1. Pay attention to database schema and table names and use them correctly in your response. 
# 2. Never query for all columns from a table. You must query only the columns that are needed to answer the question.
# 3. Wrap all column name(s) in double quotes (") to denote them as delimited identifiers. 
# 4. DO NOT escape underscores (_) in column name(s). Just wrap them in double quotes (").
# In your response, provide a single SQL statement to answer the question, avoid additional text that would cause failure during executing the sql. 
# Format your response as:
# <sql>
# generated SQL statement 
# </sql>
# Question: {question} [/INST] """

    q_s=query_llm(prompts,200)
    sql_pattern = re.compile(r'<sql>(.*?)(?:</sql>|$)', re.DOTALL)           
    sql_match = re.search(sql_pattern, q_s)
    q_s = sql_match.group(1) 
    q_s = q_s.replace("\\","")
    print(f" FIRST ATTEMPT SQL:\n{q_s}")
    output, q_s=single_execute_query(q_s, CLUSTER_IDENTIFIER, DATABASE, DB_USER,question)    
    input_token=mixtral_counter("mistralai/Mixtral-8x7B-v0.1").encode(output)
   
    if len(input_token)>28000:    
        csv_rows=output.split('\n')
        chunk_rows=chunk_csv_rows(csv_rows, 20000)
        initial_summary=[]
        for chunk in chunk_rows:
            prompts=f'''<s><<SYS>>[INST]You are a helpful and truthful assistant. Your job is provide answers based on samples of a tabular data provided.

Here is the tabular data:
#######
{chunk}
#######
<</SYS>>
Question: {question}

When providing your response:
- First, review the result to understand the information within. Then provide a complete answer to the my question, based on the result.
- If you can't answer the question, please say so[/INST]'''
            initial_summary.append(qna_llm(prompts,params))
        prompts = f'''<s><<SYS>>[INST]You are a helpful and truthful assistant.

Here are multiple answer for a question on different subset of a tabular data:
#######
{initial_summary}
#######
<</SYS>>
Question: {question}
Based on the given question above, merege all answers provided in a coherent singular answer[/INST]'''
        response=qna_llm(prompts,params)
        
    else:        
        prompts=f'''<s><<SYS>>[INST]You are a helpful and truthful assistant. Your job is to examine a sql statement and its generated result, then provide a response to my question.

Here is the sql query:
{q_s}

Here is the corresponding sql query result:
{output}
<</SYS>>
question: {question}

When providing your response:
- First, review the sql query and the corresponding result. Then provide a complete answer to the my question, based on the result.
- If you can't answer the question, please say so[/INST]'''
        response=qna_llm(prompts, params) 
    return response, q_s,output

In [None]:
db=get_db_redshift(CLUSTER_IDENTIFIER, DATABASE, DB_USER)[-1]
schm=get_schema_redshift(CLUSTER_IDENTIFIER, db, DB_USER)[-1]
tables=get_tables_redshift(CLUSTER_IDENTIFIER, db, DB_USER,schm)
db, schm, tables

#### Example prompts:

In [None]:
prompt1 = "Who are the 5 people who spent the most on tickets for events?"

In [None]:
prompt2 = "the top five sellers names in San Diego, based on the number of tickets sold in 2008?"

In [None]:
prompt3 = "What where the 10 events for which tickets took the longest to sell?"

In [None]:
prompt4 = "the most popular state to host events based on the number of venues per state."

In [None]:
prompt5 = "Number of Venues where the show Macbeth was held."

In [None]:
prompt6 = "what are the top 10 buyers by quantity."

In [None]:
prompt7 = "for the top 10 events, count the number of times each of them occur."

In [None]:
prompt8 = "Total Commissions Generated for Macbeth at Royce Hall."

In [None]:
entered_text = widgets.Text(
    value='',
    description='Enter prompt:',
)
display(entered_text)

In [None]:
prompt = entered_text.value
params={'sql-len':700,'text-token':500,'tables':tables,'db':schm,'temp':0.1,'model_id':'mixtral',
        "prompt":prompt}
print(params["prompt"])

In [None]:
%%time
result_text2sql = redshift_qna(params)

In [None]:
# Query result in Natural Language
print(f"\nAnswer:\n\n{result_text2sql[0]}\n")

In [None]:
# Generated SQL query used
print(f"\nSQL Query generated from the prompt:\n")
display(Code(result_text2sql[1], language='sql'))
print("")

In [None]:
# Tabular results from the SQL Query 
print(f"\nTabular results from the SQL query:\n")
df=pd.read_csv(StringIO(result_text2sql[2]))
df