In [None]:
# @title Install Necessary Packages
first = True
if first:
  ! pip install google-cloud-aiplatform --upgrade --quiet
  ! pip install shapely==1.8.5 --quiet
  ! pip install sqlalchemy --upgrade --quiet
  ! pip install asyncio asyncpg cloud-sql-python-connector["asyncpg"] --quiet
  ! pip install numpy pandas --quiet
  ! pip install pgvector --quiet
  ! pip install pg8000 --quiet
  ! pip install gradio --quiet

In [None]:
# get_ipython().kernel.do_shutdown(True)

In [1]:
from google.colab import auth
auth.authenticate_user()

In [2]:
import os
auth_user=!gcloud config get-value account
auth_user=auth_user[0]
print('Authenticated User: ' + str(auth_user))


Authenticated User: moksh.google@shoppersstop.com


In [3]:
#@title Assignment of Variables
source_type='BigQuery'


# @markdown Provide the below details to start using the notebook
PROJECT_ID='ss-genai-npd-svc-prj-01' # @param {type:"string"}
LLM_ENDPOINT_REGION = 'asia-southeast1' # @param {type:"string"}
DATAPROJECT_ID='ss-genai-npd-svc-prj-01'  # This needs to be adjusted when using the bq public dataset

#set and show gcp project
!gcloud config set project {PROJECT_ID}
!gcloud config get-value project
#!bash gcloud auth application-default login


# BQ Schema (DATASET) where tables leave

schema="NA" # @param {type:"string"}.  ### DDL extraction performed at this level, for the entire schema
USER_DATASET= DATAPROJECT_ID + '.' + schema

table_id_list="ss-jarvis-npd-svc-prj-01.Customer.FC_Customer_Master, ss-jarvis-npd-svc-prj-01.Customer.Persona,ss-jarvis-npd-svc-prj-01.Product.Product_Master, ss-jarvis-npd-svc-prj-01.Sales.storesales,ss-jarvis-npd-svc-prj-01.Sales.ecommdenormdata,ss-jarvis-npd-svc-prj-01.Store.dim_location_full" # @param {type:"string"}


# BQ Schema (DATASET) where tables leave

# Execution Parameters
SQL_VALIDATION='ALL'
INJECT_ONE_ERROR=False
EXECUTE_FINAL_SQL=True
SQL_MAX_FIX_RETRY=3
AUTO_ADD_KNOWNGOOD_SQL=True

# Analytics Warehouse
ENABLE_ANALYTICS=True
DATASET_NAME='nl2sql'
DATASET_LOCATION='asia-south1'
LOG_TABLE_NAME='query_logs'
FULL_LOG_TEXT=''


# Palm Models to use
model_id='gemini-pro' # @param {type:"string"}
chat_model_id='codechat-bison-32k' # @param {type:"string"}
embeddings_model='textembedding-gecko@001'





Updated property [core/project].
ss-genai-npd-svc-prj-01


In [4]:
# @title Common Imports
import time
import datetime
from datetime import datetime, timezone
import hashlib
import vertexai
import pandas
import pandas_gbq
import matplotlib.pyplot as plt
from sqlalchemy import create_engine
from sqlalchemy import text
import pandas as pd
from google.colab import data_table
data_table.enable_dataframe_formatter()
import json
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from logging import exception
import asyncio
import asyncpg
from google.cloud.sql.connector import Connector
import numpy as np
from pgvector.asyncpg import register_vector
from google.cloud import aiplatform
from vertexai.language_models import TextEmbeddingModel
import gradio as gr

In [5]:
# @title Model Endpoint Creation
def createModel(PROJECT_ID, LLM_ENDPOINT_REGION, model_id):
  from vertexai.preview.language_models import TextGenerationModel
  from vertexai.preview.language_models import CodeGenerationModel
  from vertexai.preview.language_models import CodeChatModel
  from vertexai.preview.generative_models import GenerativeModel

  if model_id == 'code-bison-32k':
    model = CodeGenerationModel.from_pretrained('code-bison-32k')
  elif model_id == 'text-bison-32k':
    model = TextGenerationModel.from_pretrained('text-bison-32k')
  elif model_id == 'codechat-bison-32k':
    model = CodeChatModel.from_pretrained("codechat-bison-32k")
  elif model_id == 'gemini-pro':
    model = GenerativeModel("gemini-pro")
  else:
    raise ValueError
  return model


