# Generative AI + GraphRAG Demo
## House/Apartment Adjacency (Kùzu-based, No-Geometry, Global-Candidate Builder)

In [None]:
# You don't need to run this cell if you have pip installed topologicpy
import sys
sys.path.append("C:/Users/sarwj/OneDrive - Cardiff University/Documents/GitHub/topologicpy/src")

In [1]:
# --- TopologicPy imports ---
from topologicpy.Vertex import Vertex
from topologicpy.Topology import Topology
from topologicpy.Dictionary import Dictionary
from topologicpy.Kuzu import Kuzu
from topologicpy.Graph import Graph

In [None]:
"""
GraphRAG Demo — House/Apartment Adjacency (Kùzu-based, No-Geometry, Global-Candidate Builder)
=====================================================================================

**What this is**
A compact, Jupyter-friendly demo that:
1) Reads *TopologicPy-like* graphs (JSON) from a folder.
2) Builds *topological* graphs strictly from the file's **vertices** and **edges** (ignores geometry entirely).
3) Loads them into a Kùzu DB using your `Kuzu.py` schema (Graph, Vertex, Edge).
4) **New logic:** At each iteration, we:
   - Build a **global candidate list** of neighbor labels by querying *all graphs* for the labels present in the **currently built graph** (frequency-ranked).
   - Ask the LLM to pick **one** action: either
     - **ADD** a node (may choose from the list or propose a new label not in the list) and connect it to a chosen existing node, or
     - **CONNECT** two existing nodes (no new node).
   - Apply the action to the **working graph** (we create/update it in the DB).
   - Save a full-graph snapshot and repeat until a stopping rule is met.

Notes
-----
- We *ignore* any polygon/geometry in JSON and rely solely on `vertices` and `edges`.
- Vertices get `label` from `node_name` or `roomtype` if present; fallback to vertex id.
- `x,y,z` default to 0.0 if missing. Original vertex/edge dicts preserved in `props` JSON.
- If `ANTHROPIC_API_KEY` is not set or Anthropic SDK is unavailable, a deterministic heuristic is used so the demo still runs.
- Edge suggestions, when accepted, are inserted with label `"suggested"` (bidirectional for simplicity).
- **Requested enhancement:** the seed node now **copies props (and x,y,z if present)** from the best-matching example across all graphs.

"""
from __future__ import annotations
import os, json, glob
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from collections import Counter

# --- Kùzu manager (ensure Kuzu.py is on sys.path or in the same directory) ---
from topologicpy.Kuzu import Kuzu

# --- Optional Claude/Anthropic (used only if available + key set) ---
try:
    from anthropic import Anthropic
    from dotenv import load_dotenv
    load_dotenv()
    anthropic_available = True
except Exception:
    anthropic_available = False

# --- Optional TopologicPy for snapshots -> real Graph objects ---
try:
    from topologicpy.Graph import Graph as TPGraph
    from topologicpy.Vertex import Vertex as TPVertex
    from topologicpy.Edge import Edge as TPEdge
    from topologicpy.Dictionary import Dictionary as TPDict
    from topologicpy.Topology import Topology as TPTopology
    _TOPOLOGICPY_AVAILABLE = True
except Exception:
    _TOPOLOGICPY_AVAILABLE = False

# ---------------------
# Data models
# ---------------------
@dataclass
class Vtx:
    id: str
    label: str
    x: float
    y: float
    z: float
    props: Dict[str, Any]

@dataclass
class ERel:
    src: str
    dst: str
    label: str
    props: Dict[str, Any]

# ---------------------
# JSON → (Vertices, Edges)
# ---------------------

def load_topologic_graph(path: str) -> tuple[list[Vtx], list[ERel]]:
    """Load a TopologicPy-like graph JSON. We ignore geometry; we only use vertices and edges dictionaries.
    Expected (flexible) shape:
    {
      "vertices": {
          "Vertex_0000": {"node_name": "Entrance", "x": 1.2, "y": 3.4, ...},
          ...
      },
      "edges": {
          "Edge_00": {"source": "Vertex_0000", "target": "Vertex_0004", "connectivity": "door", ...},
          ...
      }
    }
    """
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    raw_vs: Dict[str, Dict[str, Any]] = data.get("vertices", {}) or {}
    raw_es: Dict[str, Dict[str, Any]] = data.get("edges", {}) or {}

    vertices: list[Vtx] = []
    for vid, v in raw_vs.items():
        label = v.get("node_name") or v.get("roomtype") or v.get("zone_name") or str(vid)
        x = float(v.get("x", 0.0))
        y = float(v.get("y", 0.0))
        z = float(v.get("z", 0.0))
        vertices.append(Vtx(id=vid, label=str(label), x=x, y=y, z=z, props=v))

    edges: list[ERel] = []
    for eid, e in raw_es.items():
        src = str(e.get("source"))
        dst = str(e.get("target"))
        label = str(e.get("connectivity") or e.get("label") or "adjacent")
        if not src or not dst:
            continue
        edges.append(ERel(src=src, dst=dst, label=label, props=e))

    return vertices, edges

# ---------------------
# Kùzu helpers
# ---------------------

def ensure_schema(manager):
    Kuzu.EnsureSchema(manager, silent=False)


