In [None]:
import sqlite3
import openai
import networkx as nx
import json

# Schema Extractor
class SchemaExtractor:
    def __init__(self, db_file):
        self.db_file = db_file

    def extract_schema(self):
        conn = sqlite3.connect(self.db_file)
        cursor = conn.cursor()

        schema = {}
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()
        for table in tables:
            table_name = table[0]
            cursor.execute(f"PRAGMA table_info({table_name});")
            columns = cursor.fetchall()
            schema[table_name] = [col[1] for col in columns]

        # Extract foreign key relationships
        foreign_keys = {}
        for table_name in schema.keys():
            cursor.execute(f"PRAGMA foreign_key_list({table_name});")
            keys = cursor.fetchall()
            foreign_keys[table_name] = [
                {
                    "from_column": key[3],
                    "to_table": key[2],
                    "to_column": key[4]
                } for key in keys
            ]
        conn.close()
        return schema, foreign_keys

# GPT-4 Query Function
def query_gpt4_for_subgraph(user_query, schema, foreign_keys):
    schema_str = "\n".join(
        [f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()]
    )

    fk_str = "\n".join(
        [f"Table: {table}\nForeign Keys: {fk}" for table, fk in foreign_keys.items()]
    )

    prompt = f"""
You are an assistant that analyzes database schemas to answer user queries. 
Here is the database schema:

{schema_str}

Here are the foreign key relationships:

{fk_str}

The user's question is: "{user_query}"

Identify the relevant tables and columns needed to answer the question. Additionally, provide a subgraph schema that connects these tables and columns, including their relationships. Format the response in JSON:

{{
    "relevant_tables": ["table_name1", "table_name2"],
    "relevant_columns": ["table_name1.column1", "table_name2.column3"],
    "subgraph_schema": {{
        "nodes": ["table_name1", "table_name2", "table_name1.column1", "table_name2.column3"],
        "edges": [
            {{"from": "table_name1", "to": "table_name2", "relation": "foreign_key", "details": {{"from_column": "column1", "to_column": "column2"}}}},
            {{"from": "table_name1", "to": "table_name1.column1", "relation": "column"}},
            {{"from": "table_name2", "to": "table_name2.column3", "relation": "column"}}
        ]
    }}
}}
"""

    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
    )
    return response['choices'][0]['message']['content']

# Parse GPT-4 Response
def parse_gpt4_response(response):
    parsed = json.loads(response)
    relevant_tables = parsed["relevant_tables"]
    relevant_columns = parsed["relevant_columns"]
    subgraph_schema = parsed["subgraph_schema"]
    return relevant_tables, relevant_columns, subgraph_schema

# Build Subgraph
def build_graph_from_schema(subgraph_schema):
    graph = nx.Graph()

    # Add nodes
    for node in subgraph_schema["nodes"]:
        graph.add_node(node)

    # Add edges
    for edge in subgraph_schema["edges"]:
        graph.add_edge(edge["from"], edge["to"], relation=edge.get("relation"), details=edge.get("details"))

    return graph

# RAG Pipeline
class RAGPipeline:
    def __init__(self, db_file, user_query):
        self.db_file = db_file
        self.user_query = user_query
        self.schema_extractor = SchemaExtractor(db_file)

    def run(self):
        print("Extracting schema...")
        schema, foreign_keys = self.schema_extractor.extract_schema()

        print("Querying GPT-4 for relevant tables, columns, and subgraph...")
        gpt_response = query_gpt4_for_subgraph(self.user_query, schema, foreign_keys)
        relevant_tables, relevant_columns, subgraph_schema = parse_gpt4_response(gpt_response)

        print("Relevant Tables:", relevant_tables)
        print("Relevant Columns:", relevant_columns)

        print("Building subgraph...")
        graph = build_graph_from_schema(subgraph_schema)

        print("Subgraph Nodes:")
        print(graph.nodes(data=True))

        print("Subgraph Edges:")
        print(graph.edges(data=True))

        return relevant_tables, relevant_columns, graph

# Main Runner
if __name__ == "__main__":
    DB_FILE = "company_data.db"
    USER_QUERY = "How many people have booked at least 50 hours before 01/05/2024?"

    # Initialize RAG pipeline
    pipeline = RAGPipeline(DB_FILE, USER_QUERY)

    # Run the pipeline
    relevant_tables, relevant_columns, graph = pipeline.run()
