In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
from ipywidgets import interact, FloatSlider, VBox, HBox, Output
import ipywidgets as widgets
from IPython.display import display, clear_output
%matplotlib inline

plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'DejaVu Sans'

class XORVisualization:
    def __init__(self):
        # Initial weights setup
        self.w_h1 = [1, 1, -1.5]  # First hidden layer neuron weights [w1, w2, b]
        self.w_h2 = [1, 1, -0.5]  # Second hidden layer neuron weights [w1, w2, b]
        self.w_out = [1, -2, -0.5]  # Output layer neuron weights [w1, w2, b]

        # Input data (XOR problem)
        self.X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
        self.y = np.array([0, 1, 1, 0])

        self.output = Output()

    def activate(self, x, weights):
        """Simple step activation function (threshold 0)"""
        weighted_sum = weights[0] * x[0] + weights[1] * x[1] + weights[2]
        return 1 if weighted_sum > 0 else 0

    def predict(self, x):
        """MLP prediction function"""
        h1 = self.activate(x, self.w_h1)
        h2 = self.activate(x, self.w_h2)
        return self.activate([h1, h2], self.w_out)

    def compute_grid(self):
        """Calculate grid points for visualization"""
        resolution = 100
        x1_grid = np.linspace(-0.5, 1.5, resolution)
        x2_grid = np.linspace(-0.5, 1.5, resolution)
        X1, X2 = np.meshgrid(x1_grid, x2_grid)
        Z = np.zeros_like(X1)

        # Predict for each grid point
        for i in range(resolution):
            for j in range(resolution):
                Z[i, j] = self.predict([X1[i, j], X2[i, j]])

        return X1, X2, Z

    def create_mlp_diagram(self, ax):
        """Draw MLP network diagram"""
        # Layers and node positions
        layer_positions = [0.1, 0.5, 0.9]
        input_nodes = [0.3, 0.7]
        hidden_nodes = [0.3, 0.7]
        output_node = [0.5]

        # Draw nodes
        # Input layer
        for pos in input_nodes:
            circle = plt.Circle((layer_positions[0], pos), 0.05, fill=False, color='blue')
            ax.add_artist(circle)

        # Hidden layer
        for pos in hidden_nodes:
            circle = plt.Circle((layer_positions[1], pos), 0.05, fill=False, color='green')
            ax.add_artist(circle)

        # Output layer
        circle = plt.Circle((layer_positions[2], output_node[0]), 0.05, fill=False, color='red')
        ax.add_artist(circle)

        # Draw connections
        for i_pos in input_nodes:
            for h_pos in hidden_nodes:
                ax.plot([layer_positions[0], layer_positions[1]], [i_pos, h_pos], 'gray', alpha=0.5)

        for h_pos in hidden_nodes:
            ax.plot([layer_positions[1], layer_positions[2]], [h_pos, output_node[0]], 'gray', alpha=0.5)

        # Add labels
        ax.text(layer_positions[0]-0.05, 0.3+0.06, 'X₁', fontsize=10)
        ax.text(layer_positions[0]-0.05, 0.7+0.06, 'X₂', fontsize=10)
        ax.text(layer_positions[1]-0.05, 0.3+0.06, 'H₁', fontsize=10)
        ax.text(layer_positions[1]-0.05, 0.7+0.06, 'H₂', fontsize=10)
        ax.text(layer_positions[2]-0.05, 0.5+0.06, 'Y', fontsize=10)

        # Weight labels
        ax.text(0.3, 0.52, f'w₁₁: {self.w_h1[0]}', fontsize=8)
        ax.text(0.3, 0.48, f'w₂₁: {self.w_h1[1]}', fontsize=8)
        ax.text(0.3, 0.42, f'w₁₂: {self.w_h2[0]}', fontsize=8)
        ax.text(0.3, 0.38, f'w₂₂: {self.w_h2[1]}', fontsize=8)

        ax.text(0.7, 0.52, f'v₁: {self.w_out[0]}', fontsize=8)
        ax.text(0.7, 0.48, f'v₂: {self.w_out[1]}', fontsize=8)

        # Bias labels
        ax.text(0.5, 0.8, f'b₁: {self.w_h1[2]}', fontsize=8)
        ax.text(0.5, 0.2, f'b₂: {self.w_h2[2]}', fontsize=8)
        ax.text(0.9, 0.4, f'b₃: {self.w_out[2]}', fontsize=8)

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_aspect('equal')
        ax.axis('off')
        ax.set_title('MLP Network Structure')

    def update_plot(self):
        """Update visualization"""
        with self.output:
            clear_output(wait=True)

            # Create subplots
            fig = plt.figure(figsize=(15, 7))
            gs = fig.add_gridspec(1, 2)

            # Draw XOR decision boundary
            ax1 = fig.add_subplot(gs[0, 0])
            X1, X2, Z = self.compute_grid()
            ax1.contourf(X1, X2, Z, cmap='coolwarm', alpha=0.8)

            # Plot actual XOR points
            for i, (x, target) in enumerate(zip(self.X, self.y)):
                color = 'blue' if target == 1 else 'red'
                ax1.scatter(x[0], x[1], color=color, s=100, edgecolor='black')
                ax1.text(x[0]+0.05, x[1]+0.05, f'({x[0]},{x[1]})→{target}', fontsize=10)

            ax1.set_xlim(-0.1, 1.1)
            ax1.set_ylim(-0.1, 1.1)
            ax1.set_xlabel('X₁')
            ax1.set_ylabel('X₂')
            ax1.set_title('XOR Problem Decision Boundary')
            ax1.grid(True, alpha=0.3)

            # Add legend
            ax1.scatter([], [], color='blue', s=50, edgecolor='black', label='Class 1')
            ax1.scatter([], [], color='red', s=50, edgecolor='black', label='Class 0')
            ax1.legend()

            # Draw MLP diagram
            ax2 = fig.add_subplot(gs[0, 1])
            self.create_mlp_diagram(ax2)

            # Show results
            results = []
            all_correct = True
            for x, target in zip(self.X, self.y):
                pred = self.predict(x)
                results.append(pred)
                if pred != target:
                    all_correct = False

            status = "Success! All XOR points classified correctly." if all_correct else "Some points are still misclassified."
            ax1.set_title(f'XOR Problem Decision Boundary - {status}')

            plt.tight_layout()
            plt.show()

    def setup_interactive_widgets(self):
        """Set up interactive widgets"""
        h1_w1_slider = FloatSlider(
            value=self.w_h1[0], min=-3, max=3, step=0.1,
            description='H₁ w₁:'
        )
        h1_w2_slider = FloatSlider(
            value=self.w_h1[1], min=-3, max=3, step=0.1,
            description='H₁ w₂:'
        )
        h1_bias_slider = FloatSlider(
            value=self.w_h1[2], min=-3, max=3, step=0.1,
            description='H₁ bias:'
        )

        h2_w1_slider = FloatSlider(
            value=self.w_h2[0], min=-3, max=3, step=0.1,
            description='H₂ w₁:'
        )
        h2_w2_slider = FloatSlider(
            value=self.w_h2[1], min=-3, max=3, step=0.1,
            description='H₂ w₂:'
        )
        h2_bias_slider = FloatSlider(
            value=self.w_h2[2], min=-3, max=3, step=0.1,
            description='H₂ bias:'
        )

        out_w1_slider = FloatSlider(
            value=self.w_out[0], min=-3, max=3, step=0.1,
            description='Output w₁:'
        )
        out_w2_slider = FloatSlider(
            value=self.w_out[1], min=-3, max=3, step=0.1,
            description='Output w₂:'
        )
        out_bias_slider = FloatSlider(
            value=self.w_out[2], min=-3, max=3, step=0.1,
            description='Output bias:'
        )

        # Slider change event handler
        def on_change(change):
            self.w_h1 = [h1_w1_slider.value, h1_w2_slider.value, h1_bias_slider.value]
            self.w_h2 = [h2_w1_slider.value, h2_w2_slider.value, h2_bias_slider.value]
            self.w_out = [out_w1_slider.value, out_w2_slider.value, out_bias_slider.value]
            self.update_plot()

        h1_w1_slider.observe(on_change, names='value')
        h1_w2_slider.observe(on_change, names='value')
        h1_bias_slider.observe(on_change, names='value')
        h2_w1_slider.observe(on_change, names='value')
        h2_w2_slider.observe(on_change, names='value')
        h2_bias_slider.observe(on_change, names='value')
        out_w1_slider.observe(on_change, names='value')
        out_w2_slider.observe(on_change, names='value')
        out_bias_slider.observe(on_change, names='value')

        h1_controls = VBox([h1_w1_slider, h1_w2_slider, h1_bias_slider])
        h2_controls = VBox([h2_w1_slider, h2_w2_slider, h2_bias_slider])
        out_controls = VBox([out_w1_slider, out_w2_slider, out_bias_slider])

        controls = HBox([h1_controls, h2_controls, out_controls])

        # Solution button
        solution_button = widgets.Button(
            description='Show XOR Solution',
            button_style='success'
        )

        def on_solution_click(b):
            # Set weights for XOR problem solution
            self.w_h1 = [1, 1, -1.5]  # Similar to AND gate
            self.w_h2 = [1, 1, -0.5]  # Similar to OR gate
            self.w_out = [1, -2, -0.5]  # Final XOR logic

            # Update slider values
            h1_w1_slider.value = self.w_h1[0]
            h1_w2_slider.value = self.w_h1[1]
            h1_bias_slider.value = self.w_h1[2]
            h2_w1_slider.value = self.w_h2[0]
            h2_w2_slider.value = self.w_h2[1]
            h2_bias_slider.value = self.w_h2[2]
            out_w1_slider.value = self.w_out[0]
            out_w2_slider.value = self.w_out[1]
            out_bias_slider.value = self.w_out[2]

            self.update_plot()

        solution_button.on_click(on_solution_click)

        # Random weights button
        random_button = widgets.Button(
            description='Random Weights',
            button_style='info'
        )

        def on_random_click(b):
            # Set random weights
            self.w_h1 = list(np.random.uniform(-2, 2, 3))
            self.w_h2 = list(np.random.uniform(-2, 2, 3))
            self.w_out = list(np.random.uniform(-2, 2, 3))

            # Update slider values
            h1_w1_slider.value = self.w_h1[0]
            h1_w2_slider.value = self.w_h1[1]
            h1_bias_slider.value = self.w_h1[2]
            h2_w1_slider.value = self.w_h2[0]
            h2_w2_slider.value = self.w_h2[1]
            h2_bias_slider.value = self.w_h2[2]
            out_w1_slider.value = self.w_out[0]
            out_w2_slider.value = self.w_out[1]
            out_bias_slider.value = self.w_out[2]

            self.update_plot()

        random_button.on_click(on_random_click)

        buttons = HBox([solution_button, random_button])

        # Display
        display(widgets.HTML("<h2>Solving XOR Problem with Multilayer Perceptron (MLP)</h2>"))
        display(widgets.HTML("<p>Adjust the sliders to change each neuron's weights and biases to solve the XOR problem.</p>"))
        display(controls)
        display(buttons)
        display(self.output)

        # Show initial plot
        self.update_plot()

# Run visualization
viz = XORVisualization()
viz.setup_interactive_widgets()

HTML(value='<h2>Solving XOR Problem with Multilayer Perceptron (MLP)</h2>')

HTML(value="<p>Adjust the sliders to change each neuron's weights and biases to solve the XOR problem.</p>")

HBox(children=(VBox(children=(FloatSlider(value=1.0, description='H₁ w₁:', max=3.0, min=-3.0), FloatSlider(val…

HBox(children=(Button(button_style='success', description='Show XOR Solution', style=ButtonStyle()), Button(bu…

Output()