In [None]:
# Working BP Visualizer with Proper Dynamic Updates
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import ipywidgets as widgets
from IPython.display import display, clear_output
import sys

# Add the parent directory to Python path
sys.path.append('..')

# Test imports first
try:
    from bp_base.factor_graph import FactorGraph
    from base_all.agents import VariableAgent, FactorAgent
    from base_all.components import Message
    from bp_base.bp_computators import MinSumComputator
    from base_all.bp_engine_base import BPEngine
    print("✓ All imports successful!")
except ImportError as e:
    print(f"✗ Import error: {e}")

def create_simple_factor_graph():
    """Create a very simple 2-variable factor graph for testing."""
    print("Creating simple factor graph...")

    # Create variables
    x1 = VariableAgent("x1", domain=2)
    x2 = VariableAgent("x2", domain=2)

    # Create factor with cost creation function (not pre-made table)
    def simple_cost_func(n_vars, domain, **kwargs):
        """Simple cost function that returns a 2x2 matrix."""
        return np.array([[1.0, 3.0],
                        [2.0, 0.5]])

    f12 = FactorAgent("f12", domain=2, ct_creation_func=simple_cost_func, param={})

    # Define edges
    edges = {f12: [x1, x2]}

    # Create factor graph - this will call initiate_cost_table()
    fg = FactorGraph(
        variable_li=[x1, x2],
        factor_li=[f12],
        edges=edges
    )

    print(f"✓ Factor graph created with {len(fg.variables)} variables and {len(fg.factors)} factors")
    print(f"✓ Cost table shape: {fg.factors[0].cost_table.shape}")
    return fg

def create_cycle_factor_graph(n_vars=4, domain=3):
    """Create a cycle factor graph with n_vars variables."""
    print(f"Creating cycle factor graph with {n_vars} variables...")

    # Create variables
    variables = [VariableAgent(f"x{i}", domain=domain) for i in range(n_vars)]

    # Create factors connecting adjacent variables in a cycle
    factors = []
    edges = {}

    for i in range(n_vars):
        j = (i + 1) % n_vars  # Next variable in cycle

        # Create different cost functions for each factor
        def make_cost_func(idx):
            def cost_func(n_vars, domain, **kwargs):
                """Cost function that creates different preferences."""
                np.random.seed(idx)  # For reproducibility
                costs = np.random.rand(domain, domain) * 5
                # Add some structure - prefer diagonal or off-diagonal
                if idx % 2 == 0:
                    # Prefer same values (diagonal)
                    for k in range(domain):
                        costs[k, k] *= 0.3
                else:
                    # Prefer different values (off-diagonal)
                    for k in range(domain):
                        costs[k, (k+1) % domain] *= 0.3
                return costs
            return cost_func

        factor_name = f"f{i}{j}"
        factor = FactorAgent(factor_name, domain=domain,
                           ct_creation_func=make_cost_func(i), param={})
        factors.append(factor)
        edges[factor] = [variables[i], variables[j]]

    # Create factor graph
    fg = FactorGraph(
        variable_li=variables,
        factor_li=factors,
        edges=edges
    )

    print(f"✓ Cycle graph created with {len(fg.variables)} variables and {len(fg.factors)} factors")
    return fg

