In [1]:
import numpy as np
from stochastic_solvers import parallel_deflation, eigengame
from utils import *

np.random.seed(41)

In [2]:
import torchvision

mnist_data = torchvision.datasets.MNIST('./', download=True)

trn_img_flt = np.array([np.array(img).flatten() for img, _ in mnist_data]) / 255.
trn_img_mean = trn_img_flt.mean(axis=0)
trn_img_centered = trn_img_flt - trn_img_mean.reshape((1, -1))
print(trn_img_centered.shape)
mat = trn_img_centered.T @ trn_img_centered / trn_img_centered.shape[0]
evals_raw, evecs_raw = np.linalg.eigh(mat)
true_evals = evals_raw[::-1][:30]
true_evecs = evecs_raw[:, ::-1][:,:30].T

data_gen = lambda sample: random_batch(trn_img_centered, sample)

num_trials = 10
r = 30
batch_size = 1000


eval_func = lambda x: compute_avg_error(true_evecs, x)

(60000, 784)


# Parallel Deflation

In [3]:
L = 1200
T = 1
step_size_func = lambda idx: 10 / (1 + np.floor(idx / 10))

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = parallel_deflation(data_gen, r, L, T, step_size_func, batch_size)
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')

max_err_mean_T1 = all_max_errs.mean(axis=0)
max_err_std_T1 = all_max_errs.std(axis=0)
avg_err_mean_T1 = all_avg_errs.mean(axis=0)
avg_err_std_T1 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.07931484685691509
Running Trial #2
Trial #2 Error: 0.09914454240575465
Running Trial #3
Trial #3 Error: 0.09418131663907331
Running Trial #4
Trial #4 Error: 0.21964881389109794
Running Trial #5
Trial #5 Error: 0.07981865075718252
Running Trial #6
Trial #6 Error: 0.08185703833859334
Running Trial #7
Trial #7 Error: 0.10887208842795389
Running Trial #8
Trial #8 Error: 0.08195899120061222
Running Trial #9
Trial #9 Error: 0.08597193398035886
Running Trial #10
Trial #10 Error: 0.2025477628134142


In [4]:
L = 240
T = 5
step_size_func = lambda idx: 10 / (1 + np.floor(idx / 2))

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = parallel_deflation(data_gen, r, L, T, step_size_func, batch_size)
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')

max_err_mean_T5 = all_max_errs.mean(axis=0)
max_err_std_T5 = all_max_errs.std(axis=0)
avg_err_mean_T5 = all_avg_errs.mean(axis=0)
avg_err_std_T5 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.08677165995138474
Running Trial #2
Trial #2 Error: 0.09958614315861679
Running Trial #3
Trial #3 Error: 0.3381445477950125
Running Trial #4
Trial #4 Error: 0.1355754896618563
Running Trial #5
Trial #5 Error: 0.10306586834036491
Running Trial #6
Trial #6 Error: 0.097766789220794
Running Trial #7
Trial #7 Error: 0.11135735956953237
Running Trial #8
Trial #8 Error: 0.09336796493401384
Running Trial #9
Trial #9 Error: 0.2621280428012151
Running Trial #10
Trial #10 Error: 0.2228245209869005


In [5]:
import json

max_err_results = dict()
avg_err_results = dict()

max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())
max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())

avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())
avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())

with open('parallel_deflation_mnist_sto.txt', 'w+') as jfile:
    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)

# EigenGame-Mu

In [6]:
L = 1200
T = 1
step_size_func = lambda idx: 10 / (1 + np.floor(idx / 10))

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='mu')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')

max_err_mean_T1 = all_max_errs.mean(axis=0)
max_err_std_T1 = all_max_errs.std(axis=0)
avg_err_mean_T1 = all_avg_errs.mean(axis=0)
avg_err_std_T1 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.2800819053946179
Running Trial #2
Trial #2 Error: 0.35477575780643655
Running Trial #3
Trial #3 Error: 0.06114400574545103
Running Trial #4
Trial #4 Error: 0.08765750793496763
Running Trial #5
Trial #5 Error: 0.06116958915120261
Running Trial #6
Trial #6 Error: 0.06521674546904341
Running Trial #7
Trial #7 Error: 0.12841713895099455
Running Trial #8
Trial #8 Error: 0.08263263885944094
Running Trial #9
Trial #9 Error: 0.13691939394910063
Running Trial #10
Trial #10 Error: 0.10995566216014638


