# Interactive Number Line and Binary Search Tree (BST) Visualization

- This Jupyter Notebook allows you to use a slider to select a number (up to `leafLimit`, default is 128). Once you select a number, it will generate a number line from 1 to the selected number, equally spaced. It will then construct a Binary Search Tree (BST) with the following rules:
- **Odd numbers** will always lie on the **y-axis** (x = 0).
- **Even numbers** will be placed higher on the y-axis, according to their position in a **perfect geometric binary search tree**.

The visualization will show the points on the number line and the corresponding BST layout.

In [1]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Set the maximum number of nodes allowed
leafLimit = 127

def generate_perfect_bst(n, show_extra_info, output_widget):
    """
    Generates and visualizes a perfect Binary Search Tree (BST).
    
    Parameters:
    - n: Number of included nodes in the BST.
    - show_extra_info: Boolean flag to show/hide extra node information.
    - output_widget: The Output widget where the plot and debug info will be displayed.
    """
    # Clear previous output
    clear_output(wait=True)
    
    # Initialize a list to hold debug messages
    debug_messages = []
    
    try:
        if n < 1:
            raise ValueError("Number of points must be at least 1.")
        
        # Determine the depth of the BST required to accommodate 'n' nodes
        depth = 1
        while (2**depth - 1) < n:
            depth += 1
        total_nodes = 2**depth - 1  # Total nodes in a perfect BST of this depth
        
        # Append key variables to debug messages
        debug_messages.append("----- Debug Information -----")
        debug_messages.append(f"Requested number of points (n): {n}")
        debug_messages.append(f"Calculated depth: {depth}")
        debug_messages.append(f"Total nodes in BST: {total_nodes}")
        
        # Assign node numbers from 1 to total_nodes
        nodes = list(range(1, total_nodes + 1))
        
        # Initialize the directed graph
        G = nx.DiGraph()
        node_positions = {}  # Dictionary to store node positions
        node_depths = {}     # Dictionary to store node depths
        parent_map = {}      # Dictionary to store parent of each node
        children_map = {}    # Dictionary to store children of each node
        
        def build_bst(sorted_nodes, current_depth, parent=None):
            """
            Recursively builds the BST by selecting the median as root.
            Assigns positions and parent/child relationships to each node for visualization.
            """
            if not sorted_nodes:
                return
            
            mid_index = len(sorted_nodes) // 2
            root = sorted_nodes[mid_index]
            G.add_node(root)
            node_depths[root] = current_depth
            
            # Assign x and y positions
            node_positions[root] = (root, -current_depth)  # y is negative for downward direction
            
            # Assign parent and add edge
            if parent is not None:
                parent_map[root] = parent
                children_map.setdefault(parent, []).append(root)
                G.add_edge(parent, root)
            else:
                parent_map[root] = None  # Root has no parent
            
            # Append node placement to debug messages
            debug_messages.append(f"Depth {current_depth}: Placing node {root} at position ({root}, {-current_depth})")
            
            # Recursively build left and right subtrees
            left_subtree = sorted_nodes[:mid_index]
            right_subtree = sorted_nodes[mid_index + 1:]
            
            if left_subtree:
                build_bst(left_subtree, current_depth + 1, parent=root)
            
            if right_subtree:
                build_bst(right_subtree, current_depth + 1, parent=root)
        
        # Start building the BST from the root
        build_bst(nodes, current_depth=0, parent=None)
        
        # Append total nodes and edges to debug messages
        debug_messages.append(f"Total nodes added to the graph: {G.number_of_nodes()}")
        debug_messages.append(f"Total edges added to the graph: {G.number_of_edges()}")
        
        # Define node groups
        included_nodes = list(range(1, n + 1))
        required_nodes = list(range(n + 1, total_nodes + 1))
        
        # Check node presence and append to debug messages
        missing_included = [node for node in included_nodes if node not in G.nodes]
        if missing_included:
            debug_messages.append(f"Warning: Included nodes missing in the graph: {missing_included}")
        else:
            debug_messages.append("All included nodes are present in the graph.")
        
        missing_required = [node for node in required_nodes if node not in G.nodes]
        if missing_required:
            debug_messages.append(f"Warning: Required nodes missing in the graph: {missing_required}")
        else:
            debug_messages.append("All required nodes are present in the graph.")
        
        # Create the plot
        fig, ax = plt.subplots(figsize=(16, 12))  # Increased figure size for better visibility
        
        # Draw edges with increased width and gray color
        try:
            nx.draw_networkx_edges(G, node_positions, edge_color='gray', width=2, ax=ax)
        except Exception as e:
            debug_messages.append(f"Error drawing edges: {e}")
        
        # Function to check if a number is prime
        def is_prime(n):
            if n <= 1:
                return False
            for i in range(2, int(np.sqrt(n)) + 1):
                if n % i == 0:
                    return False
            return True
        
        # Determine node colors based on primality
        node_colors = ['green' if is_prime(node) else 'blue' for node in included_nodes]

        # Draw included nodes with green for primes and blue for non-primes
        included = nx.draw_networkx_nodes(
            G, node_positions, nodelist=included_nodes, node_color=node_colors, node_size=800, label='Included Nodes', ax=ax
        )
        
        # Draw required nodes in gray
        if required_nodes:
            try:
                required = nx.draw_networkx_nodes(
                    G, node_positions, nodelist=required_nodes, node_color='gray', node_size=800, label='Required Nodes', ax=ax
                )
            except Exception as e:
                debug_messages.append(f"Error drawing required nodes: {e}")
        
        # Prepare main labels with only node value
        main_labels = {node: str(node) for node in included_nodes}
        
        # Draw main labels for included nodes
        try:
            label_objects = nx.draw_networkx_labels(
                G, node_positions, labels=main_labels, font_size=12, font_color='white',
                verticalalignment='center', horizontalalignment='center', ax=ax
            )
        except Exception as e:
            debug_messages.append(f"Error drawing main labels: {e}")
        
        # Prepare extra labels (Fraction, Binary, Reverse Binary)
        if show_extra_info:
            for node in included_nodes:
                depth_level = node_depths.get(node, 0)
                # Calculate fraction without simplification
                frac = f"{node}/{2**depth}"
                # Calculate binary and reverse binary
                binary_str = bin(node)[2:]  # Remove '0b' prefix
                reverse_binary = binary_str[::-1]
                # Position offset (closer to the node)
                x_offset = 0.2
                y_offset = 0.2
                # Prepare the multi-line extra label
                extra_info = f"Fraction: {frac}\nBinary: {binary_str}\nRev Bin: {reverse_binary}"
                # Add extra label as a separate text object
                ax.text(
                    node_positions[node][0] + x_offset, node_positions[node][1] + y_offset,
                    extra_info,
                    fontsize=10, color='black',  # Larger font size
                    bbox=dict(facecolor='white', alpha=0.6, boxstyle='round,pad=0.2'),
                    verticalalignment='bottom', horizontalalignment='left',
                    wrap=True
                )
        
        # Draw number line below the BST
        number_line_y = min([pos[1] for pos in node_positions.values()]) - 3  # Position below the lowest node
        try:
            ax.scatter(
                nodes, [number_line_y] * total_nodes, color='green', s=100, label='Number Line'
            )
            for node in nodes:
                color = 'blue' if node <= n else 'gray'
                ax.annotate(
                    str(node),
                    (node, number_line_y - 0.3),
                    fontsize=10,
                    ha='center',
                    va='top',
                    color=color
                )
        except Exception as e:
            debug_messages.append(f"Error drawing number line: {e}")
        
        # Configure grid and ticks
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.set_xticks(nodes)
        y_min = number_line_y - 2  # Adjust as needed
        y_max = 0  # Root is at y=0
        ax.set_yticks(range(int(y_min), int(y_max) + 1))
        
        # Set plot limits to encompass all nodes and edges
        ax.set_xlim(0, total_nodes + 1)
        ax.set_ylim(y_min, y_max + 1)
        
        # Set titles and labels with increased font sizes
        ax.set_title(f"Perfect Binary Search Tree (BST) for 1 to {n}", fontsize=20)
        ax.set_xlabel("Number Line Position", fontsize=16)
        ax.set_ylabel("Depth in BST", fontsize=16)
        
        # Add legend with increased font size
        ax.legend(loc='upper right', fontsize=12)
        
        # Adjust layout to prevent clipping of labels
        plt.tight_layout()
        
        # Display the plot within the output widget
        try:
            display(fig)
            plt.close(fig)  # Close the figure to prevent duplication in some environments
        except Exception as e:
            debug_messages.append(f"Error displaying plot: {e}")
    
    except Exception as e:
        debug_messages.append(f"Error: {e}")
    
    # Display debug messages
    for message in debug_messages:
        print(message)

