In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego.style as style
import numpy as np
import os

from noise_correlations import discriminability, utils
from mpl_lego.colorbar import append_cax_to_ax
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

%matplotlib inline

In [None]:
style.use_latex_style()

In [None]:
data_path = '/storage/data/neurocorr'
fits_path = '/storage/fits/neurocorr/exp07'

In [None]:
pvc11_1_path = os.path.join(fits_path, 'exp07_1_pvc11_20_1000_1000.h5')
rotations_path = os.path.join(data_path, 'rotations.h5')

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
rotations = h5py.File(rotations_path, 'r')

In [None]:
X = pvc11_1['X'][:]
stimuli = pvc11_1['stimuli'][:]
stims = pvc11_1['stims'][:]
units = pvc11_1['units'][:]
v_lfi = pvc11_1['v_lfi'][:]

n_dims, n_dimlet_stims, n_repeats = pvc11_1['v_r_lfi'].shape
n_max_units = pvc11_1['units'].shape[2]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

dims = 2 + np.arange(n_max_units - 1)
lfi_median = np.median(v_lfi, axis=1)

ax.fill_between(
    x=dims,
    y1=np.percentile(v_lfi, q=25, axis=1),
    y2=np.percentile(v_lfi, q=75, axis=1),
    color=f'C0',
    alpha=0.1)
ax.plot(
    dims,
    lfi_median,
    linewidth=4,
    color=f'C0')

ax.set_xlim([2, 20])
ax.set_ylim([0, 0.03])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{LFI}', fontsize=15)

In [None]:
np.max(pvc11_1['v_r_lfi'])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

dims = 2 + np.arange(n_max_units - 1)
p_r_lfi = pvc11_1['p_r_lfi'][:]
p_s_lfi = pvc11_1['p_s_lfi'][:]

ax.fill_between(
    x=dims,
    y1=np.percentile(p_r_lfi, q=25, axis=-1),
    y2=np.percentile(p_r_lfi, q=75, axis=-1),
    color='red',
    alpha=0.1)
ax.plot(
    dims,
    np.median(p_r_lfi, axis=-1),
    linewidth=4,
    color='red')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_s_lfi, q=40, axis=-1),
    y2=np.percentile(p_s_lfi, q=60, axis=-1),
    color='gray',
    alpha=0.1)
ax.plot(
    dims,
    np.median(p_s_lfi, axis=-1),
    linewidth=4,
    color='gray')

ax.set_xlim([2, 20])
ax.set_ylim([0, 1])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{LFI}', fontsize=15)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

rng = np.random.default_rng(289381)
dim_idxs = [0, 1, 3, 8, 18]
random_pairing = rng.integers(low=0, high=12000, size=1000)

for idx, dim_idx in enumerate(dim_idxs):
    opt_lfis_scaled = np.zeros((1000, 1000))
    for jj, pairing_idx in enumerate(random_pairing):
        this_stim = stims[dim_idx, pairing_idx]
        this_unit = units[dim_idx, pairing_idx, :dim_idx+2]
        X1, X2 = utils.get_dimstim_responses(X, stimuli, this_unit, this_stim)
        mu1, cov1 = utils.mean_cov(X1)
        mu2, cov2 = utils.mean_cov(X2)
        opt_cov = pvc11_1[f'opt_covs/{dim_idx+2}'][pairing_idx]
        opt_lfi = discriminability.lfi(mu1, opt_cov, mu2, opt_cov, dtheta=30.)
        opt_lfis_scaled[jj] = pvc11_1['v_r_lfi'][dim_idx, pairing_idx, :] / opt_lfi
    ax.hist(
        opt_lfis_scaled.ravel(),
        color=f'C{idx}',
        histtype='step',
        density=True,
        lw=3,
        label=f'$d={dim_idx+2}$',
        bins=np.linspace(0, 1, 30))
ax.legend(
    loc='center left',
    bbox_to_anchor=(1.0, 0.5),
    prop={'size': 14})
ax.tick_params(labelsize=12)
ax.set_xlim([-0.05, 1.05])
ax.set_xlabel(r'\textbf{Normalized LFI}', fontsize=18)
ax.set_ylabel(r'\textbf{Density}', fontsize=18)

In [None]:
opt_ff = np.zeros((n_dims, n_dimlet_stims))
obs_ff = np.zeros_like(opt_ff)