def upsert_graph(manager, graph_id: str, vertices: list[Vtx], edges: list[ERel], undirected: bool):
    """Insert a graph with vertices/edges into Kùzu using raw Cypher.
    If undirected=True, we create two directed edges for each input edge.
    """
    ensure_schema(manager)

    # Clear prior graph with same id
    manager.exec("MATCH (a:Vertex)-[r:Edge]->(b:Vertex) WHERE a.graph_id=$gid AND b.graph_id=$gid DELETE r;",
                 {"gid": graph_id}, write=True)
    manager.exec("MATCH (v:Vertex) WHERE v.graph_id=$gid DELETE v;", {"gid": graph_id}, write=True)
    manager.exec("MATCH (g:Graph) WHERE g.id=$id DELETE g;", {"id": graph_id}, write=True)

    # Create Graph card
    manager.exec(
        """
        CREATE (g:Graph {id:$id, label:$label, num_nodes:$n, num_edges:$m, props:$props});
        """,
        {"id": graph_id, "label": graph_id, "n": len(vertices), "m": len(edges), "props": json.dumps({})},
        write=True,
    )

    # Insert vertices

    query = """
    CREATE (v:Vertex {
        id:$id,
        graph_id:$gid,
        label:$label,
        x:$x,
        y:$y,
        z:$z,
        props:$props
    });
    """

    for v in vertices:
        params = {
                "id": f"{graph_id}:{v.id}",
                "gid": graph_id,
                "label": v.label,
                "x": float(v.x),
                "y": float(v.y),
                "z": float(v.z),
                "props": json.dumps(v.props),
                }
        manager.exec(query, params, write=True)
    # Insert edges (directed; if undirected, add reverse)
    for e in edges:
        params = {"a": f"{graph_id}:{e.src}", "b": f"{graph_id}:{e.dst}", "label": e.label, "props": json.dumps(e.props)}
        manager.exec(
            """
            MATCH (va:Vertex {id:$a}), (vb:Vertex {id:$b})
            CREATE (va)-[:Edge {label:$label, props:$props}]->(vb);
            """,
            params,
            write=True,
        )
        if undirected:
            manager.exec(
                """
                MATCH (va:Vertex {id:$a}), (vb:Vertex {id:$b})
                CREATE (vb)-[:Edge {label:$label, props:$props}]->(va);
                """,
                params,
                write=True,
            )

# --- Small builders for the *working* graph we are constructing ---

def create_graph_card_if_missing(manager, graph_id: str):
    rows = manager.exec("MATCH (g:Graph {id:$id}) RETURN 1 LIMIT 1", {"id": graph_id}, write=False) or []
    if not rows:
        manager.exec(
            "CREATE (g:Graph {id:$id, label:$id, num_nodes:0, num_edges:0, props:'{}'})",
            {"id": graph_id}, write=True)


def create_vertex(manager, graph_id: str, local_id: str, label: str, props: Dict[str,Any] | None = None,
                  x: float = 0.0, y: float = 0.0, z: float = 0.0):
    create_graph_card_if_missing(manager, graph_id)
    manager.exec(
        """
        CREATE (:Vertex {id:$id, graph_id:$gid, label:$label, x:$x, y:$y, z:$z, props:$props});
        """,
        {"id": f"{graph_id}:{local_id}",
         "gid": graph_id,
         "label": label,
         "x": float(x),
         "y": float(y),
         "z": float(z),
         "props": json.dumps(props or {})},
        write=True,
    )


def edge_exists(manager, graph_id: str, a_local: str, b_local: str) -> bool:
    rows = manager.exec(
        """
        MATCH (a:Vertex {id:$a})-[:Edge]->(b:Vertex {id:$b}) RETURN 1 LIMIT 1
        """,
        {"a": f"{graph_id}:{a_local}", "b": f"{graph_id}:{b_local}"}, write=False,
    ) or []
    return len(rows) > 0


def create_edge_bidirectional(manager, graph_id: str, a_local: str, b_local: str, label: str = "suggested",
                              props: Dict[str,Any] | None = None):
    a_local = a_local.split()[0]
    b_local = b_local.split()[0]
    if edge_exists(manager, graph_id, a_local, b_local) and edge_exists(manager, graph_id, b_local, a_local):
        print("Warning: Edge already exists. Skipping.")
        return False
    manager.exec(
        """
        MATCH (a:Vertex {id:$a}), (b:Vertex {id:$b})
        CREATE (a)-[:Edge {label:$lbl, props:$props}]->(b),
               (b)-[:Edge {label:$lbl, props:$props}]->(a);
        """,
        {"a": f"{graph_id}:{a_local}", "b": f"{graph_id}:{b_local}", "lbl": label, "props": json.dumps(props or {})},
        write=True,
    )
    return True


def list_working_nodes(manager, graph_id: str) -> list[Dict[str,str]]:
    rows = manager.exec(
        "MATCH (v:Vertex) WHERE v.graph_id=$gid RETURN v.id AS id, v.label AS label, v.props AS props ORDER BY id",
        {"gid": graph_id}, write=False
    ) or []
    return [{"id": r["id"].split(":",1)[1],
             "label": r.get("label",""),
             "props": r.get("props")} for r in rows]

