In [None]:
#!/usr/bin/env python3
from gbp.gbp import *
from gbp.factor import *
from gbp.grid import *
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

class MultilayerAbstractionSystem:
    def __init__(self, base_graph, num_layers=3):
        self.layers = []  # Each is an instance of AbstractionLayer
        self.num_layers = num_layers
        self.build_base_layer(base_graph)
    
    def build_base_layer(self, base_graph):
        layer0 = AbstractionLayer(level=0, graph=base_graph)
        self.layers.append(layer0)

    def build_abstraction_layers(self):
        for l in range(1, self.num_layers):
            coarse_layer = self.layers[-1].abstract_to_next_layer()
            self.layers.append(coarse_layer)

    def run_gbp_on_all_layers(self, num_iters=20):
        for layer in self.layers:
            layer.run_gbp(num_iters=num_iters)

    def recover_to_base(self):
        """
        Backward propagation of solutions from coarsest layer to base.
        Could apply coarse correction or interpolation.
        """
        for l in reversed(range(1, len(self.layers))):
            self.layers[l].propagate_to_finer(self.layers[l-1])


In [None]:
class AbstractionLayer:
    def __init__(self, level, PrevLayer):
        self.level = level
        self.super_graph = self.build_super_graph(PrevLayer)  # GBP-style graph (var_nodes, factors)
        self.abstract_graph = self.build_abstract_graph(self.supergraph)  # Abstract graph
    
    
    def build_super_graph(self, PrevLayer):
        """
        Construct supergraph from previous layer's graph.
        This involves aggregating variable nodes and factors.
        """

        varis_sup, prior_facs_sup, horizontal_facs_sup, vertical_facs_sup  = build_coarse_slam_graph(
            prior_facs_fine=prior_facs,
            between_facs_fine=between_facs,
            H=H, W=W,
            stride = 2,
        )

        return PrevLayer.graph
    

    def build_abstract_graph(self, supergraph, r=2):
        varis_sup = supergraph.varis
        prior_facs_sup = supergraph.prior_facs
        horizontal_facs_sup = supergraph.horizontal_facs
        vertical_facs_sup = supergraph.vertical_facs

        abs_beliefs = []
        Bs = []
        ks = []

        # === 1. Build Abstraction Variables ===
        varis_sup_mu = varis_sup.belief.mu()
        varis_sup_sigma = varis_sup.belief.sigma()

        for i in range(len(varis_sup.var_id)):
            varis_sup_mu_i = varis_sup_mu[i]
            varis_sup_sigma_i = varis_sup_sigma[i]

            eigvals, eigvecs = np.linalg.eigh(varis_sup_sigma_i)

            # Step 2: Sort eigenvalues and eigenvectors in descending order of eigenvalues
            idx = np.argsort(eigvals)[::-1]      # Get indices of sorted eigenvalues (largest first)
            eigvals = eigvals[idx]               # Reorder eigenvalues
            eigvecs = eigvecs[:, idx]            # Reorder corresponding eigenvectors

            # Step 3: Select the top-k eigenvectors to form the projection matrix (principal subspace)
            r = r
            B_reduced = eigvecs[:, :r]                 # B_reduced: shape (8, r), projects 8D to rD

            Bs.append(B_reduced)                        # Store the projection matrix for this variable

            # Step 4: Project eta and Lam onto the reduced 2D subspace
            # This gives the natural parameters of the reduced 2D Gaussian
            varis_abs_mu_i = B_reduced.T @ varis_sup_mu_i          # Projected natural mean: shape (2,)
            varis_abs_sigma_i = B_reduced.T @ varis_sup_sigma_i @ B_reduced  # Projected covariance: shape (2, 2)
            ks.append(varis_sup_mu_i - B_reduced @ varis_abs_mu_i)  # Store the offset for this variable

            varis_abs_lam_i = jnp.linalg.inv(varis_abs_sigma_i)  # Inverse covariance (precision matrix): shape (2, 2)
            varis_abs_eta_i = varis_abs_lam_i @ varis_abs_mu_i  # Natural parameters: shape (2,)
            abs_beliefs.append(Gaussian(varis_abs_eta_i, varis_abs_lam_i))


        N, Ni_v, _ = varis_sup.msgs.eta.shape
        abs_msgs = Gaussian(jnp.zeros((N, Ni_v, r)), jnp.zeros((N, Ni_v, r, r)))  # messages (eta, Lambda) to each factor port

        varis_abs = Variable(
            var_id=varis_sup.var_id,
            belief=tree_stack(abs_beliefs, axis=0),
            msgs=abs_msgs,
            adj_factor_idx=jnp.stack(varis_sup.adj_factor_idx),
        )

        # === 2. Build Abs Priors ===
        prior_facs_abs = Factor(
            factor_id=prior_facs_sup.factor_id,
            z=prior_facs_sup.z,
            z_Lam=prior_facs_sup.z_Lam,
            threshold=prior_facs_sup.threshold,
            potential=None,
            adj_var_id=prior_facs_sup.adj_var_id,
            adj_var_idx=prior_facs_sup.adj_var_idx,
        )

        # === Build Horizontal & Vertical Between Factors Separately ===
        horizontal_facs_abs = Factor(
            factor_id=horizontal_facs_sup.factor_id,
            z= horizontal_facs_sup.z,
            z_Lam= horizontal_facs_sup.z_Lam,
            threshold= horizontal_facs_sup.threshold,
            potential=None,
            adj_var_id= horizontal_facs_sup.adj_var_id,
            adj_var_idx= horizontal_facs_sup.adj_var_idx,
        )   

        vertical_facs_abs = Factor(
            factor_id=vertical_facs_sup.factor_id,
            z= vertical_facs_sup.z,
            z_Lam= vertical_facs_sup.z_Lam,
            threshold= vertical_facs_sup.threshold,
            potential=None,                                         
            adj_var_id= vertical_facs_sup.adj_var_id,
            adj_var_idx= vertical_facs_sup.adj_var_idx,
        )

        # === 3. Build Abstract Graph ===
        abstract_graph = Graph(
            varis=varis_abs,
            prior_facs=prior_facs_abs,
            horizontal_facs=horizontal_facs_abs,
            vertical_facs=vertical_facs_abs,
            r=r,  # Dimension of the reduced subspace
            Bs=Bs,  # List of projection matrices for each variable
            ks=ks,  # Offsets for each variable)
        )
   
        return abstract_graph


    def run_gbp(self, num_iters=20):
        # Call your GBP solver on self.graph
        pass

    def build_super_graph(self, PrevLayer):
        # Construct supergraph from previous layer's graph
        # This could involve aggregating variable nodes and factors
        pass    

    def build_abstract_graph(self, supergraph):
        # Create abstract graph from supergraph
        # This could involve SVD or PCA to reduce dimensions
        pass


    def propagate_to_finer(self, finer_layer):
        # Optional: propagate mu or corrections back to finer layer
        pass

    def stack_variable_nodes(self):
        # Collect 4 fine variable mus into one 8D vector
        pass

    def compute_svd_embeddings(self, supernodes):
        # Use SVD or PCA to map each 8D vector into 2D
        pass

    def build_graph_from_embeddings(self, embeddings):
        # Connect adjacent nodes in grid pattern, construct new graph
        pass

    def create_supernode_mapping(self):
        # Track which fine nodes belong to which coarse node
        pass


In [None]:
class FactorGraph:
    def __init__(self, variable_nodes, factors):
        self.variable_nodes = variable_nodes  # List of variable nodes
        self.factors = factors  # List of factor nodes

    def build_graph(self):
        # Construct the internal graph structure from variable nodes and factors
        pass

    def add_factor(self, factor):
        # Add a new factor to the graph
        pass

    def update_variable(self, var_node, new_value):
        # Update a variable node's value
        pass

    def run_gbp(self, num_iters=20):
        # Run Generalized Belief Propagation on the graph
        pass
