# GenAI Chatbot


Using HeatWave GenAI

## 1. Getting Everything Ready to Use the Database

### We install needed tools, load them, and set up the connection so we can talk to the database.

### a. Install All The Required Libraries

The latest version of the HeatWave AutoML Classification Notebook uses the latest version of python (3.11.9). It is highly suggested to match the python version to mitigate any errors upon execution of each cell

In [1]:
# Uncomment and run the following lines if you need to install the libraries
# !pip install sshtunnel pandas seaborn matplotlib mysql-connector-python

### b. Import Libraries and Configure Environment

Loads required libraries and sets display, plot, and styling options to improve data visualization and readability throughout the notebook.

In [2]:
# %reload_ext autotime

import re
import sshtunnel
import mysql.connector
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import ast

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 1000)
plt.rcParams['figure.figsize'] = [30, 15]
plt.rcParams['font.size'] = 15
sns.set(color_codes=True)
sns.set(font_scale=1.5)
sns.set_palette("bright")
sns.set_style("whitegrid")

### c. Connection to HeatWave Database Service

This section prepares the connection to your database by saving the necessary addresses and login details.

In [3]:
from getpass import getpass
BASTION_IP = '170.9.234.39'
BASTION_USER = 'opc'
DBSYSTEM_IP = '10.0.1.54'
DBSYSTEM_USER = 'admin'
# DBSYSTEM_PASSWORD = getpass("Enter your DB password: ")
DBSYSTEM_PASSWORD = '@Mysqlse2025'
DBSYSTEM_SCHEMA = 'airportdb'
Model_ID = 'meta.llama-3.1-405b-instruct'
DEBUG = False

### d. Function to start the SHH Tunnel

Set up the SSH tunnel via the Bastion node to the HeatWave DBSystem

In [4]:
tunnel = sshtunnel.SSHTunnelForwarder((BASTION_IP,22), ssh_username=BASTION_USER, ssh_pkey='/Users/oscarden/Documents/SandBox/Ashburn6/ssh-key-2024-01-02.key', remote_bind_address=(DBSYSTEM_IP, 3306))
tunnel.start()
mydb = mysql.connector.connect(host="127.0.0.1", port=tunnel.local_bind_port, user=DBSYSTEM_USER, password=DBSYSTEM_PASSWORD, database=DBSYSTEM_SCHEMA, allow_local_infile=True, use_pure=True,autocommit=True)
mycursor = mydb.cursor()

def execute_sql(sql):
    if DEBUG:
        print(f'Running {sql}')
    mycursor.execute(sql)
    return pd.DataFrame(mycursor.fetchall(), columns=mycursor.column_names)

execute_sql(f"""SELECT version()""")

Unnamed: 0,version()
0,9.3.0-cloud


### a. View the Structure of All Tables in the Database

Asks the database to list all tables and their columns, so we can understand how the data is stored.

In [5]:
# Extract schema information from the database
query = f"""
SELECT
    TABLE_NAME,
    COLUMN_NAME,
    COLUMN_TYPE,
    IS_NULLABLE,
    COLUMN_KEY,
    COLUMN_COMMENT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = '{DBSYSTEM_SCHEMA}'
ORDER BY TABLE_NAME, ORDINAL_POSITION;
"""

db_schema_info = execute_sql(query)

# Preview the first few rows of the schema info
db_schema_info.head(50)

Unnamed: 0,TABLE_NAME,COLUMN_NAME,COLUMN_TYPE,IS_NULLABLE,COLUMN_KEY,COLUMN_COMMENT
0,airline,airline_id,smallint,NO,PRI,Unique identifier for each airline.
1,airline,iata,char(2),YES,UNI,"Two-character IATA code assigned to the airline, used globally for identification."
2,airline,airlinename,varchar(30),YES,,The full name of the airline.
3,airline,base_airport,smallint,YES,MUL,"ID of the base airport for the airline, referring to the primary operational hub."
4,airplane,airplane_id,int,NO,PRI,Unique identifier for each airplane. This is the primary key and is auto-incremented.
5,airplane,capacity,mediumint unsigned,NO,,Maximum number of passengers that the airplane can accommodate.
6,airplane,type_id,int,NO,MUL,Identifier for the airplane model/type. This is a foreign key referencing the airplane_type table.
7,airplane,airline_id,int,NO,,Identifier of the airline that owns or operates the airplane. This is a foreign key referencing the airline table.
8,airplane_type,type_id,int,NO,PRI,Unique identifier for each airplane type or model.
9,airplane_type,identifier,varchar(50),YES,MUL,Model identifier or code for the airplane type.


### b. Clean Up AI-Generated SQL Code

Fixes any weird formatting or extra characters in the SQL code generated by AI.