vertexai.init(project=PROJECT_ID, location=LLM_ENDPOINT_REGION)
model=createModel(PROJECT_ID, LLM_ENDPOINT_REGION,model_id)
chat_model=createModel(PROJECT_ID, LLM_ENDPOINT_REGION,chat_model_id)

#Utility Functions

In [6]:
def clean_sql(result):
  result = result.replace("```sql", "").replace("```", "")
  return result

def clean_json(result):
  result = result.replace("```json", "").replace("```", "").replace("json", "")
  return result

def execute_final_sql(generated_sql):
  df = pandas_gbq.read_gbq(generated_sql, project_id=PROJECT_ID)
  return df

In [7]:
def summarize_results( question, generated_sql,final_exec_result_df):
  context_prompt=f''' You are expert in summarizing results in natural language to a user's question in natural language.
    User has asked the question : "{question}" and the SQL generated for the question is : "{generated_sql}"

    By running the SQL the results are as below

    {final_exec_result_df.to_markdown(index=False)}

    Us the SQL results above and answer the question in natural language.
    '''
  context_query = model.generate_content(context_prompt, stream=False)
  return context_query.candidates[0].text


#Query Build

In [8]:
# @title RAG Based SQL Generator
def gen_dyn_rag_sql(question,table_result_joined,column_result_joined,similar_questions):
  global FULL_LOG_TEXT
  not_related_msg='select \'Question is not related to the dataset\' as unrelated_answer from dual;'
  context_prompt = f"""

      You are a BigQuery SQL guru. Write a SQL comformant query for Bigquery that answers the following question while using the provided context to correctly refer to the BigQuery tables and the needed column names.

      Guidelines:
      - Only answer questions relevant to the tables listed in the table schema. If a non-related question comes, answer exactly: {not_related_msg}
      - Join as minimal tables as possible.
      - When joining tables ensure all join columns are the same data_type.
      - Analyze the database and the table schema provided as parameters and undestand the relations (column and table relations).
      - Consider alternative options to CAST function. If performing a CAST, use only Bigquery supported datatypes.
      - Don't include any comments in code.
      - Remove ```sql and ``` from the output and generate the SQL in single line.
      - Tables should be refered to using a fully qualified name (project_id.owner.table_name).
      - Use all the non-aggregated columns from the "SELECT" statement while framing "GROUP BY" block.
      - Return syntactically and symantically correct SQL for BigQuery with proper relation mapping i.e project_id, owner, table and column relation.
      - Use ONLY the column names (column_name) mentioned in Table Schema. DO NOT USE any other column names outside of this.
      - Associate column_name mentioned in Table Schema only to the table_name specified under Table Schema.
      - Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
      - Table names are case sensitive. DO NOT uppercase or lowercase the table names.

    Table Schema:
    {table_result_joined}

    Column Description:
    {column_result_joined}

    Question/SQL Generated Examples:

    {similar_questions}

    Question:
    {question}

    SQL Generated:


    """


  print('\n LLM GEN SQL Prompt: \n' + context_prompt)
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n LLM GEN SQL Prompt:  ... \n'
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + context_prompt + '\n'


  if model_id =='code-bison-32k':
      context_query = model.predict(context_prompt, max_output_tokens = 1024, temperature=0)
  elif model_id =='gemini-pro':
      context_query = model.generate_content(context_prompt, stream=False)
  else:
    raise ValueError('model_id not found')

  resp_return = clean_sql(str(context_query.candidates[0].text))

  print(clean_sql(str(context_query.candidates[0].text)))

  return resp_return



In [9]:
# @title Validate Syntax of SQL
def validate_sql_syntax(question, generated_sql, table_result_joined,column_result_joined,similar_questions):

  #print(columns_df.to_markdown(index = False))
  s = 'error'

  context_prompt = f"""

    Classify the SQL query (generated sql) as valid or invalid.

    Instructions:
    - Use ONLY the column names (column_name) mentioned in Table Schema.
    - DO NOT USE any other column names outside the schema in this context.
    - To be considered valid, a SQL must be semantically correct, use correct ANSI SQL syntax and answer the question below.
    - Respond using a valid JSON format with only two elements: valid and errors. Remove ```json and ``` from the output
    - JSON must be valid. In the JSON data format, the keys must be enclosed in double quotes. Document must start with LEFT CURLY BRACKET character and end with the RIGHT CURLY BRACKET character

    Table Schema:
    {table_result_joined}

    Columns Description:
    {column_result_joined}

    Examples:
    {similar_questions}

    Question:
    {question}

    SQL Query:
    {generated_sql}

  """

  # print(context_prompt)
  if model_id =='code-bison-32k':
      context_query = model.predict(context_prompt, max_output_tokens = 1024, temperature=0)
  elif model_id =='gemini-pro':
      context_query = model.generate_content(context_prompt, stream=False)
  else:
    raise ValueError('model_id not found')

  resp_return = clean_json(clean_sql(str(context_query.candidates[0].text)))

  # print(clean_json(clean_sql(str(context_query.candidates[0].text))))

  return resp_return


