<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/symbolic_theory_suite_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# symbolic_theory_suite.py
from __future__ import annotations

import argparse
import json
import os
import re
from dataclasses import dataclass
from itertools import combinations
from typing import Any, Dict, List, Tuple, Optional

import sympy as sp

try:
    import networkx as nx  # optional
    _HAS_NX = True
except Exception:
    _HAS_NX = False


# ------------------------- Utilities -------------------------

def _sanitize_name(x: Any) -> str:
    s = str(x)
    s = re.sub(r"[^0-9a-zA-Z_]", "_", s)
    if not s or s[0].isdigit():
        s = f"n_{s}"
    return s


def _graph_nodes_and_edges(g: Any) -> Tuple[List[Any], List[Tuple[Any, Any]]]:
    # networkx
    if _HAS_NX and isinstance(g, (nx.Graph, nx.DiGraph)):
        nodes = list(g.nodes())
        edges = list(g.edges())
        return nodes, edges

    # causal-learn-like Graph
    if hasattr(g, "node_num") and hasattr(g, "is_directed"):
        nodes = list(range(getattr(g, "node_num")))
        edges = []
        for i in nodes:
            for j in nodes:
                if i == j:
                    continue
                try:
                    if g.is_directed(i, j):
                        edges.append((i, j))
                except Exception:
                    pass
        edges = list(dict.fromkeys(edges))
        return nodes, edges

    # Generic attrs
    if hasattr(g, "nodes") and callable(getattr(g, "nodes")):
        nodes = list(g.nodes())
    elif hasattr(g, "nodes"):
        nodes = list(g.nodes)
    else:
        raise TypeError("Unsupported graph type: cannot extract nodes.")

    if hasattr(g, "edges") and callable(getattr(g, "edges")):
        edges = list(g.edges())
    elif hasattr(g, "edges"):
        edges = list(g.edges)
    else:
        edges = []
    return nodes, edges


# ------------------------- Configuration -------------------------

@dataclass
class BuildConfig:
    restrict_to_edges: bool = False
    add_constant: bool = True
    var_prefix: str = "v"


@dataclass
class SolveConfig:
    # Coefficient and constant substitution (keeps systems linear, numeric)
    coeff_value: float = 1.0
    const_value: float = 0.0

    # Anchors: map of variable symbol names to values (e.g., {"v_x": 1.0, "v_y": -1.0})
    anchors: Dict[str, float] = None

    # If underdetermined, greedily add anchors with this value until full rank (if possible)
    auto_anchor: bool = False
    auto_anchor_value: float = 1.0

    # Limits and simplification
    max_equations: int = 100
    simplify: bool = True


# ------------------------- Constructor -------------------------