def list_working_edges(manager, graph_id: str) -> list[Dict[str, str]]:
    """
    Returns all edges in the current working graph as a list of dicts:
    [{'a': 'n0', 'b': 'n1', 'label': 'suggested', 'props': {...}}, ...]
    """
    rows = manager.exec(
        """
        MATCH (a:Vertex)-[r:Edge]->(b:Vertex)
        WHERE a.graph_id=$gid AND b.graph_id=$gid
        RETURN a.id AS a, b.id AS b, r.label AS label, r.props AS props
        """,
        {"gid": graph_id}, write=False
    ) or []
    return [
        {
            "a": r["a"].split(":", 1)[1],
            "b": r["b"].split(":", 1)[1],
            "label": r.get("label", ""),
            "props": r.get("props", {}),
        }
        for r in rows
    ]

def max_neighbors_for_label(
    manager,
    label_query: str,
    *,
    undirected: bool = True,
    substring: bool = True,
    selection_mode: str = "first",            # "first" | "max_out_degree_within_graph"
    include_graph_ids: set[str] | None = None,
    exclude_graph_ids_prefixes: tuple[str, ...] = ("work_",),
    exclude_edge_labels: set[str] = frozenset({"suggested"})
) -> int:
    """
    Return the maximum number of unique neighbors for the label across the DB,
    but constrain to **one representative vertex per graph** to avoid inflation
    when a graph contains multiple vertices with the same label.

    Parameters mirror the earlier function, with `selection_mode` controlling how
    the per-graph representative is chosen:
      - "first": smallest v.id (stable & fast)
      - "max_out_degree_within_graph": pick the matching vertex that has the
        highest out-degree (using filtered edges), then count its neighbors
        (undirected or directed per `undirected` flag).

    Returns
    -------
    int
    """

    # --- Fetch vertices
    rows_v = manager.exec(
        """
        MATCH (v:Vertex)
        RETURN v.id AS id, v.graph_id AS gid, v.label AS label
        """,
        {}, write=False
    ) or []

    def graph_allowed(gid: str) -> bool:
        if include_graph_ids is not None:
            return gid in include_graph_ids
        return not any(gid.startswith(pref) for pref in exclude_graph_ids_prefixes)

    needle = (label_query or "").strip().lower()
    # group matching vertices by graph id
    matches_by_gid: dict[str, list[str]] = {}
    labels_by_vid: dict[str, str] = {}

    for r in rows_v:
        vid = r["id"]
        gid = r.get("gid", "")
        lbl = str(r.get("label") or "")
        if not graph_allowed(gid):
            continue
        labels_by_vid[vid] = lbl
        lbln = lbl.strip().lower()
        ok = (needle in lbln) if substring else (lbln == needle)
        if ok:
            matches_by_gid.setdefault(gid, []).append(vid)

    if not matches_by_gid:
        return 0

    # --- Fetch edges once (filtered)
    rows_e = manager.exec(
        """
        MATCH (a:Vertex)-[r:Edge]->(b:Vertex)
        RETURN a.id AS a, a.graph_id AS agid, b.id AS b, b.graph_id AS bgid, r.label AS rlabel
        """,
        {}, write=False
    ) or []

    edges = []
    for r in rows_e:
        a, agid = r["a"], r.get("agid", "")
        b, bgid = r["b"], r.get("bgid", "")
        if not (graph_allowed(agid) and graph_allowed(bgid)):
            continue
        if str(r.get("rlabel") or "") in exclude_edge_labels:
            continue
        edges.append((a, b))

    # Pre-index neighbors for quick degree checks
    out_neighbors: dict[str, set[str]] = {}
    in_neighbors: dict[str, set[str]] = {}
    for a, b in edges:
        out_neighbors.setdefault(a, set()).add(b)
        in_neighbors.setdefault(b, set()).add(a)

    # Choose exactly one representative per graph
    rep_ids: set[str] = set()
    if selection_mode == "max_out_degree_within_graph":
        for gid, vids in matches_by_gid.items():
            # pick the vertex with max OUT-degree (based on filtered edges)
            best_vid = max(vids, key=lambda v: len(out_neighbors.get(v, set())))
            rep_ids.add(best_vid)
    else:  # "first" (deterministic by min id)
        for gid, vids in matches_by_gid.items():
            rep_ids.add(min(vids))

    # Compute neighbor counts for representatives only
    def neighbor_set(v: str) -> set[str]:
        if undirected:
            return out_neighbors.get(v, set()) | in_neighbors.get(v, set())
        return out_neighbors.get(v, set())

    return max((len(neighbor_set(v)) for v in rep_ids), default=0)



# ---------------------
# Global candidate list (across *all* graphs)
# ---------------------

def _working_vertices_and_edges(manager, working_graph_id: str):
    """Return (vertices, edges) for the working graph.
       vertices: [{id_local, label, props}]
       edges:    [(a_local, b_local)]   (directed, as stored)
    """
    rows_v = manager.exec(
        """
        MATCH (v:Vertex) WHERE v.graph_id=$gid
        RETURN v.id AS id, v.label AS label, v.props AS props
        """,
        {"gid": working_graph_id}, write=False
    ) or []
    rows_e = manager.exec(
        """
        MATCH (a:Vertex)-[:Edge]->(b:Vertex)
        WHERE a.graph_id=$gid AND b.graph_id=$gid
        RETURN a.id AS a, b.id AS b
        """,
        {"gid": working_graph_id}, write=False
    ) or []

    vertices = [{"id_local": r["id"].split(":", 1)[1],
                 "label": r.get("label", ""),
                 "props": r.get("props", {})} for r in rows_v]
    edges = [(r["a"].split(":", 1)[1], r["b"].split(":", 1)[1]) for r in rows_e]
    return vertices, edges