class CycleBPVisualizer:
    def __init__(self, n_vars=4, domain=3):
        # Create a cycle factor graph
        self.fg = create_cycle_factor_graph(n_vars=n_vars, domain=domain)
        self.engine = BPEngine(self.fg, computator=MinSumComputator())

        # Simple state tracking
        self.step = 0
        self.max_steps = 20  # More steps for cycle convergence

        # Create output widget for dynamic updates
        self.output = widgets.Output()

        # Store layout positions for consistent visualization
        self.pos = self._create_circular_layout()

        # Initialize with zero messages
        print("Initializing zero messages...")
        for var in self.fg.variables:
            for neighbor in self.fg.G.neighbors(var):
                if isinstance(neighbor, FactorAgent):
                    zero_msg = Message(
                        data=np.zeros(var.domain),
                        sender=neighbor,
                        recipient=var
                    )
                    var.mailer.receive_messages(zero_msg)

        print("✓ Cycle visualizer initialized successfully!")

    def _create_circular_layout(self):
        """Create a nice circular layout for the cycle graph."""
        pos = {}
        n_vars = len(self.fg.variables)
        n_factors = len(self.fg.factors)
        total_nodes = n_vars + n_factors

        # Place variables on outer circle
        for i, var in enumerate(self.fg.variables):
            angle = 2 * np.pi * i / n_vars
            pos[var] = (1.5 * np.cos(angle), 1.5 * np.sin(angle))

        # Place factors between variables on inner circle
        for i, factor in enumerate(self.fg.factors):
            # Find which variables this factor connects
            connected_vars = list(self.fg.G.neighbors(factor))
            if len(connected_vars) == 2:
                # Place factor between its two variables
                pos1 = pos[connected_vars[0]]
                pos2 = pos[connected_vars[1]]
                factor_pos = ((pos1[0] + pos2[0])/2 * 0.6,
                             (pos1[1] + pos2[1])/2 * 0.6)
                pos[factor] = factor_pos

        return pos

    def get_beliefs(self):
        """Get current beliefs for all variables."""
        beliefs = {}
        for var in self.fg.variables:
            if hasattr(var.mailer, 'inbox') and var.mailer.inbox:
                belief = np.zeros(var.domain)
                for msg in var.mailer.inbox:
                    belief += msg.data
            else:
                belief = np.zeros(var.domain)
            beliefs[var.name] = belief
        return beliefs

    def step_algorithm(self):
        """Run one step of BP algorithm."""
        print(f"\n=== STEP {self.step + 1} ===")

        if self.step >= self.max_steps:
            print("Maximum steps reached!")
            return

        if self.step % 2 == 0:
            # Variable phase
            print("Variables computing and sending messages...")
            for var in self.fg.variables:
                if hasattr(var, 'compute_messages'):
                    var.compute_messages()
                if hasattr(var.mailer, 'send'):
                    var.mailer.send()
                var.empty_mailbox()
                if hasattr(var.mailer, 'prepare'):
                    var.mailer.prepare()
        else:
            # Factor phase
            print("Factors computing and sending messages...")
            for factor in self.fg.factors:
                if hasattr(factor, 'compute_messages'):
                    factor.compute_messages()
                if hasattr(factor.mailer, 'send'):
                    factor.mailer.send()
                factor.empty_mailbox()
                if hasattr(factor.mailer, 'prepare'):
                    factor.mailer.prepare()

        self.step += 1
        self.update_display()

    def reset_algorithm(self):
        """Reset to initial state."""
        print("\n=== RESET ===")
        self.step = 0

        # Clear all messages
        for node in self.fg.G.nodes():
            node.empty_mailbox()
            node.empty_outgoing()
            if hasattr(node.mailer, '_incoming'):
                node.mailer._incoming.clear()
            if hasattr(node.mailer, '_outgoing'):
                node.mailer._outgoing.clear()

        # Reinitialize zero messages
        for var in self.fg.variables:
            for neighbor in self.fg.G.neighbors(var):
                if isinstance(neighbor, FactorAgent):
                    zero_msg = Message(
                        data=np.zeros(var.domain),
                        sender=neighbor,
                        recipient=var
                    )
                    var.mailer.receive_messages(zero_msg)

        print("Reset complete!")
        self.update_display()

    def update_display(self):
        """Update the visualization."""
        with self.output:
            clear_output(wait=True)

            # Create larger figure for cycle
            fig = plt.figure(figsize=(18, 12))

            # Create grid layout
            gs = fig.add_gridspec(3, 3, height_ratios=[2, 1, 1], width_ratios=[1, 1, 1])

            # Main graph takes up top 2/3
            ax_graph = fig.add_subplot(gs[0, :])

            # Beliefs in middle row
            ax_beliefs = fig.add_subplot(gs[1, :])

            # Cost tables in bottom row
            ax_tables = [fig.add_subplot(gs[2, i]) for i in range(3)]

            # 1. Draw factor graph
            var_nodes = [n for n in self.fg.G.nodes() if isinstance(n, VariableAgent)]
            factor_nodes = [n for n in self.fg.G.nodes() if isinstance(n, FactorAgent)]

            # Draw edges
            nx.draw_networkx_edges(self.fg.G, self.pos, ax=ax_graph, alpha=0.3, width=2)

            # Draw nodes
            nx.draw_networkx_nodes(self.fg.G, self.pos, nodelist=var_nodes,
                                  node_color='lightblue', node_shape='o',
                                  node_size=1200, ax=ax_graph)
            nx.draw_networkx_nodes(self.fg.G, self.pos, nodelist=factor_nodes,
                                  node_color='lightgreen', node_shape='s',
                                  node_size=800, ax=ax_graph)

            # Draw labels
            nx.draw_networkx_labels(self.fg.G, self.pos, ax=ax_graph, font_size=12)

            # Draw active messages as arrows
            for node in self.fg.G.nodes():
                if hasattr(node.mailer, '_outgoing') and node.mailer._outgoing:
                    for msg in node.mailer._outgoing:
                        start = self.pos[msg.sender]
                        end = self.pos[msg.recipient]
                        # Offset arrow slightly to avoid overlap
                        dx = end[0] - start[0]
                        dy = end[1] - start[1]
                        norm = np.sqrt(dx**2 + dy**2)
                        start_offset = (start[0] + 0.15*dx/norm, start[1] + 0.15*dy/norm)
                        end_offset = (end[0] - 0.15*dx/norm, end[1] - 0.15*dy/norm)
                        ax_graph.annotate('', xy=end_offset, xytext=start_offset,
                                        arrowprops=dict(arrowstyle='->', lw=2.5,
                                                      color='red', alpha=0.8))

            ax_graph.set_title(f"Cycle Factor Graph - Step {self.step}", fontsize=16)
            ax_graph.axis('equal')
            ax_graph.axis('off')
            ax_graph.set_xlim(-2, 2)
            ax_graph.set_ylim(-2, 2)

            # 2. Draw beliefs for all variables
            beliefs = self.get_beliefs()
            n_vars = len(self.fg.variables)
            domain_size = self.fg.variables[0].domain

            if beliefs and n_vars > 0:
                bar_width = 0.8 / n_vars
                x_base = np.arange(domain_size)
                colors = plt.cm.tab10(np.linspace(0, 1, n_vars))

                for i, (var_name, belief) in enumerate(sorted(beliefs.items())):
                    x_pos = x_base + (i - n_vars/2 + 0.5) * bar_width

                    bars = ax_beliefs.bar(x_pos, belief, bar_width,
                                        label=var_name, color=colors[i], alpha=0.7)

                    # Highlight minimum
                    if np.any(belief != 0):
                        min_idx = np.argmin(belief)
                        bars[min_idx].set_edgecolor('black')
                        bars[min_idx].set_linewidth(3)

                ax_beliefs.set_xlabel("Domain Value")
                ax_beliefs.set_ylabel("Belief (Cost)")
                ax_beliefs.set_title("Variable Beliefs")
                ax_beliefs.legend(ncol=n_vars)
                ax_beliefs.grid(True, alpha=0.3)
                ax_beliefs.set_xticks(x_base)

            # 3. Show cost tables for first few factors
            for idx, (ax, factor) in enumerate(zip(ax_tables, self.fg.factors[:3])):
                if factor.cost_table is not None and factor.cost_table.ndim == 2:
                    ct = factor.cost_table
                    im = ax.imshow(ct, cmap='YlOrRd', aspect='auto')

                    # Add text annotations
                    for i in range(ct.shape[0]):
                        for j in range(ct.shape[1]):
                            color = 'white' if ct[i, j] > np.mean(ct) else 'black'
                            ax.text(j, i, f'{ct[i, j]:.1f}', ha="center", va="center",
                                   color=color, fontsize=9)

                    ax.set_title(f"{factor.name}", fontsize=10)

                    # Get variable names for labels
                    connected_vars = list(self.fg.G.neighbors(factor))
                    if len(connected_vars) >= 2:
                        var_names = sorted([v.name for v in connected_vars])
                        ax.set_xlabel(var_names[1], fontsize=9)
                        ax.set_ylabel(var_names[0], fontsize=9)

                    ax.set_xticks(range(ct.shape[1]))
                    ax.set_yticks(range(ct.shape[0]))

            # Add info text
            info_text = f"Step: {self.step} | "
            info_text += f"Phase: {'Variables → Factors' if self.step % 2 == 0 else 'Factors → Variables'} | "
            info_text += f"Global Cost: {self.fg.global_cost:.2f}"
            fig.text(0.5, 0.02, info_text, ha='center', fontsize=12,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

            plt.tight_layout()
            plt.show()

    def run(self):
        """Run the visualizer with controls."""
        # Display the output widget first
        display(self.output)

        # Create control buttons
        btn_step = widgets.Button(description="Next Step", button_style='primary',
                                 layout=widgets.Layout(width='120px'))
        btn_reset = widgets.Button(description="Reset", button_style='danger',
                                  layout=widgets.Layout(width='120px'))
        btn_auto = widgets.Button(description="Auto Run", button_style='success',
                                 layout=widgets.Layout(width='120px'))

        # Auto-run state
        self.auto_running = False

        def on_step_click(b):
            self.step_algorithm()

        def on_reset_click(b):
            self.auto_running = False
            btn_auto.description = "Auto Run"
            self.reset_algorithm()

        def on_auto_click(b):
            import time
            self.auto_running = not self.auto_running
            btn_auto.description = "Stop" if self.auto_running else "Auto Run"

            while self.auto_running and self.step < self.max_steps:
                self.step_algorithm()
                time.sleep(0.5)  # Delay between steps
                if self.step >= self.max_steps:
                    self.auto_running = False
                    btn_auto.description = "Auto Run"

        btn_step.on_click(on_step_click)
        btn_reset.on_click(on_reset_click)
        btn_auto.on_click(on_auto_click)

        # Display controls
        controls = widgets.HBox([btn_step, btn_auto, btn_reset],
                               layout=widgets.Layout(justify_content='center'))
        display(controls)

        # Show initial state
        print("\n=== INITIAL STATE ===")
        self.update_display()

# Create and run the visualizer
print("=== BP Visualizer for Cycle Graphs ===")
print("\nChoose visualization:")
print("1. Simple 2-variable graph (for testing)")
print("2. 4-variable cycle graph")
print("3. 6-variable cycle graph")

# For Jupyter notebooks, you can uncomment the one you want:

# Option 1: Simple 2-variable graph
# viz = SimpleBPVisualizer()
# viz.run()

# Option 2: 4-variable cycle (default)
print("\n✓ Creating 4-variable cycle visualizer...")
viz = CycleBPVisualizer(n_vars=4, domain=3)
viz.run()

print("\n✓ Cycle visualization ready!")
print("🔄 The cycle graph shows belief propagation in a loop")
print("📊 Cost tables: Diagonal preference = same values, Off-diagonal = different values")
print("➡️  Red arrows show active messages being passed")
print("📈 Watch how beliefs oscillate and then converge!")
print("🎯 Global cost shows the total cost of current assignments")
print("\n💡 BP in cycles:")
print("   - Messages loop around the cycle")
print("   - May oscillate without damping")
print("   - Convergence takes more iterations than trees")
print("   - Try 'Auto Run' to see the full convergence!")

# Option 3: Larger cycle
# viz_large = CycleBPVisualizer(n_vars=6, domain=3)
# viz_large.run()