In [5]:
import os
import warnings

warnings.simplefilter('ignore')

import matplotlib.pyplot as plt
import numpy as np

plt.style.use('seaborn')
plt.style.use('seaborn-notebook')

from soma.generators.normal import MultivariateNormalGenerator
from soma.tests import som_test, knn_test, c2s_knn_test, c2s_nn_test
from soma.util.errors import stat_errors_vs_sample_size
from soma.util.plot import plot_errors, plot_time

In [6]:
plot_dir = os.path.expanduser('~/Plots/power_sample_size')
os.makedirs(plot_dir, exist_ok=True)

In [7]:
tests = {
    'knn': knn_test,
    'som': som_test,
    'c2st_knn': c2s_knn_test,
    'c2st_nn': c2s_nn_test,
}

In [8]:
samples = np.array([100, 250, 500, 1000, 2500, 5000, 10000])
repeat = 200

# Normal (scale)

In [9]:
means = np.zeros(1000)

ns1 = MultivariateNormalGenerator(means, wishart_df=len(means) + 1)
ns2 = MultivariateNormalGenerator(means, wishart_df=len(means) + 1)

In [10]:
ns_results = stat_errors_vs_sample_size(ns1, ns2, tests, samples, repeat=repeat)

  0%|          | 0/28 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
ns_results.to_csv(os.path.join(plot_dir, 'ns_results_samples.csv'))

In [None]:
normal_scale_fig = plot_errors(ns_results)
normal_scale_fig.savefig(os.path.join(plot_dir, 'normal_scale_power_samples.eps'))

In [None]:
normal_scale_time = plot_time(ns_results)
normal_scale_time.savefig(os.path.join(plot_dir, 'normal_scale_time_samples.eps'))

# Normal (scale, fair)

In [None]:
nsf1 = MultivariateNormalGenerator(means, wishart_df=len(means) * 32)
nsf2 = MultivariateNormalGenerator(means, wishart_df=len(means) * 32)

In [None]:
nsf_results = stat_errors_vs_sample_size(nsf1, nsf2, tests, samples, repeat=repeat)

In [None]:
nsf_results.to_csv(os.path.join(plot_dir, 'nsf_results_samples.csv'))

In [None]:
normal_scale_fair_fig = plot_errors(nsf_results)
normal_scale_fair_fig.savefig(os.path.join(plot_dir, 'normal_scale_fair_power_samples.eps'))

In [None]:
normal_scale_fair_time = plot_time(nsf_results)
normal_scale_fair_time.savefig(os.path.join(plot_dir, 'normal_scale_fair_time_samples.eps'))