In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist

np.random.seed(2)

In [2]:
means = [[2, 2], [4, 2]]
cov = [[0.3, 0.2], [0.2, 0.3]]
N = 10
X0 = np.random.multivariate_normal(means[0], cov, N).T
X1 = np.random.multivariate_normal(means[1], cov, N).T

In [3]:
X0

array([[2.22096057, 2.70132234, 3.08493823, 2.02701417, 2.73223639,
        1.21171968, 2.22920603, 1.8637762 , 1.74682699, 2.37191737],
       [2.19579728, 3.43487375, 2.70849736, 1.47010441, 2.32571583,
        2.23682627, 1.72925457, 1.59716548, 2.27230351, 2.37595358]])

In [4]:
X1

array([[4.47403369, 4.09281249, 4.22222334, 4.58438569, 4.74493118,
        3.6355797 , 5.19217738, 3.51075436, 3.93784332, 3.8787214 ],
       [2.4040742 , 1.65061706, 2.11659863, 2.05326933, 2.67628604,
        2.63347726, 3.2425902 , 2.11880111, 1.56029947, 2.12126884]])

In [5]:
X = np.concatenate((X0,X1), axis=1)
X

array([[2.22096057, 2.70132234, 3.08493823, 2.02701417, 2.73223639,
        1.21171968, 2.22920603, 1.8637762 , 1.74682699, 2.37191737,
        4.47403369, 4.09281249, 4.22222334, 4.58438569, 4.74493118,
        3.6355797 , 5.19217738, 3.51075436, 3.93784332, 3.8787214 ],
       [2.19579728, 3.43487375, 2.70849736, 1.47010441, 2.32571583,
        2.23682627, 1.72925457, 1.59716548, 2.27230351, 2.37595358,
        2.4040742 , 1.65061706, 2.11659863, 2.05326933, 2.67628604,
        2.63347726, 3.2425902 , 2.11880111, 1.56029947, 2.12126884]])

In [6]:
y = np.concatenate((np.ones((1, N)), -1 * np.ones((1, N))), axis=1)
y

array([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.]])

In [7]:
X = np.concatenate((np.ones((1, 2*N)), X), axis = 0)
X

array([[1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [2.22096057, 2.70132234, 3.08493823, 2.02701417, 2.73223639,
        1.21171968, 2.22920603, 1.8637762 , 1.74682699, 2.37191737,
        4.47403369, 4.09281249, 4.22222334, 4.58438569, 4.74493118,
        3.6355797 , 5.19217738, 3.51075436, 3.93784332, 3.8787214 ],
       [2.19579728, 3.43487375, 2.70849736, 1.47010441, 2.32571583,
        2.23682627, 1.72925457, 1.59716548, 2.27230351, 2.37595358,
        2.4040742 , 1.65061706, 2.11659863, 2.05326933, 2.67628604,
        2.63347726, 3.2425902 , 2.11880111, 1.56029947, 2.12126884]])

In [8]:
def h(W,x):
    return np.sign(np.dot(W.T, x))

In [9]:
def has_converged(X, y, W):
    return np.array_equal(h(W, X), y)

def perceptron(X, y, w_init):
    W = [w_init]
    N = X.shape[1]
    d = X.shape[0]
    miss_points = []
    while True:
        mix_id = np.random.permutation(N)
        for i in range(N):
            xi = X[:, mix_id[i]].reshape(d, 1)
            yi = y[0, mix_id[i]]
            if h(W[-1], xi) != yi:
                miss_points.append(mix_id[i])
                w_new = W[-1] + yi * xi
                W.append(w_new)
        if has_converged(X, y, W[-1]):
            break
    return (W, miss_points)

In [11]:
d = X.shape[0]
w_init = np.random.randn(d, 1)
(w, m) = perceptron(X,y, w_init)
print('Misspoint', m)
print(w)
print(len(w))

Misspoint [14, 4, 2, 10, 8, 16, 0, 1, 18, 3, 14, 9]
[array([[1.73118467],
       [1.58160763],
       [0.01896191]]), array([[ 0.73118467],
       [-3.16332356],
       [-2.65732414]]), array([[ 1.73118467],
       [-0.43108717],
       [-0.33160831]]), array([[2.73118467],
       [2.65385106],
       [2.37688905]]), array([[ 1.73118467],
       [-1.82018263],
       [-0.02718515]]), array([[ 2.73118467],
       [-0.07335564],
       [ 2.24511836]]), array([[ 1.73118467],
       [-5.26553302],
       [-0.99747185]]), array([[ 2.73118467],
       [-3.04457245],
       [ 1.19832543]]), array([[ 3.73118467],
       [-0.3432501 ],
       [ 4.63319918]]), array([[ 2.73118467],
       [-4.28109342],
       [ 3.07289971]]), array([[ 3.73118467],
       [-2.25407925],
       [ 4.54300412]]), array([[ 2.73118467],
       [-6.99901043],
       [ 1.86671808]]), array([[ 3.73118467],
       [-4.62709307],
       [ 4.24267166]])]
13
