In [1]:
import numpy as np
from centralized_solvers import eigengame, parallel_deflation
from utils import *

np.random.seed(41)

# Load and Preprocess MNIST Dataset

In [2]:
import torchvision

mnist_data = torchvision.datasets.MNIST('./', download=True)

trn_img_flt = np.array([np.array(img).flatten() for img, _ in mnist_data]) / 255.
trn_img_mean = trn_img_flt.mean(axis=0)
trn_img_centered = trn_img_flt - trn_img_mean.reshape((1, -1))
mat = trn_img_centered.T @ trn_img_centered / trn_img_centered.shape[0]
evals_raw, evecs_raw = np.linalg.eigh(mat)
true_evals = evals_raw[::-1][:30]
true_evecs = evecs_raw[:, ::-1][:,:30].T

num_trials = 10
r = 30
step_size_func = lambda idx: 0

## Parallel Deflation-MNIST

In [3]:
L = 500
T = 1

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = parallel_deflation(mat, r, L, T, step_size_func, update='pw')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T1 = all_max_errs.mean(axis=0)
max_err_std_T1 = all_max_errs.std(axis=0)
avg_err_mean_T1 = all_avg_errs.mean(axis=0)
avg_err_std_T1 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.0012911806414984868
Running Trial #2
Trial #2 Error: 0.0007581033047346986
Running Trial #3
Trial #3 Error: 0.0017219689803699082
Running Trial #4
Trial #4 Error: 0.08568833012720048
Running Trial #5
Trial #5 Error: 0.0033488808258883315
Running Trial #6
Trial #6 Error: 0.00026956639105399577
Running Trial #7
Trial #7 Error: 0.002020231562018738
Running Trial #8
Trial #8 Error: 0.0014138823110246626
Running Trial #9
Trial #9 Error: 5.968110946590731e-05
Running Trial #10
Trial #10 Error: 0.00978324677440718


In [4]:
L = 300
T = 3

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = parallel_deflation(mat, r, L, T, step_size_func, update='pw')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T3 = all_max_errs.mean(axis=0)
max_err_std_T3 = all_max_errs.std(axis=0)
avg_err_mean_T3 = all_avg_errs.mean(axis=0)
avg_err_std_T3 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 9.511821003990935e-06
Running Trial #2
Trial #2 Error: 2.6115134195799512e-05
Running Trial #3
Trial #3 Error: 3.234525558147667e-05
Running Trial #4
Trial #4 Error: 1.2876362316179243e-06
Running Trial #5
Trial #5 Error: 6.120369666192785e-07
Running Trial #6
Trial #6 Error: 6.160634560337176e-06
Running Trial #7
Trial #7 Error: 1.846781170240909e-05
Running Trial #8
Trial #8 Error: 7.939117177244544e-06
Running Trial #9
Trial #9 Error: 6.5524795991680816e-06
Running Trial #10
Trial #10 Error: 1.8719203803372405e-06


In [5]:
L = 200
T = 5

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = parallel_deflation(mat, r, L, T, step_size_func, update='pw')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T5 = all_max_errs.mean(axis=0)
max_err_std_T5 = all_max_errs.std(axis=0)
avg_err_mean_T5 = all_avg_errs.mean(axis=0)
avg_err_std_T5 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 2.4408356754897522e-06
Running Trial #2
Trial #2 Error: 8.855553364358597e-06
Running Trial #3
Trial #3 Error: 5.273748101160096e-06
Running Trial #4
Trial #4 Error: 1.7431174297201346e-06
Running Trial #5
Trial #5 Error: 2.2276114631740875e-06
Running Trial #6
Trial #6 Error: 2.240383515010395e-06
Running Trial #7
Trial #7 Error: 2.0968321874009244e-06
Running Trial #8
Trial #8 Error: 1.971576235862463e-06
Running Trial #9
Trial #9 Error: 4.445443716697782e-07
Running Trial #10
Trial #10 Error: 1.680333061100298e-06


In [6]:
import json

max_err_results = dict()
avg_err_results = dict()

max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())
max_err_results['T=3'] = (max_err_mean_T3.tolist(), max_err_std_T3.tolist())
max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())

avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())
avg_err_results['T=3'] = (avg_err_mean_T3.tolist(), avg_err_std_T3.tolist())
avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())

with open('parallel_deflation_mnist.txt', 'w+') as jfile:
    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)

## EigenGame-Alpha-MNIST

In [7]:
L = 500
T = 1

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='alpha')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T1 = all_max_errs.mean(axis=0)
max_err_std_T1 = all_max_errs.std(axis=0)
avg_err_mean_T1 = all_avg_errs.mean(axis=0)
avg_err_std_T1 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.0013855998608143286
Running Trial #2
Trial #2 Error: 0.000780804400791811
Running Trial #3
Trial #3 Error: 0.0165827110447389
Running Trial #4
Trial #4 Error: 0.006610207604548143
Running Trial #5
Trial #5 Error: 0.02401295088768233
Running Trial #6
Trial #6 Error: 0.005669936512737841
Running Trial #7
Trial #7 Error: 0.00023028584947130785
Running Trial #8
Trial #8 Error: 0.0027082466731954513
Running Trial #9
Trial #9 Error: 0.00010240437484804006
Running Trial #10
Trial #10 Error: 0.002617239526788341


