<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/rge_abstraction_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install causal-learn

In [None]:
# rge/abstraction.py
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import List, Optional, Any

import numpy as np
import torch
import torch.nn as nn
from sklearn.cluster import SpectralClustering
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

try:
    import networkx as nx  # optional
    _HAS_NX = True
except Exception:
    _HAS_NX = False

# causal-learn: pip install causallearn
from causallearn.search.ConstraintBased.PC import pc


@dataclass
class ExtractConfig:
    """Configuration for invariant extraction."""
    n_clusters: Optional[int] = None               # if None, auto-select
    max_clusters: int = 8
    min_cluster_size: int = 20
    standardize: bool = True
    affinity: str = "nearest_neighbors"            # more robust in high-d than rbf
    n_neighbors: Optional[int] = None              # auto if None
    random_state: Optional[int] = 42
    pca_dim: Optional[int] = 16                    # reduce latent dim before PC; None to disable
    pc_alpha: float = 0.05
    pc_indep_test: str = "fisherz"                 # common choice for continuous data
    to_networkx: bool = False                      # convert cg.G to networkx if available
    verbose: bool = False


class LatentInvariantExtractor(nn.Module):
    """
    Autoencoder → latent → spectral clusters → causal graph:
        raw → z → {clusters} → PC algorithm → invariant subgraphs
    """
    def __init__(self, in_dim: int = 6, latent_dim: int = 128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, 256), nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, in_dim)
        )
        self.latent_dim = latent_dim

    def forward(self, x: torch.Tensor):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return z, x_hat

    @torch.no_grad()
    def extract(self, data: torch.Tensor, cfg: Optional[ExtractConfig] = None) -> List[Any]:
        """
        Returns a list of candidate invariant structures as causal subgraphs.

        - data: tensor of shape (n_samples, in_dim)
        - cfg: ExtractConfig with clustering and PC options

        Each returned item is either:
          - causal-learn Graph (cg.G), or
          - networkx.DiGraph if cfg.to_networkx=True and networkx is installed
        """
        if cfg is None:
            cfg = ExtractConfig()

        # Fail-fast validations
        if data.dim() != 2:
            raise ValueError(f"data must be 2D (n_samples, in_dim); got shape {tuple(data.shape)}")
        n_samples, _ = data.shape
        if n_samples < 4:
            raise ValueError(f"Need at least 4 samples; got {n_samples}")

        self.eval()

        # 1) Encode to latent
        z, _ = self.forward(data)
        z_np = z.detach().cpu().numpy()

        # 2) Optional standardization (recommended)
        if cfg.standardize:
            z_np = StandardScaler().fit_transform(z_np)

        # 3) Decide number of clusters
        if cfg.n_clusters is None:
            # Heuristic: about one cluster per ~100 samples, clipped
            auto_k = max(2, min(cfg.max_clusters, n_samples // 100 or 2))
        else:
            auto_k = int(cfg.n_clusters)
        if auto_k > n_samples:
            warnings.warn(f"Reducing n_clusters from {auto_k} to {n_samples} (n_samples limit).")
            auto_k = n_samples
        if auto_k < 2:
            auto_k = 2  # SpectralClustering requires at least 2

        # 4) Spectral clustering
        # Choose neighbors if using nearest_neighbors
        n_neighbors = cfg.n_neighbors
        if cfg.affinity == "nearest_neighbors":
            # A small, safe default; ensure < n_samples
            if n_neighbors is None:
                n_neighbors = max(5, min(10, n_samples - 1))
            if n_neighbors >= n_samples:
                n_neighbors = n_samples - 1

        sc = SpectralClustering(
            n_clusters=auto_k,
            affinity=cfg.affinity,
            n_neighbors=n_neighbors if cfg.affinity == "nearest_neighbors" else None,
            random_state=cfg.random_state,
            assign_labels="kmeans"
        )
        clusters = sc.fit_predict(z_np)

        # 5) Optional dimensionality reduction before PC (keeps PC tractable)
        X_for_pc = z_np
        if cfg.pca_dim is not None and cfg.pca_dim < X_for_pc.shape[1]:
            pca = PCA(n_components=cfg.pca_dim, random_state=cfg.random_state)
            X_for_pc = pca.fit_transform(X_for_pc)

        # 6) Causal discovery inside each cluster
        invariants: List[Any] = []
        unique_clusters = np.unique(clusters)
        for c in unique_clusters:
            idx = clusters == c
            size = int(idx.sum())
            if size < cfg.min_cluster_size:
                if cfg.verbose:
                    print(f"[SKIP] cluster {c}: size {size} < min_cluster_size {cfg.min_cluster_size}")
                continue

            Xc = X_for_pc[idx, :]
            # causal-learn expects shape (n_samples, n_features)
            cg = pc(
                Xc,
                alpha=cfg.pc_alpha,
                indep_test=cfg.pc_indep_test,
                verbose=cfg.verbose
            )
            G = cg.G
            if cfg.to_networkx:
                if not _HAS_NX:
                    warnings.warn("networkx not installed; returning causal-learn Graph instead.")
                    invariants.append(G)
                else:
                    # causal-learn Graph has utility to convert; fallback to manual conversion if needed
                    try:
                        G_nx = G.to_nx_graph()  # available in recent versions
                    except Exception:
                        # Manual conversion as directed graph
                        G_nx = nx.DiGraph()
                        nodes = list(range(G.node_num))
                        G_nx.add_nodes_from(nodes)
                        for i in nodes:
                            for j in nodes:
                                if i != j and G.is_directed(i, j):
                                    G_nx.add_edge(i, j)
                    invariants.append(G_nx)
            else:
                invariants.append(G)

        if cfg.verbose:
            kept = len(invariants)
            total = len(unique_clusters)
            print(f"[DONE] clusters processed: {total}, invariants kept: {kept}")

        return invariants