### PGVector RAG Tool for Question Answering - Text to PostgreSQL SQL with Groupchat

In [1]:
from dotenv import load_dotenv
load_dotenv() 
import numpy as np
import pandas as pd
import json
import os
import openai
import time
from openai import AzureOpenAI
import psycopg
from sentence_transformers import SentenceTransformer
import autogen
from autogen import GroupChat
from autogen import GroupChatManager
from autogen import AssistantAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen import ConversableAgent, UserProxyAgent, config_list_from_json
from autogen.retrieve_utils import TEXT_FORMATS
import psycopg2
from psycopg2.extras import execute_values
from pgvector.psycopg2 import register_vector
from autogen import AssistantAgent, ConversableAgent, UserProxyAgent

  from .autonotebook import tqdm as notebook_tqdm
flaml.automl is not available. Please install flaml[automl] to enable AutoML functionalities.


Reference: https://microsoft.github.io/autogen/0.2/docs/reference/agentchat/contrib/vectordb/pgvectordb

In [2]:
llm_config = {
    "config_list": [
        {
            "model": "gpt-4o",
            "api_key": os.getenv("OPENAI_API_KEY"),
            "api_type": "azure",
            "base_url": os.getenv("OPENAI_API_BASE"),
            "api_version": os.getenv("API_VERSION"),
        },
    ],
    "temperature": 0.0,
    "timeout": 300,
}

In [3]:
def embed_openai(text: str):
    client = AzureOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        api_version="2024-02-01",
        azure_endpoint=os.getenv("OPENAI_API_BASE")
    )

    response = client.embeddings.create(
        input=text,
        model="text-embedding-ada-002"
    )

    embeddings = [np.array(embedding.embedding) for embedding in response.data]

    if len(embeddings) == 1:
        return embeddings[0]

    return np.array(embeddings)

In [4]:
conn = psycopg2.connect(
    dbname='test',
    user='postgres',
    password='test_password',
    host='localhost',
    port='5433'
)

cur = conn.cursor()

cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';")
tables = cur.fetchall()

for table in tables:
    print(table)

for table in tables:
    table_name = table[0]  
    print(f"Dropping table: {table_name}") 
    
    cur.execute(f"DROP TABLE IF EXISTS {table_name} CASCADE;")
    conn.commit() 

cur.close()
conn.close()

print("All tables in 'public' schema have been dropped.")

All tables in 'public' schema have been dropped.


Catalog >> Limit the search to 3-5 tables RAG on (description of the tables (embeddings)) >> run the logic on the retrieved tables 

### data

In [5]:
data = {
    "id": ['442515', '752565', '742142', '741265', '147785'],
    "documents": [
        'All new drivers must complete a state-approved education course that covers traffic laws, signs, and safe driving practices.',
        'Driver must complete a minimum of 10 hours of behind-the-wheel driving instruction with a licensed instructor.',
        'Drivers over 65 must complete a refresher course to renew their driver license, focusing on changes in road safety and laws.',
        'Drivers seeking an advanced license must complete a defensive driving course to improve skills in handling challenging road conditions.',
        'Teenagers between the ages of 16-18 are required to complete a state-certified driver safety course before obtaining a license.'
    ],
    "metadatas": [
        'Driver Education Course Requirement', 
        'Behind-the-Wheel Training', 
        'License Renewal Training', 
        'Advanced Driving Course', 
        'Teen Driver Safety Training'
    ]
}

df = pd.DataFrame(data)

df['embedding'] = df["documents"].apply(lambda x: embed_openai(x)) 

display(df.head())

Unnamed: 0,id,documents,metadatas,embedding
0,442515,All new drivers must complete a state-approved...,Driver Education Course Requirement,"[0.009563765488564968, -0.0038687754422426224,..."
1,752565,Driver must complete a minimum of 10 hours of ...,Behind-the-Wheel Training,"[-0.001858723466284573, -0.0006039661238901317..."
2,742142,Drivers over 65 must complete a refresher cour...,License Renewal Training,"[0.005323335062712431, -0.007420213893055916, ..."
3,741265,Drivers seeking an advanced license must compl...,Advanced Driving Course,"[-0.0024872892536222935, 0.004641841631382704,..."
4,147785,Teenagers between the ages of 16-18 are requir...,Teen Driver Safety Training,"[0.013322370126843452, -0.003426025155931711, ..."


In [6]:
conn = psycopg2.connect(
    dbname='test',
    user='postgres',
    password='test_password',
    host='localhost',
    port='5433'
)

cur = conn.cursor()

cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
conn.commit()

data_list = [
    (
        row['id'], 
        row['documents'], 
        row['metadatas'], 
        row['embedding'].tolist() if isinstance(row['embedding'], np.ndarray) else row['embedding']
    )
    for index, row in df.iterrows()
]