In [8]:
L = 200
T = 5

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='alpha')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T5 = all_max_errs.mean(axis=0)
max_err_std_T5 = all_max_errs.std(axis=0)
avg_err_mean_T5 = all_avg_errs.mean(axis=0)
avg_err_std_T5 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 6.786166153392613e-07
Running Trial #2
Trial #2 Error: 2.589116958093376e-06
Running Trial #3
Trial #3 Error: 1.923577733409992e-06
Running Trial #4
Trial #4 Error: 7.434171249429239e-07
Running Trial #5
Trial #5 Error: 6.789437383350442e-07
Running Trial #6
Trial #6 Error: 2.3654860059004343e-07
Running Trial #7
Trial #7 Error: 1.3818768405162847e-07
Running Trial #8
Trial #8 Error: 2.601223797730691e-06
Running Trial #9
Trial #9 Error: 2.9405545403869946e-06
Running Trial #10
Trial #10 Error: 5.267681671121682e-06


In [9]:
import json

max_err_results = dict()
avg_err_results = dict()

max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())
max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())

avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())
avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())

with open('eigengame_alpha_mnist.txt', 'w+') as jfile:
    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)

## EigenGame-Mu-MNIST

In [10]:
L = 500
T = 1

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='mu')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T1 = all_max_errs.mean(axis=0)
max_err_std_T1 = all_max_errs.std(axis=0)
avg_err_mean_T1 = all_avg_errs.mean(axis=0)
avg_err_std_T1 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 0.0002828524228628293
Running Trial #2
Trial #2 Error: 0.0008641836877627376
Running Trial #3
Trial #3 Error: 0.0022671544850690266
Running Trial #4
Trial #4 Error: 0.0010584118205372256
Running Trial #5
Trial #5 Error: 1.7724303300665546e-05
Running Trial #6
Trial #6 Error: 0.0024958343802720657
Running Trial #7
Trial #7 Error: 0.017461269201645743
Running Trial #8
Trial #8 Error: 0.00012372311472815897
Running Trial #9
Trial #9 Error: 0.003511987358313588
Running Trial #10
Trial #10 Error: 0.0024374102225675287


In [11]:
L = 200
T = 5

all_max_errs = np.zeros((num_trials, L+1))
all_avg_errs = np.zeros((num_trials, L+1))
for trial_idx in range(num_trials):
    print(f'Running Trial #{trial_idx+1}')
    evecs_hist = eigengame(mat, r, L, T, step_size_func, update='mu')
    all_max_errs[trial_idx] = np.array([compute_max_error(true_evecs, evecs) for evecs in evecs_hist])
    all_avg_errs[trial_idx] = np.array([compute_avg_error(true_evecs, evecs) for evecs in evecs_hist])
    print(f'Trial #{trial_idx+1} Error: {all_max_errs[trial_idx,-1]}')

max_err_mean_T5 = all_max_errs.mean(axis=0)
max_err_std_T5 = all_max_errs.std(axis=0)
avg_err_mean_T5 = all_avg_errs.mean(axis=0)
avg_err_std_T5 = all_avg_errs.std(axis=0)

Running Trial #1
Trial #1 Error: 5.515442944562956e-07
Running Trial #2
Trial #2 Error: 5.09698133959981e-07
Running Trial #3
Trial #3 Error: 6.025304382787862e-07
Running Trial #4
Trial #4 Error: 8.407622852060422e-06
Running Trial #5
Trial #5 Error: 5.790400274910167e-07
Running Trial #6
Trial #6 Error: 9.805371687628092e-07
Running Trial #7
Trial #7 Error: 7.477978162145817e-06
Running Trial #8
Trial #8 Error: 1.059177011399501e-05
Running Trial #9
Trial #9 Error: 1.0330277688763746e-06
Running Trial #10
Trial #10 Error: 3.0264401453122348e-06


In [12]:
import json

max_err_results = dict()
avg_err_results = dict()

max_err_results['T=1'] = (max_err_mean_T1.tolist(), max_err_std_T1.tolist())
max_err_results['T=5'] = (max_err_mean_T5.tolist(), max_err_std_T5.tolist())

avg_err_results['T=1'] = (avg_err_mean_T1.tolist(), avg_err_std_T1.tolist())
avg_err_results['T=5'] = (avg_err_mean_T5.tolist(), avg_err_std_T5.tolist())

with open('eigengame_mu_mnist.txt', 'w+') as jfile:
    json.dump(dict(max_result=max_err_results, avg_result=avg_err_results), jfile)