In [2]:
import numpy as np
import random
import argparse

tra_path='hw3_train.dat'
tst_path='hw3_test.dat'
bias = 1.0

def get_data(path, bias=1.0, transform=None):
    X = []
    for x in open(path, 'r'):
        x = x.strip().split('\t')
        x = [float(v) for v in x]
        X.append([bias] + x)

    X = np.array(X)
    X, Y = np.array(X[:, :-1]), np.array(X[:, -1])
    
    if transform is not None:
        X = transform(X)
    
    return X, Y

def get_wLIN(X, Y):
    X_plus = np.matmul(np.linalg.inv(np.matmul(X.T, X)), X.T)
    return np.matmul(X_plus, Y)

def sigmoid(s):
    return 1 / (1 + np.exp(-s))

def sign(s):
    s = np.sign(s)
    s[s == 0] = -1
    return s

def Q_transform(X, Q=3):
    return np.hstack([X]+[X[:, 1:]**q for q in range(2, Q+1)])
        
def err(w, X, Y, mode='sqr'):
    Y_pred = np.matmul(X, w)
    if mode == 'sqr':
        return ((Y_pred - Y)**2).mean()
    elif mode == 'ce':
        return -np.log(sigmoid(Y * Y_pred)).mean()
    elif mode == '01':
        Y_pred = sign(Y_pred)
        return (Y.astype(int) != Y_pred.astype(int)).mean()
        
def SGD(X, Y, lr, w_init=None, step_num=5566, mode='sqr'):
    def random_pick(X, Y):
        idx = random.randint(0, X.shape[0] - 1)
        return X[idx:idx+1], Y[idx:idx+1]
        
    def grad_func(w, X, Y, mode):
        batch_size = X.shape[0]
        if mode == 'sqr':
            return -(2 / batch_size) * (np.matmul(X.T, np.matmul(X, w)) - np.matmul(X.T, Y))
        elif mode == 'ce':
            return np.mean(sigmoid(-Y * np.matmul(X, w)).reshape(-1, 1) * (Y.reshape(-1, 1) * X), axis=0)
    
    def update_w(w, x, y, lr):        
        return w + lr * grad_func(w, x, y, mode)
    
    if mode == 'sqr':
        wLIN = get_wLIN(X, Y)
        
    # initialization
    step = 0
    w = np.zeros(X.shape[1:]) if w_init is None else w_init
        
    # training
    while step < step_num:
        x, y = random_pick(X, Y)
        w = update_w(w, x, y, lr)
        step += 1
        
        # check early stopping
        if mode == 'sqr':
            Y_pred = np.matmul(X, wLIN)
            E_in_sqr_LIN = sqr_err(Y_pred, Y)
            
            Y_pred = np.matmul(X, w)
            E_in_sqr = sqr_err(Y_pred, Y)
            if E_in_sqr <= 1.01 * E_in_sqr_LIN:
                break
    
    return w, step

X_tra, Y_tra = get_data(tra_path)
X_tst, Y_tst = get_data(tst_path)

In [40]:
wLIN = get_wLIN(X_tra, Y_tra)
print(err(wLIN, X_tra, Y_tra, mode='sqr'))

0.6053223804672917


In [5]:
update_num_list = []
for _ in range(1000):
    _, update_num = SGD(X_tra, Y_tra, lr=0.001)
    update_num_list.append(update_num)
print(np.mean(update_num_list))

1721.31


In [13]:
ce_loss_list = []
for _ in range(1000):
    w, _ = SGD(X_tra, Y_tra, lr=0.001, step_num=500, mode='ce')
    ce_loss = err(w, X_tra, Y_tra, mode='ce')
    ce_loss_list.append(ce_loss)
print(np.mean(ce_loss_list))

0.5692749939818216


In [22]:
ce_loss_list = []
for _ in range(1000):
    w, _ = SGD(X_tra, Y_tra, lr=0.001, w_init=wLIN, step_num=500, mode='ce')
    ce_loss = err(w, X_tra, Y_tra, mode='ce')
    ce_loss_list.append(ce_loss)
print(np.mean(ce_loss_list))

0.5028919699941422


In [31]:
abs(err(wLIN, X_tst, Y_tst, mode='01') - err(wLIN, X_tra, Y_tra, mode='01'))

0.3226666666666666

In [38]:
X_tra_Q = Q_transform(X_tra, Q=3)
X_tst_Q = Q_transform(X_tst, Q=3)
wLIN_Q = get_wLIN(X_tra_Q, Y_tra)
abs(err(wLIN_Q, X_tst_Q, Y_tst, mode='01') - err(wLIN_Q, X_tra_Q, Y_tra, mode='01'))

0.37366666666666665

In [39]:
X_tra_Q = Q_transform(X_tra, Q=10)
X_tst_Q = Q_transform(X_tst, Q=10)
wLIN_Q = get_wLIN(X_tra_Q, Y_tra)
abs(err(wLIN_Q, X_tst_Q, Y_tst, mode='01') - err(wLIN_Q, X_tra_Q, Y_tra, mode='01'))

0.44666666666666666

In [20]:
def f(x1, x2):
    return np.array([1, x1, x2, x1*x1, x1*x2, x2*x2])
X = [[0, 1, -1], [1,-0.5, -1], [-1, 0, -1], [-1,2,1], [2,0,1], [1,-1.5,1], [0,-2,1]]
w = np.array([[-9,1,0,2,-2,3], [-5,-1,2,3,-7,2], [9,-1,4,2,-2,3], [2,1,-4,-2,7,-4], [-7,0,0,2,-2,3]])
for i in range(len(w)):
    c = 0
    for x in X:
        z = f(x[0], x[1])
        pred = int(np.sign(sum(z * w[i])))
        if pred == 0:
            pred = -1
        if int(pred) != x[2]:
            c += 1
    if c == 0:
        print(i)

0
4
