In [None]:
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import neuropacks as packs
import numpy as np
import os

from noise_correlations import analysis, utils

%matplotlib inline

In [None]:
plt.rcParams.update({'text.usetex': True, 'font.family': 'serif'})

In [None]:
base_path = "/storage/fits/neurocorr"

In [None]:
pvc11_1_path = os.path.join(base_path, 'exp4_1_tuned_values_pvc11_65_1000_1000.h5')
pvc11_2_path = os.path.join(base_path, 'exp4_2_tuned_values_pvc11_50_1000_1000.h5')
pvc11_3_path = os.path.join(base_path, 'exp4_3_tuned_values_pvc11_75_1000_1000.h5')
ret2_path = os.path.join(base_path, 'exp4_tuned_values_ret2_50_1000_1000.h5')

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
pvc11_2 = h5py.File(pvc11_2_path, 'r')
pvc11_3 = h5py.File(pvc11_3_path, 'r')
ret2 = h5py.File(ret2_path, 'r')
results = [pvc11_1, pvc11_2, pvc11_3, ret2]

In [None]:
p_r_vals = [
    np.mean(
        result['v_lfi'][:][..., np.newaxis] > result['v_r_lfi'][:],
        axis=-1
    )
    for result in results
]

In [None]:
opt_ffs = []
obs_ffs = []

for result in results:
    print(result)
    X = result['X'][:]
    stimuli = result['stimuli'][:]
    n_dims, n_dimlet_stims = result['stims'].shape[:2]
    opt_ff = np.zeros((n_dims, n_dimlet_stims))
    obs_ff = np.zeros_like(opt_ff)
    for dim_idx in range(n_dims):
        for pairing in range(n_dimlet_stims):
            dimlet = result['units'][dim_idx, pairing, :dim_idx + 2].astype('int')
            stim1, stim2 = result['stims'][dim_idx, pairing]
            # Get means
            X1 = X[stim1 == stimuli][:, dimlet]
            mu1 = np.mean(X1, axis=0)
            X2 = X[stim2 == stimuli][:, dimlet]
            mu2 = np.mean(X2, axis=0)
            mu_mean = np.mean(np.vstack((mu1, mu2)), axis=0)
            # Get variances
            var_opt = np.diag(result['opt_covariances'][f'{dim_idx}'][pairing])
            opt_ff[dim_idx, pairing] = np.mean(var_opt / mu_mean)
            var_obs = np.diag(0.5 * (np.cov(X1.T) + np.cov(X2.T)))
            obs_ff[dim_idx, pairing] = np.mean(var_obs / mu_mean)
    opt_ffs.append(opt_ff)
    obs_ffs.append(obs_ff)

In [None]:
opt_ffs[3].max()

In [None]:
fig, axes = plt.subplots(3, 4, figsize=(18, 10), sharey=True)

for ii, d_idx in enumerate([0, 2, 4]):
    for jj, ax in enumerate(axes[ii, :3]):
        ax.hexbin(
            x=opt_ffs[jj][d_idx],
            y=p_r_vals[jj][d_idx],
            bins='log',
            gridsize=50,
            cmap='Greys',
            extent=(0, 100, 0, 1))
    
    axes[ii, 3].hexbin(
        x=opt_ffs[3][d_idx],
        y=p_r_vals[3][d_idx],
        bins='log',
        gridsize=50,
        cmap='Greys',
        extent=(0, 0.30, 0, 1)
    )
    
for ax in axes[:, :3].ravel():
    ax.set_xlim([-0.05, 100])
    ax.set_ylim([0, 1.1])
for ax in axes[:, 3].ravel():
    ax.set_xlim([0, 0.3])
    ax.set_ylim([0, 1.1])
    

axes[0, 0].set_title(r'\textbf{PVC11, 1}', fontsize=22)
axes[0, 1].set_title(r'\textbf{PVC11, 2}', fontsize=22)
axes[0, 2].set_title(r'\textbf{PVC11, 3}', fontsize=22)
axes[0, 3].set_title(r'\textbf{RET2}', fontsize=22)

