In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import plasma
from noise_correlations import discriminability, null_models, plot, utils
from imp import reload

# Subsampling

In [16]:
reload(utils)
rng = np.random.RandomState(20190214)

ambient_dim = 100
var = 1.
corr = .1
mu0 = np.concatenate([np.ones(1), -np.ones(ambient_dim-1)])
mu1 = np.concatenate([-np.ones(1), np.ones(1), -np.ones(ambient_dim-2)])
if not np.allclose(np.dot(mu1 - mu0, np.ones(ambient_dim)), 0):
    raise ValueError
cov0 = utils.uniform_correlation_matrix(ambient_dim, var, corr)
cov1 = cov0.copy()

n_rots = 20
n_trials = 100
n_sample_dims = 5
n_null_samples = 10000

sample_dims = np.logspace(np.log10(2), np.log10(ambient_dim // 2), n_sample_dims, dtype=int)

# LFI

In [None]:
ambient_vals = np.zeros(n_rots)
ambient_ps = np.zeros(n_rots)
sample_vals = np.zeros((n_rots, n_sample_dims, n_trials))
sample_ps = np.zeros((n_rots, n_sample_dims, n_trials))


_, vals, _ = null_models.eval_null(mu0, cov0, mu1, cov0,
                                   null_models.random_rotation, discriminability.lfi,
                                   n_null_samples, same_null=True)

for ii in range(n_rots):
    _, covp = null_models.random_rotation(mu0, cov0)
    val = discriminability.lfi(mu0, covp, mu1, covp)
    ambient_vals[ii] = val
    ambient_ps[ii] = (vals >= val).mean()
    for jj, dim in enumerate(sample_dims):
        for kk in range(n_trials):
            (sample_mu0, sample_mu1), sample_covp = utils.subsample_cov([mu0, mu1], covp, dim, rng)
            val, _, p = null_models.eval_null(sample_mu0, sample_covp, sample_mu1, sample_covp,
                                              null_models.random_rotation, discriminability.lfi,
                                              n_null_samples, same_null=True)
            sample_vals[ii, jj, kk] = val
            sample_ps[ii, jj, kk] = p

In [None]:
f, (ax0, ax1) = plt.subplots(1, 2)
for ii in range(n_rots):
    ax0.plot(sample_dims, sample_vals[ii].mean(axis=-1), c=plasma(ambient_ps[ii]))
ax0.set_xscale('log')
ax1.imshow(np.linspace(0, 1, 100)[::-1, np.newaxis], cmap='plasma', extent=[0, 1, 0, 1])
ax1.set_xticks([])

In [None]:
f, (ax0, ax1) = plt.subplots(1, 2)
for ii in range(n_rots):
    ax0.plot(sample_dims, sample_ps[ii].mean(axis=-1), c=plasma(ambient_ps[ii]))
ax0.set_xscale('log')
ax0.set_yscale('log')
ax1.imshow(np.linspace(0, 1, 100)[::-1, np.newaxis], cmap='plasma', extent=[0, 1, 0, 1])
ax1.set_xticks([])

# sD_KL

In [None]:
sample_dims = np.logspace(np.log10(2), np.log10(ambient_dim - 1), n_sample_dims, dtype=int)
ambient_vals = np.zeros(n_rots)
ambient_ps = np.zeros(n_rots)
sample_vals = np.zeros((n_rots, n_sample_dims, n_trials))
sample_ps = np.zeros((n_rots, n_sample_dims, n_trials))

_, vals, _ = null_models.eval_null(mu0, cov0, mu1, cov1,
                                   null_models.random_rotation, discriminability.mv_normal_jeffreys,
                                   n_null_samples)

for ii in range(n_rots):
    _, cov0p = null_models.random_rotation(mu0, cov0)
    _, cov1p = null_models.random_rotation(mu1, cov1)
    val = discriminability.mv_normal_jeffreys(mu0, cov0p, mu1, cov1p)
    ambient_vals[ii] = val
    ambient_ps[ii] = (vals >= val).mean()
    for jj, dim in enumerate(sample_dims):
        for kk in range(n_trials):
            (sample_mu0, sample_mu1), [sample_cov0p, sample_cov1p] = utils.subsample_cov([mu0, mu1], [cov0p, cov1p], dim, rng)
            val, _, p = null_models.eval_null(sample_mu0, sample_cov0p, sample_mu1, sample_cov1p,
                                              null_models.random_rotation, discriminability.mv_normal_jeffreys,
                                              n_null_samples)
            sample_vals[ii, jj, kk] = val
            sample_ps[ii, jj, kk] = p

In [None]:
f, (ax0, ax1) = plt.subplots(1, 2)
for ii in range(n_rots):
    ax0.plot(sample_dims, sample_vals[ii].mean(axis=-1), c=plasma(ambient_ps[ii]))
ax0.set_xscale('log')
ax1.imshow(np.linspace(0, 1, 100)[::-1, np.newaxis], cmap='plasma', extent=[0, 1, 0, 1])
ax1.set_xticks([])

In [None]:
f, (ax0, ax1) = plt.subplots(1, 2)
for ii in range(n_rots):
    ax0.plot(sample_dims, sample_ps[ii].mean(axis=-1), c=plasma(ambient_ps[ii]))
ax0.set_xscale('log')
ax0.set_yscale('log')
ax1.imshow(np.linspace(0, 1, 100)[::-1, np.newaxis], cmap='plasma', extent=[0, 1, 0, 1])
ax1.set_xticks([])