In [6]:
def extract_clean_sql(raw_response):
    # Remove leading and trailing quotes if present
    if raw_response.startswith("'") and raw_response.endswith("'"):
        raw_response = raw_response[1:-1]

    # Try to parse JSON if wrapped in {"text": "..."}
    try:
        parsed = json.loads(raw_response)
        text = parsed.get("text", "")
    except json.JSONDecodeError:
        text = raw_response

    # Replace escaped newlines, quotes, and backslashes
    cleaned = text.replace('\\n', '\n').replace('\\', '').replace('\\"', '"').strip()

    # Remove ```sql and ``` if present
    if cleaned.startswith('```sql'):
        cleaned = cleaned[6:]  # Remove the first 6 characters: ```sql
    if cleaned.startswith('```'):
        cleaned = cleaned[3:]  # Remove any lingering ```
    if cleaned.endswith('```'):
        cleaned = cleaned[:-3]  # Remove ending ```

    return cleaned.strip()

Translates user input into English using sys.ML_GENERATE with a preamble. Matches structure and logic to call_ml_generate().

In [7]:
def translate_to_english(user_input, user_language, model_id=Model_ID):

    # Build preamble
    preamble = (
        f"You are a professional translator. Translate the following text into English, "
        f"keeping the meaning intact. The original language is '{user_language}'. "
        "Return only the English translation without any additional explanations or markdown."
    )

    # Combine preamble and user input
    combined_question = f"{preamble.strip()}\n\n{user_input.strip()}"

    # Escape quotes
    combined_question_escaped = combined_question.replace("'", "\\'")

    # Build the SQL to send
    sql = f"""
    SELECT sys.ML_GENERATE(
        '{combined_question_escaped}',
        JSON_OBJECT(
            'task', 'generation',
            'model_id', '{model_id}',
            'language', 'en',  -- Always English output after translation
            'max_tokens', 4000
        )
    ) AS response;
    """

    if DEBUG:
        print("Executing ML_GENERATE to translate to English:")
        print(f"- Original Language: {user_language}")
        print(f"- Text to Translate:\n{user_input}")

    # Execute translation
    mycursor.execute(sql)
    result = mycursor.fetchall()
    translated_text = result[0][0]

    # Clean up the result
    return extract_clean_sql(translated_text)

 Converts the user's question into a SQL query based on the database schema context. Handles translation dynamically only if needed (if user_language != 'en').

In [8]:
def call_ml_generate(full_question, user_language='en', model_id=Model_ID):

    # Step 1: Translate only if user_language is not English
    if user_language.lower() != 'en':
        translated_question = translate_to_english(full_question, user_language=user_language)
    else:
        translated_question = full_question

    # Step 2: Build SQL generation based on the (translated) question
    schema_name = DBSYSTEM_SCHEMA

    preamble = (
        f"You are an expert in MySQL. Convert the following question into a SQL query that retrieves data "
        f"from the '{schema_name}' database. The schema of '{schema_name}' is provided as context. "
        "Avoid using information_schema or metadata queries. Return only the SQL query without explanations or markdown."
    )

    combined_question = f"{preamble.strip()}\n\n{translated_question.strip()}"

    context_lines = [
        f"Table: {row['TABLE_NAME']}, Column: {row['COLUMN_NAME']}, Type: {row['COLUMN_TYPE']}, Nullable: {row['IS_NULLABLE']}, Key: {row['COLUMN_KEY']}, Comment: {row['COLUMN_COMMENT']}"
        for _, row in db_schema_info.iterrows()
    ]
    context_text = "\n".join(context_lines)

    combined_question_escaped = combined_question.replace("'", "\\'")
    context_text_escaped = context_text.replace("'", "\\'")

    sql = f"""
    SELECT sys.ML_GENERATE(
        '{combined_question_escaped}',
        JSON_OBJECT(
            'task', 'generation',
            'model_id', '{model_id}',
            'language', 'en',
            'context', '{context_text_escaped}',
            'max_tokens', 4000
        )
    ) AS response;
    """

    if DEBUG:
        print("Executing ML_GENERATE with:")
        print(f"- User Language: {user_language}")
        print(f"- Final Question to Model:\n{translated_question}")
        print(f"- Context preview:\n" + "\n".join(context_lines[:10]) + "\n...")

    mycursor.execute(sql)
    result = mycursor.fetchall()
    return result[0][0]