def _anchor_labels_with_degree_cap(manager, working_graph_id: str, oracle: dict) -> list[str]:
    """Compute *undirected* degree per node in working graph and return labels for nodes with deg ≤ max_deg."""
    vertices, edges = _working_vertices_and_edges(manager, working_graph_id)

    # Build undirected adjacency (so an edge a->b and b->a counts as ONE connection)
    adj = {v["id_local"]: set() for v in vertices}
    for a, b in edges:
        # treat each directed edge as undirected; add both ways
        adj.setdefault(a, set()).add(b)
        adj.setdefault(b, set()).add(a)

    # Degree = number of unique neighbors
    id_to_label = {v["id_local"]: v["label"] for v in vertices}
    anchor_labels = []
    print(f" The following nodes have less connections than what the maximum found in the graph database so they will be considered for expansion:")
    for vid, nbrs in adj.items():
        deg = len(nbrs)
        anchor_label = id_to_label.get(vid, "")
        if oracle.get(anchor_label, 1) == 0:
            continue
        else:
            max_degree = max_neighbors_for_label(manager, anchor_label) 
            if deg <= max_degree:
                anchor_label = id_to_label.get(vid, "")
                print(f"  . {anchor_label} (No. Connections: {deg}, Max found in DB: {max_degree}) ")
                anchor_labels.append(anchor_label)
            else:
                oracle[anchor_label] = 0

    # Clean, dedupe, keep non-empty
    return sorted({lbl for lbl in anchor_labels if lbl}), oracle


def fetch_all_pairs(manager, working_graph_id: str, oracle:dict) -> list[tuple[str, str]]:
    """
    Enhanced: Use ONLY input nodes (anchors) from the current working graph whose undirected degree ≤ n,
    then fetch (a.label, b.label) pairs from the *entire* dataset, filtered to those anchors.
    Returns list of (a_label, b_label) pairs.
    """
    # 1) Get anchor labels from the working graph, filtered by degree cap
    anchors, oracle = _anchor_labels_with_degree_cap(manager, working_graph_id, oracle)
    if not anchors:
        return [], {}  # nothing to expand from

    anchors_lower = {a.lower() for a in anchors}

    # 2) Pull all pairs across all graphs (no graph_id filter), then filter by anchor labels in Python
    rows = manager.exec(
        """
        MATCH (a:Vertex)-[:Edge]->(b:Vertex)
        RETURN a.label AS a_label, b.label AS b_label
        """,
        {}, write=False
    ) or []

    pairs = []
    for r in rows:
        a_lab = str(r.get("a_label") or "").strip()
        b_lab = str(r.get("b_label") or "").strip()
        if a_lab and b_lab and a_lab.lower() in anchors_lower:
            pairs.append((a_lab, b_lab))

    return pairs, oracle

def candidate_counts_for_labels(manager, working_graph_id: str, labels: list[str], oracle: dict) -> list[tuple[str,int]]:
    """Aggregate neighbor label frequencies across *all* graphs for any a.label in labels (case-insensitive)."""
    pairs, oracle = fetch_all_pairs(manager, working_graph_id = working_graph_id, oracle=oracle)
    label_set = {l.lower() for l in labels}
    cnt = Counter(b for (a,b) in pairs if a.lower() in label_set)
    if "" in cnt:
        del cnt[""]
    return sorted(cnt.items(), key=lambda kv: (-kv[1], kv[0])), oracle

# ---------------------
# Seed props copier — find best example for a label across all graphs
# ---------------------

import math
from collections import Counter
from statistics import median, mean
from typing import Optional, Dict, Any

