In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
import ipywidgets as widgets
from ipywidgets import interact

# Generate synthetic data
X, y = datasets.make_blobs(n_samples=100, centers=2, random_state=6)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Function to create and plot SVM decision boundary
def plot_svm(C=1.0, gamma=0.1):
    # Train the SVM model
    model = SVC(C=C, gamma=gamma, kernel='rbf')
    model.fit(X_train, y_train)

    # Create a mesh grid for plotting decision boundary
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                         np.arange(y_min, y_max, 0.01))

    # Predict the labels for each point in the mesh grid
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    # Plot the decision boundary
    plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.coolwarm)
    plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, s=30, cmap=plt.cm.coolwarm, edgecolors='k')

    # Mark support vectors
    plt.scatter(model.support_vectors_[:, 0], model.support_vectors_[:, 1], s=100, facecolors='none', edgecolors='k', label='Support Vectors')

    # Plot formatting
    plt.title(f'SVM Decision Boundary (C={C}, gamma={gamma})')
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.legend()
    plt.show()

# Interactive widgets to control C and gamma
C_slider = widgets.FloatLogSlider(value=1.0, base=10, min=-2, max=2, step=0.1, description='C:', continuous_update=False)
gamma_slider = widgets.FloatLogSlider(value=0.1, base=10, min=-4, max=1, step=0.1, description='Gamma:', continuous_update=False)

# Interactive function
interact(plot_svm, C=C_slider, gamma=gamma_slider)


interactive(children=(FloatLogSlider(value=1.0, continuous_update=False, description='C:', max=2.0, min=-2.0),…

<function __main__.plot_svm(C=1.0, gamma=0.1)>