In [7]:
L = 240
T = 5
step_size_func = lambda idx: 10 / (1 + np.floor(idx / 2))

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='mu')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')

max_err_mean_T5 = all_max_errs.mean(axis=0)
max_err_std_T5 = all_max_errs.std(axis=0)
avg_err_mean_T5 = all_avg_errs.mean(axis=0)
avg_err_std_T5 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.20000367110492265
Running Trial #2
Trial #2 Error: 0.18002064143225627
Running Trial #3
Trial #3 Error: 0.11461935797674873
Running Trial #4
Trial #4 Error: 0.26313974013063113
Running Trial #5
Trial #5 Error: 0.19840596289598525
Running Trial #6
Trial #6 Error: 0.12021433784788399
Running Trial #7
Trial #7 Error: 0.08374777631890154
Running Trial #8
Trial #8 Error: 0.09099088260523475
Running Trial #9
Trial #9 Error: 0.1866020109061588
Running Trial #10
Trial #10 Error: 0.2387226931153399


In [8]:
import json

max_err_results = dict()
avg_err_results = dict()

max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())
max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())

avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())
avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())

with open('eigengame_mu_mnist_sto.txt', 'w+') as jfile:
    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)

# EigenGame-Alpha

In [9]:
L = 1200
T = 1
step_size_func = lambda idx: 10 / (1 + np.floor(idx / 10))

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='alpha')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')

max_err_mean_T1 = all_max_errs.mean(axis=0)
max_err_std_T1 = all_max_errs.std(axis=0)
avg_err_mean_T1 = all_avg_errs.mean(axis=0)
avg_err_std_T1 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.06680067135104835
Running Trial #2
Trial #2 Error: 0.08383224909477664
Running Trial #3
Trial #3 Error: 0.2521713095341494
Running Trial #4
Trial #4 Error: 0.06849535367003343
Running Trial #5
Trial #5 Error: 0.09019275791369717
Running Trial #6
Trial #6 Error: 0.06280845752427353
Running Trial #7
Trial #7 Error: 0.09847965652348292
Running Trial #8
Trial #8 Error: 0.06646231971721014
Running Trial #9
Trial #9 Error: 0.0873212897821445
Running Trial #10
Trial #10 Error: 0.11153374469673465


In [10]:
L = 240
T = 5
step_size_func = lambda idx: 10 / (1 + np.floor(idx / 2))

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(data_gen, r, L, T, step_size_func, batch_size, update='alpha')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_avg_errs[trial_idx,-1]}')

max_err_mean_T5 = all_max_errs.mean(axis=0)
max_err_std_T5 = all_max_errs.std(axis=0)
avg_err_mean_T5 = all_avg_errs.mean(axis=0)
avg_err_std_T5 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.10539663923046567
Running Trial #2
Trial #2 Error: 0.0805605605666201
Running Trial #3
Trial #3 Error: 0.09776179975499993
Running Trial #4
Trial #4 Error: 0.14264293137552858
Running Trial #5
Trial #5 Error: 0.08702676952366487
Running Trial #6
Trial #6 Error: 0.3419322727578004
Running Trial #7
Trial #7 Error: 0.1886071802060075
Running Trial #8
Trial #8 Error: 0.35532805096160336
Running Trial #9
Trial #9 Error: 0.06182658173880311
Running Trial #10
Trial #10 Error: 0.16130111013789586


In [11]:
import json

max_err_results = dict()
avg_err_results = dict()

max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())
max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())

avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())
avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())

with open('eigengame_alpha_mnist_sto.txt', 'w+') as jfile:
    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)