def find_best_example_for_label(manager, attach_to, label_substring: str) -> Optional[Dict[str, Any]]:
    """
    Search the entire Kùzu DB for occurrences of edges (attach_to_label -> target_label).
    Use the most-popular direction (by angle bin) and typical distance (median within that bin)
    to compute a RECOMMENDED coordinate for the new node as an offset from the given attach_to node.

    Parameters
    ----------
    manager : Kùzu manager
    attach_to : str | dict
        - str: label of the attach node (e.g., "Entrance"). Anchor coords default to (0,0,0).
        - dict: must include at least {"label": "..."} and *optionally* {"x":..,"y":..,"z":..}
                If x/y/z are present, they are used as the anchor for the recommended offset.
    label_substring : str
        Target node label (first word is used for matching, e.g., "Living" from "Living Room").

    Returns
    -------
    dict or None:
      {
        "best_example": {"gid","id","label","x","y","z","props"},
        "recommended": {"x": float, "y": float, "z": float, "distance": float}
      }
      or None if no corpus matches are found.
    """
    # Normalize inputs
    if isinstance(attach_to, str):
        attach_to_label = attach_to
        anchor_x, anchor_y, anchor_z = 0.0, 0.0, 0.0
    elif isinstance(attach_to, dict):
        attach_to_label = attach_to.get("label", "")
        anchor_x = float(attach_to.get("x", 0.0))
        anchor_y = float(attach_to.get("y", 0.0))
        anchor_z = float(attach_to.get("z", 0.0))
    else:
        attach_to_label, anchor_x, anchor_y, anchor_z = "", 0.0, 0.0, 0.0

    # Only use first word (e.g., "Living" from "Living Room")
    if attach_to_label == "":
        a_word = ""
    else:
        a_word = (attach_to_label or "").split()[0].lower()
    if label_substring == "":
        b_word = ""
    else:
        b_word = (label_substring or "").split()[0].lower()

    # 1) Pull ALL (a -> b) pairs with coordinates
    rows = manager.exec(
        """
        MATCH (a:Vertex)-[:Edge]->(b:Vertex)
        RETURN
          a.graph_id AS agid, a.id AS aid, a.label AS a_label, a.x AS ax, a.y AS ay, a.z AS az,
          b.graph_id AS bgid, b.id AS bid, b.label AS b_label, b.x AS bx, b.y AS by, b.z AS bz, b.props AS bprops
        """,
        {}, write=False
    ) or []

    # 2) Filter in Python (case-insensitive, first-word heuristic)
    pairs = []
    for r in rows:
        a_lab = str(r.get("a_label") or "")
        b_lab = str(r.get("b_label") or "")
        if (a_word in a_lab.lower()) and (b_word in b_lab.lower()):
            ax, ay, az = float(r.get("ax", 0.0)), float(r.get("ay", 0.0)), float(r.get("az", 0.0))
            bx, by, bz = float(r.get("bx", 0.0)), float(r.get("by", 0.0)), float(r.get("bz", 0.0))
            dx, dy, dz = (bx - ax), (by - ay), (bz - az)
            dist = math.sqrt(dx*dx + dy*dy + dz*dz)
            pairs.append({
                "agid": r.get("agid"), "aid": r.get("aid"), "a_label": a_lab, "ax": ax, "ay": ay, "az": az,
                "bgid": r.get("bgid"), "bid": r.get("bid"), "b_label": b_lab, "bx": bx, "by": by, "bz": bz,
                "bprops": r.get("bprops", {}),
                "dx": dx, "dy": dy, "dz": dz, "dist": dist
            })

    if not pairs:
        return None

    # 3) Choose a "best example" node for the target label
    target_counts = Counter(p["b_label"] for p in pairs)
    best_target_label, _ = max(target_counts.items(), key=lambda kv: kv[1])
    best_row = next(p for p in pairs if p["b_label"] == best_target_label)
    best_example = {
        "gid": best_row["bgid"],
        "id":  best_row["bid"],
        "label": best_row["b_label"],
        "x": best_row["bx"], "y": best_row["by"], "z": best_row["bz"],
        "props": best_row.get("bprops", {})
    }

    # 4) Determine most-popular direction bin and typical distance
    DIR_BIN_COUNT = 16     # 22.5-degree bins
    DIST_BIN_SIZE = 0.5    # meters

    dir_bins = []
    dist_bins = []
    for p in pairs:
        ang = math.degrees(math.atan2(p["dy"], p["dx"])) % 360.0
        dir_bin = int((ang / 360.0) * DIR_BIN_COUNT) % DIR_BIN_COUNT
        dir_bins.append(dir_bin)
        dist_bins.append(int(p["dist"] / DIST_BIN_SIZE))

    # Mode direction bin and mode distance bin
    dir_mode_bin, _ = Counter(dir_bins).most_common(1)[0]

    # Compute a representative unit direction from vectors in the mode direction bin
    sel_vectors = []
    for p, db in zip(pairs, dir_bins):
        if db == dir_mode_bin and p["dist"] > 1e-9:
            ux, uy, uz = p["dx"]/p["dist"], p["dy"]/p["dist"], p["dz"]/p["dist"]
            sel_vectors.append((ux, uy, uz))

    if sel_vectors:
        mx = sum(v[0] for v in sel_vectors)/len(sel_vectors)
        my = sum(v[1] for v in sel_vectors)/len(sel_vectors)
        mz = sum(v[2] for v in sel_vectors)/len(sel_vectors)
        norm = math.sqrt(mx*mx + my*my + mz*mz) or 1.0
        unit_dir = (mx/norm, my/norm, mz/norm)
    else:
        unit_dir = (1.0, 0.0, 0.0)  # fallback

    # Typical distance: median of distances in that same direction bin (robust)
    sel_dists = [p["dist"] for p, db in zip(pairs, dir_bins) if db == dir_mode_bin]
    rec_dist = median(sel_dists) if sel_dists else 0.0

    # 5) Compute recommended coordinates as an offset from the provided attach_to anchor
    rx = anchor_x + unit_dir[0] * rec_dist
    ry = anchor_y + unit_dir[1] * rec_dist
    rz = anchor_z + unit_dir[2] * rec_dist

    return {
        "best_example": best_example,
        "recommended": {"x": rx, "y": ry, "z": rz, "distance": rec_dist}
    }

# ---------------------
# TopologicPy graph snapshots (FULL graph export)
# ---------------------

def _build_tp_graph(vertices: list[Dict[str, Any]], edges: list[Tuple[str, str]]) -> Any:
    """Return a TopologicPy Graph if available; otherwise a plain dict with vertices/edges.
    Vertices: list of dicts {id, label, x, y, z, props}
    Edges: list of (src_local_id, dst_local_id)
    """
    import random

    if _TOPOLOGICPY_AVAILABLE:
        id_to_vertex: Dict[str, Any] = {}
        tp_vertices: list[Any] = []
        for v in vertices:
            x = random.uniform(0,10)
            y = random.uniform(0,10)
            z = 0
            vx = TPVertex.ByCoordinates(x,y,z)
            # Attach dictionary from props (inherited from graph DB)
            props = v.get("props", {})
            if isinstance(props, str):
                try:
                    props = json.loads(props)
                except Exception:
                    props = {"_raw_props": props}
            if isinstance(props, dict) and props:
                keys = list(props.keys())
                vals = list(props.values())
                try:
                    d = TPDict.ByKeysValues(keys, vals)
                    vx = TPTopology.SetDictionary(vx, d)
                except Exception:
                    pass
            id_to_vertex[v["id"]] = vx
            tp_vertices.append(vx)
        tp_edges: list[Any] = []
        for (s, t) in edges:
            sv = id_to_vertex.get(s)
            tv = id_to_vertex.get(t)
            if sv is not None and tv is not None:
                tp_edges.append(TPEdge.ByStartVertexEndVertex(sv, tv))
        try:
            return TPGraph.ByVerticesEdges(tp_vertices, tp_edges)
        except Exception:
            return {"vertices": tp_vertices, "edges": tp_edges}
    else:
        return {"vertices": vertices, "edges": edges}


