In [1]:
"""
Make your saccadic search + Dijkstra use LABELS ("row,col") as inputs,
print the path as LABELS, and compute/return the TOTAL COST using edge weights.

What this adds/changes:
- Reads each node’s yEd label "row,col" from the original .graphml and injects it
  into the NetworkX graph as node attribute: graph.nodes[n]["node_label"].
- Lets you call Dijkstra and Saccadic *by labels* (e.g., "14,30" -> "2,30").
- Ensures each edge has a numeric 'weight' (uses weight/w_total/d_weight/d_wtot).
- Displays the solution as a list of LABELS and the TOTAL COST.

Drop-in usage (example):
    map_graph = nx.read_graphml(map_graph_path)
    inject_node_labels_from_graphml_file(map_graph, map_graph_path)  # <-- important
    ensure_edge_weights(map_graph)                                   # <-- important

    # Dijkstra by labels
    path_labels, cost = dijkstra_search_labels(map_graph, "14,30", "2,30")
    print("Dijkstra path:", " -> ".join(path_labels), "| cost =", cost)

    # Saccadic by labels
    s_path_labels, s_cost, found = saccadic_search_strategy_by_label(
        map_graph, start_label="14,30", target_label="2,30", max_steps=500
    )
    print("Saccadic path:", " -> ".join(s_path_labels), "| cost =", s_cost, "| found =", found)
"""

import re
import random
import heapq
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Dict, Tuple, Optional, List, Set, Union, Iterable

import networkx as nx

# ------------------------ Config you already have ------------------------
p_tp = 0.8  # P(true positive)
p_fp = 0.2  # P(false positive)
p_tn = 0.8  # P(true negative)
p_fn = 0.2  # P(false negative)

# ------------------------ yEd / GraphML namespaces ------------------------
NS = {
    "g": "http://graphml.graphdrawing.org/xmlns",
    "y": "http://www.yworks.com/xml/graphml",
}
LABEL_RE = re.compile(r"^\s*(-?\d+)\s*,\s*(-?\d+)\s*$")

# ========================================================================
#                    LABEL INJECTION (id <-> "row,col")
# ========================================================================

def inject_node_labels_from_graphml_file(
    G: nx.Graph,
    graphml_path: Union[str, Path],
    label_attr: str = "node_label",
) -> Dict[str, str]:
    """
    Parse the original yEd .graphml with ElementTree, read each node's y:NodeLabel text
    like "row,col", and set it on the NX graph as node attribute G.nodes[id][label_attr].

    Returns a mapping: label_str -> node_id (for quick lookup).
    """
    tree = ET.parse(str(graphml_path))
    root = tree.getroot()

    id_to_label: Dict[str, str] = {}
    label_to_id: Dict[str, str] = {}

    # Walk all nodes in the raw GraphML to extract labels
    for n in root.findall(".//g:graph/g:node", NS):
        nid = n.attrib.get("id")
        if not nid or nid not in G:
            continue

        text = None
        for d in n.findall("./g:data", NS):
            lab = d.find("./y:ShapeNode/y:NodeLabel", NS)
            if lab is None:
                lab = d.find("./y:GenericNode/y:NodeLabel", NS)
            if lab is not None and lab.text:
                text = lab.text.strip()
                break

        if text and LABEL_RE.match(text):
            id_to_label[nid] = text
            label_to_id.setdefault(text, nid)

    # Inject into the NX graph (fallback to node id if no label found)
    for nid in G.nodes:
        lbl = id_to_label.get(nid, str(nid))
        G.nodes[nid][label_attr] = lbl
        if lbl not in label_to_id:
            label_to_id[lbl] = nid  # ensure we can find it by whatever label we set

    return label_to_id


def build_label_index(G: nx.Graph, label_attr: str = "node_label") -> Dict[str, str]:
    """Build label -> node_id index from the NX graph (after injection)."""
    idx: Dict[str, str] = {}
    for nid, data in G.nodes(data=True):
        lbl = data.get(label_attr, str(nid))
        idx.setdefault(lbl, nid)
    return idx