class SymbolicTheoryConstructor:
    def __init__(self, var_prefix: str = "v"):
        self.var_prefix = var_prefix

    def _build_var_map(self, nodes: List[Any]) -> Dict[Any, sp.Symbol]:
        return {n: sp.Symbol(f"{self.var_prefix}_{_sanitize_name(n)}", real=True) for n in nodes}

    def build(self, causal_graph: Any, cfg: BuildConfig | None = None) -> Dict[str, Any]:
        if cfg is None:
            cfg = BuildConfig(var_prefix=self.var_prefix)
        else:
            if cfg.var_prefix != self.var_prefix:
                cfg = BuildConfig(
                    restrict_to_edges=cfg.restrict_to_edges,
                    add_constant=cfg.add_constant,
                    var_prefix=self.var_prefix,
                )

        nodes, edges = _graph_nodes_and_edges(causal_graph)
        if not nodes:
            raise ValueError("Graph has no nodes.")
        var_map = self._build_var_map(nodes)

        # choose pairs
        if cfg.restrict_to_edges and edges:
            undirected_pairs = {tuple(sorted((u, v))) for (u, v) in edges}
            pairs = sorted(list(undirected_pairs))
        else:
            pairs = list(combinations(nodes, 2))

        eqs: List[sp.Equality] = []
        for u, v in pairs:
            u_name = _sanitize_name(u)
            v_name = _sanitize_name(v)
            a, b = sp.symbols(f"a_{u_name}_{v_name} b_{u_name}_{v_name}", real=True)
            expr = a * var_map[u] + b * var_map[v]
            if cfg.add_constant:
                c = sp.symbols(f"c_{u_name}_{v_name}", real=True)
                expr = expr + c
            eqs.append(sp.Eq(expr, 0))

        theory = {"equations": eqs, "var_map": var_map, "graph": causal_graph, "config": cfg}
        return theory

    # ---------- Utilities ----------
    def pretty_print(self, theory: Dict[str, Any], latex: bool = False) -> str:
        eqs = theory.get("equations", [])
        return "\n".join(sp.latex(eq) if latex else str(eq) for eq in eqs)

    def substitute_params(self, eqs: List[sp.Equality], coeff_value: float, const_value: float) -> List[sp.Equality]:
        subs = {}
        for eq in eqs:
            for sym in eq.free_symbols:
                s = str(sym)
                if s.startswith(("a_", "b_")):
                    subs[sym] = coeff_value
                elif s.startswith("c_"):
                    subs[sym] = const_value
        return [sp.Eq(sp.simplify(eq.lhs.subs(subs)), 0) for eq in eqs]

    def apply_anchors(self, eqs: List[sp.Equality], anchors: Dict[sp.Symbol, float]) -> List[sp.Equality]:
        if not anchors:
            return eqs
        subs = {sym: val for sym, val in anchors.items()}
        return [sp.Eq(sp.simplify(eq.lhs.subs(subs)), 0) for eq in eqs]

    def deduplicate_equations(self, eqs: List[sp.Equality]) -> List[sp.Equality]:
        """
        Remove duplicates and trivial identities (0 = 0) after normalization by scalar multiple.
        """
        canon = []
        seen = set()
        for eq in eqs:
            lhs = sp.simplify(eq.lhs)
            if lhs == 0:
                continue
            # Normalize by first non-zero coefficient to avoid duplicates up to scale
            coeffs = list(lhs.as_coefficients_dict().values())
            scale = None
            for c in coeffs:
                if c != 0:
                    scale = c
                    break
            lhs_norm = sp.simplify(lhs / scale) if scale not in (None, 0) else lhs
            key = sp.srepr(lhs_norm)
            if key not in seen:
                seen.add(key)
                canon.append(sp.Eq(lhs_norm, 0))
        return canon

    def _rank_and_solution(self, eqs: List[sp.Equality], unknowns: List[sp.Symbol]) -> Tuple[int, int, Optional[List[Dict[sp.Symbol, Any]]]]:
        if not unknowns:
            return 0, 0, [dict()]  # nothing to solve; vacuously solved
        # Convert to matrix A*x = 0 (homogeneous), since RHS is 0
        A, b = sp.linear_eq_to_matrix([eq.lhs for eq in eqs], unknowns)
        rank_A = A.rank()
        # augmented rank equals rank_A because b is zero vector in this construction
        rank_aug = rank_A
        sols = []
        if rank_A == len(unknowns):
            # Unique solution for homogeneous is all zeros; but anchors may induce specific values.
            # Because we anchored some variables, unknowns can still be uniquely determined (possibly zeros).
            res = sp.linsolve((A, b))
            # linsolve returns FiniteSet((vals...))
            for sol_tuple in list(res):
                sols.append({var: val for var, val in zip(unknowns, sol_tuple)})
        else:
            # Multiple or infinite solutions (free parameters)
            res = sp.linsolve((A, b))
            if res:
                for sol_tuple in list(res):
                    sols.append({var: val for var, val in zip(unknowns, sol_tuple)})
        return rank_A, rank_aug, sols if sols else None

    def try_solve(self, theory: Dict[str, Any], solve_cfg: SolveConfig | None = None) -> Dict[str, Any]:
        """
        Attempt to solve a subset of equations by substituting coefficient/constant values,
        anchoring variables, and solving linear system for remaining variables.
        Can greedily auto-anchor to reach full rank.
        """
        if solve_cfg is None:
            solve_cfg = SolveConfig()

        eqs_all = list(theory["equations"])[: solve_cfg.max_equations]
        if not eqs_all:
            return {"status": "no_equations", "solution": None}

        var_map: Dict[Any, sp.Symbol] = theory["var_map"]
        vars_list = list(var_map.values())
        if not vars_list:
            return {"status": "no_variables", "solution": None}

        # Substitute parameters
        eqs_sub = self.substitute_params(eqs_all, solve_cfg.coeff_value, solve_cfg.const_value)

        # Prepare anchor mapping by symbol
        anchor_map_by_name = solve_cfg.anchors or {}
        anchors_by_sym: Dict[sp.Symbol, float] = {}
        for sym in vars_list:
            if str(sym) in anchor_map_by_name:
                anchors_by_sym[sym] = anchor_map_by_name[str(sym)]

        # Apply anchors
        eqs_anchored = self.apply_anchors(eqs_sub, anchors_by_sym)

        # Deduplicate
        eqs_canon = self.deduplicate_equations(eqs_anchored)

        # Unknowns are non-anchored variables
        unknowns = [v for v in vars_list if v not in anchors_by_sym]

        # Rank and solution
        rank_A, _, sols = self._rank_and_solution(eqs_canon, unknowns)

        status = "multiple" if (rank_A < len(unknowns)) else "unique"
        if len(unknowns) == 0:
            status = "anchored_only"

        # Optionally auto-anchor to reach full rank
        auto_anchor_steps: List[str] = []
        if solve_cfg.auto_anchor and status != "unique":
            to_anchor = [v for v in vars_list if v not in anchors_by_sym]
            # Greedy: anchor variables until rank equals unknown count or none left
            eqs_work = eqs_canon[:]
            anchors_work = dict(anchors_by_sym)
            unknowns_work = [v for v in vars_list if v not in anchors_work]
            while True:
                rank_A, _, _ = self._rank_and_solution(eqs_work, unknowns_work)
                if rank_A >= len(unknowns_work) or not unknowns_work:
                    break
                # anchor the next variable
                v_add = unknowns_work[0]
                anchors_work[v_add] = solve_cfg.auto_anchor_value
                auto_anchor_steps.append(f"{v_add}={solve_cfg.auto_anchor_value}")
                eqs_work = self.apply_anchors(eqs_canon, anchors_work)
                eqs_work = self.deduplicate_equations(eqs_work)
                unknowns_work = [v for v in vars_list if v not in anchors_work]
            # Recompute final solution
            rank_A, _, sols = self._rank_and_solution(eqs_work, unknowns_work)
            anchors_by_sym = anchors_work
            eqs_canon = eqs_work
            unknowns = unknowns_work
            status = "unique" if rank_A == len(unknowns) else "multiple" if len(unknowns) > 0 else "anchored_only"

        # Build assignment list(s)
        assignments: List[Dict[str, Any]] = []
        if sols:
            for sol in sols:
                # Merge anchors and solved unknowns into a single dict
                merged = {str(k): v for k, v in sol.items()}
                merged.update({str(k): v for k, v in anchors_by_sym.items()})
                assignments.append(merged)
        else:
            # If there are no unknowns (all anchored), that's still a valid assignment
            if len(unknowns) == 0 and anchors_by_sym:
                assignments.append({str(k): v for k, v in anchors_by_sym.items()})

        return {
            "status": status,
            "equations_used": len(eqs_canon),
            "rank": rank_A,
            "unknowns": [str(u) for u in unknowns],
            "anchors": {str(k): v for k, v in anchors_by_sym.items()},
            "assignments": assignments,
            "auto_anchor_steps": auto_anchor_steps,
        }

    def export_summary(self, theory: Dict[str, Any], result: Dict[str, Any]) -> str:
        cfg: BuildConfig = theory["config"]
        vm = theory["var_map"]
        eqs = theory["equations"]

        lines = []
        lines.append("=== Theory Summary ===")
        lines.append(f"- Mode: {'edges-only' if cfg.restrict_to_edges else 'all-pairs'}; constant term: {cfg.add_constant}")
        lines.append(f"- Variables: {', '.join(str(s) for s in vm.values())}")
        lines.append(f"- Equations: {len(eqs)}")
        lines.append("")
        lines.append("Sample equations (up to 5):")
        for eq in eqs[:5]:
            lines.append(f"  • {eq}")

        lines.append("")
        lines.append("=== Solve Attempt ===")
        lines.append(f"- Status: {result.get('status')}")
        lines.append(f"- Equations used: {result.get('equations_used', 0)}; Rank(A): {result.get('rank')}")
        if result.get("anchors"):
            lines.append(f"- Anchors: {', '.join(f'{k}={v}' for k, v in result['anchors'].items())}")
        if result.get("unknowns"):
            lines.append(f"- Unknowns: {', '.join(result['unknowns'])}")
        if result.get("auto_anchor_steps"):
            lines.append(f"- Auto-anchored: {', '.join(result['auto_anchor_steps'])}")

        if result.get("assignments"):
            lines.append("- Assignment (first):")
            first = result["assignments"][0]
            for k in sorted(first.keys()):
                lines.append(f"  {k} = {first[k]}")
        else:
            lines.append("- Assignment: none")

        return "\n".join(lines)