fig.text(
    x=0.20, y=-0.2,
    s=r'\textbf{Variance / Mean}',
    va='center',
    ha='center',
    transform=axes[2, 2].transAxes,
    fontsize=25)
fig.text(
    x=-0.3, y=0.5,
    s=r'\textbf{Percentile}',
    va='center',
    ha='center',
    transform=axes[1, 0].transAxes,
    rotation=90,
    fontsize=25)

axes[0, 0].set_ylabel(r'$D=2$', fontsize=25)
axes[1, 0].set_ylabel(r'$D=4$', fontsize=25)
axes[2, 0].set_ylabel(r'$D=8$', fontsize=25)
plt.show()

#########################################
fig, axes = plt.subplots(3, 4, figsize=(18, 10), sharey=True)

for ii, d_idx in enumerate([0, 2, 4]):
    for jj, ax in enumerate(axes[ii, :3]):
        ax.hexbin(
            x=obs_ffs[jj][d_idx],
            y=p_r_vals[jj][d_idx],
            bins='log',
            gridsize=50,
            cmap='Greys',
            extent=(0, 10, 0, 1))
    
    axes[ii, 3].hexbin(
        x=obs_ffs[3][d_idx],
        y=p_r_vals[3][d_idx],
        bins='log',
        gridsize=50,
        cmap='Greys',
        extent=(0, 0.10, 0, 1)
    )
    
for ax in axes[:, :3].ravel():
    ax.set_xlim([-0.05, 10])
    ax.set_ylim([0, 1.1])
for ax in axes[:, 3].ravel():
    ax.set_xlim([0, 0.1])
    ax.set_ylim([0, 1.1])

axes[0, 0].set_title(r'\textbf{PVC11, 1}', fontsize=22)
axes[0, 1].set_title(r'\textbf{PVC11, 2}', fontsize=22)
axes[0, 2].set_title(r'\textbf{PVC11, 3}', fontsize=22)
axes[0, 3].set_title(r'\textbf{RET2}', fontsize=22)

fig.text(
    x=0.20, y=-0.2,
    s=r'\textbf{Variance / Mean}',
    va='center',
    ha='center',
    transform=axes[2, 2].transAxes,
    fontsize=25)
fig.text(
    x=-0.3, y=0.5,
    s=r'\textbf{Percentile}',
    va='center',
    ha='center',
    transform=axes[1, 0].transAxes,
    rotation=90,
    fontsize=25)

axes[0, 0].set_ylabel(r'$D=2$', fontsize=25)
axes[1, 0].set_ylabel(r'$D=4$', fontsize=25)
axes[2, 0].set_ylabel(r'$D=8$', fontsize=25)
plt.show()

In [None]:
cs = plt.get_cmap('plasma')((255 * p_r_vals[0][0]).astype('int'))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.scatter(
    obs_ffs[0][0],
    opt_ffs[0][0],
    alpha=0.10,
    c=cs)
#ax.set_xlim([0, 0.15])
ax.set_ylim(ax.get_xlim())
ax.plot(ax.get_xlim(), ax.get_ylim())

In [None]:
from mpl_lego.colorbar import append_cax_to_ax
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.hexbin(
    obs_ffs[0][0],
    opt_ffs[0][0],
    C=p_r_vals[0][0],
    gridsize=200,
    cmap='plasma',
    extent=(0, 20, 0, 20))
ax.plot(ax.get_xlim(), ax.get_ylim(), color='k')
cax = append_cax_to_ax(ax, width=0.03)
fig.colorbar(ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap='plasma'), cax=cax)
cax.set_ylabel(r'\textbf{Average Percentile}', fontsize=15, rotation=270, labelpad=20)
ax.set_xlabel(r'\textbf{Observed Fano Factor}', fontsize=18)
ax.set_ylabel(r'\textbf{Optimal Fano Factor}', fontsize=18)
plt.savefig('woof.pdf', bbox_inches='tight')