# Table Metadata Store
====================================

Prerequisite
1. Install Neo4j
2. Install Bird Dataset for Text2SQL Task

Files:
- **Schema loading**: `create_graph_db.py`
- **Database reset**: `delete_graph_db.py`
- **Steiner tree join path**: `steiner_tree.py`
- **Metadata extraction for top-k**: `query_metadata_top_k.py`

## Neo4j
====================================

Install Neo4j via Homebrew

```bash
brew install neo4j
brew services start neo4j
```

Access the Neo4j GUI at https://localhost:7474/

Configure your credentials in `db_config.json`
```json
{
    "username": "neo4j",
    "password": "your_password_here"
}
```

### Creating / Deleting Graph Database
====================================

The following code are in 
- `create_graph_db.py`
- `delete_graph_db.py`

In [None]:
# Code from create_graph_db.py

from neo4j import GraphDatabase
import json

with open("table_metadata_store/graph_db/db_config.json") as f:
    db_config = json.load(f)
with open("bird/data/train/train_tables.json") as f:
    schemas = json.load(f)

URI = "bolt://localhost:7687"
AUTH = (db_config["username"], db_config["password"])

driver = GraphDatabase.driver(URI, auth=AUTH)

def load_schema(tx, db_id, table_names, column_names, column_types, primary_keys, foreign_keys):
    # Create Table nodes
    for idx, table_name in enumerate(table_names):
        full_table_name = f"{db_id}.{table_name}"
        tx.run(
            "MERGE (t:Table {db_id: $db_id, name: $name, full_name: $full_name})",
            db_id=db_id, name=table_name, full_name=full_table_name
        )
    # ... (The code is too long to include here)

**Explaination:**
- Each table becomes `Table` node.
- Each column becomes `Column` node.
- `HAS_COLUMN`, `HAS_PRIMARY_KEY` relationships connect tables to columns
- `FOREIGN_KEY_TO` relationships connect columns to referenced columns

In [None]:
# Continue loading the schema (add columns and FKs)
def load_schema(tx, db_id, table_names, column_names, column_types, primary_keys, foreign_keys):
    # ... (table creation as above)
    # Create Column nodes and HAS_COLUMN/HAS_PRIMARY_KEY edges
    for col_idx, (table_idx, col_name) in enumerate(column_names):
        if table_idx == -1:  # skip *
            continue
        table_name = table_names[table_idx]
        full_table_name = f"{db_id}.{table_name}"
        full_col_name = f"{db_id}.{table_name}.{col_name}"
        data_type = column_types[col_idx]
        is_pk = col_idx in [pk if isinstance(pk, int) else pk[0] for pk in primary_keys]
        if is_pk:
            tx.run(
                """
                MERGE (c:Column {db_id: $db_id, table_name: $table_name, name: $col_name, data_type: $data_type, full_name: $full_col_name})
                WITH c
                MATCH (t:Table {full_name: $full_table_name})
                MERGE (t)-[:HAS_PRIMARY_KEY]->(c)
                """,
                db_id=db_id, table_name=table_name, col_name=col_name, data_type=data_type,
                full_col_name=full_col_name, full_table_name=full_table_name
            )
        else:
            tx.run(
                """
                MERGE (c:Column {db_id: $db_id, table_name: $table_name, name: $col_name, data_type: $data_type, full_name: $full_col_name})
                WITH c
                MATCH (t:Table {full_name: $full_table_name})
                MERGE (t)-[:HAS_COLUMN]->(c)
                """,
                db_id=db_id, table_name=table_name, col_name=col_name, data_type=data_type,
                full_col_name=full_col_name, full_table_name=full_table_name
            )
    # ... (The code is too long to include here)

**Explanation:**  
- Columns are created as nodes and linked to their tables.
- Primary keys are marked with a `HAS_PRIMARY_KEY` relationship to tables.
- Other columns are linked with `HAS_COLUMN` to tables.

In [None]:
# Create FOREIGN_KEY edges
def load_schema(tx, db_id, table_names, column_names, column_types, primary_keys, foreign_keys):
    # ... (table and column creation as above)
    # Create FOREIGN_KEY edges
    for fk in foreign_keys:
        if not fk: continue
        from_idx = fk[0]
        to_idx = fk[1] if len(fk) > 1 else None
        if to_idx is None: continue
        from_table_idx, from_col_name = column_names[from_idx]
        to_table_idx, to_col_name = column_names[to_idx]
        from_full_col = f"{db_id}.{table_names[from_table_idx]}.{from_col_name}"
        to_full_col = f"{db_id}.{table_names[to_table_idx]}.{to_col_name}"
        tx.run(
            """
            MATCH (from:Column {full_name: $from_full_col})
            MATCH (to:Column {full_name: $to_full_col})
            MERGE (from)-[:FOREIGN_KEY_TO]->(to)
            """,
            from_full_col=from_full_col, to_full_col=to_full_col
        )

**Explanation:**  
- Foreign key relationships are created between columns using the `FOREIGN_KEY_TO` relationship.

In [None]:
# Code from delete_graph_db.py

