In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks as packs
import numpy as np
import os

from mpl_lego.ellipse import plot_cov_ellipse
from noise_correlations import analysis, utils 

In [None]:
mplego.style.use_latex_style()

In [None]:
base_path = '/storage/fits/neurocorr/exp09'

In [None]:
pvc11_1_path = os.path.join(base_path, 'exp09_1_pvc11_15_1000_1000.h5')
pvc11_2_path = os.path.join(base_path, 'exp09_2_pvc11_15_1000_1000.h5')
pvc11_3_path = os.path.join(base_path, 'exp09_3_pvc11_15_1000_1000.h5')
ret2_path = os.path.join(base_path, 'exp09_ret2_15_1000_1000.h5')
ecog_path = os.path.join(base_path, 'exp09_ecog_15_3000_1000.h5')

In [None]:
titles = [r'\textbf{V1 (1)}',
          r'\textbf{V1 (2)}',
          r'\textbf{V1 (3)}',
          r'\textbf{Retina}',
          r'\textbf{A1}']

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
pvc11_2 = h5py.File(pvc11_2_path, 'r')
pvc11_3 = h5py.File(pvc11_3_path, 'r')
ret2 = h5py.File(ret2_path, 'r')
ecog = h5py.File(ecog_path, 'r')
results = [pvc11_1, pvc11_2, pvc11_3, ret2, ecog]
n_max_units = pvc11_1['units'].shape[2]
dims = 3 + np.arange(n_max_units - 2)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes_list = axes.ravel()
plt.subplots_adjust(wspace=0.3, hspace=0.5)
# Percentile bounds for curves
percentile_lower = 40
percentile_upper = 60
linestyles = ['--', ':', '-']

# Enumerate over results
for idx, result in enumerate(results):
    ax = axes_list[idx]
    # Plot observed LFI
    v_lfi = result['v_lfi'][:]
    lfi_median = np.median(v_lfi, axis=1)
    lfi_lower = np.percentile(v_lfi, q=percentile_lower, axis=1)
    lfi_upper = np.percentile(v_lfi, q=percentile_upper, axis=1)
    ax.fill_between(
        x=dims,
        y1=lfi_lower,
        y2=lfi_upper,
        color='C0',
        alpha=0.1)
    ax.plot(
        dims,
        lfi_median,
        linewidth=4,
        color='C0',
        alpha=0.5)
    # Plot shuffled LFI
    v_s_lfi = np.median(result['v_s_lfi'], axis=-1)
    s_median = np.median(v_s_lfi, axis=1)
    s_lower = np.percentile(v_s_lfi, q=percentile_lower, axis=1)
    s_upper = np.percentile(v_s_lfi, q=percentile_upper, axis=1)
    ax.fill_between(
        x=dims,
        y1=s_lower,
        y2=s_upper,
        color='gray',
        alpha=0.1)
    ax.plot(
        dims,
        s_median,
        linewidth=4,
        color='gray',
        alpha=0.5)
    # Plot rotation LFI
    v_r_lfi = np.median(result['v_r_lfi'], axis=-1)
    r_median = np.median(v_r_lfi, axis=1)
    r_lower = np.percentile(v_r_lfi, q=percentile_lower, axis=1)
    r_upper = np.percentile(v_r_lfi, q=percentile_upper, axis=1)
    ax.fill_between(
        x=dims,
        y1=r_lower,
        y2=r_upper,
        color='red',
        alpha=0.1)
    ax.plot(
        dims,
        r_median,
        linewidth=4,
        color='red',
        alpha=0.5)
    # Plot factor analysis LFI
    v_fa_lfi = np.median(result['v_fa_lfi'], axis=-1)
    fa_median = np.median(v_fa_lfi, axis=1)
    fa_lower = np.percentile(v_fa_lfi, q=percentile_lower, axis=1)
    fa_upper = np.percentile(v_fa_lfi, q=percentile_upper, axis=1)
    ax.fill_between(
        x=dims,
        y1=fa_lower,
        y2=fa_upper,
        color='fuchsia',
        alpha=0.1)
    ax.plot(
        dims,
        fa_median,
        linewidth=4,
        color='fuchsia',
        alpha=0.5)
# Set bound
axes_list[2].set_ylim(bottom=1e-3)
axes_list[3].set_ylim(bottom=5e-4, top=1e-2)
axes_list[4].set_ylim(bottom=5)

for (ax, title) in zip(axes_list, titles):
    ax.set_xlim([3, 15])
    ax.set_yscale('log')
    ax.set_xticks([3, 5, 10, 15])
    ax.tick_params(labelsize=15)
    
    ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=16)
    ax.set_ylabel(r'\textbf{LFI}', fontsize=16)
    ax.set_title(title, fontsize=20)

# Create legend in last axis spot
axes_list[-1].axis('off')
axes_list[-2].plot([], [], color='C0', label=r'\textbf{Observed}', linewidth=4)
axes_list[-2].plot([], [], color='gray', label=r'\textbf{Shuffle}', linewidth=4)
axes_list[-2].plot([], [], color='red', label=r'\textbf{Rotation}', linewidth=4)
axes_list[-2].plot([], [], color='fuchsia', label=r'\textbf{FA}', linewidth=4)
axes_list[-2].legend(loc='center left',
                     bbox_to_anchor=(1.3, 0.5),
                     prop={'size': 20})
# Apply subplot labels
mplego.labels.apply_subplot_labels(axes.ravel()[:-1], bold=True, x=-0.16, y=1.10, size=23)

plt.savefig('figure3.pdf', bbox_inches='tight')