In [1]:
import numpy as np
from numpy                   import *
from scipy.optimize          import minimize, Bounds, LinearConstraint
from sklearn.model_selection import train_test_split
from plotly.express          import scatter, scatter_3d, imshow
from plotly.graph_objects    import Mesh3d, Layout, Figure, Contour, Scatter, Isosurface
from sklearn.datasets        import load_iris

In [2]:
def linear_svm_primal_non_separable_case(x, y, C):
    n, d = x.shape[0], x.shape[1]
    Q = eye(d+1)
    Q[d,d] = 0.0
    c = zeros((d+1, 1))
    A = -diag(y) @ hstack([x, ones((n, 1))])
    b = -ones(n)
    Aeq = zeros((0, d+1))
    beq = zeros(0)
    lb  = zeros(n)
    ub  = C*ones(n)
    results = quadprog(Q, c, A=A, b=b, Aeq=Aeq, beq=beq, lb=lb, ub=ub)
    w = results.x[:d]
    b = results.x[d]
    return w, b

In [3]:
np.random.seed(123)
x1 = np.random.randn(2, 20)
x2 = np.random.randn(2, 20) + 2
y1, y2 = np.ones(20), -np.ones(20)
X_train = np.hstack([x1, x2]).T
y_train = np.hstack([y1, y2]).T

In [4]:
def quadprog(Q, c, A=None, b=None, Aeq=None, beq=None, lb=None, ub=None, x0=None):
    # DESCRIPTION:
    # min  0.5 * x.T @ Q @ x + c.T @ x
    # s.t. A   @ x <= b
    #      Aeq @ x == beq
    #      lb <= x <= ub
    # init x with x0
    if A   is None: A   =  zeros(0, c.shape[0])
    if b   is None: b   =  zeros(0)
    if Aeq is None: Aeq =  zeros(0, c.shape[0])
    if beq is None: beq =  zeros(0)
    if lb  is None: lb  = -inf * ones(c.shape[0])
    if ub  is None: ub  =  inf * ones(c.shape[0])
    if x0  is None: x0  =  zeros(c.shape[0])
    fun            = lambda x: 0.5 * x.T @ Q @ x + c.T @ x
    thebounds      = Bounds(lb, ub)
    theconstraints = LinearConstraint(vstack([A,Aeq]), 
                concatenate([-inf*ones(A.shape[0]), beq]),
                concatenate([b, beq]))
    return minimize(fun, x0, 
                    bounds=thebounds, 
                    constraints=theconstraints)

In [5]:
C = 1.0
w, b = linear_svm_primal_non_separable_case(X_train, y_train, C)

ValueError: operands could not be broadcast together with shapes (3,) (40,) (40,) 

In [None]:
# function to plot a plane
# this doesn't always show the plane where we are interested in it!
def add_plane(fig, w, b, opacity=0.5, scaling=100.0):
    # find three vectors that are orthogonal to w
    # vector w=(a, b, c) is orthogonal to (-b, a, 0), or (-c, 0, a) or (0, -c, b)
    # each row is a vector orthogonal to w = (a,b,c)
    x_plane = array([[-w[1],  w[0],     0],  # (-b, a, 0)
                     [-w[2],     0,  w[0]],  # (-c, 0, a)
                     [    0, -w[2],  w[1]]]) # ( 0,-c, b)
    # next shift in w direction, such that x_plane @ w + b = 0
    x_plane += -w * b / (w@w)  # shift in w direction
    tau = scaling   # scaling factor, scale in all directions
                    # increase if plane is not visible
    x_plane = array([[1-tau,   0  ,  tau ],
                     [ tau , 1-tau,   0  ],
                     [  0  ,  tau , 1-tau]]) @ x_plane
    fig.add_trace(Mesh3d(
        color='green', opacity=opacity,
        x = x_plane[:,0], y = x_plane[:,1], z = x_plane[:,2], # define three points
        i = [0], j = [1], k = [2]))                           # define a triangle
    return x_plane

In [None]:
def add_margin(fig, w, b):
    # fix the axis (the planes are much larger)
    fix_axis(fig)    # to keep the current view 
    # even when adding a large triangle for the plane
    # the separating hyperplane
    add_plane(fig, w, b)
    # the margin
    # note that for x on one margin we have x @ w + b == +1
    #          and the other margin we have x @ w + b == -1
    add_plane(fig, w, b + 1, opacity=0.15)
    add_plane(fig, w, b - 1, opacity=0.15)

In [None]:
fig = scatter_3d(x   = X_train[:,0],       # x axis of plot
                 y   = X_train[:,1],       # y axis of plot
                 z   = X_train[:,2],       # z axis of plot
                 color = [['class -1', 'class +1'][n] for n in (y+1)//2])
fig['layout']['scene']['aspectmode'] = "data"
add_margin(fig, w, b)
fig.update_layout(title={'text': "BALANCED toy data"})
fig.show()