from neo4j import GraphDatabase
import json

with open("table_metadata_store/graph_db/db_config.json") as f:
    db_config = json.load(f)

URI = "bolt://localhost:7687"
AUTH = (db_config["username"], db_config["password"])

driver = GraphDatabase.driver(URI, auth=AUTH)

with driver.session() as session:
    session.run("MATCH (n) DETACH DELETE n")
    print("All nodes and relationships have been deleted.")

driver.close()

**Explanation:**  
- [WARNING] This script deletes all nodes and relationships in the Neo4j database.

### Steiner Tree Algorithm for Join Path
====================================

This code is used to find tables needed to connect tables from Top K Tables

The following code are in 
- `steiner_tree.py`



In [None]:
# Code from steiner_tree.py

from neo4j import GraphDatabase
import heapq
from collections import defaultdict

class KouMarkowskyAlgorithm:
    def __init__(self, uri, auth):
        self.driver = GraphDatabase.driver(uri, auth=auth)

    def close(self):
        self.driver.close()

    def steiner_tree(self, terminals, db_id):
        graph, edge_types = self.driver.session().execute_read(KouMarkowskyAlgorithm._get_subgraph, db_id)
        terminal_set = set(terminals)
        tree_nodes = []
        tree_edges = set()
        involved_columns = set()
        used_terminals = {terminals[0]}
        tree_nodes.append(terminals[0])

        while used_terminals != terminal_set:
            frontier, parents = self._multi_source_dijkstra(graph, set(tree_nodes))
            min_dist, best_terminal = float('inf'), None
            for terminal in terminal_set - used_terminals:
                if frontier[terminal] < min_dist:
                    min_dist = frontier[terminal]
                    best_terminal = terminal
            if best_terminal is None:
                break
            # Reconstruct path
            path = []
            current = best_terminal
            while parents[current] is not None:
                prev = parents[current]
                edge = frozenset([prev, current])
                tree_edges.add(edge)
                rel_type = edge_types.get(edge)
                if rel_type in ("HAS_PRIMARY_KEY", "FOREIGN_KEY_TO"):
                    involved_columns.update([prev, current])
                if current not in tree_nodes:
                    path.append(current)
                current = prev
            if current not in tree_nodes:
                path.append(current)
            for node in reversed(path):
                tree_nodes.append(node)
            used_terminals.add(best_terminal)
        return {
            "nodes": tree_nodes,
            "edges": [list(edge) for edge in tree_edges],
            "involved_columns": list(involved_columns),
            "edge_types": edge_types
        }

**Explanation:**  
- This function finds the subgraph based on db_id
- Perform Kou-Markowsky-Berman Algorithm to find terminal tables (Tables from Top K) and steiner tables (Tables needed to connect termimal tables).
- It tracks which columns are involved in PK/FK relationships along the join path.

### Metadata Extraction for Top-K Tables
====================================

This code is used to return all necessary metadata of top-k tables for Text2SQL prompt

The following code are in 
- `query_metadata_top_k.py`

In [None]:
# Code from query_metadata_top_k.py (core logic)

from neo4j import GraphDatabase
from collections import defaultdict
import json
from steiner_tree import KouMarkowskyAlgorithm

class TopKSteinerMetadata:
    def __init__(self, uri, auth):
        self.driver = GraphDatabase.driver(uri, auth=auth)
        self.steiner = KouMarkowskyAlgorithm(uri, auth)

    def close(self):
        self.driver.close()
        self.steiner.close()

    def run(self, top_k):
        db_id_map = self.extract_info_top_k(top_k)
        all_results = {}
        for db_id, tables in db_id_map.items():
            steiner_result = self.steiner.steiner_tree(tables, db_id)
            involved_columns = set(steiner_result.get("involved_columns", []))
            tables_set = set(tables)
            table_nodes_set = set()
            for node in steiner_result["nodes"]:
                if node.count(".") == 1:
                    table_nodes_set.add(node)
                elif node.count(".") == 2:
                    table_name = KouMarkowskyAlgorithm.extract_table_name(node)
                    if table_name:
                        table_nodes_set.add(table_name)
            terminal_tables = list(table_nodes_set & tables_set)
            steiner_tables = list(table_nodes_set - tables_set)
            terminal_metadata = self.get_table_metadata(db_id, terminal_tables)
            steiner_metadata = self.get_steiner_connection_metadata(
                db_id, steiner_tables, involved_columns
            )
            all_results[db_id] = {
                "terminal_tables": terminal_metadata,
                "steiner_tables": steiner_metadata
            }
        return all_results

**Explanation:**  
- For each database, the top-k tables are grouped and the Steiner tree is computed.
- Metadata is extracted for terminal and steiner tables.
- Terminal tables gets all columns, column_types, primary_key, foreign_key extracted
- Steiner tables gets only necessary information to connect terminal tables together.

### Summary
- **Schema loading**: `create_graph_db.py`
- **Database reset**: `delete_graph_db.py`
- **Steiner tree join path**: `steiner_tree.py`
- **Metadata extraction for top-k**: `query_metadata_top_k.py`
