# Parse Query Components

In [8]:
import sqlglot
from dataclasses import dataclass

In [1]:
@dataclass
class QueryComponents:
    tables: set
    columns: set
    joins: list
    conditions: list

        
def parse_sql(query: str) -> QueryComponents:
    parsed = sqlglot.parse_one(query)
    components = QueryComponents(
        tables=set(),
        columns=set(),
        joins=[],
        conditions=[]
    )
    
    # Extract tables and their aliases
    table_aliases = {}
    for table in parsed.find_all(sqlglot.exp.Table):
        components.tables.add(table.name)
        if table.alias:
            table_aliases[table.alias] = table.name
    
    # Extract columns with table context
    for col in parsed.find_all(sqlglot.exp.Column):
        table_name = table_aliases.get(col.table, col.table)  # Resolve aliases
        components.columns.add(f"{table_name}.{col.name}")
    
    # Extract joins with corrected source/target
    for join in parsed.find_all(sqlglot.exp.Join):
        source_table = join.this.alias or join.this.name
        target_table = join.args.get('on').find(sqlglot.exp.Column).table
        components.joins.append({
            "source": table_aliases.get(source_table, source_table),
            "target": table_aliases.get(target_table, target_table),
            "type": "LEFT JOIN" if join.side == "LEFT" else "INNER JOIN"
        })
    
    # Extract conditions
    components.conditions = [where.this.sql() for where in parsed.find_all(sqlglot.exp.Where)]
    
    return components

In [2]:
parse_sql("SELECT users.name FROM users RIGHT JOIN orders ON users.id = orders.user_id")

QueryComponents(tables={'users', 'orders'}, columns={'orders.user_id', 'users.name', 'users.id'}, joins=[{'source': 'orders', 'target': 'users', 'type': 'INNER JOIN'}], conditions=[])

In [3]:
parse_sql("DELETE FROM products WHERE price > 100")

QueryComponents(tables={'products'}, columns={'.price'}, joins=[], conditions=['price > 100'])

In [4]:
queries = [
    "SELECT users.name FROM users JOIN orders ON users.id = orders.user_id",
    "DELETE FROM products WHERE price > 100"
]
for q in queries:
    print(parse_sql(q))

QueryComponents(tables={'users', 'orders'}, columns={'orders.user_id', 'users.name', 'users.id'}, joins=[{'source': 'orders', 'target': 'users', 'type': 'INNER JOIN'}], conditions=[])
QueryComponents(tables={'products'}, columns={'.price'}, joins=[], conditions=['price > 100'])


# Build Knowledge Graph

In [13]:
import numpy as np
import networkx as nx
from sentence_transformers import SentenceTransformer

In [76]:
class KnowledgeGraph:
    def __init__(self):
        self.graph = nx.MultiDiGraph()
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
        
    def add_query(self, query: str):
        components = parse_sql(query)
        query_id = f"query_{hash(query)}"
        
        # Add nodes
        self.graph.add_node(query_id, type="query", embedding=self._get_embedding(query))
        
        # Add tables and connect to query
        for table in components.tables:
            self.graph.add_node(table, type="table")
            self.graph.add_edge(query_id, table, relation="accesses")
            
            # Add columns as sub-nodes of tables
            for col in [c for c in components.columns if c.startswith(f"{table}.")]:
                col_node = f"{table}.{col.split('.')[1]}"
                self.graph.add_node(col_node, type="column")
                self.graph.add_edge(table, col_node, relation="has_column")
        
        # Add joins (now column-aware)
        for join in components.joins:
            self.graph.add_edge(
                join["source"], 
                join["target"],
                relation="join",
                type=join["type"],
                label=f"{join['source']} ↔ {join['target']}"  # For visualization
            )
    
    def _get_embedding(self, text: str) -> np.ndarray:
        return self.encoder.encode(text, convert_to_tensor=True).cpu().numpy()

In [79]:
kg = KnowledgeGraph()