create_table_query = """
CREATE TABLE IF NOT EXISTS policies (
    id TEXT PRIMARY KEY, 
    documents TEXT,
    metadatas TEXT,  
    embedding vector(1536)  -- Adjust the size to match your model's embedding dimensions
);
"""

cur.execute(create_table_query)
conn.commit()

insert_query = """
INSERT INTO policies (id, documents, metadatas, embedding) 
VALUES %s
ON CONFLICT (id) DO UPDATE 
SET documents = EXCLUDED.documents,
    metadatas = EXCLUDED.metadatas,
    embedding = EXCLUDED.embedding;
"""

execute_values(cur, insert_query, data_list)

conn.commit()

cur.close()
conn.close()

print(f"{len(data_list)} rows inserted or updated successfully.")

5 rows inserted or updated successfully.


## Text to PostgreSQL SQL with Groupchat 

### Tool Schema
Pydantic model

https://microsoft.github.io/autogen/0.2/docs/tutorial/tool-use#tool-schema

### Tool 1

In [7]:
def sql_writer() -> str: 

    conn = psycopg2.connect(
        dbname=os.getenv("MYPOSTGRES_DATABASE"),
        user=os.getenv("MYPOSTGRES_USER"),
        password=os.getenv("MYPOSTGRES_PASSWORD"), 
        host=os.getenv("MYPOSTGRES_HOST"),
        port=os.getenv("MYPOSTGRES_PORT")
    )

    cur = conn.cursor()

    query = """
    SELECT 
        table_schema, 
        table_name, 
        column_name, 
        data_type
    FROM 
        information_schema.columns
    WHERE 
        table_schema NOT IN ('information_schema', 'pg_catalog')  -- Exclude system schemas
    ORDER BY 
        table_schema, 
        table_name, 
        ordinal_position;
    """

    cur.execute(query)

    rows = cur.fetchall()

    schema_info = {}
    for row in rows:
        table_schema, table_name, column_name, data_type = row
        if table_schema not in schema_info:
            schema_info[table_schema] = {}
        if table_name not in schema_info[table_schema]:
            schema_info[table_schema][table_name] = []
        schema_info[table_schema][table_name].append({
            "column_name": column_name,
            "data_type": data_type
        })

    cur.close()
    conn.close()  

    return json.dumps(schema_info, indent=4)

json_columns_info = sql_writer()
print(json_columns_info)

{
    "public": {
        "policies": [
            {
                "column_name": "id",
                "data_type": "text"
            },
            {
                "column_name": "documents",
                "data_type": "text"
            },
            {
                "column_name": "metadatas",
                "data_type": "text"
            },
            {
                "column_name": "embedding",
                "data_type": "USER-DEFINED"
            }
        ]
    }
}


### Tool 2

In [8]:
def execute_query(query: str) -> str: 
    try:
        conn = psycopg2.connect(
            dbname=os.getenv("MYPOSTGRES_DATABASE"),
            user=os.getenv("MYPOSTGRES_USER"),
            password=os.getenv("MYPOSTGRES_PASSWORD"), 
            host=os.getenv("MYPOSTGRES_HOST"),
            port=os.getenv("MYPOSTGRES_PORT")
        )

        register_vector(conn)
        cur = conn.cursor()
        cur.execute(query)
        res = cur.fetchall()
        cur.close()
        conn.close()
        return res
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

### Agent 1

In [9]:
template1 = """
Role: PostgreSQL SQL Writer Specialist
Task: Given a question and schema information you receive using the tool sql_writer, your job is to create a PostgreSQL SQL query
based on the question requirements.
Output: Valid PostgreSQL SQL query (with proper joins, filtering, grouping, etc., based on the given goal).
Do not wrap the output in quotation marks.
Do not wrap the output in code block delimiters (```)
"""

In [10]:
text_to_sql_agent = ConversableAgent(
    name="text_to_sql_agent",
    system_message= template1 + "schema information:",
    llm_config=llm_config,
)

### Agent 2

In [11]:
template2 = """
PostgreSQL SQL Executor Task: Given a PostgreSQL SQL query you receive,
your job is to execute the PostgreSQL SQL query using the tool execute_query.
Output: Valid results
Example: execute_query("SELECT metadatas FROM public.policies;")
Results: 
[('Driver Education Course Requirement',),
 ('Behind-the-Wheel Training',),
 ('License Renewal Training',),
 ('Advanced Driving Course',),
 ('Teen Driver Safety Training',)]

Do not wrap the output in quotation marks.
Do not wrap the output in code block delimiters (```)
"""