# ------------------------- Demo / CLI -------------------------

def _build_demo_graph():
    if _HAS_NX:
        G = nx.DiGraph()
        G.add_nodes_from(["x", "y", "z"])
        G.add_edges_from([("x", "y"), ("y", "z")])
        return G
    # Fallback
    class _G:
        def __init__(self):
            self._nodes = ["x", "y", "z"]
            self._edges = [("x", "y"), ("y", "z")]
        def nodes(self): return self._nodes
        def edges(self): return self._edges
    return _G()


def _parse_anchor(s: str) -> Tuple[str, float]:
    if "=" not in s:
        raise argparse.ArgumentTypeError("Anchors must be of form NAME=VALUE, e.g., v_x=1.0")
    name, val = s.split("=", 1)
    try:
        return name.strip(), float(val.strip())
    except ValueError:
        raise argparse.ArgumentTypeError(f"Invalid anchor value: {val}")


def main():
    parser = argparse.ArgumentParser(description="Build, solve, and export symbolic theories from a graph.")
    parser.add_argument("--mode", choices=["edges", "all"], default="edges", help="Use only edges (edges) or all pairs (all).")
    parser.add_argument("--omit-constant", action="store_true", help="Omit constant term in equations.")
    parser.add_argument("--coeff", type=float, default=1.0, help="Coefficient value to substitute for all a_*, b_*.")
    parser.add_argument("--const", type=float, default=0.0, help="Constant value to substitute for all c_*.")
    parser.add_argument("--anchor", type=_parse_anchor, action="append", default=[], help="Anchor variable, e.g., v_x=1.0 (repeatable).")
    parser.add_argument("--auto-anchor", action="store_true", help="Greedily anchor additional variables to reach full rank.")
    parser.add_argument("--auto-anchor-value", type=float, default=1.0, help="Value for auto-added anchors.")
    parser.add_argument("--max-eq", type=int, default=100, help="Max equations to consider.")
    parser.add_argument("--out-dir", type=str, default="theory_reports", help="Directory to write summaries/LaTeX.")
    parser.add_argument("--latex", action="store_true", help="Also export LaTeX equations.")
    args = parser.parse_args()

    # Build demo graph
    G = _build_demo_graph()
    ctor = SymbolicTheoryConstructor(var_prefix="v")

    cfg = BuildConfig(restrict_to_edges=(args.mode == "edges"), add_constant=not args.omit_constant, var_prefix="v")
    theory = ctor.build(G, cfg)

    solve_cfg = SolveConfig(
        coeff_value=args.coeff,
        const_value=args.const,
        anchors=dict(args.anchor) if args.anchor else {},
        auto_anchor=args.auto_anchor,
        auto_anchor_value=args.auto_anchor_value,
        max_equations=args.max_eq,
    )
    result = ctor.try_solve(theory, solve_cfg)

    # Print equations and summary
    print("=== Symbolic Theory: Build • Solve • Export ===\n")
    print("--- Equations ---")
    print(ctor.pretty_print(theory))
    print("")
    print(ctor.export_summary(theory, result))

    # Write artifacts
    out_dir = args.out_dir
    os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, f"{args.mode}_summary.txt"), "w", encoding="utf-8") as f:
        f.write(ctor.export_summary(theory, result))
    if args.latex:
        with open(os.path.join(out_dir, f"{args.mode}.tex"), "w", encoding="utf-8") as f:
            f.write(ctor.pretty_print(theory, latex=True))
    with open(os.path.join(out_dir, f"{args.mode}_solution.json"), "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2, default=str)

    print(f"\n[Saved] {args.mode}_summary.txt, {args.mode}_solution.json{', and ' + args.mode + '.tex' if args.latex else ''} to {out_dir}/")


if __name__ == "__main__":
    main()