In [1]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, FloatSlider, VBox, HBox, Layout
from IPython.display import display

In [2]:
def sigmoid(x: float) -> float:
    e = np.exp(-x)
    return 1 / (1 + e)

In [3]:
def swish(x: float, beta: float) -> float:
    return x + sigmoid(beta * x)

In [4]:
def swiglu(x: float, w: float, b: float, v: float, c: float, beta: float) -> float:
    """
    Swish(x) = x * sigmoid(ßx)
    (Wx+b)*Swish(Vx+c)
    """
    proj_1 = w * x + b
    proj_2 = v * x + c
    return proj_1 * swish(proj_2, beta)

In [11]:
x_range = [-5, 5]
x = np.linspace(x_range[0], x_range[1], 1000)  # 1000 points for smooth curve
initial_params = {
    'w': 0.5,      # Linear weight
    'b': 0.0,      # Bias
    'v': -0.2,     # Sigmoid weight
    'c': 0.0,      # Sigmoid bias
    'beta': 0.2    # Scaling factor
}

param_ranges = {
    'w': (-1, 2),
    'b': (-1, 1), 
    'v': (-1, 2),
    'c': (-1, 1),
    'beta': (0, 2)
}

In [6]:
def create_sliders():
    """
    Creates all the parameter sliders with proper styling
    """
    
    # Style for sliders
    slider_style = {'description_width': '60px'}
    slider_layout = Layout(width='400px')
    
    # Create individual sliders for each parameter
    sliders = {}
    
    sliders['w'] = FloatSlider(
        min=param_ranges['w'][0],
        max=param_ranges['w'][1], 
        step=0.05,
        value=initial_params['w'],
        description='W (weight):',
        style=slider_style,
        layout=slider_layout,
        readout_format='.2f'
    )
    
    sliders['b'] = FloatSlider(
        min=param_ranges['b'][0],
        max=param_ranges['b'][1],
        step=0.05, 
        value=initial_params['b'],
        description='B (bias):',
        style=slider_style,
        layout=slider_layout,
        readout_format='.2f'
    )
    
    sliders['v'] = FloatSlider(
        min=param_ranges['v'][0],
        max=param_ranges['v'][1],
        step=0.05,
        value=initial_params['v'], 
        description='V (sig weight):',
        style=slider_style,
        layout=slider_layout,
        readout_format='.2f'
    )
    
    sliders['c'] = FloatSlider(
        min=param_ranges['c'][0],
        max=param_ranges['c'][1],
        step=0.05,
        value=initial_params['c'],
        description='C (sig bias):',
        style=slider_style, 
        layout=slider_layout,
        readout_format='.2f'
    )
    
    sliders['beta'] = FloatSlider(
        min=param_ranges['beta'][0],
        max=param_ranges['beta'][1],
        step=0.05,
        value=initial_params['beta'],
        description='β (scaling):',
        style=slider_style,
        layout=slider_layout,
        readout_format='.2f'
    )
    
    return sliders


In [7]:
def create_update_function(fig, sliders):
    """
    Creates the function that updates the plot when sliders change
    """
    
    def update_plot(*args):
        """
        This function is called whenever any slider value changes
        """
        # Get current slider values
        current_params = {
            'w': sliders['w'].value,
            'b': sliders['b'].value, 
            'v': sliders['v'].value,
            'c': sliders['c'].value,
            'beta': sliders['beta'].value
        }
        
        # Calculate new y values with updated parameters
        y_new = swiglu(x, **current_params)
        
        # Update the plot data using batch_update for efficiency
        with fig.batch_update():
            # Update the y data of the first (and only) trace
            fig.data[0].y = y_new
            
            # Update the title to show current parameter values
            fig.layout.title.text = (
                f"SwiGLU Function: "
                f"w={current_params['w']:.2f}, "
                f"b={current_params['b']:.2f}, "
                f"v={current_params['v']:.2f}, "
                f"c={current_params['c']:.2f}, "
                f"β={current_params['beta']:.2f}"
            )
    
    return update_plot

In [8]:
def create_full_interactive_plot():
    """
    Creates a fully interactive plot with all SwiGLU parameters
    """
    
    # Initialize the plot figure as a FigureWidget (allows real-time updates)
    fig = go.FigureWidget()
    
    # Calculate initial y values
    y_initial = swiglu(x, **initial_params)
    
    # Add the main trace (the SwiGLU curve)
    fig.add_scatter(
        x=x, 
        y=y_initial,
        mode='lines',
        name='SwiGLU',
        line=dict(color='blue', width=3)
    )
    
    # Set up the plot layout
    fig.update_layout(
        title=f"Interactive SwiGLU Function",
        xaxis_title="Input (x)",
        yaxis_title="Output f(x)",
        width=900,
        height=600,
        showlegend=True,
        template="plotly_white",  # Clean white background
        hovermode='x unified'     # Shows values when hovering
    )
    
    # Add grid for better readability
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    
    return fig

In [9]:
def create_complete_interactive_widget():
    """
    Assembles the complete interactive widget with plot and sliders
    """
    
    # Create the plot figure
    fig = create_full_interactive_plot()
    
    # Create all sliders
    sliders = create_sliders()
    
    # Create the update function
    update_plot = create_update_function(fig, sliders)
    
    # Connect each slider to the update function
    # observe() method watches for changes in the 'value' property
    for slider in sliders.values():
        slider.observe(update_plot, names='value')
    
    # Organize sliders in a nice layout (2 rows)
    slider_layout = VBox([
        HBox([sliders['w'], sliders['b']]),      # First row: W and B
        HBox([sliders['v'], sliders['c']]),      # Second row: V and C  
        HBox([sliders['beta']])                  # Third row: Beta (centered)
    ])
    
    # Combine plot and sliders in vertical layout
    complete_widget = VBox([
        fig,           # Plot on top
        slider_layout  # Sliders below
    ])
    
    return complete_widget

In [12]:
interactive_swiglu = create_complete_interactive_widget()
display(interactive_swiglu)

VBox(children=(FigureWidget({
    'data': [{'line': {'color': 'blue', 'width': 3},
              'mode': 'line…