# RAG Preprocessing

In [64]:
from openai import OpenAI
import os
import psycopg2
import instructor
from pydantic import BaseModel
import openai
import sqlvalidator
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain

In [65]:
postgres_username = os.getenv("POSTGRES_USERNAME")
postgres_pwd = os.getenv("POSTGRES_PASSWORD")

client = instructor.from_openai(OpenAI())

In [66]:
conn = psycopg2.connect(
    dbname="postgresdb",
    user=postgres_username,
    password=postgres_pwd,
    host="host.docker.internal",  # e.g., "localhost"
    port="5433"        # default PostgreSQL port
)
cursor = conn.cursor()

In [8]:
sql_query = "SELECT * FROM us_attractions LIMIT 5;"

cursor.execute(sql_query)
rows = cursor.fetchall()

In [14]:
def remove_vector_like_elements(tuples_list):
    new_list = []
    for tpl in tuples_list:
        filtered_tpl = tuple(
            x for x in tpl 
            if not (isinstance(x, str) and x.startswith('[') and x.endswith(']'))
        )
        new_list.append(filtered_tpl)
    return new_list

In [15]:
new_list = remove_vector_like_elements(rows)
new_list[0]

(1331,
 'Waterfall Garden',
 'Park',
 4.6,
 1054.0,
 'Park',
 'Waterfall Garden, 219 2nd Ave S, Seattle, WA 98104',
 'Seattle',
 'USA',
 'WA',
 None,
 'Nature',
 4848.4,
 4.66,
 'Manchester, Seattle')

In [68]:
class SQLQuery(BaseModel):
    sql_command: str


class RAGGenerationResponse(BaseModel):
    answer: str

In [None]:
def is_valid_sql(query):
    parsed = sqlvalidator.parse(query)
    return parsed.is_valid()


def build_sql_generate_prompt(user_query):
    prompt = f"""
    You are a PostgreSQL expert. You only respond with PostgreSQL commands for the question asked by the user.

    You are given a database schema:
        Schema: public
        Table: us_attractions
        Columns:
        - id INTEGER PRIMARY KEY
        - name VARCHAR(250)
        - main_category VARCHAR(250)
        - rating REAL
        - reviews REAL
        - categories VARCHAR(250)
        - address VARCHAR(250)
        - city VARCHAR(250)
        - country VARCHAR(250)
        - state VARCHAR(250)
        - zipcode INTEGER
        - broader_category VARCHAR(250)
        - weighted_score REAL
        - weighted_average REAL
        - all_cities VARCHAR(250)

    The us_attractions table only has information for USA only. The values under the country column are all 'USA'. 

    Translate the following user question into PostgreSQL query statement:

    "{user_query}"
    Instructions:
    - Write PostgreSQL query using the "public" schema for all tables (e.g., public.us_attraction).
    - If the US State label is given in full, search for the abbreviated form in all caps.
    - Always perform LOWER() on all string type comparisons, filtering, etc.
    - If user does not specify the number of top-ranked records, provide up to 5 records.
    - If you cannot respond a PostgreSQL command respond with 'Sorry, no relevant data was found in the database for your query.'. Don't respond with anything else.
    """
    return prompt


def build_rag_response_prompt(rows, user_question, sql_query):
    formatted_rows = "\n".join([", ".join(map(str, row)) for row in rows])
    # print(formatted_rows)

    # Create a prompt for the LLM
    prompt = f'''Here are the SQL query results:\n{formatted_rows}

    Generated by this SQL query: {sql_query}\n
    '''
    if user_question:
        prompt += f"Based on these results, answer the question: {user_question}"

    return prompt


def generate_sql_query(prompt):
    response, _ = client.chat.completions.create_with_completion(
        model="gpt-4.1-mini",
        response_model=SQLQuery,
        messages=[{"role":"user", "content": prompt}],
        temperature=0
    )
    return response


def retrieve_from_postgres(cursor, sql_query):
    cursor.execute(sql_query)
    rows = cursor.fetchall()
    return rows


def generate_answer(prompt):
    response, _ = client.chat.completions.create_with_completion(
        model="gpt-4.1-mini",
        response_model=RAGGenerationResponse,
        messages=[{"role":"user", "content": prompt}],
        temperature=0.
    )
    return response.answer




In [82]:
def get_embedding(text):
    result = openai.embeddings.create(
        input=[text],
        model="text-embedding-3-small"
    )
    embedding = result.data[0].embedding 
    return embedding


def generate_sql_command(question):
    response, _ = client.chat.completions.create_with_completion(
        model="gpt-4.1-mini",
        response_model=SQLQuery,
        messages=[{"role":"user", "content": question}],
        temperature=0
    )
    return response


def execute_sql_query(cursor, sql_query):
    rows = list()
    if is_valid_sql(sql_query):
        cursor.execute(sql_query)
        rows = cursor.fetchall()
    return rows


