In [712]:
# %matplotlib widget

In [713]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

In [714]:
X, y = make_blobs(500, centers=2, cluster_std=1.4)
y[y == 0] = -1

In [715]:
def decision_func(X: np.ndarray, y: np.ndarray, b: float, w: np.ndarray) -> np.ndarray:
    return np.dot(X, w) + b

In [716]:
def loss(X: np.ndarray, y: np.ndarray, w: np.ndarray, b: np.ndarray) -> float:
        return np.mean(np.max(0, 1 - y * decision_func(X, y, b, w))) + 0.01*np.sum(w**2)

In [717]:
def training_func(X: np.ndarray, y: np.ndarray, b: float, w: np.ndarray, epochs: int, lr: float = 0.1) -> tuple[np.ndarray]:
    for e in range(epochs):
        print(f'Epoch {e}/{epochs}, Loss: {loss(X, y, w, b)}{20*" "}', end='\r')

        batch_mask = np.random.choice(y.size, size=32, replace=False)
        X_batch, y_batch = X[batch_mask], y[batch_mask]

        for x_i, y_i in zip(X_batch, y_batch):
            dL_dw = -(y_i * x_i) * ((y_i * np.dot(x_i, w) + b) < 1) + 2*0.01*w
            dL_db = -y_i * ((y_i * np.dot(x_i, w) + b) < 1)

            w -= lr*dL_dw
            b -= lr*dL_db

    return w, b

In [718]:
def fit(X: np.ndarray, y: np.ndarray) -> np.ndarray:
    w, b = training_func(X, y, 0.0, np.random.random(X.shape[1]), 1000)
    print(w, b)
    return decision_func(X, y, b, w), w

In [719]:
def draw_decision_boundary(X: np.ndarray, y: np.ndarray):
    xx, yy = np.meshgrid(np.linspace(X[:, 0].min()-1, X[:, 0].max()+1, X.shape[0]),
                         np.linspace(X[:, 1].min()-1, X[:, 1].max()+1, X.shape[0]))
    X_grid = np.c_[xx.ravel(), yy.ravel()]
    pred, w = fit(X_grid, y)
    pred = pred.reshape(xx.shape)

    fig, ax = plt.subplots()
    fig.set_figwidth(12)
    fig.set_figheight(10)
    ax.contour(xx, yy, pred, levels=[-1, 0, 1], linestyles=['--', '-', '--'], colors=['r', 'k', 'r'])
    ax.scatter(X[:, 0], X[:, 1], c=y)
    ax.plot(np.linspace(0, w[0], 100), np.linspace(0, w[1], 100))
    plt.show()

In [720]:
draw_decision_boundary(X, y)

TypeError: decision_func() missing 3 required positional arguments: 'y', 'b', and 'w'