In [80]:
queries = [
    "SELECT users.name FROM users JOIN orders ON users.id = orders.user_id",
    "DELETE FROM products WHERE price > 100"
]
for q in queries:
    kg.add_query(q)

# Visualize Knowledge Graph

In [81]:
from pyvis.network import Network

def visualize_graph(kg):
    net = Network(notebook=True, directed=True, cdn_resources='in_line')
    
    # Add nodes with styling
    for node in kg.graph.nodes():
        if 'table' in node.lower():
            net.add_node(node, color='#FFA07A', shape='box')  # Light salmon for tables
        elif 'column' in node.lower():
            net.add_node(node, color='#98FB98', size=15)  # Pale green for columns
        else:
            net.add_node(node, color='#ADD8E6')  # Light blue for queries
    
    # Add edges with labels
    for src, dst, data in kg.graph.edges(data=True):
        net.add_edge(src, dst, label=data.get('type', ''), title=data.get('relation', ''))
    
    with open("sql_graph.html", "w", encoding="utf-8") as f:
        f.write(net.generate_html())

In [82]:
visualize_graph(kg)

In [83]:
kg.graph.nodes

NodeView(('query_7979636527217952724', 'users', 'users.name', 'users.id', 'orders', 'orders.user_id', 'query_5240566835026746143', 'products'))

In [84]:
kg.graph.edges

OutMultiEdgeView([('query_7979636527217952724', 'users', 0), ('query_7979636527217952724', 'orders', 0), ('users', 'users.name', 0), ('users', 'users.id', 0), ('orders', 'orders.user_id', 0), ('orders', 'users', 0), ('query_5240566835026746143', 'products', 0)])

In [85]:
kg.graph.get_edge_data(u='users', v='users.name')

{0: {'relation': 'has_column'}}

In [86]:
kg.graph.get_edge_data(u='query_7979636527217952724', v='users')

{0: {'relation': 'accesses'}}

In [87]:
kg.graph.get_edge_data(u='orders', v='users')

{0: {'relation': 'join', 'type': 'INNER JOIN', 'label': 'orders ↔ users'}}

# Setup Q&A Engine

In [70]:
from transformers import pipeline

In [71]:
torch.tensor([3,4, 5])

tensor([3, 4, 5])

In [72]:
class QAEngine:
    def __init__(self, kg: KnowledgeGraph):
        self.kg = kg
        self.llm = pipeline(
            "text-generation",
            model="google/flan-t5-large",
            device="cpu"
        )
        
    def answer(self, question: str):
        # 1. Retrieve relevant subgraph
        question_embed = self.kg._get_embedding(question)
        similarities = [
            (node, torch.cosine_similarity(torch.tensor(question_embed), torch.tensor(self.kg.graph.nodes[node]["embedding"]), dim=0))
            for node in self.kg.graph.nodes if "embedding" in self.kg.graph.nodes[node]
        ]
        top_nodes = sorted(similarities, key=lambda x: -x[1])[:3]
        
        # 2. Generate answer
        context = "\n".join([f"Node {n[0]}: {self.kg.graph.nodes[n[0]]['type']}" for n in top_nodes])
        prompt = f"""SQL Knowledge Graph Context:
{context}
Question: {question}
Answer:"""
        
        return self.llm(prompt, max_new_tokens=200)[0]['generated_text']

In [74]:
import torch

engine = QAEngine(kg)

Device set to use cpu
The model 'T5ForConditionalGeneration' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForConditionalGeneration', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeo

# Inference

In [75]:
result = engine.answer("What are the columns of 'products' table?")
print(result)

SQL Knowledge Graph Context:
Node DELETE FROM products WHERE price > 100: query
Node SELECT users.name FROM users JOIN orders ON users.id = orders.user_id: query
Question: What are the columns of 'products' table?
Answer:


In [45]:
result = engine.answer("What are the columns of 'users' table?")
print(result)

('query_7979636527217952724', tensor(0.3086))
SQL Knowledge Graph Context:
Node query_7979636527217952724: query
Node query_5240566835026746143: query
Question: What are the columns of 'users' table?
Answer:
