In [1]:
# Simplified BP Message Passing Visualizer with R Message Details
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import networkx as nx
from IPython.display import display, clear_output
import ipywidgets as widgets
import time

class SimpleBPVisualizer:
    """Simplified step-by-step BP visualizer focusing on message flow."""

    def __init__(self, engine):
        self.engine = engine
        self.graph = engine.graph
        self.step_count = 0
        self.show_r_details = True
        self.show_cost_tables = True

        # Create layout
        self.pos = nx.bipartite_layout(
            self.graph.G,
            nodes=self.graph.variables,
            scale=2.0,
            center=[0, 0]
        )

        # UI elements
        self.output = widgets.Output()
        self.btn_step = widgets.Button(description="Next Step", button_style='primary')
        self.btn_reset = widgets.Button(description="Reset", button_style='danger')
        self.chk_r_details = widgets.Checkbox(value=True, description='Show R Message Details')
        self.chk_cost_tables = widgets.Checkbox(value=True, description='Show Cost Tables')
        self.info_label = widgets.Label(value="Step: 0")

        self.btn_step.on_click(self.on_step)
        self.btn_reset.on_click(self.on_reset)
        self.chk_r_details.observe(self.on_toggle_details, 'value')
        self.chk_cost_tables.observe(self.on_toggle_tables, 'value')

        self.controls = widgets.VBox([
            widgets.HBox([self.btn_step, self.btn_reset, self.info_label]),
            widgets.HBox([self.chk_r_details, self.chk_cost_tables])
        ])

    def visualize_state(self):
        """Visualize current state of the factor graph."""
        with self.output:
            clear_output(wait=True)

            # Dynamic layout based on what's shown
            show_r = self.show_r_details and any(f.mailer.outbox for f in self.graph.factors)
            show_tables = self.show_cost_tables

            # Create figure with explicit constrained_layout
            if show_r or show_tables:
                fig = plt.figure(figsize=(18, 12))
                fig.set_constrained_layout(True)
                gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3, height_ratios=[1, 1, 1])
                ax1 = fig.add_subplot(gs[0, 0])
                ax2 = fig.add_subplot(gs[0, 1])
                ax3 = fig.add_subplot(gs[1, :]) if show_tables else None
                ax4 = fig.add_subplot(gs[2, :]) if show_r else None
            else:
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
                fig.set_constrained_layout(True)
                ax3 = None
                ax4 = None

            # Draw factor graph
            self.draw_graph_with_messages(ax1)

            # Draw beliefs
            self.draw_beliefs(ax2)

            # Draw cost tables
            if ax3 and self.show_cost_tables:
                self.draw_cost_tables(ax3)

            # Draw R message computation details
            if ax4 and self.show_r_details:
                self.draw_r_message_details(ax4)

            plt.show()

    def draw_cost_tables(self, ax):
        """Draw all factor cost tables."""
        factors = sorted(self.graph.factors, key=lambda f: f.name)
        if not factors:
            ax.text(0.5, 0.5, "No factors", ha='center', va='center', transform=ax.transAxes)
            ax.axis('off')
            return

        # Get axis position
        ax_pos = ax.get_position()
        fig = ax.figure

        # Clear main axis
        ax.clear()
        ax.axis('off')

        n_factors = len(factors)
        cols = min(4, n_factors)
        rows = (n_factors + cols - 1) // cols

        # Calculate subplot dimensions
        width = (ax_pos.x1 - ax_pos.x0) / cols * 0.9
        height = ax_pos.height / rows * 0.85

        for idx, factor in enumerate(factors):
            row = idx // cols
            col = idx % cols

            # Calculate position for this subplot
            left = ax_pos.x0 + col * (ax_pos.width / cols)
            bottom = ax_pos.y1 - (row + 1) * (ax_pos.height / rows)

            # Create subplot axis
            sub_ax = fig.add_axes([left, bottom, width, height])

            if factor.cost_table is None:
                sub_ax.text(0.5, 0.5, "No cost table", ha='center', va='center')
                sub_ax.axis('off')
                continue

            ct = factor.cost_table

            if ct.ndim == 2:
                # 2D heatmap
                im = sub_ax.imshow(ct, cmap='YlOrRd', aspect='auto', interpolation='nearest')

                # Add values
                for i in range(ct.shape[0]):
                    for j in range(ct.shape[1]):
                        color = 'white' if ct[i, j] > ct.max()/2 else 'black'
                        sub_ax.text(j, i, f'{ct[i, j]:.1f}',
                                   ha="center", va="center", color=color, fontsize=9)

                # Labels
                var_names = list(factor.connection_number.keys())
                if len(var_names) >= 2:
                    sub_ax.set_xlabel(var_names[1], fontsize=9)
                    sub_ax.set_ylabel(var_names[0], fontsize=9)

                sub_ax.set_xticks(range(ct.shape[1]))
                sub_ax.set_yticks(range(ct.shape[0]))

                # Colorbar
                cbar = plt.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
                cbar.ax.tick_params(labelsize=7)

            elif ct.ndim == 1:
                # 1D bar chart
                bars = sub_ax.bar(range(len(ct)), ct, color='coral', edgecolor='darkred')
                sub_ax.set_xlabel("Value", fontsize=9)
                sub_ax.set_ylabel("Cost", fontsize=9)

                # Add value labels
                for i, bar in enumerate(bars):
                    height = bar.get_height()
                    sub_ax.text(bar.get_x() + bar.get_width()/2., height,
                               f'{ct[i]:.1f}', ha='center', va='bottom', fontsize=8)

            sub_ax.set_title(f"{factor.name}", fontsize=10, fontweight='bold')

        # Add overall title
        fig.text(ax_pos.x0 + ax_pos.width/2, ax_pos.y1 + 0.02,
                "Factor Cost Tables", fontsize=14, fontweight='bold', ha='center')

    def visualize_state(self):
        """Visualize current state of the factor graph."""
        with self.output:
            clear_output(wait=True)

            if self.show_r_details and any(f.mailer.outbox for f in self.graph.factors):
                fig = plt.figure(figsize=(16, 10))
                gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
                ax1 = fig.add_subplot(gs[0, 0])
                ax2 = fig.add_subplot(gs[0, 1])
                ax3 = fig.add_subplot(gs[1, :])
            else:
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
                ax3 = None

            # Draw factor graph
            self.draw_graph_with_messages(ax1)

            # Draw beliefs
            self.draw_beliefs(ax2)

            # Draw R message computation details
            if ax3 and self.show_r_details:
                self.draw_r_message_details(ax3)

            plt.tight_layout()
            plt.show()

    def draw_graph_with_messages(self, ax):
        """Draw graph with current messages highlighted."""
        # Draw edges
        nx.draw_networkx_edges(self.graph.G, self.pos, ax=ax, alpha=0.2)

        # Draw and label variable nodes
        var_nodes = self.graph.variables
        nx.draw_networkx_nodes(
            self.graph.G, self.pos, nodelist=var_nodes,
            node_shape='o', node_color='lightblue',
            node_size=1000, ax=ax
        )

        # Draw and label factor nodes
        factor_nodes = self.graph.factors
        nx.draw_networkx_nodes(
            self.graph.G, self.pos, nodelist=factor_nodes,
            node_shape='s', node_color='lightgreen',
            node_size=1000, ax=ax
        )

        # Labels
        labels = {n: n.name for n in self.graph.G.nodes()}
        nx.draw_networkx_labels(self.graph.G, self.pos, labels, ax=ax)

        # Draw messages as arrows
        for node in self.graph.G.nodes():
            for msg in node.mailer.outbox:
                start = self.pos[msg.sender]
                end = self.pos[msg.recipient]

                # Draw arrow
                ax.annotate('', xy=end, xytext=start,
                           arrowprops=dict(arrowstyle='->', lw=2, color='red'))

                # Show message values (abbreviated)
                mid = [(start[0] + end[0])/2, (start[1] + end[1])/2]
                msg_text = f"{msg.data[:2].round(1)}..." if len(msg.data) > 2 else str(msg.data.round(1))
                ax.text(mid[0], mid[1], msg_text, fontsize=8,
                       bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))

        ax.set_title(f"Factor Graph - Step {self.step_count}", fontsize=14)
        ax.axis('equal')
        ax.axis('off')

    def draw_beliefs(self, ax):
        """Draw current beliefs for all variables."""
        n_vars = len(self.graph.variables)
        if n_vars == 0:
            return

        width = 0.8 / n_vars
        x_base = np.arange(self.graph.variables[0].domain)

        # Plot each variable's belief
        for i, var in enumerate(self.graph.variables):
            belief = var.belief
            x_offset = (i - n_vars/2) * width

            bars = ax.bar(x_base + x_offset, belief, width,
                          label=var.name, alpha=0.7)

            # Highlight max belief
            max_idx = np.argmax(belief)
            bars[max_idx].set_color('red')
            bars[max_idx].set_alpha(1.0)

        ax.set_xlabel("Domain Value")
        ax.set_ylabel("Belief")
        ax.set_title("Variable Beliefs")
        ax.legend()
        ax.set_xticks(x_base)

        # Show global cost
        cost = self.graph.global_cost
        ax.text(0.95, 0.95, f"Global Cost: {cost:.2f}",
                transform=ax.transAxes, ha='right', va='top',
                bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))

    def draw_r_message_details(self, ax):
        """Draw detailed R message computation."""
        # Check if we're in a state where R messages exist
        active_factors = [f for f in self.graph.factors if f.mailer.outbox]

        if not active_factors:
            ax.text(0.5, 0.5, "No R messages being computed",
                   ha='center', va='center', transform=ax.transAxes)
            ax.axis('off')
            return

        # Get axis position for proper layout
        ax_pos = ax.get_position()
        fig = ax.figure

        # Show computation for first active factor
        factor = active_factors[0]

        if factor.cost_table is None or factor.cost_table.ndim != 2:
            ax.text(0.5, 0.5, f"Factor {factor.name}: Cost table not available or not 2D",
                   ha='center', va='center', transform=ax.transAxes)
            ax.axis('off')
            return

        # Clear the axis but keep it for reference
        ax.clear()
        ax.axis('off')

        # Create subgrid for R message visualizations
        n_messages = min(3, len(factor.mailer.outbox))
        if n_messages == 0:
            return

        # Create new axes in the same area
        cols = min(3, n_messages)
        width = (ax_pos.x1 - ax_pos.x0) / cols

        for idx, msg in enumerate(factor.mailer.outbox[:n_messages]):
            # Create new axis for this message
            left = ax_pos.x0 + idx * width
            sub_ax = fig.add_axes([left, ax_pos.y0, width * 0.9, ax_pos.height])

            # Prepare augmented cost table
            augmented = factor.cost_table.copy()
            recipient_var = msg.recipient

            # Add incoming messages from OTHER variables
            for in_msg in factor.mailer.inbox:
                if in_msg.sender != recipient_var:
                    sender_dim = factor.connection_number.get(in_msg.sender.name)
                    if sender_dim is not None:
                        if sender_dim == 0:  # Rows
                            augmented += in_msg.data.reshape(-1, 1)
                        else:  # Columns
                            augmented += in_msg.data.reshape(1, -1)

            # Find recipient dimension
            recipient_dim = factor.connection_number.get(recipient_var.name)

            # Visualize augmented table
            im = sub_ax.imshow(augmented, cmap='YlOrRd', aspect='auto')

            # Compute and highlight minimums (for min-sum)
            if recipient_dim == 0:  # Keep rows, minimize over columns
                for i in range(augmented.shape[0]):
                    min_j = np.argmin(augmented[i, :])
                    rect = patches.Rectangle((min_j-0.4, i-0.4), 0.8, 0.8,
                                           linewidth=3, edgecolor='blue',
                                           facecolor='none')
                    sub_ax.add_patch(rect)
            else:  # Keep columns, minimize over rows
                for j in range(augmented.shape[1]):
                    min_i = np.argmin(augmented[:, j])
                    rect = patches.Rectangle((j-0.4, min_i-0.4), 0.8, 0.8,
                                           linewidth=3, edgecolor='blue',
                                           facecolor='none')
                    sub_ax.add_patch(rect)

            # Add value annotations
            for i in range(augmented.shape[0]):
                for j in range(augmented.shape[1]):
                    color = 'white' if augmented[i, j] > augmented.max()/2 else 'black'
                    sub_ax.text(j, i, f'{augmented[i, j]:.1f}',
                               ha="center", va="center", color=color, fontsize=8)

            # Show resulting message
            result_text = f"R→{recipient_var.name}: {msg.data.round(2)}"
            sub_ax.text(0.5, -0.15, result_text, transform=sub_ax.transAxes,
                       ha='center', va='top', fontweight='bold', color='blue',
                       bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

            # Labels
            var_names = list(factor.connection_number.keys())
            if len(var_names) >= 2:
                sub_ax.set_xlabel(var_names[1])
                sub_ax.set_ylabel(var_names[0])
            sub_ax.set_title(f"{factor.name} → {recipient_var.name}", fontsize=10)

            # Add minimal colorbar
            cbar = plt.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
            cbar.ax.tick_params(labelsize=7)

        # Add overall title
        fig.text(ax_pos.x0 + ax_pos.width/2, ax_pos.y1 + 0.02,
                f"R Message Computation: {factor.name}",
                fontsize=12, fontweight='bold', ha='center')

    def on_step(self, b):
        """Execute one BP step and visualize."""
        # Run one step
        self.engine.step(self.step_count)
        self.step_count += 1

        # Update info
        self.info_label.value = f"Step: {self.step_count}"

        # Visualize
        self.visualize_state()

    def on_reset(self, b):
        """Reset to initial state."""
        # Reset step counter
        self.step_count = 0
        self.info_label.value = "Step: 0"

        # Clear all messages and mailboxes
        for node in self.graph.G.nodes():
            node.empty_mailbox()
            node.empty_outgoing()
            if hasattr(node.mailer, '_incoming'):
                node.mailer._incoming.clear()
            if hasattr(node.mailer, '_outgoing'):
                node.mailer._outgoing.clear()

        # Reinitialize messages for variables
        for var in self.graph.variables:
            for neighbor in self.graph.G.neighbors(var):
                var.mailer.set_first_message(var, neighbor)

        # Reset engine history
        if hasattr(self.engine, 'history'):
            self.engine.history.cycles.clear()
            self.engine.history.beliefs.clear()
            self.engine.history.assignments.clear()
            self.engine.history.costs.clear()

        # Redraw
        self.visualize_state()

    def on_toggle_details(self, change):
        """Toggle R message detail display."""
        self.show_r_details = change['new']
        self.visualize_state()

    def on_toggle_tables(self, change):
        """Toggle cost table display."""
        self.show_cost_tables = change['new']
        self.visualize_state()

    def run(self):
        """Run the interactive visualizer."""
        display(self.controls)
        display(self.output)

        # Initial visualization
        self.visualize_state()


# Example: Quick visualization of BP on a simple graph
def demo_bp_visualization():
    """Demo with a simple 4-variable cycle graph."""
    from bp_base.agents import VariableAgent, FactorAgent
    from bp_base.factor_graph import FactorGraph
    from bp_base.computators import MinSumComputator
    from bp_base.bp_engine_base import BPEngine

    # Create a 4-variable cycle
    variables = [VariableAgent(f"x{i}", domain=3) for i in range(4)]

    # Create factors connecting adjacent variables with specific cost tables
    factors = []
    edges = {}

    # Define some interesting cost tables
    cost_tables = [
        np.array([[1, 4, 2], [3, 0, 5], [2, 3, 1]]),  # f01
        np.array([[2, 1, 3], [0, 4, 2], [5, 1, 3]]),  # f12
        np.array([[3, 2, 0], [1, 3, 4], [2, 5, 1]]),  # f23
        np.array([[0, 3, 2], [4, 1, 3], [2, 2, 4]])   # f30
    ]

    for i in range(4):
        j = (i + 1) % 4
        factor = FactorAgent(
            f"f{i}{j}",
            domain=3,
            ct_creation_func=lambda n, d, ct=cost_tables[i]: ct,
            param={}
        )
        factors.append(factor)
        edges[factor] = [variables[i], variables[j]]

    # Create factor graph
    fg = FactorGraph(variables, factors, edges)

    # Create engine
    engine = BPEngine(fg, computator=MinSumComputator())

    # Run visualizer
    viz = SimpleBPVisualizer(engine)
    viz.run()

# Alternative demo with a tree structure
def demo_tree_bp():
    """Demo with a simple tree structure."""
    from bp_base.agents import VariableAgent, FactorAgent
    from bp_base.factor_graph import FactorGraph
    from bp_base.computators import MinSumComputator
    from bp_base.bp_engine_base import BPEngine

    # Create variables
    x0 = VariableAgent("x0", domain=3)  # Root
    x1 = VariableAgent("x1", domain=3)  # Child 1
    x2 = VariableAgent("x2", domain=3)  # Child 2
    x3 = VariableAgent("x3", domain=3)  # Grandchild

    # Create factors with specific cost tables
    f01 = FactorAgent("f01", domain=3,
                      ct_creation_func=lambda n, d: np.array([[0, 2, 5], [3, 1, 4], [2, 3, 1]]),
                      param={})
    f02 = FactorAgent("f02", domain=3,
                      ct_creation_func=lambda n, d: np.array([[1, 3, 2], [4, 0, 3], [2, 5, 1]]),
                      param={})
    f13 = FactorAgent("f13", domain=3,
                      ct_creation_func=lambda n, d: np.array([[2, 1, 4], [0, 3, 2], [5, 1, 3]]),
                      param={})

    # Define edges
    edges = {
        f01: [x0, x1],
        f02: [x0, x2],
        f13: [x1, x3]
    }

    # Create factor graph
    fg = FactorGraph(
        variable_li=[x0, x1, x2, x3],
        factor_li=[f01, f02, f13],
        edges=edges
    )

    # Create engine
    engine = BPEngine(fg, computator=MinSumComputator())

    # Run visualizer
    viz = SimpleBPVisualizer(engine)
    viz.run()

# Run the demo
if __name__ == "__main__":
    demo_bp_visualization()

NetworkX version: 3.4.2
Attempting to load: C:\Users\Public\projects\Belief_propagation_simulator_\configs\factor_graphs\factor-graph-cycle-3-random_intlow1,high100-number5.pkl
File does not exist: C:\Users\Public\projects\Belief_propagation_simulator_\configs\factor_graphs\factor-graph-cycle-3-random_intlow1,high100-number5.pkl
Available factor graph files in C:\Users\Public\projects\Belief_propagation_simulator_\configs\factor_graphs:
  - factor-graph-cycle-3-random_intlow100,high2000.3-number0.pkl
Using first available file: C:\Users\Public\projects\Belief_propagation_simulator_\configs\factor_graphs\factor-graph-cycle-3-random_intlow100,high2000.3-number0.pkl
Graph loaded. Type: <class 'bp_base.factor_graph.FactorGraph'>

Factor graph details:
Variables: 3
Factors: 3
Graph nodes: 6
Graph edges: 6

First few nodes:
  - X1
  - X2
  - X3
  - F12
  - F23

First variable: x1, Domain: 3

First factor: f12
Cost table shape: (3, 3)

Repaired graph saved to: C:\Users\Public\projects\Belief_

VBox(children=(HBox(children=(Button(button_style='primary', description='Next Step', style=ButtonStyle()), Bu…

Output()