def snapshot_full_graph(manager, graph_id: str) -> Any:
    """Export the **entire current graph** from Kùzu and return a TopologicPy Graph (if available) or a dict.
    Ensures vertex dictionaries are inherited from DB `props`.
    """
    import random

    rows_v = manager.exec(
        """
        MATCH (v:Vertex)
        WHERE v.graph_id=$gid
        RETURN v.id AS id, v.label AS label, v.x AS x, v.y AS y, v.z AS z, v.props AS props
        """,
        {"gid": graph_id}, write=False,
    ) or []

    rows_e = manager.exec(
        """
        MATCH (a:Vertex)-[:Edge]->(b:Vertex)
        WHERE a.graph_id=$gid AND b.graph_id=$gid
        RETURN a.id AS a, b.id AS b
        """,
        {"gid": graph_id}, write=False,
    ) or []

    verts = [{
        "id": r["id"].split(":",1)[1],
        "label": r.get("label",""),
        "x": r.get("x",random.uniform(0,100)*0.1),
        "y": r.get("y",random.uniform(0,100)*0.1),
        "z": r.get("z",0.0),
        "props": r.get("props", {})
    } for r in rows_v]
    eds = [(r["a"].split(":",1)[1], r["b"].split(":",1)[1]) for r in rows_e]
    return _build_tp_graph(verts, eds)

# ---------------------
# Global-candidate logic + LLM action picker — single action per iteration
# ---------------------

def _heuristic_pick_action(current_nodes: list[Dict[str,str]], candidate_counts: list[tuple[str,int]]):
    existing_labels = {n["label"].lower() for n in current_nodes}
    # Try ADD a high-frequency label not already present
    for lab, _ in candidate_counts:
        if lab.lower() not in existing_labels:
            attach_to = current_nodes[0]["id"] if current_nodes else None
            return {"action": "add", "new_label": lab, "attach_to": attach_to}
    # Else CONNECT first two nodes if any
    if len(current_nodes) >= 2:
        return {"action": "connect", "a": current_nodes[0]["id"], "b": current_nodes[1]["id"]}
    return {"action": "stop", "reason": "No candidates and insufficient nodes to connect."}


def llm_pick_action(description,
                    current_nodes: list[Dict[str,str]],
                    candidate_counts: list[tuple[str,int]],
                    current_edges: list
                    ):
    """
    Ask Claude to choose exactly one action. It knows the candidate list is frequency-sorted
    but may propose a new label not in the list. Returns one of:
      {"action":"add","new_label":"Kitchen","attach_to":"<existing_local_id>"}
      {"action":"connect","a":"<existing_local_id>","b":"<existing_local_id>"}
    """
    import copy
    if (not anthropic_available) or (os.getenv("ANTHROPIC_API_KEY") is None):
        return _heuristic_pick_action(current_nodes, candidate_counts)

    try:
        # Initialize Claude client
        client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
        model = os.getenv("CLAUDE_MODEL", "claude-sonnet-4-20250514")
        
        sys_prompt = (
            f"You are designing an adjacency graph that represents {description}. You receive: "
            "(1) A description of what the graph represents, (2) the current graph's nodes, (2) the current graph's edges, and (3) a frequency-sorted list of candidate neighbor labels "
            f"aggregated from many example graphs. Build a list of nodes usually found in a graph that represents {description}."
            "You may choose from the provided list of candidate node labels or propose a new label from the list that you built."
            "Choose exactly ONE action: either ADD a new node with a single connection to an existing node, "
            "or CONNECT two existing nodes, or STOP if no further action is needed. Include a reason for stopping."
            "Do not repeat previous suggestions."
            "Return ONLY valid JSON (no markdown, no extra text) with one of these forms:\n"
            "{\"action\":\"add\",\"new_label\":\"<string>\",\"attach_to\":\"<existing_local_id> (<string>)\"}\n"
            "{\"action\":\"connect\",\"a\":\"<existing_local_id> (<string>) \",\"b\":\"<existing_local_id> (<string>)\"}\n"
            "{\"action\":\"stop\",\"reason\":\"<string> \"}"
        )
        
        user_payload = {
            "description": description,
            "current_nodes": current_nodes,
            "current_edges": current_edges or [],
            "candidate_counts": candidate_counts,
            "note": "The candidate list is sorted by frequency across many graphs; you may propose a new label."
        }
        
        # Call Claude API
        message = client.messages.create(
            model=model,
            max_tokens=1024,
            temperature=0.3,
            system=sys_prompt,
            messages=[{
                "role": "user",
                "content": json.dumps(user_payload, indent=2)
            }]
        )
        
        # Extract response
        text = message.content[0].text.strip()
        
        # Parse JSON (handle markdown code blocks if present)
        if text.startswith("```"):
            lines = text.split("\n")
            text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
        
        # Extract JSON by finding first { and last }
        start_idx = text.find("{")
        end_idx = text.rfind("}")
        if start_idx != -1 and end_idx != -1:
            text = text[start_idx:end_idx+1]
        
        try:
            json_data = json.loads(text)
            json_action = json_data['action']
            json_a_label = json_data.get('new_label', json_data.get('a'))
            json_b_label = json_data.get('attach_to', json_data.get('b'))
            print_b_label = copy.copy(json_b_label)
            print_b_label = print_b_label.split()[1].strip("()") if isinstance(print_b_label, str) and len(print_b_label.split()) > 1 else print_b_label
            
            if "add" in json_action.lower():
                print(f" I suggest that you {json_action.lower()} '{json_a_label}' and connect it to '{print_b_label}'")
            elif "connect" in json_action.lower():
                print(f" I suggest that you {json_action.lower()} '{json_a_label}' to '{print_b_label}'")
            elif "stop" in json_action.lower():
                print(" I suggest that you stop.")
            else:
                print(" I don't know what to suggest.")

            return json_data
        except Exception:
            return _heuristic_pick_action(current_nodes, candidate_counts)
    except Exception as e:
        print(f"Claude API error: {e}")
        return _heuristic_pick_action(current_nodes, candidate_counts)

