In [2]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, widgets, VBox, Label

def plot_perceptron(w1, w2, b):
    x1 = np.linspace(-10, 10, 400)
    x2 = np.linspace(-10, 10, 400)
    X1, X2 = np.meshgrid(x1, x2)
    
    # 计算每个网格点的 logit 值
    logit = w1 * X1 + w2 * X2 + b
    
    # 创建一个布尔数组，用于标记满足所有logit都为正的点
    positive_region = (logit >= 0)
    
    
    fig, ax = plt.subplots(figsize=(5, 5))
    # 绘制直线
    x1_plot = np.linspace(-10, 10, 100)
    if w2 == 0:
        w2 = 1e-10
    
    x2_plot = -(w1 * x1_plot + b) / w2
    
    ax.plot(x1_plot, x2_plot, '-b', label='Gerade')  
    
    
    
    # 绘制阴影区域
    ax.contourf(X1, X2, positive_region, levels=[0.5, 1], colors=['red'], alpha=0.1)
    
    # 4个固定的点：
    points = np.array([[-5, -5], [-5, 5], [5, -5], [5, 5]])
    colors = []
    coordinates_labels = []
    logit_labels = []
    markers = []
    linewidths = []

    # 计算4个点的logit值并确定颜色和符号
    for (px, py) in points:
        logit = w1 * px + w2 * py + b
        
        if logit >= 0:
            color = 'red'
            marker = '+'
            linewidth = 2
        else:
            color = 'green'
            marker = 'o'
            linewidth = 1
        colors.append(color)
        markers.append(marker)
        linewidths.append(linewidth)
        coordinates_labels.append(f'({px}, {py})')
        logit_labels.append(f'Logit: {logit:.2f}')
    
    # 绘制固定点和标签
    for idx, (px, py) in enumerate(points):
        ax.plot(px, py, marker=markers[idx], color=colors[idx], markersize=10, markeredgewidth=linewidths[idx])
        ax.text(px, py + 0.5, coordinates_labels[idx], fontsize=9, verticalalignment='bottom', horizontalalignment='center')
        ax.text(px, py - 0.5, logit_labels[idx], fontsize=9, verticalalignment='top', horizontalalignment='center')

    ax.set_title('Ein einziges Neuron')
    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')
    ax.set_xlim([-10, 10])
    ax.set_ylim([-10, 10])
    ax.axhline(0, color='black', linewidth=0.5)
    ax.axvline(0, color='black', linewidth=0.5)
    ax.grid(True)
    ax.legend(loc='upper right')
    plt.show()


interactive_plot = interactive(plot_perceptron, 
                               w1=widgets.FloatSlider(value=1.0, min=-2.0, max=2.0, step=0.1),
                               w2=widgets.FloatSlider(value=1.2, min=-2.0, max=2.0, step=0.1),
                               b=widgets.FloatSlider(value=0.0, min=-10.0, max=10.0, step=0.1))

interactive_plot

interactive(children=(FloatSlider(value=1.0, description='w1', max=2.0, min=-2.0), FloatSlider(value=1.2, desc…