In [None]:
import ollama
import psycopg2
from psycopg2 import sql

# =========================================================
# CONFIG
# =========================================================
DB_CONFIG = {
    "dbname": "rag_db",
    "user": "postgres",
    "password": "postgres",
    "host": "localhost",
    "port": 5432,
}

EMBED_MODEL = "nomic-embed-text"
CHAT_MODEL = "deepseek-r1"
EMBED_DIM = 768


# =========================================================
# DB CONNECTION
# =========================================================
conn = psycopg2.connect(**DB_CONFIG)
cur = conn.cursor()

# Enable pgvector
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
conn.commit()


# =========================================================
# DB UTILITIES
# =========================================================
def get_all_tables():
    cur.execute("""
        SELECT table_name
        FROM information_schema.tables
        WHERE table_schema = 'public'
          AND table_type = 'BASE TABLE'
        ORDER BY table_name
    """)
    return [row[0] for row in cur.fetchall()]


table_names = get_all_tables()
table_list_str = ", ".join(table_names)


def ensure_table(tablename: str):
    query = sql.SQL("""
        CREATE TABLE IF NOT EXISTS {} (
            id SERIAL PRIMARY KEY,
            context TEXT,
            embedding VECTOR({})
        )
    """).format(
        sql.Identifier(tablename),
        sql.SQL(str(EMBED_DIM))
    )
    cur.execute(query)
    conn.commit()


def insert_data(tablename: str, context: str, embedding: list):
    query = sql.SQL("""
        INSERT INTO {} (context, embedding)
        VALUES (%s, %s)
    """).format(sql.Identifier(tablename))
    cur.execute(query, (context, embedding))
    conn.commit()


def search_similar(tablename: str, embedding: list, limit: int = 3):
    query = sql.SQL("""
        SELECT context
        FROM {}
        ORDER BY embedding <=> %s::vector
        LIMIT %s
    """).format(sql.Identifier(tablename))
    cur.execute(query, (embedding, limit))
    return [row[0] for row in cur.fetchall()]


# =========================================================
# EMBEDDING
# =========================================================
def embed_text(text: str) -> list:
    res = ollama.embeddings(
        model=EMBED_MODEL,
        prompt=text
    )
    return res["embedding"]


# =========================================================
# LLM ROUTING
# =========================================================
def record_data(context: str) -> str:
    system_prompt = (
        "You are a strict decision engine.\n"
        "If the context clearly matches an existing table, return EXACTLY that table name.\n"
        "If not, create ONE new short english table name.\n"
        "Return ONLY the table name.\n\n"
        f"Existing tables: {table_list_str}"
    )

    response = ollama.chat(
        model=CHAT_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": context},
        ]
    )
    return response["message"]["content"].strip()


def choose_table_for_question(question: str) -> str:
    system_prompt = (
        "You are a semantic router.\n"
        "Choose ONE most relevant table for answering the question.\n"
        "Return ONLY the table name.\n\n"
        f"Existing tables: {', '.join(table_names)}"
    )

    response = ollama.chat(
        model=CHAT_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question},
        ]
    )
    return response["message"]["content"].strip()


# =========================================================
# QA
# =========================================================
def answer_question(question: str, contexts: list, language: str) -> str:
    joined_context = "\n".join(f"- {c}" for c in contexts)

    system_prompt = (
        "You are a question answering system.\n"
        "Answer ONLY using the provided context.\n"
        f"Answer in {language}.\n"
        "If the context is insufficient, say you don't know."
    )

    response = ollama.chat(
        model=CHAT_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": f"Context:\n{joined_context}\n\nQuestion:\n{question}"
            }
        ]
    )
    return response["message"]["content"]


# =========================================================
# CLI
# =========================================================
print("Select answer language:")
print("1 = Thai")
print("2 = English")

lang_choice = input("Language (1/2): ").strip()
answer_language = "English" if lang_choice == "2" else "Thai"

print(f"Answer language set to: {answer_language}")

print("\nChoose mode:")
print("1 = Ingest (save memory)")
print("2 = Ask (question)")
print("Type /bye to exit")


while True:
    mode = input("\nMode (1/2): ").strip()

    if mode.lower() == "/bye":
        print("Bye üëã")
        break

    # -----------------------
    # MODE 1: INGEST
    # -----------------------
    if mode == "1":
        context = input("Context to save: ").strip()
        if not context:
            continue

        table = record_data(context)
        print("Chosen table:", table)

        if table not in table_names:
            ensure_table(table)
            table_names.append(table)

        embedding = embed_text(context)
        insert_data(table, context, embedding)

        print("Saved ‚úÖ")

    # -----------------------
    # MODE 2: ASK
    # -----------------------
    elif mode == "2":
        question = input("Question: ").strip()
        if not question:
            continue

        table = choose_table_for_question(question)
        print("Searching in table:", table)

        if table not in table_names:
            print("No relevant data found ‚ùå")
            continue

        q_embedding = embed_text(question)
        contexts = search_similar(table, q_embedding)

        if not contexts:
            print("No matching context ‚ùå")
            continue

        answer = answer_question(question, contexts, answer_language)
        print("\nAnswer:")
        print(answer)

    else:
        print("Invalid mode")


# =========================================================
# CLEANUP
# =========================================================
cur.close()
conn.close()


Select answer language:
1 = Thai
2 = English


Language (1/2):  1


Answer language set to: Thai

Choose mode:
1 = Ingest (save memory)
2 = Ask (question)
Type /bye to exit



Mode (1/2):  1
Context to save:  ‡∏û‡∏µ‡∏û‡∏µ ‡∏ä‡∏∑‡πà‡∏≠‡∏à‡∏£‡∏¥‡∏á‡∏ä‡∏∑‡πà‡∏≠ ‡∏®‡∏±‡∏Å‡∏¢‡πå‡∏®‡∏£‡∏ì‡πå


Chosen table: users
Saved ‚úÖ



Mode (1/2):  2
Question:  ‡∏û‡∏µ‡∏û‡∏µ‡∏Ñ‡∏∑‡∏≠‡πÉ‡∏Ñ‡∏£


Searching in table: users

Answer:
‡∏û‡∏µ‡∏û‡∏µ‡∏Ñ‡∏∑‡∏≠ nickname ‡∏Ç‡∏≠‡∏á ‡∏®‡∏±‡∏Å‡∏¢‡πå‡∏®‡∏£‡∏ì‡πå