In [12]:
execute_sql_agent = ConversableAgent(
    name="execute_sql_agent",
    system_message= template2 + "Query:",
    llm_config=llm_config,
)

### Agent 3

In [13]:
user = ConversableAgent(
    name="user",
    system_message= "You are asking questions to your assistants and reporting back in a detailed tone.",
    llm_config=llm_config,
    human_input_mode= "NEVER",
)

### Registering Tools

In [14]:
text_to_sql_agent.register_for_llm(
    name="sql_writer", 
    description="Generate PostgreSQL SQL query."
)(sql_writer)


execute_sql_agent.register_for_llm(
    name="execute_query",
    description= "Execute a PostgreSQL SQL query."
)(execute_query)


user.register_for_execution(
    name="sql_writer"
)(sql_writer)


execute_sql_agent.register_for_execution(
    name="execute_query"
)(execute_query)

<function __main__.execute_query(query: str) -> str>

### Group Chat

In [15]:
allowed_transitions = {
    user: [text_to_sql_agent],
    text_to_sql_agent: [execute_sql_agent],
    execute_sql_agent: [user]
}

constrained_graph_chat = GroupChat(
    agents=[user,
            text_to_sql_agent,
            execute_sql_agent
           ],
    allowed_or_disallowed_speaker_transitions=allowed_transitions,
    speaker_transitions_type="allowed",
    messages=[],
    max_round=10,
    send_introductions=True,
)

constrained_group_chat_manager = GroupChatManager(
    groupchat=constrained_graph_chat,
    llm_config=llm_config,
)

## Question

In [16]:
user_query = "Give me the content abou Teen Driver Safety Training?"

## initiate_chat

In [17]:
chat_result = user.initiate_chat(
    constrained_group_chat_manager,
    message= f"question: {user_query}",
    summary_method="last_msg",
)

[33muser[0m (to chat_manager):

question: Give me the content abou Teen Driver Safety Training?

--------------------------------------------------------------------------------
[32m
Next speaker: text_to_sql_agent
[0m
[31m
>>>>>>>> USING AUTO REPLY...[0m
[33mtext_to_sql_agent[0m (to chat_manager):

[32m***** Suggested tool call (call_gVsx8wI2Qcrf9PgodsqED8GY): sql_writer *****[0m
Arguments: 
{"question": "Give me the content about Teen Driver Safety Training?"}
[32m***************************************************************************[0m

--------------------------------------------------------------------------------
[32m
Next speaker: user
[0m
[35m
>>>>>>>> EXECUTING FUNCTION sql_writer...[0m
[33muser[0m (to chat_manager):

[33muser[0m (to chat_manager):

[32m***** Response from calling tool (call_gVsx8wI2Qcrf9PgodsqED8GY) *****[0m
Error: sql_writer() got an unexpected keyword argument 'question'
[32m*******************************************************

## Group Chat - Customize Speaker Selection

In [18]:
def state_transition(last_speaker, groupchat):
    messages = groupchat.messages

    if last_speaker is user:
        return text_to_sql_agent
    elif last_speaker is text_to_sql_agent:
        return execute_sql_agent
    elif last_speaker is execute_sql_agent:
        return None

groupchat = autogen.GroupChat(
    agents=[user,
            text_to_sql_agent,
            execute_sql_agent
           ],
    allowed_or_disallowed_speaker_transitions=allowed_transitions,
    speaker_transitions_type="allowed",
    messages=[],
    max_round=15,
    send_introductions=True,
)

manager2 = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)

## initiate_chat

In [19]:
user_query = "Give me the content abou Teen Driver Safety Training?"

response = user.initiate_chat(
    manager2,
    message=user_query
)

[33muser[0m (to chat_manager):

Give me the content abou Teen Driver Safety Training?

--------------------------------------------------------------------------------
[32m
Next speaker: text_to_sql_agent
[0m
[31m
>>>>>>>> USING AUTO REPLY...[0m
[33mtext_to_sql_agent[0m (to chat_manager):

[32m***** Suggested tool call (call_w7kDq6BxRbLJEyzJlEVZ9JMZ): sql_writer *****[0m
Arguments: 
{"schema_information":{"tables":[{"name":"public.policies","columns":[{"name":"metadatas","type":"text"}]}]},"question":"Give me the content about Teen Driver Safety Training"}
[32m***************************************************************************[0m

--------------------------------------------------------------------------------
[32m
Next speaker: user
[0m
[35m
>>>>>>>> EXECUTING FUNCTION sql_writer...[0m
[33muser[0m (to chat_manager):

[33muser[0m (to chat_manager):

[32m***** Response from calling tool (call_w7kDq6BxRbLJEyzJlEVZ9JMZ) *****[0m
Error: sql_writer() got an une