def execute_sql_similarity_match(cursor, input_embedding, top_k=5):
    vector_search_sql = "SELECT * FROM attraction_table ORDER BY embedding <-> %s::vector LIMIT %s"
    cursor.execute(vector_search_sql, (input_embedding, top_k))
    rows = cursor.fetchall()
    return rows

In [83]:
def rag_pipeline(user_question, psycopg_cursor):
    query_result = list()
    sql_prompt = build_sql_generate_prompt(user_question)
    text2sql_response = generate_sql_command(sql_prompt)
    print(text2sql_response)
    if is_valid_sql(text2sql_response.sql_command):
        query_result = retrieve_from_postgres(psycopg_cursor, text2sql_response.sql_command)
        formatted_query_result = [", ".join(map(str, row)) for row in query_result]
        if not query_result:
            # No data found guardrail
            answer = "Sorry, no relevant data was found in the database for your query."
        else:
            response_prompt = build_rag_response_prompt(query_result, user_question, text2sql_response.sql_command)
            answer = generate_answer(response_prompt)
    else:
        answer = text2sql_response
    
    final_result = {
        "answer": answer,
        "question": user_question,
        # "retrieved_context_ids": [row[0] for row in query_result],
        "retrieved_context": formatted_query_result,
        "generated_sql": text2sql_response.sql_command
    }
    return final_result

In [86]:
user_question = "Are there theme parks suitable for family outings in San Diego?"
sql_prompt = build_sql_generate_prompt(user_question)
sql_query = generate_sql_query(sql_prompt)
# answer = rag_pipeline(user_question, cursor)

In [87]:
sql_query

SQLQuery(sql_command="SELECT * FROM public.us_attractions WHERE LOWER(city) = 'san diego' AND LOWER(categories) LIKE '%theme park%' AND LOWER(main_category) LIKE '%family%' ;")

In [80]:
answer

{'answer': 'Some top-rated tourist attractions in Savannah, Georgia include:\n1. Hearse Ghost Tours (Rating: 4.9, Reviews: 1912) - 1410 E Broad St, Savannah, GA 31401\n2. Pin Point Heritage Museum (Rating: 4.9, Reviews: 436) - 9924 Pin Point Ave, Savannah, GA 31406\n3. SCADstory (Rating: 4.9, Reviews: 81) - 342 Bull St, Savannah, GA 31401\n4. Forsyth Park (Rating: 4.8, Reviews: 16538) - Savannah, GA 31401\n5. The Cathedral Basilica of St. John the Baptist (Rating: 4.8, Reviews: 5911) - 222 E Harris St, Savannah, GA 31401\n6. Fort Pulaski National Monument (Rating: 4.8, Reviews: 5221) - 101 Fort Pulaski Rd, Savannah, GA\n7. Fountain at Forsyth Park (Rating: 4.8, Reviews: 4234) - 1 W Gaston St, Savannah, GA 31401\n8. Graveface Museum (Rating: 4.8, Reviews: 642) - 410 E Lower, Factors Walk, Savannah, GA 31401\n9. Monterey Square (Rating: 4.8, Reviews: 586) - 11 W Gordon St, Savannah, GA 31401\n10. Madison Square (Rating: 4.8, Reviews: 504) - 332 Bull St, Savannah, GA 31401',
 'question': 

In [44]:
type(sql_query)

__main__.SQLQuery

In [45]:
sql_query.sql_query

"SELECT name, rating, reviews, address, city, state, zipcode FROM public.us_attractions WHERE city = 'Savannah' AND state = 'GA' ORDER BY rating DESC, reviews DESC LIMIT 10;"

In [61]:
answer['retrieved_context']

[(24,
  'Pin Point Heritage Museum',
  'Museum',
  4.9,
  436.0,
  'Museum, Heritage museum, Tourist attraction',
  'Pin Point Heritage Museum, 9924 Pin Point Ave, Savannah, GA 31406',
  'Savannah',
  'USA',
  'GA',
  None,
  'Cultural',
  2136.4,
  4.53,
  'Atlanta, Augusta, Chattanooga, Savannah'),
 (39,
  'SCADstory',
  'Tourist attraction',
  4.9,
  81.0,
  'Tourist attraction, Amusement park ride, Museum, Private university',
  'SCADstory, 342 Bull St, Savannah, GA 31401',
  'Savannah',
  'USA',
  'GA',
  None,
  'Entertainment',
  396.9,
  4.59,
  'Atlanta, Augusta, Chattanooga, Savannah'),
 (9,
  'Hearse Ghost Tours',
  'Tourist attraction',
  4.9,
  1912.0,
  'Tourist attraction, Entertainment agency, Tour operator, Travel clinic',
  'Hearse Ghost Tours, 1410 E Broad St, Savannah, GA 31401',
  'Savannah',
  'USA',
  'GA',
  None,
  'Entertainment',
  9368.8,
  4.59,
  'Atlanta, Augusta, Chattanooga, Savannah'),
 (29,
  'Beach Institute African American Cultural Center',
  'Muse