In [2]:
import numpy as np
import math

# Perceptron ex4

In [41]:
np.random.seed(37)

def get_data(N, P):
    """
    Creates data, initial weights, and corresponding labels
    
    Parameters:
    N : number of samples
    P : number of patterns
    
    Returns:
    xi : np matrix size (N, P), binary -1 or 1
    w0 : np matrix size (1, N), normally distributed
    labels : list length P, indicating the truth labels of xi * w0
    """
    xi = np.random.choice([0, 1], size=(N, P))
    xi[xi==0] = -1
    
    w0 = np.random.randn(1, N)
    labels = []
    for j in range(P):
        labels.append(np.sign(xi[:,j] * w0))
        
    return xi, w0, labels


def train(x, w0, eta=0.01, n_epoch=100):
    """Trains the perceptron
    
    Parameters:
    x : input data, size (N, P)
    w0 : initial weights, size (1, N)
    n_epoch : number of iterations to train for
    
    Returns:
    w : weight parameters after training such that 
    """
    w = w0
    for epoch in range(n_epoch):
        labels = np.sign(np.dot(x.T, w.T))
        update_idxs = np.equal(labels, -1) + np.equal(labels, 0)
        if np.all(update_idxs == 0):
            return w
        else:
            w = w + eta * np.sum(np.multiply(x.T, update_idxs).T, axis=1)
    return w

def error(N, P):
    """Calculates the error according to the derived formula in (a)"""
    return 8/P * (5.991 + N * (1.693 + math.log(P) - math.log(N)))

def numerical_error(x, w, y):
    """Calculates the numerical error according to the test set"""
    preds = np.multiply(x.T, w).T
    return np.sum(y) / preds.shape[1]
    
N = 10
Ps = [10, 50, 100, 500, 1000]

# Hyperparams
n_learning_runs = 100
eta = 0.01

for P in Ps:
    # Get data
    train_xi, train_w0, train_labels = get_data(N, P)
    test_xi, test_w, test_labels = get_data(N, 10000) 
    
    # Train weights
    w = train(train_xi, train_w0, eta, n_learning_runs)
    
    # Obtain bound + error on test set
    epsilon = error(N, P)
    ner = numerical_error(test_xi, test_w, test_labels)
    print(epsilon, ner)

18.3368 0.0478
6.242460659894561 0.0608
3.675748074395236 -0.023
0.9926596808685032 0.0652
0.5517816148790473 0.0618