def label_of(G: nx.Graph, node_id: str, label_attr: str = "node_label") -> str:
    """Convenience: get the display label for a node id."""
    return G.nodes[node_id].get(label_attr, str(node_id))

# ========================================================================
#                    EDGE WEIGHTS (ensure numeric)
# ========================================================================

def _pick_weight_from_attrs(attrs: Dict) -> Optional[float]:
    """Try weight fields in this order, converting to float if possible."""
    for key in ("weight", "w_total", "d_weight", "d_wtot"):
        if key in attrs and attrs[key] is not None:
            try:
                return float(attrs[key])
            except Exception:
                pass
    return None


def ensure_edge_weights(G: nx.Graph, default: float = 1.0) -> None:
    """
    Ensure every edge has a numeric 'weight' attribute:
    - Prefer existing 'weight' if numeric/parsable
    - Else use 'w_total' or 'd_weight' or 'd_wtot'
    - Else assign the given default
    """
    if G.is_multigraph():
        for u, v, k, data in G.edges(keys=True, data=True):
            w = _pick_weight_from_attrs(data)
            data["weight"] = w if w is not None else float(default)
    else:
        for u, v, data in G.edges(data=True):
            w = _pick_weight_from_attrs(data)
            data["weight"] = w if w is not None else float(default)


def edge_weight(G: nx.Graph, u: str, v: str) -> float:
    """Get the weight for edge (u,v); if MultiGraph pick the minimum-weight parallel edge."""
    if G.is_multigraph():
        data = G.get_edge_data(u, v)
        if not data:
            return float("inf")
        best = float("inf")
        for k, attrs in data.items():
            try:
                w = float(attrs.get("weight", 1.0))
            except Exception:
                w = 1.0
            best = min(best, w)
        return best
    else:
        attrs = G.get_edge_data(u, v, default=None) or {}
        try:
            return float(attrs.get("weight", 1.0))
        except Exception:
            return 1.0

def path_total_cost(G: nx.Graph, path_ids: List[str]) -> float:
    """Sum weights along a node-id path [n0, n1, ..., nk]."""
    total = 0.0
    for a, b in zip(path_ids, path_ids[1:]):
        total += edge_weight(G, a, b)
    return total

# ========================================================================
#                    DIJKSTRA (by labels)
# ========================================================================

def dijkstra_search_labels(
    G: nx.Graph,
    start_label: str,
    target_label: str,
    *,
    label_attr: str = "node_label",
) -> Tuple[List[str], float]:
    """
    Run Dijkstra from start_label to target_label.
    Returns (path_as_LABELS, total_cost).
    """
    label_index = build_label_index(G, label_attr=label_attr)
    if start_label not in label_index:
        raise KeyError(f"Start label not found: {start_label}")
    if target_label not in label_index:
        raise KeyError(f"Target label not found: {target_label}")

    s = label_index[start_label]
    t = label_index[target_label]

    # Dijkstra using 'weight' already ensured to be numeric
    path_ids = nx.dijkstra_path(G, source=s, target=t, weight="weight")
    cost = path_total_cost(G, path_ids)
    path_labels = [label_of(G, nid, label_attr=label_attr) for nid in path_ids]
    return path_labels, cost

# ========================================================================
#                 Belief map + Saccadic (by labels)
# ========================================================================

def set_target_by_label(G: nx.Graph, target_label: str, *, label_attr: str = "node_label") -> None:
    """Mark exactly one node (by label) as the target in node attribute 'target'=True."""
    label_index = build_label_index(G, label_attr=label_attr)
    if target_label not in label_index:
        raise KeyError(f"Target label not found: {target_label}")
    target_id = label_index[target_label]
    # Clear existing target flags
    for n in G.nodes():
        if "target" in G.nodes[n]:
            G.nodes[n]["target"] = False
    G.nodes[target_id]["target"] = True


