In [None]:
import h5py
import matplotlib.pyplot as plt
import neuropacks as packs
import numpy as np
import os

from pratplot.ellipse import plot_cov_ellipse
from noise_correlations import analysis, utils 

%matplotlib inline

In [None]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.sans-serif": ["Computer Modern Roman"]})

In [None]:
base_path = '/storage/fits/neurocorr'

In [None]:
pvc11_1_path = os.path.join(base_path, 'exp4_1_resp_values_pvc11_75_1000_1000.h5')
pvc11_2_path = os.path.join(base_path, 'exp4_2_resp_values_pvc11_50_1000_1000.h5')
pvc11_3_path = os.path.join(base_path, 'exp4_3_resp_values_pvc11_75_1000_1000.h5')
ac1_path = os.path.join(base_path, 'exp4_resp_values_ac1_50_1000_1000.h5')
ret2_path = os.path.join(base_path, 'exp4_tuned_values_ret2_50_1000_1000.h5')

In [None]:
v1_labels = [r'\textbf{Dataset 1}', r'\textbf{Dataset 2}', r'\textbf{Dataset 3}']

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
pvc11_2 = h5py.File(pvc11_2_path, 'r')
pvc11_3 = h5py.File(pvc11_3_path, 'r')
ac1 = h5py.File(ac1_path, 'r')
ret2 = h5py.File(ret2_path, 'r')
results = [pvc11_1, pvc11_2, pvc11_3, ret2, ac1]

In [None]:
p_s_vals = [
    np.mean(result['v_lfi'][:][..., np.newaxis] > result['v_s_lfi'][:], axis=-1)
    for result in results
]

In [None]:
p_r_vals = [
    np.mean(result['v_lfi'][:][..., np.newaxis] > result['v_r_lfi'][:], axis=-1)
    for result in results
]

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharex=True)

# V1
pvc11ax = axes[0, 0]

for idx, result in enumerate([pvc11_1, pvc11_2, pvc11_3]):
    v_lfi = result['v_lfi'][:]
    n_max_units = result['units'].shape[2]
    dims = 2 + np.arange(n_max_units - 1)
    lfi_mean = np.mean(v_lfi, axis=1)
    
    pvc11ax.fill_between(
        x=dims,
        y1=np.percentile(v_lfi, q=25, axis=1),
        y2=np.percentile(v_lfi, q=75, axis=1),
        color=f'C{idx}',
        alpha=0.1
    )
    pvc11ax.plot(
        dims, np.mean(result['v_lfi'], axis=1),
        linewidth=4,
        color=f'C{idx}',
        label=v1_labels[idx]
    )
    
pvc11ax.legend(
    loc=2,
    prop={'size': 10}
)
pvc11ax.set_ylim([0, 0.1])
pvc11ax.set_yticks([0, 0.05, 0.1])

# Retina
retax = axes[0, 1]
v_lfi = ret2['v_lfi'][:]
n_max_units = ret2['units'].shape[2]
dims = 2 + np.arange(n_max_units - 1)
lfi_mean = np.mean(v_lfi, axis=1)

retax.fill_between(
    x=dims,
    y1=np.percentile(v_lfi, q=25, axis=1),
    y2=np.percentile(v_lfi, q=75, axis=1),
    color=f'C0',
    alpha=0.1
)
retax.plot(
    dims, np.mean(ret2['v_lfi'], axis=1),
    linewidth=4,
    color=f'C0'
)
retax.set_xlim([2, 50])
retax.set_ylim([0, 0.015])
retax.set_xticks([25, 50])
retax.set_yticks([0, 0.005, 0.01, 0.015])

# Auditory Cortex
ac1ax = axes[0, 2]
v_lfi = ac1['v_lfi'][:]
n_max_units = ac1['units'].shape[2]
dims = 2 + np.arange(n_max_units - 1)
lfi_mean = np.mean(v_lfi, axis=1)

ac1ax.fill_between(
    x=dims,
    y1=np.percentile(v_lfi, q=40, axis=1),
    y2=np.percentile(v_lfi, q=60, axis=1),
    color=f'C0',
    alpha=0.1
)
ac1ax.plot(
    dims, np.median(ac1['v_lfi'], axis=1),
    linewidth=4,
    color=f'C0'
)
# ac1ax.set_ylim([0, 0.015])
# ac1ax.set_yticks([0, 0.005, 0.01, 0.015])

for ax in axes[0]:
    ax.tick_params(labelsize=15)
    ax.set_ylabel(r'\textbf{LFI}', fontsize=15)

    
# Second/third row: percentiles
for idx, (result, p_s_val, p_r_val) in enumerate(zip(results, p_s_vals, p_r_vals)):
    n_max_units = result['units'].shape[2]
    dims = 2 + np.arange(n_max_units - 1)

    if idx == 0 or idx == 1 or idx == 2:
        shuffax = axes[1, 0]
        rotax = axes[2, 0]
        c_idx = idx
    elif idx == 3:
        shuffax = axes[1, 1]
        rotax = axes[2, 1]
        c_idx = 0
    elif idx == 4:
        shuffax = axes[1, 2]
        rotax = axes[2, 2]
        c_idx = 0

    shuffax.plot(
        dims, np.median(p_s_val, axis=1),
        linewidth=3,
        color=f'C{c_idx}',
    )
    shuffax.fill_between(
        x=dims,
        y1=np.percentile(p_s_val, q=40, axis=1),
        y2=np.percentile(p_s_val, q=60, axis=1),
        color=f'C{c_idx}',
        alpha=0.1
    )
    rotax.plot(
        dims, np.median(p_r_val, axis=1),
        linewidth=3,
        color=f'C{c_idx}',
    )
    rotax.fill_between(
        x=dims,
        y1=np.percentile(p_r_val, q=40, axis=1),
        y2=np.percentile(p_r_val, q=60, axis=1),
        color=f'C{c_idx}',
        alpha=0.1
    )
    
for ax in axes[1]:
    ax.set_ylim([-0.05, 1.05])
    ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    ax.tick_params(labelsize=15)
    ax.set_ylabel(r'\textbf{Shuffle Percentile}', fontsize=15)

    
for ax in axes[2]:
    ax.set_ylim([-0.05, 1.])
    ax.set_yticks([0, 0.25, 0.5, 0.75, 1])
    ax.tick_params(labelsize=15)
    ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
    ax.set_ylabel(r'\textbf{Rotation Percentile}', fontsize=15)


for ax in axes[:, 0]:
    ax.set_xlim([2, 75])
    ax.set_xticks([25, 50, 75])
    
for ax in axes[:, 1]:
    ax.set_xlim([2, 50])
    ax.set_xticks([25, 50])
    
for ax in axes[:, 2]:
    ax.set_xlim([2, 50])
    ax.set_xticks([25, 50])
    

    
plt.tight_layout()

plt.savefig('figure3.pdf', bbox_inches='tight')