In [4]:
import os
import pickle
import shutil
import pandas as pd
import gradio as gr
from config import PATHS
from secret_keys import *
import gradio as gr
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
    inspect,
    text,
    exc,
)
from smolagents import CodeAgent, InferenceClientModel, tool
engine = create_engine("sqlite:///agentDB.db")
metadata_obj = MetaData()

  from .autonotebook import tqdm as notebook_tqdm


# Define Custom Exception Classes

In [5]:
class ColumnDataTypeError(Exception):
    def __init__(self, num_col, NUM_COL,len_columns, message="Number of table data types must match number of uploaded columns and the length of the columns variable in dynamic table"):
        self.num_col = num_col
        self.NUM_COL = NUM_COL
        self.len_columns = len_columns
        super().__init__(message)

class TableIntializationError(Exception):
    def __init__(self, TABLE_NAME, message=f"Table '{PATHS.TABLE_NAME}' does not exist."):
        self.TABLE_NAME = TABLE_NAME
        super().__init__(message)

# Load Table Data

In [6]:
def load_rows():
    # load dict from pickle
    with open(PATHS.PKL_FILE_PATH, "rb") as f:
        sql_dict = pickle.load(f)

    # collect column names
    col_names = list(sql_dict.keys())
    num_cols = len(col_names)

    # Ensure the dictionary is not empty
    if not col_names:
        raise ValueError("The dictionary is empty.")
        
    # collect table rows from dict
    num_rows = len(sql_dict[col_names[0]])
    rows = []
    # Iterate through dict collecting each columns info as a row
    for i in range(num_rows):
        row = {}
        for col in col_names:
            value = sql_dict[col][i]
            row[col] = value
        rows.append(row)
    return col_names, rows, num_cols

#   Setup SQL Environment
1. SQL engine initialization

In [7]:
def insert_rows(rows, table, engine = engine):
    for row in rows:
        stmt = insert(table).values(**row)
        with engine.begin() as connection:
            connection.execute(stmt)

2. Dynamic table

In [19]:
def create_dynamic_table(table_name, columns):
       table = Table(table_name, metadata_obj, 
           Column('id', Integer, primary_key=True),
           * [Column(name, type_) for name, type_ in columns.items()],
           extend_existing=True
        )
       return table 
    
def update_table(column_type):
    #load rows for the table
    col_names, rows, num_cols = load_rows()
    # split str into list of data types
    dataType_list = column_type.split(",")
    try:
        if len(dataType_list) != len(col_names):
            raise ValueError()
        for i in range(len(dataType_list)):
            match dataType_list[i].strip():
                case "String":
                    dataType_list[i] = String
                case "Integer":
                    dataType_list[i] = Integer
                case "Float":
                    dataType_list[i] = Float
            if dataType_list[i] != String and dataType_list[i] != Float and dataType_list[i] != Integer:
                raise TypeError()
    except TypeError as e:
        return f"A data type you entered was invalid."
    except ValueError as e:
        return f"{e}. Number of data types ({len(dataType_list)}) does not match number of columns ({len(col_names)})."
    
    # Dynamically create the columns dictionary
    columns = {
        col_name: dataType_list[i]  # Map column name to data type by index
        for i, col_name in enumerate(col_names)
    }
    print(columns)
    len_cols = len(columns)
    dynamic_table = create_dynamic_table(PATHS.TABLE_NAME, columns)
    metadata_obj.create_all(engine)
    
    try:   
        insert_rows(rows, dynamic_table)
    except exc.CompileError as e:
        return (f"{e}.") 
    except exc.OperationalError as e:
        return (f"{e}. agentDB has already had it's schema defined.")
    return "Row insertion succesful"
# column_type = "Integer, Integer, Integer, Integer, Integer, String, String, String, String, String, Integer, String, String"
# update_table(column_type)

#   Agent Setup
1.  Make SQL table retrievable by a tool

In [22]:
def table_description():
    inspector = inspect(engine)
    try:
        columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(PATHS.TABLE_NAME)]
        
        table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
    except exc.NoSuchTableError as e:
        return f"NoSuchTableError: {e}. The referenced table does not exist."
    return table_description
table_description()

'Columns:\n - id: INTEGER\n - price: INTEGER\n - area: INTEGER\n - bedrooms: INTEGER\n - bathrooms: INTEGER\n - stories: INTEGER\n - mainroad: VARCHAR\n - guestroom: VARCHAR\n - basement: VARCHAR\n - hotwaterheating: VARCHAR\n - airconditioning: VARCHAR\n - parking: INTEGER\n - prefarea: VARCHAR\n - furnishingstatus: VARCHAR'

2. Error checking tools

In [11]:
def table_check()-> str:
    """
    Verify the table exists. Returns a string which will say if the table exists or not.
        Args:
            None
        """
    try:
        inspector = inspect(engine)
        if inspector.has_table(PATHS.TABLE_NAME):
            return f"Table '{PATHS.TABLE_NAME}' exists."
        else:
            raise exc.NoSuchTableError() 
    except exc.NoSuchTableError as e:
        return f"NoSuchTableError: {e} The referenced table does not exist."

