In [543]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

In [544]:
def random_target():
    # Two random points (p, q) and (r, s)
    p = np.random.uniform(-1, 1)
    q = np.random.uniform(-1, 1)
    r = np.random.uniform(-1, 1)
    s = np.random.uniform(-1, 1)

    # Weights are the coefficients in ax + by + c = 0 form, so
    # must convert two points into general form.
    # Point slope form: y - y0 = m(x - x0)
    # Converted to general form: -mx + y + (-y0 + m*x0) = 0
    # where m = (q - s)/(p - r), y0 = q, x0 = p
    # Thus, a = -(q - s)/(p - r), b = 1, c = -q + p*(q - s)/(p - r)

    a = -(q - s)/(p - r)
    b = 1
    c = -q + p*(q - s)/(p - r)

    target_weights = np.array([c, a, b])

    return target_weights

In [545]:
def generate_data(N, target_weights):
    '''
    Creates N random training examples of dimension d from the target function f.
    '''
    data = []
    for i in range(N):
        x = np.array([1])
        for j in range(2):
            x = np.append(x, np.random.uniform(-1, 1))
        if np.dot(target_weights, x) > 0:
            y = 1
        else:
            y = -1
        data.append((x, y))
    return data

In [610]:
total_eout = 0
total_epochs = 0
k = 100
for trial in range(k):
    s = random_target()
    data = generate_data(100, s)
    w = np.array([0.0, 0.0, 0.0])

    change = 99999
    data_it = iter(np.random.permutation(data))
    prev_w = w
    iteration = 0
    epoch = 0
    while change >= 0.01:
        try:
            x, y = next(data_it)
        except StopIteration:
            epoch += 1
            change = np.linalg.norm(prev_w - w)
            prev_w = w
            data_it = iter(np.random.permutation(data))
            x, y = next(data_it)
        de = -(y*x / (1 + np.exp(y*np.dot(w, x))))
        w -= 0.01 * de

    eout = 0
    for x, y in generate_data(100, s):
        eout += np.log(1 + np.exp(-y*np.dot(w, x)))
    eout /= len(data)

    total_eout += eout
    total_epochs += epoch

print(total_eout / k, total_epochs / k)

[ 0.30524803 -0.10027034  0.05672116] [ 0.30524803 -0.10027034  0.05672116]
[-0.07332646  0.25575143  0.01241638] [-0.07332646  0.25575143  0.01241638]
[0.07273766 0.19235131 0.04044052] [0.07273766 0.19235131 0.04044052]
[-0.33607045 -0.06406115  0.00072091] [-0.33607045 -0.06406115  0.00072091]
[-0.07053308 -0.12925406  0.17774071] [-0.07053308 -0.12925406  0.17774071]
[0.07393758 0.08021953 0.19229654] [0.07393758 0.08021953 0.19229654]
[-0.11913601 -0.20942007  0.0640332 ] [-0.11913601 -0.20942007  0.0640332 ]
[-0.40766909  0.00336565 -0.00573819] [-0.40766909  0.00336565 -0.00573819]
[0.11439788 0.18055363 0.09862843] [0.11439788 0.18055363 0.09862843]
[0.14195146 0.06920501 0.20317361] [0.14195146 0.06920501 0.20317361]
[ 0.02429273 -0.19198847  0.11032663] [ 0.02429273 -0.19198847  0.11032663]
[ 0.04568485 -0.23474188  0.06029522] [ 0.04568485 -0.23474188  0.06029522]
[ 0.07023415 -0.14953264  0.18567865] [ 0.07023415 -0.14953264  0.18567865]
[-0.2288086   0.17321453  0.07390744