In [10]:
# @title Test SQL Plan
def test_sql_plan_execution(generated_sql):
  from google.cloud import bigquery
  try:

    run_dataset=PROJECT_ID + '.' + DATASET_NAME
    df=pd.DataFrame()

    # Construct a BigQuery client object.
    client = bigquery.Client(project=PROJECT_ID)

    job_config = bigquery.QueryJobConfig(dry_run=True, use_query_cache=False)

    # Start the query, passing in the extra configuration.
    query_job = client.query(
        (generated_sql),
        job_config=job_config,
    )  # Make an API request.

    # A dry run query completes immediately.
    print("This query will process {} bytes.".format(query_job.total_bytes_processed))
    return 'Execution Plan OK'
  except Exception as e:
    print(e)
    msg='Error. Message: '+ str(e)
    return msg



In [11]:
# @title Chat Model INIT for retrys
def init_chat():
  global FULL_LOG_TEXT
  not_related_msg='select \'Question is not related to the dataset\' as unrelated_answer from dual;'

  context_prompt = f"""

    You are a BigQuery SQL guru. This session is trying to troubleshoot a Google BigQuery SQL query.
    As the user provides versions of the query and the errors returned by BigQuery,
    return a never seen alternative SQL query that fixes the errors.
    It is important that the query still answer the original question.

      Guidelines:
      - Only answer questions relevant to the tables listed in the table schema. If a non-related question comes, answer exactly: {not_related_msg}
      - Join as minimal tables as possible.
      - When joining tables ensure all join columns are the same data_type.
      - Analyze the database and the table schema provided as parameters and undestand the relations (column and table relations).
      - Consider alternative options to CAST function. If performing a CAST, use only Bigquery supported datatypes.
      - Don't include any comments in code.
      - Remove ```sql and ``` from the output and generate the SQL in single line.
      - Tables should be refered to using a fully qualified name (project_id.owner.table_name).
      - Use all the non-aggregated columns from the "SELECT" statement while framing "GROUP BY" block.
      - Return syntactically and symantically correct SQL for BigQuery with proper relation mapping i.e project_id, owner, table and column relation.
      - Use ONLY the column names (column_name) mentioned in Table Schema. DO NOT USE any other column names outside of this.
      - Associate column_name mentioned in Table Schema only to the table_name specified under Table Schema.
      - Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
      - Table names are case sensitive. DO NOT uppercase or lowercase the table names.

  """
  print('\n Initializing code chat model ...')
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Initializing code chat model ... \n'
  chat_session = chat_model.start_chat(context=context_prompt)
  #chat_session = chat_model.start_chat(context="")
  # context_prompt
  return chat_session


In [12]:
# @title Vector Loop Up for RAG
def get_tables_columns_bqvector(question):
  table_sql=f"""  SELECT base.idx, base.detailed_description,distance
FROM VECTOR_SEARCH(
  TABLE `ss-genai-npd-svc-prj-01.nl2sql.table_comments_embeddings`, "ml_generate_embedding_result",
  (
  SELECT ml_generate_embedding_result, content AS query
  FROM ML.GENERATE_EMBEDDING(
  MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
  (SELECT "{question}" AS content))
  ),
  top_k => 5,distance_type=>"COSINE" )"""

  column_sql=f"""  SELECT base.idx, base.detailed_description,distance
FROM VECTOR_SEARCH(
  TABLE `ss-genai-npd-svc-prj-01.nl2sql.column_comments_embeddings`,
  "ml_generate_embedding_result",
  (
  SELECT ml_generate_embedding_result, content AS query
  FROM ML.GENERATE_EMBEDDING(
  MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
  (SELECT "{question}" AS content))
  ),
  top_k => 5,distance_type=>"COSINE" )"""


  table_results_joined=""
  column_results_joined=""

  tables_df=pandas_gbq.read_gbq(table_sql,project_id=PROJECT_ID)
  columns_df=pandas_gbq.read_gbq(column_sql,project_id=PROJECT_ID)

  for index, row in tables_df.iterrows():
    table_results_joined = table_results_joined + row["detailed_description"] + ' \n'

  for index, row in columns_df.iterrows():
    column_results_joined = column_results_joined + row["detailed_description"] + '\n'

  return table_results_joined,column_results_joined