# ---------------------
# Builder loop — seed from dataset example, then iterate
# ---------------------

def import_folder_to_kuzu(json_folder: str, manager, undirected: bool = True) -> List[str]:
    graph_ids: List[str] = []
    for path in sorted(glob.glob(os.path.join(json_folder, "*.json"))):
        verts, edges = load_topologic_graph(path)
        gid = os.path.splitext(os.path.basename(path))[0]
        upsert_graph(manager, gid, verts, edges, undirected=undirected)
        graph_ids.append(gid)
    return graph_ids

def init_working_graph(manager, working_graph_id: str, start_label: str):
    """Create a new working graph with a seed node copied from the best dataset example for start_label."""
    # reset working graph
    manager.exec("MATCH (a:Vertex)-[r:Edge]->(b:Vertex) WHERE a.graph_id=$gid AND b.graph_id=$gid DELETE r;",
                 {"gid": working_graph_id}, write=True)
    manager.exec("MATCH (v:Vertex) WHERE v.graph_id=$gid DELETE v;", {"gid": working_graph_id}, write=True)
    manager.exec("MATCH (g:Graph) WHERE g.id=$id DELETE g;", {"id": working_graph_id}, write=True)
    create_graph_card_if_missing(manager, working_graph_id)

    best_example_dict = find_best_example_for_label(manager, attach_to=None, label_substring = start_label)
    if best_example_dict is not None:
        ex = best_example_dict['best_example']
    else:
        ex = None
    if ex is None:
        # fallback: minimal seed
        create_vertex(manager, working_graph_id, local_id="n0", label=start_label,
                      props={"source": "seed", "label": start_label}, x=0.0, y=0.0, z=0.0)
        return

    # parse props if string
    props = ex.get("props", {})
    if isinstance(props, str):
        try:
            props = json.loads(props)
        except Exception:
            props = {"_raw_props": props}

    # enrich props with provenance
    props = dict(props or {})
    props.update({
        "source": "seed_from_dataset",
        "matched_query": start_label,
        "matched_label": ex.get("label",""),
        "matched_graph_id": ex.get("gid",""),
        "matched_vertex_id": ex.get("id",""),
    })

    create_vertex(manager, working_graph_id, local_id="n0", label=ex.get("label", start_label),
                  props=props, x=float(ex.get("x",0.0)), y=float(ex.get("y",0.0)), z=float(ex.get("z",0.0)))