def initialize_belief_map(G: nx.Graph) -> None:
    """Uniform prior over nodes in attribute 'P(target)'."""
    num_nodes = len(G.nodes())
    if num_nodes == 0:
        return
    uniform_belief = 1.0 / num_nodes
    nx.set_node_attributes(G, {n: uniform_belief for n in G.nodes()}, "P(target)")


def update_belief_map(G: nx.Graph, agent_node: str, observation: bool) -> None:
    """
    Bayesian update of P(target) for all nodes given an observation at agent_node.
    Uses globals p_tp, p_fp, p_tn, p_fn.
    """
    # denominator = P(O) or P(~O)
    denom = 0.0
    for n in G.nodes():
        prior = G.nodes[n].get("P(target)", 0.0)
        if n == agent_node:
            denom += prior * (p_tp if observation else p_fn)
        else:
            denom += prior * (p_fp if observation else p_tn)
    if denom == 0.0:
        return  # no update possible

    new_probs: Dict[str, float] = {}
    for n in G.nodes():
        prior = G.nodes[n].get("P(target)", 0.0)
        if n == agent_node:
            num = prior * (p_tp if observation else p_fn)
        else:
            num = prior * (p_fp if observation else p_tn)
        new_probs[n] = num / denom

    nx.set_node_attributes(G, new_probs, "P(target)")


def saccadic_search_strategy_by_label(
    G: nx.Graph,
    start_label: str,
    target_label: str,
    *,
    max_steps: int = 500,
    label_attr: str = "node_label",
) -> Tuple[List[str], float, bool]:
    """
    Saccadic search:
    - Start at start_label
    - At each step, route (via Dijkstra) to the node with highest belief
    - Move one hop toward it, observe, Bayesian update
    - Stops when a true-positive observation occurs at the target node or when max_steps reached

    Returns (path_as_LABELS, total_travel_cost, target_found_bool)
    """
    ensure_edge_weights(G)  # make sure 'weight' exists

    label_index = build_label_index(G, label_attr=label_attr)

    if start_label not in label_index:
        raise KeyError(f"Start label not found: {start_label}")
    if target_label not in label_index:
        raise KeyError(f"Target label not found: {target_label}")

    current = label_index[start_label]
    set_target_by_label(G, target_label, label_attr=label_attr)

    initialize_belief_map(G)

    path_taken_ids: List[str] = [current]
    total_cost = 0.0
    target_found = False

    current_planned_path: List[str] = []  # in node IDs

    for step in range(max_steps):
        # Need a plan?
        if not current_planned_path:
            # pick highest belief node
            highest_belief_node = max(G.nodes(data="P(target)"), key=lambda x: x[1])[0]

            # Dijkstra path from current to that node
            try:
                current_planned_path = nx.dijkstra_path(G, current, highest_belief_node, weight="weight")
            except nx.NetworkXNoPath:
                # No path; stop
                break

            # first is current; drop it so next move advances
            if len(current_planned_path) > 1:
                current_planned_path.pop(0)

        # Move one hop if possible
        if current_planned_path:
            nxt = current_planned_path.pop(0)
            # accumulate travel cost
            total_cost += edge_weight(G, current, nxt)
            current = nxt
            path_taken_ids.append(current)

        # Observation at current
        is_target = G.nodes[current].get("target", False)
        observation = (random.random() < (p_tp if is_target else p_fp))
        update_belief_map(G, current, observation)

        if is_target and observation:
            target_found = True
            break

    # Build label path for output
    path_labels = [label_of(G, nid, label_attr=label_attr) for nid in path_taken_ids]
    return path_labels, total_cost, target_found

# ========================================================================
#                      OPTIONAL: pretty-printer
# ========================================================================

def print_path_and_cost(label_path: List[str], total_cost: float, header: str = "Solution"):
    print(header)
    print("Path (labels):", " -> ".join(label_path))
    print(f"Total cost: {total_cost:.6f}")