In [13]:
# @title Exact same question match search
def search_sql_vector_by_id(question):
    global FULL_LOG_TEXT
    msg=''
    sql=f"""select generated_sql from `ss-genai-npd-svc-prj-01.nl2sql.sql_embeddings` where
    idx=to_hex(md5("{question}"))"""
    results=pandas_gbq.read_gbq(sql,project_id=PROJECT_ID)
    # print(results)
    if results.empty:
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n SQL Not Found in Vector DB. \n'
        msg='SQL Not Found in Vector DB'

    for index,row in results.iterrows():
        print('\n Record found in Vector DB. Parameters: \n')
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Record found in Vector DB. Parameters: \n'
        #print(r[0])
        msg=str(row["generated_sql"])
    return msg

In [14]:
# @title Similar Question search
def search_sql_nearest_vector(question,):
    global FULL_LOG_TEXT

    msg=''
    sql=f""" SELECT base.question, base.generated_sql
FROM VECTOR_SEARCH(
  TABLE `ss-genai-npd-svc-prj-01.nl2sql.sql_embeddings`, "ml_generate_embedding_result",
  (
  SELECT ml_generate_embedding_result, content AS query
  FROM ML.GENERATE_EMBEDDING(
  MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
  (SELECT "{question}" AS content))
  ),
  top_k => 3,distance_type=>"COSINE" )

    """
    results=pandas_gbq.read_gbq(sql,project_id=PROJECT_ID)
    if results.empty:
     msg='SQL Nearest Not Found in Vector DB'
     print('\n No record near the query was found in the Vector DB. \n')
     FULL_LOG_TEXT= FULL_LOG_TEXT + '\n No record near the query was found in the Vector DB. \n'

    for index,row in results.iterrows():
      msg=msg + '\nQuestion:' + str(row["question"]) + '\n' + 'Generated SQL:' + str(row["generated_sql"]) + '\n'

    return msg

In [15]:
# @title Retry SQL generation with Chat Model
def rewrite_sql_chat(chat_session, question, generated_sql, table_result_joined , column_result_joined, error_msg, similar_questions):

  context_prompt = f"""
    What is an alternative SQL statement to address the error mentioned below?
    Present a different SQL from previous ones. It is important that the query still answer the original question.
    Do not repeat suggestions.

  Question:
  "{question}"

  Previously Generated (bad) SQL Query:
 " {generated_sql}"

  Error Message:
  {error_msg}

  Table Schema:
  {table_result_joined}

  Column Description:
  {column_result_joined}

  Good SQL Examples:
  {similar_questions}
  """

  #Column Descriptions:
  #{column_result_joined}


  if chat_model_id =='codechat-bison-32k':
      response = chat_session.send_message(context_prompt)
      resp_return = clean_sql(str(response.candidates[0]))
  elif chat_model_id =='gemini-pro':
      response = chat_session.send_message(context_prompt, stream=False)
      resp_return = clean_sql(str(response.text))

  else:
    raise ValueError('model_id not found')

  print(resp_return)

  return resp_return


In [16]:
# @title Adding Good SQLs to Vector Tables
def add_vector_sql_collection(question, final_sql):

  global FULL_LOG_TEXT

  sql=f'''
      insert into `ss-genai-npd-svc-prj-01.nl2sql.sql_embeddings`
      SELECT * FROM ML.GENERATE_EMBEDDING(
 MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
   (
   select to_hex(md5(concat(question))) as idx,
       "{question}" as question ,final_sql,current_datetime() as epoch_time,
       "{question}" as content  from ss-genai-npd-svc-prj-01.nl2sql.sql
  )
);
    '''

  print(sql)
  pandas_gbq.read_gbq(sql,project=PROJECT_ID)
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Record added to Vector DB. Parameters: \n'
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(question) + '\n'

  return 'Question ' + str(question) + ' added to the Vector DB'