3.  SQL Tool

In [12]:
@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
        The Table is named agent_table. 
        Args:
            query: The query to be performed on the table. This should always be correct SQL.
        """
    output = ""
    
    with engine.begin() as con:
        try:
            rows = con.execution_options(autocommit=True).execute(text(query))
            if not rows:
                return "No rows found, include the `RETURNING` keyword to ensure the result object always returns rows."
            else:
                for row in rows:
                    output += str(row) + "\n"
        except exc.SQLAlchemyError as e:
            return f"{e}. Include the `RETURNING` keyword to ensure the result object always returns rows."   
    return output

4. Perform a Query for Debugging

In [11]:
# # sql_engine(f"UPDATE {PATHS.TABLE_NAME} SET Paid = Paid - 5 WHERE name = 'John Smith';")
# # sql_engine("SELECT * FROM {PATHS.TABLE_NAME} WHERE name = 'John Smith';")
# result = sql_engine(query="SELECT * FROM {PATHS.TABLE_NAME};")                                      
# print(result)          

5.  Agent Initialization

In [13]:
# model initialization \
def agent_setup():
    sql_model = InferenceClientModel(
            api_key= NEBIUS_API_KEY,
            model_id= "Qwen/Qwen3-235B-A22B", #Qwen/Qwen3-4B
            provider = "nebius",
        )
    
    # define SQL Agent
    sql_agent = CodeAgent(
        tools = [sql_engine],
        model = sql_model,
        max_steps = 5,
    )
    return sql_agent

In [13]:
# sql_agent.visualize()

# Define Function For Prompt Insertion

In [14]:
def run_prompt(prompt, history):
    table_descrip = table_description()
    table_status = table_check()
    if "NoSuchTableError" in table_status:
        return table_status + " Check the table has the expected name and it is consistent."  
    agent = agent_setup()
    return agent.run(prompt + f". Always wrap the result in relevant context and enforce the results object returning rows. Table description is as follows:{table_descrip}")

# Define Functions For Gradio interface

In [None]:
def vote(data: gr.LikeData):
    if data.liked:
        print("You upvoted this response: " + data.value["value"])
    else:
        print("You downvoted this response: " + data.value["value"])

def process_file(fileobj):
    csv_path = PATHS.TEMP_PATH + os.path.basename(fileobj)
    # copy file to path
    shutil.copyfile(fileobj.name, csv_path)
    return csv_2_dict(csv_path)

def csv_2_dict(path):
    # read csv as dataframe then drop empties
    df = pd.read_csv(path)
    df_cleaned = df.dropna()
    # convert dataframe to a dictionary and save as pickle file
    table_data = df_cleaned.to_dict(orient='list')
    with open(PATHS.PKL_FILE_PATH, "wb") as f:
        pickle.dump(table_data, f)

def change_insert_mode(choice):
    table_status = table_check()
    if choice == "Upload New" and not "NoSuchTableError" in table_status:
        sql_engine(f"DROP TABLE {PATHS.TABLE_NAME};")
    else:
        pass

    
with gr.Blocks() as demo:
    with gr.Tab("Table Setup"):
        insert_mode = gr.Radio(["Upload New", "Upload to Existing"], label="Insertion Mode", info="Warning selecting Upload New will immediately drop existing table, leaving unseleted will add to existing table.")
        insert_mode.input(fn=change_insert_mode, inputs=insert_mode, outputs=None)
        gr.Markdown("Next upload the csv:")
        gr.Interface(
            fn=process_file,
            inputs=[
                "file",
            ],
            outputs= None,
            flagging_mode = "never"
            )
        column_type =  gr.Textbox(label="Enter column data types (String, Integer, Float) as a comma seperated list:")
        column_type_message = gr.Textbox(label="Feedback:")
        col_type_button = gr.Button("Submit")
        col_type_button.click(update_table, inputs=column_type, outputs=[column_type_message,])
    with gr.Tab("Text2SQL Agent"):
        chatbot = gr.Chatbot(type="messages" , placeholder=f"<strong>Ask agent to perform a query.</strong>" )
        chatbot.like(vote, None, None)
        gr.ChatInterface(fn=run_prompt, type="messages", chatbot=chatbot)
demo.launch(debug = True)

* Running on local URL:  http://127.0.0.1:7864
* To create a public link, set `share=True` in `launch()`.


{'contact_email': <class 'sqlalchemy.sql.sqltypes.String'>, 'name': <class 'sqlalchemy.sql.sqltypes.String'>, 'camps': <class 'sqlalchemy.sql.sqltypes.Integer'>, 'Owed': <class 'sqlalchemy.sql.sqltypes.Float'>, 'Paid': <class 'sqlalchemy.sql.sqltypes.Float'>, 'Balance': <class 'sqlalchemy.sql.sqltypes.Float'>, 'Month': <class 'sqlalchemy.sql.sqltypes.String'>, 'number': <class 'sqlalchemy.sql.sqltypes.Integer'>}


In [None]:
# String, String, Integer, Float, Float, Float, String, Integer
# Integer, Integer, Integer, Integer, Integer, String, String, String, String, String, Integer, String, String