<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/rge_theory_constructor_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# rge_theory_constructor.py
from __future__ import annotations

import re
from dataclasses import dataclass
from itertools import combinations
from typing import Any, Dict, List, Tuple

import sympy as sp

try:
    import networkx as nx  # optional, improves graph handling if available
    _HAS_NX = True
except Exception:
    _HAS_NX = False


# ------------------------- Utilities -------------------------

def _sanitize_name(x: Any) -> str:
    """
    Make a safe symbol name fragment from an arbitrary node label.
    """
    s = str(x)
    s = re.sub(r"[^0-9a-zA-Z_]", "_", s)
    if not s or s[0].isdigit():
        s = f"n_{s}"
    return s


def _graph_nodes_and_edges(g: Any) -> Tuple[List[Any], List[Tuple[Any, Any]]]:
    """
    Extract nodes and edges from:
      - networkx.Graph/DiGraph (if installed)
      - causal-learn-like Graph (has .node_num and .is_directed)
      - generic object exposing .nodes()/.edges() or .nodes/.edges properties
    """
    # networkx
    if _HAS_NX and isinstance(g, (nx.Graph, nx.DiGraph)):
        nodes = list(g.nodes())
        edges = list(g.edges())
        return nodes, edges

    # causal-learn-like Graph
    if hasattr(g, "node_num") and hasattr(g, "is_directed"):
        nodes = list(range(getattr(g, "node_num")))
        edges = []
        for i in nodes:
            for j in nodes:
                if i == j:
                    continue
                try:
                    if g.is_directed(i, j):
                        edges.append((i, j))
                except Exception:
                    # best-effort
                    pass
        # Deduplicate
        edges = list(dict.fromkeys(edges))
        return nodes, edges

    # Generic
    if hasattr(g, "nodes") and callable(getattr(g, "nodes")):
        nodes = list(g.nodes())
    elif hasattr(g, "nodes"):
        nodes = list(g.nodes)
    else:
        raise TypeError("Unsupported graph type: cannot extract nodes. Provide a networkx-like or causal-learn-like graph.")

    if hasattr(g, "edges") and callable(getattr(g, "edges")):
        edges = list(g.edges())
    elif hasattr(g, "edges"):
        edges = list(g.edges)
    else:
        edges = []
    return nodes, edges


# ------------------------- Configuration -------------------------

@dataclass
class BuildConfig:
    """
    Configuration for theory construction.
    - restrict_to_edges: If True, propose equations only for adjacent node pairs (edges).
                         If False, propose for all pairs of nodes.
    - add_constant: Include a free constant term 'c' in the linear relation.
    - var_prefix: Prefix for per-node SymPy variables.
    """
    restrict_to_edges: bool = False
    add_constant: bool = True
    var_prefix: str = "v"


# ------------------------- Constructor -------------------------

