In [1]:
from llama_index.core import SQLDatabase
from llama_index.core.query_engine import NLSQLTableQueryEngine
from IPython.display import Markdown, display
from llama_index.core import Settings
import os

os.environ["OPENAI_API_KEY"]=''

In [2]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    Integer,
    String,
    Date,
    ForeignKey,
    insert
)
from datetime import datetime

# Function to parse string to date
def parse_date(date_str):
    return datetime.strptime(date_str, "%Y-%m-%d").date()

# Create an in-memory SQLite database engine
engine = create_engine("sqlite:///:memory:")
metadata = MetaData()  #An object that holds information about the database schema (tables, columns, etc.).

# Define the 'employee' table
employee_table = Table(
    "employee",
    metadata,
    Column("emp_id", Integer, primary_key=True),
    Column("first_name", String(40)),
    Column("last_name", String(40)),
    Column("birth_day", Date),
    Column("sex", String(1)),
    Column("salary", Integer),
    Column("super_id", Integer, ForeignKey("employee.emp_id", ondelete="SET NULL")),
    Column("branch_id", Integer, ForeignKey("branch.branch_id", ondelete="SET NULL")),
)

"""A foreign key is a database concept used to maintain relationships between tables and ensure data integrity. 
It is a column (or a set of columns) in one table that establishes a link to a column in another table (or sometimes the same table)
ON DELETE: Specifies what happens in the child table when a row in the parent table is deleted (e.g., CASCADE, SET NULL, RESTRICT).
ON UPDATE: Specifies what happens when a referenced value in the parent table is updated.
"""

# Define the 'branch' table
branch_table = Table(
    "branch",
    metadata,
    Column("branch_id", Integer, primary_key=True),
    Column("branch_name", String(40)),
    Column("mgr_id", Integer, ForeignKey("employee.emp_id", ondelete="SET NULL")),
    Column("mgr_start_date", Date),
)

# Define the 'client' table
client_table = Table(
    "client",
    metadata,
    Column("client_id", Integer, primary_key=True),
    Column("client_name", String(40)),
    Column("branch_id", Integer, ForeignKey("branch.branch_id", ondelete="SET NULL")),
)

# Define the 'works_with' table
works_with_table = Table(
    "works_with",
    metadata,
    Column("emp_id", Integer, ForeignKey("employee.emp_id", ondelete="CASCADE")),
    Column("client_id", Integer, ForeignKey("client.client_id", ondelete="CASCADE")),
    Column("total_sales", Integer),
)

# Define the 'branch_supplier' table
branch_supplier_table = Table(
    "branch_supplier",
    metadata,
    Column("branch_id", Integer, ForeignKey("branch.branch_id", ondelete="CASCADE")),
    Column("supplier_name", String(40)),
    Column("supply_type", String(40)),
)

# Create all tables
metadata.create_all(engine)

# Insert data into tables
employee_details =[
        {"emp_id": 100, "first_name": "David", "last_name": "Wallace", "birth_day": parse_date("1967-11-17"), "sex": "M", "salary": 250000, "super_id": None, "branch_id": None},
        {"emp_id": 101, "first_name": "Jan", "last_name": "Levinson", "birth_day": parse_date("1961-05-11"), "sex": "F", "salary": 110000, "super_id": 100, "branch_id": 1},
        {"emp_id": 102, "first_name": "Michael", "last_name": "Scott", "birth_day": parse_date("1964-03-15"), "sex": "M", "salary": 75000, "super_id": 100, "branch_id": None},
        {"emp_id": 103, "first_name": "Angela", "last_name": "Martin", "birth_day": parse_date("1971-06-25"), "sex": "F", "salary": 63000, "super_id": 102, "branch_id": 2},
        {"emp_id": 104, "first_name": "Kelly", "last_name": "Kapoor", "birth_day": parse_date("1980-02-05"), "sex": "F", "salary": 55000, "super_id": 102, "branch_id": 2},
        {"emp_id": 105, "first_name": "Stanley", "last_name": "Hudson", "birth_day": parse_date("1958-02-19"), "sex": "M", "salary": 69000, "super_id": 102, "branch_id": 2},
        {"emp_id": 107, "first_name": "Andy", "last_name": "Bernard", "birth_day": parse_date("1958-02-19"), "sex": "M", "salary": 69000, "super_id": 102, "branch_id": 3},
        {"emp_id": 108, "first_name": "Jim", "last_name": "Halpert", "birth_day": parse_date("1958-02-19"), "sex": "M", "salary": 69000, "super_id": 102, "branch_id": 3},
    ]

