In [5]:
# imports
import sys, os
sys.path.append(os.getcwd()+'/code')
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import Algorithms
from data_streamer import Streamer, MultiStreamer

In [6]:
# constants
train_data_file=os.getcwd() + '/data/train_mnist_normalized.csv'
validation_data_file=os.getcwd() + '/data/test_mnist_normalized.csv'
assert os.path.isfile(train_data_file)
k = 10 # dimensionality of representations
d=784
learning_rate = 1e-3 # initial learning rate
validation_freq = 1000 # how many training steps in between validations

In [7]:
# data streamer and algorithm
train_stream = Streamer(train_data_file).get_stream()
algorithm=Algorithms.rfoja({'d': d, 'k': k, 'learning_rate': learning_rate, 'mean_center': 0.0,'kernel':'rbf','kernel_hyperparameter':0.1,'m':50})
# load validation data at once to make validation faster
validation_data, training_data=[],[]
for point in Streamer(validation_data_file).get_stream():
    validation_data.append(point)
for point in Streamer(train_data_file).get_stream():
    training_data.append(point)
validation_data=np.stack(validation_data,axis=0)
training_data=np.stack(training_data,axis=0)
# do PCA on the validation as a lower bound
centered_data=validation_data #- np.mean(validation_data,axis=0,keepdims=True)
lambdas,_=np.linalg.eigh(np.matmul(centered_data.T,centered_data))
lambdas=sorted(lambdas)
best_validation_loss=sum(lambdas[:-k])/(centered_data.shape[0]*centered_data.shape[1])

In [8]:
# placeholders for loss
validation_loss, training_loss, min_training_loss = [], [], []
training_batch = []
# main loop
counter=0
train_stream = Streamer(train_data_file).get_stream()
for point in train_stream:
    counter+=1
    algorithm.step(point)
    if counter%validation_freq==0:
        print(counter)
        loss=algorithm.loss(validation_data)
#         loss/=(validation_data.shape[0]*validation_data.shape[1])
        validation_loss.append(loss)
        loss=algorithm.loss(training_data)
#         loss/=(training_data.shape[0]*training_data.shape[1])
        training_loss.append(loss)

0.915201282678
0.78722215559
0.863126180444
0.69470928335
0.89476887963
0.788518602728
0.934677086238
0.845328474239
0.77347523625
0.778611187134
0.692429189099
1.02971657683
0.772655121505
0.801356217231
0.875984906811
0.771904556749
0.675202280727
0.625759615068
0.866627089522
0.872078429718
0.930125304278
0.765529771732
0.922276518481
0.856810876946
0.809547783789
0.814189361749
0.879318255935
0.741287575518
0.772408363526
0.744928464066
0.879765951575
0.740936593593
0.711758227713
0.9355718351
0.555578092504
0.831687036556
0.773508206566
0.713054303528
0.797817206399
0.726380001813
0.996384320796
0.859742279528
0.780838121742
0.862908637713
0.715534969574
0.749246080194
0.910381346506
0.833143367077
0.702206284592
0.751175334676
0.815588780218
0.831574439381
0.737499129164
0.99322505208
0.810817197776
0.929934966551
0.838228287837
0.930218099329
0.750681926371
0.840890580602
0.755635861856
0.66076853465
0.844071365355
0.694767449003
0.75906619361
0.92235590894
0.820777453121
0.8976

0.701488073628
0.861568804794
0.787851480354
0.812308627313
0.791771820454
0.884067544358
0.862570473226
0.802116067266
0.74364582251
0.766205714314
0.812043424336
0.873370156573
0.878771352603
0.724747397501
0.946858599695
0.796745355738
0.7584831787
0.893510082353
0.80140936199
0.977245521356
0.72933083982
0.830393827551
0.855194754062
0.796486936617
0.831872103406
0.708809173236
0.684240512707
0.74379254357
0.751126031178
0.7392414301
0.764592688674
0.878317389313
0.890822426159
0.87623648302
0.793121849158
0.769942782292
0.886341550677
0.862875994271
0.703326400546
0.787196664939
0.849642469139
0.605970092377
0.80859510546
0.820128256067
0.814538126781
0.782779303091
0.840192835413
0.765275577956
0.667027475136
0.766329518876
0.779059851964
0.823366320539
0.663103890087
0.889332301378
0.737513244149
0.872044501035
0.911000618394
0.861542730938
0.705721633791
0.677968208443
0.830860118371
0.750095339609
0.746266878216
0.696762311338
0.826556798229
0.719569717768
0.881002053001
0.683

NameError: global name 'rf_ponts' is not defined

In [None]:
print training_loss

In [None]:
# plot the train and test loss
plt.figure(figsize=(8, 8))
plt.xlabel('Step')
plt.ylabel('Loss')
x_scale=[x*validation_freq for x in range(1,len(training_loss)+1)]
plt.plot(x_scale, training_loss, '--b', label='train loss')
plt.plot(x_scale, validation_loss, '--r', label='test loss')
plt.plot(x_scale, [best_validation_loss]*len(training_loss), '--k', label='train loss lower bound')
_ = plt.legend()