def graphrag_build_loop(manager, 
                        working_graph_id: str,
                        start_label: str,
                        description: str,
                        max_steps: int = 8,
                        patience: int = 2,
                        max_connections: int = 2) -> Dict[str,Any]:
    """
    New logic:
      - Start a fresh working graph; the seed node copies props from the best dataset example for `start_label`.
      - Iteratively:
          * Build global candidate list from ALL graphs using labels present in the working graph.
          * Ask LLM to choose exactly one action (ADD or CONNECT) — it may propose a label not in the list.
          * Apply action to working graph (may create node or connect existing nodes).
          * Snapshot the full working graph.
      - Stop at max_steps or if no effective change occurs.
      - Patience: The maximum number of no effective change before giving up and stopping
    Returns: { 'snapshots': [...], 'actions': [...], 'reason': str }
    """
    init_working_graph(manager, working_graph_id, start_label)
    snapshots = [snapshot_full_graph(manager, working_graph_id)]
    actions_log: list[Dict[str,Any]] = []

    no_action = 0
    oracle = {}
    for step in range(1, max_steps+1):
        print("STEP:", step)
        current_nodes = list_working_nodes(manager, working_graph_id)
        current_edges = list_working_edges(manager, working_graph_id)
        labels_now = [n["label"] for n in current_nodes]
        cand_counts, oracle = candidate_counts_for_labels(manager=manager,
                                                  working_graph_id=working_graph_id,
                                                  labels=labels_now,
                                                  oracle = oracle)

        action = llm_pick_action(description = description,
                                 current_nodes = current_nodes,
                                 candidate_counts = cand_counts,
                                 current_edges= current_edges)
            

        if action.get("action") == "add":
            new_label = str(action.get("new_label") or "").strip()
            attach_to = str(action.get("attach_to") or "").strip().split()[0]
            # ensure attach_to is a valid existing local id; if not, pick first
            existing_ids = {n["id"] for n in current_nodes}
            if attach_to not in existing_ids:
                print("Warning: Could not find attach_to in existing_ids")
                attach_to = next(iter(existing_ids), None)
            if new_label and attach_to:
                new_id = f"n{len(current_nodes)}"
                # attempt to copy props from best example for new_label
                best_ex_dict = find_best_example_for_label(manager, attach_to=attach_to, label_substring = new_label)
                if best_ex_dict is not None:
                    ex = best_ex_dict['best_example']
                else:
                    ex = None
                
                props = {}
                if ex is not None:
                    props = ex.get("props", {})
                    if isinstance(props, str):
                        try: props = json.loads(props)
                        except Exception: props = {"_raw_props": props}
                    props = dict(props or {})
                    props.update({
                        "source": "suggested_node_from_dataset",
                        "matched_label": ex.get("label",""),
                        "matched_graph_id": ex.get("gid",""),
                        "matched_vertex_id": ex.get("id",""),
                    })
                    x = best_ex_dict['recommended']['x']
                    y = best_ex_dict['recommended']['y']
                    z = best_ex_dict['recommended']['z']
                else:
                    x = y = z = 0.0
                    props = {"label": new_label, "source": "suggested_node_no_example"}
                create_vertex(manager, working_graph_id, local_id=new_id, label=new_label,
                              props=props, x=x, y=y, z=z)
                create_edge_bidirectional(manager, working_graph_id, attach_to, new_id, label="suggested",
                                          props={"source": "llm"})
            else:
                no_action += 1

        elif action.get("action") == "connect":
            a = str(action.get("a") or "").strip()
            b = str(action.get("b") or "").strip()
            if a and b and a != b:
                applied = create_edge_bidirectional(manager, working_graph_id, a, b, label="suggested",
                                          props={"source": "llm"})
                if not applied:
                    no_action +=1
        else:
            return {"snapshots": snapshots, "actions": actions_log, "reason": action.get("reason","Stopped.")}

        actions_log.append(action)
        snapshots.append(snapshot_full_graph(manager, working_graph_id))

        if no_action > patience:
            print("Ran out of patience with no action. Stopping.")
            return {"snapshots": snapshots, "actions": actions_log, "reason": "Action produced no change."}

    return {"snapshots": snapshots, "actions": actions_log, "reason": f"Reached max steps ({max_steps})."}


## Create a Kuzu DB Manager

In [None]:
#this next line is useful for windows
#db_path = "C:/Users/sarwj/OneDrive - Cardiff University/Desktop/demo_kuzu"         # Kùzu DB directory (will be created/used)
#mgr = Kuzu.Manager(db_path)

# Configuration: Paths
db_path = "./demo_kuzu"  # Kùzu DB directory (will be created in project root)
mgr = Kuzu.Manager(db_path)
print(f"✓ Kuzu database initialized at: {db_path}")

## Import the graphs and store in Kuzu (Run Once)

In [None]:
# Configuration: This next line works on windows
#json_folder = "C:/Users/sarwj/OneDrive - Cardiff University/Desktop/msd_json/sample_graphs"        # folder with your *.json graphs


# Configuration: Swiss dwelling dataset path
json_folder = "/Users/td3003/import_export/msd_json/sample"  # Swiss dwelling dataset

print(f"Importing graphs from: {json_folder}")
_ = Kuzu.EmptyDatabase(mgr, recreateSchema=False)
gids = import_folder_to_kuzu(json_folder, mgr, undirected=True)
print("Imported", len(gids), "graphs")


## Expand the Graph

In [None]:

# Build a new working graph from a seed label that copies props from dataset
result = graphrag_build_loop(mgr,
                             working_graph_id = "work_demo",
                             start_label = "Entry",
                             description = "4 bedroom apartment with home office and a nursery",
                             max_steps = 12,
                             patience = 2,
                             max_connections = 2)
print(result["reason"])    # why it stopped
result["actions"]           # actions chosen at each step

## The Final Resulting Graph

In [None]:

last_graph = result["snapshots"][-1]  # TopologicPy Graph representation of the last graph
Topology.Show(last_graph,
              backgroundColor="white",
              vertexLabelKey="label",
              showVertexLabel=True,
              vertexSize=10,
              width=800,
              height=800,
              camera=[0,0,4])

In [None]:
Topology.Show(last_graph,
              backgroundColor="white",
              vertexLabelKey="roomtype",
              showVertexLabel=True,
              vertexSize=10,
              width=800,
              height=800,
              camera=[0,0,4])

## Show the sequence of suggestions

In [None]:
for i in range(len(result["snapshots"])):
    graph = Graph.Reshape(result["snapshots"][i], silent=True)  # TopologicPy Graph (if available) or dict
    vertices = Graph.Vertices(graph)
    for v in vertices:
        d = Topology.Dictionary(v)
        print(Dictionary.Keys(d), Dictionary.Values(d))
    Topology.Show(graph, backgroundColor="white", vertexLabelKey="roomtype", showVertexLabel=True, vertexSize=10, width=400, height=400, camera=[0,0,4])