for dim_idx in range(n_dims):
    for pairing in range(n_dimlet_stims):
        this_stim = stims[dim_idx, pairing]
        this_unit = units[dim_idx, pairing, :dim_idx+2]
        # Get means
        X1, X2 = utils.get_dimstim_responses(X, stimuli, this_unit, this_stim)
        mu1, cov1 = utils.mean_cov(X1)
        mu2, cov2 = utils.mean_cov(X2)
        mu_mean = np.mean(np.vstack((mu1, mu2)), axis=0)
        # Get variances
        var_opt = np.diag(pvc11_1[f'opt_covs/{dim_idx + 2}'][pairing])
        opt_ff[dim_idx, pairing] = np.mean(var_opt / mu_mean)
        var_obs = np.diag(0.5 * (cov1 + cov2))
        obs_ff[dim_idx, pairing] = np.mean(var_obs / mu_mean)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
dim_idx_to_plot = 3
ax.hexbin(
    obs_ff[dim_idx_to_plot],
    opt_ff[dim_idx_to_plot],
    C=pvc11_1['p_r_lfi'][dim_idx_to_plot],
    gridsize=200,
    cmap='plasma',
    extent=(0.5, 20, 0.5, 20))
ax.plot(ax.get_xlim(), ax.get_ylim(), color='k')
cax = append_cax_to_ax(ax, width=0.03)
fig.colorbar(ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap='plasma'), cax=cax)
cax.set_ylabel(r'\textbf{Average Percentile}', fontsize=15, rotation=270, labelpad=20)
ax.set_xlabel(r'\textbf{Observed Fano Factor}', fontsize=18)
ax.set_ylabel(r'\textbf{Optimal Fano Factor}', fontsize=18)

In [None]:
dim_idx = 3
rotation_group = rotations[f'{dim_idx+2}']
rot_ff = np.zeros((n_dimlet_stims, n_repeats))
obs_ff = np.zeros(n_dimlet_stims)

for pairing in range(n_dimlet_stims):
    this_stim = stims[dim_idx, pairing]
    this_unit = units[dim_idx, pairing, :dim_idx+2]
    X1, X2 = utils.get_dimstim_responses(X, stimuli, this_unit, this_stim)
    mu1, cov1 = utils.mean_cov(X1)
    mu2, cov2 = utils.mean_cov(X2)
    # Get means
    mu_mean = np.mean(np.vstack((mu1, mu2)), axis=0)
    # Get variances
    avg_cov = 0.5 * (cov1 + cov2)
    var_obs = np.diag(avg_cov)
    obs_ff[pairing] = np.mean(var_obs / mu_mean)
    R_idx = pvc11_1['R_idxs'][dim_idx, pairing, :, 0]
    for repeat in range(n_repeats):
        R = rotation_group[R_idx[repeat]]
        R_cov = R @ avg_cov @ R.T
        var_rot = np.diag(R_cov)
        rot_ff[pairing, repeat] = np.mean(var_rot / mu_mean)

In [None]:
np.argsort(pvc11_1['p_r_lfi'][dim_idx])[-20:]

In [None]:
np.argsort(pvc11_1['p_r_lfi'][dim_idx])[:20]

In [None]:
np.sort(pvc11_1['p_r_lfi'][dim_idx])[-20:]

In [None]:
plt.hist(rot_ff[7823])
plt.axvline(obs_ff[7823])
plt.xlabel(r'\textbf{Fano Factor}', fontsize=18)

In [None]:
ff_percentile = np.count_nonzero(obs_ff[..., np.newaxis] > rot_ff, axis=1) / n_repeats
ff_mean = np.median(obs_ff[..., np.newaxis] / rot_ff, axis=1)

In [None]:
np.percentile(rot_ff, q=5, axis=1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))

ax.scatter(ff_percentile, pvc11_1['p_r_lfi'][dim_idx], alpha=0.25)
ax.set_xlabel(r'\textbf{Fano Factor Percentile}', fontsize=18)
ax.set_ylabel(r'\textbf{LFI Percentile}', fontsize=18)
              
ax.set_xlim([-0.05, 1])
ax.set_ylim([-0.05, 1])
ax.set_aspect('equal')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=)
plt.scatter(
    ff_mean,
    pvc11_1['p_r_lfi'][dim_idx],
    alpha=0.1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.scatter(
    np.percentile(rot_ff, q=50, axis=1),
    pvc11_1['p_r_lfi'][dim_idx],
    alpha=0.1)
ax.set_xlabel(r'\textbf{Median FF across rotations}', fontsize=18)
ax.set_ylabel(r'\textbf{LFI percentile}', fontsize=18)