In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks as packs
import numpy as np
import os

from mpl_lego.ellipse import plot_cov_ellipse
from noise_correlations import analysis, utils 

In [None]:
mplego.style.use_latex_style()

In [None]:
base_path = '/storage/fits/neurocorr/exp09'

In [None]:
pvc11_1_path = os.path.join(base_path, 'exp09_1_pvc11_15_1000_1000.h5')
pvc11_2_path = os.path.join(base_path, 'exp09_2_pvc11_15_1000_1000.h5')
pvc11_3_path = os.path.join(base_path, 'exp09_3_pvc11_15_1000_1000.h5')
ret2_path = os.path.join(base_path, 'exp09_ret2_15_1000_1000.h5')
ecog_path = os.path.join(base_path, 'exp09_ecog_15_3000_1000.h5')

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
pvc11_2 = h5py.File(pvc11_2_path, 'r')
pvc11_3 = h5py.File(pvc11_3_path, 'r')
ret2 = h5py.File(ret2_path, 'r')
ecog = h5py.File(ecog_path, 'r')
results = [pvc11_1, ret2, ecog]
n_max_units = pvc11_1['units'].shape[2]
dims = 3 + np.arange(n_max_units - 2)

In [None]:
titles = ['V1', 'Retina', 'AC']
titles = mplego.labels.bold_text(titles)

In [None]:
groups = [
    'v_lfi',
    'v_s_lfi',
    'v_r_lfi',
    'v_fa_lfi'
]
colors = [
    'C0',
    'gray',
    'red',
    'fuchsia'
]
labels = mplego.labels.bold_text([
    'Observed',
    'Shuffle',
    'Rotation',
    'FA'
])

In [None]:
"""
Figure Settings
"""
# Subplot adjustments
wspace = 0.3
hspace = 0.5
# Line settings
linewidth = 2
line_alpha = 0.8
# Fill settings
fill_alpha = 0.1
# Percentile bounds for curves
percentile_lower = 40
percentile_upper = 60

"""
Figure 3
"""
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
plt.subplots_adjust(wspace=wspace, hspace=hspace)

# Enumerate over results
for idx, (result, ax) in enumerate(zip(results, axes)):
    # Plot observed LFI
    for group, color in zip(groups, colors):
        if group == 'v_lfi':
            values = result[group][:]
        else:
            values = np.median(result[group], axis=2)
            # Alternative: take statistics across all dim-stims and repeats
            # values = np.reshape(result[group], (dims.size, -1))
        median = np.median(values, axis=1)
        lower = np.percentile(values, q=percentile_lower, axis=1)
        upper = np.percentile(values, q=percentile_upper, axis=1)
        # Fill region between percentile bounds
        ax.fill_between(
            x=dims,
            y1=lower,
            y2=upper,
            color=color,
            alpha=fill_alpha)
        ax.plot(
            dims,
            median,
            linewidth=linewidth,
            color=color,
            alpha=line_alpha)

# Set bounds
axes[0].set_ylim(bottom=1e-3)
axes[1].set_ylim(bottom=5e-4, top=1e-2)
axes[2].set_ylim(bottom=5)

# Set axis limits, scales, and labels
for (ax, title) in zip(axes, titles):
    ax.set_xlim([3, 15])
    ax.set_yscale('log')
    ax.set_xticks([3, 5, 10, 15])
    ax.tick_params(labelsize=15)
    
    ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=16)
    ax.set_ylabel(r'\textbf{LFI}', fontsize=16)
    ax.set_title(title, fontsize=20)

# Create legend in last axis spot
for color, label in zip(colors, labels):
    axes[-1].plot([], [], color=color, label=label, linewidth=linewidth)
axes[-1].legend(
    loc='center left',
    bbox_to_anchor=(1.05, 0.5),
    prop={'size': 15})

# Apply subplot labels
mplego.labels.apply_subplot_labels(
    axes.ravel()[:-1],
    bold=True,
    x=-0.16,
    y=1.10,
    size=23)

plt.savefig('figure3.pdf', bbox_inches='tight')
plt.show()