In [9]:
def run_generated_sql_with_repair(raw_response, translated_prompt, max_attempts=10):
    attempt = 0
    current_response = raw_response

    restricted_patterns = re.compile(r'\b(INSERT|UPDATE|DELETE|DROP|ALTER|TRUNCATE|CREATE|REPLACE)\b', re.IGNORECASE)

    while attempt < max_attempts:
        attempt += 1

        sql_query = extract_clean_sql(current_response).strip()

        if DEBUG:
            print(f"\nAttempt {attempt}: Running cleaned SQL:")
            print(sql_query)

        if restricted_patterns.search(sql_query):
            return f"❌ Restricted operation detected in SQL:\n{sql_query}"

        try:
            queries = [q.strip() for q in sql_query.split(';') if q.strip()]
            results = []
            columns_set = set()

            for q in queries:
                mycursor.execute(q)

                try:
                    result = mycursor.fetchall()
                    columns = tuple(mycursor.column_names)
                    df = pd.DataFrame(result, columns=columns)
                    results.append((columns, df))
                    columns_set.add(columns)
                except mysql.connector.errors.InterfaceError:
                    results.append(("no_result", f"✅ Query executed successfully (no result set): {q}"))

                while mycursor.nextset():
                    pass

            dataframes = [df for col, df in results if isinstance(df, pd.DataFrame)]
            if len(columns_set) == 1 and len(dataframes) == len(queries):
                return pd.concat(dataframes, ignore_index=True)
            else:
                return {f"Query {i+1}": df for i, (_, df) in enumerate(results)}

        except mysql.connector.Error as err:
            print(f"❌ MySQL Error: {err}")

            # Self- Healing prompt
            repair_prompt = (
                f"The original user intent was:\n\n{translated_prompt}\n\n"
                f"The following SQL query was generated but caused an error:\n\n{sql_query}\n\n"
                f"The error message was:\n\n{str(err)}\n\n"
                f"Please regenerate a corrected SQL query that fulfills the user's original intent, "
                f"works against the '{DBSYSTEM_SCHEMA}' database, and only uses SELECT statements. "
                f"Do not include INSERT, UPDATE, DELETE, DROP, ALTER, or any other DDL/DML statements. "
                f"Return only the corrected SQL query without explanations or markdown formatting."
            )

            while mycursor.nextset():
                pass

            current_response = call_ml_generate(repair_prompt)

    return "❌ Failed to generate a valid, executable SQL query after multiple attempts."

 Turns a SQL final_result (DataFrame or dict of DataFrames) into a natural language answer using sys.ML_GENERATE, with context injection and respecting user language.

In [10]:
def generate_natural_language_answer(user_prompt, final_result, user_language, model_id=Model_ID):
    # Step 1: Prepare the context text from final_result
    if isinstance(final_result, pd.DataFrame):
        context_text = final_result.to_string(index=False)
    elif isinstance(final_result, dict):
        context_parts = []
        for label, df in final_result.items():
            context_parts.append(f"{label}:\n{df.to_string(index=False)}")
        context_text = "\n\n".join(context_parts)
    else:
        context_text = str(final_result)

    # Escape quotes in user prompt and context
    user_prompt_escaped = user_prompt.replace("'", "\\'")
    context_text_escaped = context_text.replace("'", "\\'")

    # Step 2: Build the SQL to ask ML_GENERATE
    sql = f"""
    SELECT sys.ML_GENERATE(
        'Respond to the following question {user_prompt_escaped} making use of the context provided',
        JSON_OBJECT(
            'task', 'generation',
            'model_id', '{model_id}',
            'language', '{user_language}',
            'context', '{context_text_escaped}',
            'max_tokens', 4000
        )
    ) AS response;
    """

    if DEBUG:
        print("Executing ML_GENERATE to turn SQL result into natural language:")
        print(f"- Prompt: {user_prompt}")
        print(f"- Target Language: {user_language}")
        print(f"- Context preview:\n{context_text[:1000]}...")  # Only show first 1000 chars

    # Step 3: Execute
    mycursor.execute(sql)
    result = mycursor.fetchall()
    natural_language_response = result[0][0]

    # Step 4: Clean output
    return extract_clean_sql(natural_language_response)

In [11]:
def full_pipeline(user_question, user_language, model_id=Model_ID):
    """
    Complete pipeline:
    1. Translate user question to English if needed.
    2. Generate SQL query.
    3. Execute SQL query with self-repair if needed.
    4. If result has <= 24 rows, turn into NL. Otherwise, return structured result.
    """
    # Step 1: Translate the user question if needed
    if user_language.lower() != 'en':
        translated_question = translate_to_english(user_question, user_language=user_language, model_id=model_id)
    else:
        translated_question = user_question

    # Step 2: Generate SQL from translated question
    response = call_ml_generate(translated_question, model_id=model_id)

    # Step 3: Execute SQL with retry/self-repair
    final_result = run_generated_sql_with_repair(response, translated_prompt=translated_question)

    # Step 4: Check result size
    if isinstance(final_result, pd.DataFrame):
        num_rows = len(final_result)
    elif isinstance(final_result, dict):
        num_rows = sum(len(df) for df in final_result.values() if isinstance(df, pd.DataFrame))
    else:
        num_rows = 0  # Fallback for string result etc.

    if DEBUG:
        print(f"Result has {num_rows} rows.")

    # Step 5: Decide based on size
    if num_rows > 24:
        return final_result  # Structured output
    else:
        return generate_natural_language_answer(
            translated_question,
            final_result,
            user_language=user_language,
            model_id=model_id
        )

In [13]:
user_question = "Cuantas aerolineas existen?"
user_language = 'es'

result = full_pipeline(user_question, user_language=user_language)

# Display logic based on return type
if isinstance(result, pd.DataFrame):
    display(result)
elif isinstance(result, dict):
    for label, df in result.items():
        print(f"\n🔹 {label}")
        display(df)
else:
    print(result)  # Natural language answer

Existen 113 aerolíneas.