In [17]:
# @title Log the process on to BQ
def append_2_bq(model, question, generated_sql, found_in_vector, need_rewrite, failure_step, error_msg):
  global FULL_LOG_TEXT

  if ENABLE_ANALYTICS is True:
      print('\nInside the Append to BQ block\n')
      table_id=PROJECT_ID + '.' + DATASET_NAME + '.' + LOG_TABLE_NAME
      now = datetime.now()

      table_exists=False
      client = bigquery.Client()

      df1 = pd.DataFrame(columns=[
          'source_type',
          'project_id',
          'user',
          'schema',
          'model_used',
          'question',
          'generated_sql',
          'found_in_vector',
          'need_rewrite',
          'failure_step',
          'error_msg',
          'execution_time',
          'full_log'
          ])

      new_row = {
          "source_type":source_type,
          "project_id":str(PROJECT_ID),
          "user":str(auth_user),
          "schema": schema,
          "model_used": model,
          "question": question,
          "generated_sql": generated_sql,
          "found_in_vector":found_in_vector,
          "need_rewrite":need_rewrite,
          "failure_step":failure_step,
          "error_msg":error_msg,
          "execution_time": now,
          "full_log": FULL_LOG_TEXT
        }

      df1.loc[len(df1)] = new_row

      db_schema=[
            # Specify the type of columns whose type cannot be auto-detected. For
            # example the "title" column uses pandas dtype "object", so its
            # data type is ambiguous.
            bigquery.SchemaField("source_type", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("project_id", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("user", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("schema", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("model_used", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("question", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("generated_sql", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("found_in_vector", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("need_rewrite", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("failure_step", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("error_msg", bigquery.enums.SqlTypeNames.STRING),
            bigquery.SchemaField("execution_time", bigquery.enums.SqlTypeNames.TIMESTAMP),
            bigquery.SchemaField("full_log", bigquery.enums.SqlTypeNames.STRING),
          ]

      try:
        client.get_table(table_id)  # Make an API request.
        #print("Table {} already exists.".format(table_id))
        table_exists=True
      except NotFound:
          print("Table {} is not found.".format(table_id))
          table_exists=False

      if table_exists is True:
          print('Performing streaming insert')
          errors = client.insert_rows_from_dataframe(table=table_id, dataframe=df1, selected_fields=db_schema)  # Make an API request.
          #if errors == []:
          #    print("New rows have been added.")
          #else:
          #    print("Encountered errors while inserting rows: {}".format(errors))
      else:
          pandas_gbq.to_gbq(df1, table_id, project_id=PROJECT_ID)  # replace to replace table; append to append to a table


      # df1.loc[len(df1)] = new_row
      # pandas_gbq.to_gbq(df1, table_id, project_id=PROJECT_ID, if_exists='append')  # replace to replace table; append to append to a table
      print('\n Query added to BQ log table \n')
      FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Query added to BQ log table \n'
      return 'Row added'
  else:
    print('\n BQ Analytics is disabled so query was not added to BQ log table \n')
    FULL_LOG_TEXT= FULL_LOG_TEXT + '\n BQ Analytics is disabled so query was not added to BQ log table \n'

    return 'BQ Analytics is disabled'


In [21]:
# @title Query call function

def call_gen_sql(question):

  # Reset question log variable
  global FULL_LOG_TEXT
  FULL_LOG_TEXT=''

  # Overwriting for testing purposes
  #INJECT_ONE_ERROR = True


  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n User Question: \n'
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(question) + '\n'

  # Will look into the Vector DB first and see if there is a hash match.
  # If yes, return the known good SQL.
  # If not, return 3 good examples to be used by the LLM
  search_sql_vector_by_id_return=search_sql_vector_by_id(question)
  print("search_sql_vector_by_id_return :" + search_sql_vector_by_id_return)

  if search_sql_vector_by_id_return == 'SQL Not Found in Vector DB':   ### Only go thru the loop if hash of the question is not found in Vector.

        # Look into Vector for similar queries. Similar queries will be added to the LLM prompt (few shot examples)
        similar_questions_return = search_sql_nearest_vector(question)
        # print('Found Similar Questions \n')
        #print(search_sql_vector_by_id_return)
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Found Similar Questions ... \n'
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(similar_questions_return) + '\n'


        unrelated_question=False
        stop_loop = False
        retry_max_count= SQL_MAX_FIX_RETRY
        retry_count=0
        chat_session=init_chat()
        table_result_joined,column_result_joined=get_tables_columns_bqvector(question)

        if len(table_result_joined) > 0 :
            #print('tables from vector:' + table_result_joined + ' : ' + str(len(table_result_joined)))
            generated_sql=gen_dyn_rag_sql(question,table_result_joined,column_result_joined, similar_questions_return)
            print('Generated SQL:\n' )
            print(generated_sql)
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Generated SQL: ... \n'
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + generated_sql + '\n'
            if 'unrelated_answer' in generated_sql :
              stop_loop=True
              #print('Inside if statement to check for unrelated question ...')
              unrelated_question=True
        else:
            stop_loop=True
            unrelated_question=True
            print('No ANN tables found in Vector ...')
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n No ANN tables found in Vector ... \n'



        while (stop_loop is False):

            ### Syntax validation via LLM block
            print('\n Will call PALM next to validate the generated SQL ... \n')
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Will call PALM next to validate the generated SQL ... \n'
            #FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + generated_sql + '\n'
            valid_sql_return=validate_sql_syntax(question, generated_sql, table_result_joined , column_result_joined,similar_questions_return)
            print('Return JSON from validation: ' + valid_sql_return)
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n  Return JSON from validation \n'
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + valid_sql_return + '\n'

            json_syntax_result=json.loads(valid_sql_return)
            print('\n SQL Syntax Validity:' + str(json_syntax_result['valid']))
            print('\n SQL Syntax Error Description:' +str(json_syntax_result['errors']) + '\n')
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n SQL Syntax Validity: \n'
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(json_syntax_result['valid']) + '\n'

            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n SQL Syntax Error Description: \n'
            FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(json_syntax_result['errors']) + '\n'

            if json_syntax_result['valid'] is True:   # LLM indicated the syntax is valid

              print('\n Testing code execution by performing explain plan on SQL ... \n')
              FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Testing code execution by performing explain plan on SQL ...: \n'
              #FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(json_syntax_result['errors']) + '\n'

              if INJECT_ONE_ERROR is True:
                if retry_count < 1:
                  print('\n Injecting error on purpose to test code ... Adding ROWID at the end of the string\n')
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Injecting error on purpose to test code ... Adding ROWID at the end of the string \n'
                  generated_sql=generated_sql + ' ROWID'

              sql_plan_test_result=test_sql_plan_execution(generated_sql) # Calling explain plan
              FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Calling explain plan on SQL ...: \n'
              FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(sql_plan_test_result) + '\n'


              if sql_plan_test_result == 'Execution Plan OK':  # Explain plan is OK

                stop_loop = True

                FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Execution plan came back OK ...: \n'

                if EXECUTE_FINAL_SQL is True:
                  final_exec_result_df=execute_final_sql(generated_sql)
                  print('\n Question: ' + question + '\n')
                  print('\n Final SQL Execution Result: \n')
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Final SQL Execution Result ... Question: \n'
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + question + '\n'
                  print(final_exec_result_df)
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(final_exec_result_df) + '\n'
                  if AUTO_ADD_KNOWNGOOD_SQL is True:  #### Adding to the Known Good SQL Vector DB
                    if len(final_exec_result_df) >= 1:
                      if not "ORA-" in str(final_exec_result_df.iloc[0,0]):
                          print('\n Adding Known Good SQL to Vector DB ... \n')
                          FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Adding Known Good SQL to Vector DB ... \n'
                          # add_vector_sql_collection_return=add_vector_sql_collection( question, generated_sql)
                      else:
                          ### Need to call retry
                          stop_loop = False
                          rewrite_result=rewrite_sql_chat(chat_session, question, generated_sql, table_result_joined , column_result_joined, str(final_exec_result_df.iloc[0,0]) ,similar_questions_return)
                          print('\n Rewritten SQL:')
                          print(rewrite_result)
                          FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Rewritten SQL: \n'
                          FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + rewrite_result + '\n'
                          generated_sql=rewrite_result
                          retry_count+=1


                else:  # Do not execute final SQL
                  print("Not executing final SQL since EXECUTE_FINAL_SQL variable is False\n ")
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Not executing final SQL since EXECUTE_FINAL_SQL variable is False \n'


                appen_2_bq_result=append_2_bq(model_id, question, generated_sql, 'N', 'N', '', '')

              else:  # Failure on explain plan execution
                  print("Error on explain plan execution: \n " + sql_plan_test_result)
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Error on explain plan execution \n'
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + sql_plan_test_result + '\n'

                  append_2_bq_result=append_2_bq(model_id, question, generated_sql, 'N', 'Y', 'explain_plan_validation', sql_plan_test_result )
                  ### Need to call retry
                  rewrite_result=rewrite_sql_chat(chat_session, question, generated_sql, table_result_joined , column_result_joined, sql_plan_test_result,similar_questions_return)
                  print('\n Rewritten SQL:')
                  print(rewrite_result)
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Rewritten SQL: \n'
                  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + rewrite_result + '\n'
                  generated_sql=rewrite_result
                  retry_count+=1

            else:  # syntax validation returned False
              print('Syntax Error')
              FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Syntax Error found ... \n'
              append_2_bq_result=append_2_bq(model_id, question, generated_sql, 'N', 'Y', 'syntax_validation', str(json_syntax_result['errors']))
              ### Need to call retry
              rewrite_result=rewrite_sql_chat(chat_session, question, generated_sql, table_result_joined , column_result_joined, str(json_syntax_result['errors']) , similar_questions_return)
              print('\n Rewritten SQL:')
              print(rewrite_result)
              FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Rewritten SQL: \n'
              FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + rewrite_result + '\n'
              generated_sql=rewrite_result
              retry_count+=1

            if retry_count > retry_max_count:
              stop_loop = True

        # After the while is completed
        if retry_count > retry_max_count:
          print('\n Oopss!!! Could not find a good SQL. This is the best I came up with !!!!! \n')
          print(generated_sql)
          FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Oopss!!! Could not find a good SQL. This is the best I came up with !!!!! \n'
          FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + generated_sql + '\n'


        # If query is unrelated to the dataset
        if unrelated_question is True:
          print('\n Question cannot be answered using this dataset! \n')
          FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Question cannot be answered using this dataset! \n'
          append_2_bq_result=append_2_bq(model_id, question, 'Question cannot be answered using this dataset!', 'N', 'N', 'unrelated_question', '')

          #print(generated_sql)

  else:   ## Found the record on vector id
    #print('\n Found Question in Vector. Returning the SQL')
    if EXECUTE_FINAL_SQL is True:
        final_exec_result_df=execute_final_sql(search_sql_vector_by_id_return)
        print('\n Question: ' + question + '\n')
        print('\n Final SQL Execution Result: \n')
        print(final_exec_result_df)
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Final SQL Execution Result ... Question: \n'
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + question + '\n'
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n' + str(final_exec_result_df) + '\n'

    else:  # Do not execute final SQL
        print("Not executing final SQL since EXECUTE_FINAL_SQL variable is False\n ")
        FULL_LOG_TEXT= FULL_LOG_TEXT + '\n Not executing final SQL since EXECUTE_FINAL_SQL variable is False \n'
    print('will call append to bq next')
    FULL_LOG_TEXT= FULL_LOG_TEXT + '\n will call append to bq next \n'
    appen_2_bq_result=append_2_bq(model_id, question, search_sql_vector_by_id_return, 'Y', 'N', '', '')
  final_answer=summarize_results( question, genefinal_exec_result_dfrated_sql,)
  print('\n Final Summarized Answer: \n')
  print(final_answer)
  FULL_LOG_TEXT= FULL_LOG_TEXT + '\n All Done! \n'
  return "\n All Done!", generated_sql

#Ask here

In [24]:
# question="Display the top 5 users by the  most total number of orders. Include the first name and the number of orders."
question = "What are the customer details of Abhinav Saxena who is of age 27?"
# question = "Show me summary analytics for Abhinav Saxena who is age 27"
start = time.time()
call_gen_result, generated_sql=call_gen_sql(question)
end = time.time()
print(call_gen_result)
print('\n Entire flow (including SQL execution) was executed in '+ str(end - start) + ' seconds') # time in seconds

Downloading: |[32m          [0m|
search_sql_vector_by_id_return :SQL Not Found in Vector DB
Downloading: |[32m          [0m|

 No record near the query was found in the Vector DB. 


 Initializing code chat model ...
Downloading: 100%|[32m██████████[0m|


GenericGBQException: Reason: 400 Array inputs are not equal in length; error in ML.DISTANCE expression

Location: asia-southeast1
Job ID: 3f489fbd-b756-41d0-a0b6-e4312b4e6b5c