class SymbolicTheoryConstructor:
    """
    Takes a causal graph → proposes candidate symbolic relations.

    Default: for each pair (u, v), proposes a linear constraint
        a_{u,v} * V_u + b_{u,v} * V_v + c_{u,v} = 0
    where V_u, V_v are SymPy symbols bound to graph nodes.

    Returns:
      dict with:
        - 'equations': List[sp.Equality]
        - 'var_map': Dict[node_label, sp.Symbol]
        - 'graph': original graph object
        - 'config': BuildConfig used
    """
    def __init__(self, var_prefix: str = "v"):
        self.var_prefix = var_prefix

    def _build_var_map(self, nodes: List[Any]) -> Dict[Any, sp.Symbol]:
        var_map: Dict[Any, sp.Symbol] = {}
        for n in nodes:
            name = f"{self.var_prefix}_{_sanitize_name(n)}"
            var_map[n] = sp.Symbol(name, real=True)
        return var_map

    def build(self, causal_graph: Any, cfg: BuildConfig | None = None) -> Dict[str, Any]:
        if cfg is None:
            cfg = BuildConfig(var_prefix=self.var_prefix)
        else:
            # harmonize prefix if passed via ctor
            if cfg.var_prefix != self.var_prefix:
                cfg = BuildConfig(
                    restrict_to_edges=cfg.restrict_to_edges,
                    add_constant=cfg.add_constant,
                    var_prefix=self.var_prefix,
                )

        nodes, edges = _graph_nodes_and_edges(causal_graph)
        if not nodes:
            raise ValueError("Graph has no nodes; cannot construct theory.")
        var_map = self._build_var_map(nodes)

        # Choose pairs
        if cfg.restrict_to_edges and edges:
            # Use undirected interpretation for pairing if underlying graph is directed
            undirected_pairs = {tuple(sorted((u, v))) for (u, v) in edges}
            pairs = sorted(list(undirected_pairs))
        else:
            pairs = list(combinations(nodes, 2))

        eqs: List[sp.Equality] = []
        for u, v in pairs:
            u_name = _sanitize_name(u)
            v_name = _sanitize_name(v)
            a, b = sp.symbols(f"a_{u_name}_{v_name} b_{u_name}_{v_name}", real=True)
            expr = a * var_map[u] + b * var_map[v]
            if cfg.add_constant:
                c = sp.symbols(f"c_{u_name}_{v_name}", real=True)
                expr = expr + c
            eqs.append(sp.Eq(expr, 0))

        theory = {
            "equations": eqs,
            "var_map": var_map,
            "graph": causal_graph,
            "config": cfg,
        }
        return theory

    def pretty_print(self, theory: Dict[str, Any], latex: bool = False) -> str:
        """
        Render the theory equations in plain text or LaTeX.
        """
        eqs = theory.get("equations", [])
        if latex:
            lines = [sp.latex(eq) for eq in eqs]
            return "\n".join(lines)
        else:
            lines = [str(eq) for eq in eqs]
            return "\n".join(lines)

    def with_sample_values(
        self,
        theory: Dict[str, Any],
        coeff_value: float = 1.0,
        const_value: float = 0.0
    ) -> List[sp.Equality]:
        """
        Return equations with coefficients replaced by sample numbers.
        """
        substitutions = {}
        for eq in theory["equations"]:
            for sym in eq.free_symbols:
                name = str(sym)
                if name.startswith("a_") or name.startswith("b_"):
                    substitutions[sym] = coeff_value
                elif name.startswith("c_"):
                    substitutions[sym] = const_value
        return [eq.subs(substitutions) for eq in theory["equations"]]


# ------------------------- Demo / CLI -------------------------

def _build_demo_graph():
    """
    Build a simple demo graph. Uses networkx if available; otherwise a minimal stub.
    """
    if _HAS_NX:
        G = nx.DiGraph()
        G.add_nodes_from(["x", "y", "z"])
        G.add_edges_from([("x", "y"), ("y", "z")])
        return G

    # Fallback: minimal graph-like object
    class _G:
        def __init__(self):
            self._nodes = ["x", "y", "z"]
            self._edges = [("x", "y"), ("y", "z")]
        def nodes(self): return self._nodes
        def edges(self): return self._edges
    return _G()


def main():
    print("=== Symbolic Theory Constructor Demo ===")
    G = _build_demo_graph()

    builder = SymbolicTheoryConstructor(var_prefix="v")

    # Scenario A: only adjacent pairs (edges)
    cfg_edges = BuildConfig(restrict_to_edges=True, add_constant=True, var_prefix="v")
    theory_edges = builder.build(G, cfg_edges)
    print("\n-- Adjacent pairs only --")
    print(f"Variables: {[str(s) for s in theory_edges['var_map'].values()]}")
    print(f"Equations count: {len(theory_edges['equations'])}")
    print("\nPlain text equations:")
    print(builder.pretty_print(theory_edges))
    print("\nLaTeX equations:")
    print(builder.pretty_print(theory_edges, latex=True))
    print("\nSample substitution (a=b=1, c=0):")
    for eq in builder.with_sample_values(theory_edges)[:5]:
        print(eq)

    # Scenario B: all pairs
    cfg_all = BuildConfig(restrict_to_edges=False, add_constant=True, var_prefix="v")
    theory_all = builder.build(G, cfg_all)
    print("\n-- All node pairs --")
    print(f"Equations count: {len(theory_all['equations'])}")
    for eq in theory_all["equations"]:
        print(eq)

    print("\nDone.")


if __name__ == "__main__":
    main()