for row in employee_details:
    stmt = insert(employee_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

# Insert into branch table
branch_table_values=[
        {"branch_id": 1, "branch_name": "Corporate", "mgr_id": 100, "mgr_start_date": parse_date("2006-02-09")},
        {"branch_id": 2, "branch_name": "Scranton", "mgr_id": 102, "mgr_start_date": parse_date("1992-04-06")},
        {"branch_id": 3, "branch_name": "Stamford", "mgr_id": 106, "mgr_start_date": parse_date("1998-02-13")},
    ]

for row in branch_table_values:
    stmt = insert(branch_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)


    # Insert into client table
client_table_values=[
        {"client_id": 400, "client_name": "Dunmore Highschool", "branch_id": 2},
        {"client_id": 401, "client_name": "Lackawana Country", "branch_id": 2},
        {"client_id": 402, "client_name": "FedEx", "branch_id": 3},
        {"client_id": 403, "client_name": "John Daly Law, LLC", "branch_id": 3},
        {"client_id": 404, "client_name": "Scranton Whitepages", "branch_id": 2},
        {"client_id": 405, "client_name": "Times Newspaper", "branch_id": 3},
    ]
for row in client_table_values:
    stmt = insert(client_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)


    # Insert into works_with table
works_with_table_values = [
        {"emp_id": 105, "client_id": 400, "total_sales": 55000},
        {"emp_id": 102, "client_id": 401, "total_sales": 267000},
        {"emp_id": 108, "client_id": 402, "total_sales": 22500},
        {"emp_id": 107, "client_id": 403, "total_sales": 5000},
        {"emp_id": 108, "client_id": 403, "total_sales": 12000},
        {"emp_id": 105, "client_id": 404, "total_sales": 33000},
        {"emp_id": 107, "client_id": 405, "total_sales": 26000},
        {"emp_id": 102, "client_id": 406, "total_sales": 15000},
        {"emp_id": 105, "client_id": 406, "total_sales": 130000},
  ]
for row in works_with_table_values:
    stmt = insert(works_with_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)


    # Insert into branch_supplier table
branch_supplier_table_values=[
        {"branch_id": 2, "supplier_name": "Hammer Mill", "supply_type": "Paper"},
        {"branch_id": 2, "supplier_name": "Uni-ball", "supply_type": "Writing Utensils"},
        {"branch_id": 3, "supplier_name": "Patriot Paper", "supply_type": "Paper"},
        {"branch_id": 2, "supplier_name": "J.T. Forms & Labels", "supply_type": "Custom Forms"},
        {"branch_id": 3, "supplier_name": "Uni-ball", "supply_type": "Writing Utensils"},
        {"branch_id": 3, "supplier_name": "Hammer Mill", "supply_type": "Paper"},
        {"branch_id": 3, "supplier_name": "Stamford Lables", "supply_type": "Custom Forms"},
    ]
for row in branch_supplier_table_values:
    stmt = insert(branch_supplier_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [3]:
import pandas as pd
from sqlalchemy import select

def display_table_as_dataframe(table, connection):
    stmt = select(table)  # Create a SELECT statement for the table
    result = connection.execute(stmt)  # Execute the statement
    df = pd.DataFrame(result.fetchall(), columns=result.keys())  # Convert the result to a pandas DataFrame
    print(df)  # Display the DataFrame

# Display the tables
with engine.connect() as conn:
    print("Employee Table:")
    display_table_as_dataframe(employee_table, conn)
    print("\nBranch Table:")
    display_table_as_dataframe(branch_table, conn)
    print("\nClient Table:")
    display_table_as_dataframe(client_table, conn)


Employee Table:
   emp_id first_name last_name   birth_day sex  salary  super_id  branch_id
0     100      David   Wallace  1967-11-17   M  250000       NaN        NaN
1     101        Jan  Levinson  1961-05-11   F  110000     100.0        1.0
2     102    Michael     Scott  1964-03-15   M   75000     100.0        NaN
3     103     Angela    Martin  1971-06-25   F   63000     102.0        2.0
4     104      Kelly    Kapoor  1980-02-05   F   55000     102.0        2.0
5     105    Stanley    Hudson  1958-02-19   M   69000     102.0        2.0
6     107       Andy   Bernard  1958-02-19   M   69000     102.0        3.0
7     108        Jim   Halpert  1958-02-19   M   69000     102.0        3.0

Branch Table:
   branch_id branch_name  mgr_id mgr_start_date
0          1   Corporate     100     2006-02-09
1          2    Scranton     102     1992-04-06
2          3    Stamford     106     1998-02-13

Client Table:
   client_id          client_name  branch_id
0        400   Dunmore Highschool

In [4]:
with engine.connect() as conn:
    print("\nWorks With Table:")
    display_table_as_dataframe(works_with_table, conn)
    print("\nBranch Supplier Table:")
    display_table_as_dataframe(branch_supplier_table, conn)


Works With Table:
   emp_id  client_id  total_sales
0     105        400        55000
1     102        401       267000
2     108        402        22500
3     107        403         5000
4     108        403        12000
5     105        404        33000
6     107        405        26000
7     102        406        15000
8     105        406       130000

Branch Supplier Table:
   branch_id        supplier_name       supply_type
0          2          Hammer Mill             Paper
1          2             Uni-ball  Writing Utensils
2          3        Patriot Paper             Paper
3          2  J.T. Forms & Labels      Custom Forms
4          3             Uni-ball  Writing Utensils
5          3          Hammer Mill             Paper
6          3      Stamford Lables      Custom Forms


In [5]:
from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT emp_id from employee"))
    for row in rows:
        print(row)

(100,)
(101,)
(102,)
(103,)
(104,)
(105,)
(107,)
(108,)


In [6]:
from llama_index.llms.openai import OpenAI
llm = OpenAI(temperature=0.1, model="gpt-4o-mini")

# Text-to-SQL Query Engine

In [7]:
sql_database = SQLDatabase(engine, include_tables=["employee","branch"])
query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["employee","branch"], llm=llm
)
query_str = "I need name of the all emplyees who are with client John Daly Law LLC"
response = query_engine.query(query_str)

In [8]:
display(Markdown(f"{response}"))

There are currently no employees associated with the client John Daly Law LLC.