# Create interactive widgets
slider = widgets.IntSlider(
    value=7,
    min=1,
    max=leafLimit,
    step=1,
    description='Number of Points:',
    continuous_update=False
)

checkbox = widgets.Checkbox(
    value=False,
    description='Show Extra Info',
    disabled=False,
    indent=False
)

# Create Increment and Decrement buttons
button_decrement = widgets.Button(
    description='⏪',  # Unicode for a left double arrow
    tooltip='Decrement by 1',
    icon='minus'
)

button_increment = widgets.Button(
    description='⏩',  # Unicode for a right double arrow
    tooltip='Increment by 1',
    icon='plus'
)

# Create an Output widget to display the plot and debug information
output = widgets.Output()

def update_plot(n, show_extra_info):
    """
    Updates the BST visualization based on the current slider and checkbox states.
    Also manages the disabled state of the increment and decrement buttons.
    """
    with output:
        generate_perfect_bst(n, show_extra_info, output_widget=output)
        # Update button states
        button_decrement.disabled = n <= slider.min
        button_increment.disabled = n >= slider.max

def on_slider_change(change):
    """
    Event handler for changes in the slider's value.
    """
    update_plot(change['new'], checkbox.value)

def on_checkbox_change(change):
    """
    Event handler for changes in the checkbox's value.
    """
    update_plot(slider.value, change['new'])

def on_decrement_clicked(b):
    """
    Event handler for the decrement button click.
    Decreases the slider's value by one, if possible.
    """
    if slider.value > slider.min:
        slider.value -= 1

def on_increment_clicked(b):
    """
    Event handler for the increment button click.
    Increases the slider's value by one, if possible.
    """
    if slider.value < slider.max:
        slider.value += 1

# Link widgets to their respective handlers
slider.observe(on_slider_change, names='value')
checkbox.observe(on_checkbox_change, names='value')
button_decrement.on_click(on_decrement_clicked)
button_increment.on_click(on_increment_clicked)

# Initial render of the BST visualization
with output:
    generate_perfect_bst(slider.value, checkbox.value, output_widget=output)
    # Set initial button states
    button_decrement.disabled = slider.value <= slider.min
    button_increment.disabled = slider.value >= slider.max

# Arrange the slider and buttons horizontally
slider_buttons = widgets.HBox([button_decrement, slider, button_increment])

# Display the widgets and output
display(widgets.VBox([slider_buttons, checkbox, output]))

VBox(children=(HBox(children=(Button(description='⏪', icon='minus', style=ButtonStyle(), tooltip='Decrement by…