In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks as packs
import numpy as np
import os

from mpl_lego.ellipse import plot_cov_ellipse
from noise_correlations import analysis, utils 

In [None]:
mplego.style.use_latex_style()

In [None]:
base_path = '/storage/fits/neurocorr/exp09'

In [None]:
pvc11_1_path = os.path.join(base_path, 'exp09_1_pvc11_15_1000_1000.h5')
pvc11_2_path = os.path.join(base_path, 'exp09_2_pvc11_15_1000_1000.h5')
pvc11_3_path = os.path.join(base_path, 'exp09_3_pvc11_15_1000_1000.h5')
ret2_path = os.path.join(base_path, 'exp09_ret2_15_1000_1000.h5')
ecog_path = os.path.join(base_path, 'exp09_ecog_15_3000_1000.h5')

In [None]:
v1_labels = [r'\textbf{Dataset 1}', r'\textbf{Dataset 2}', r'\textbf{Dataset 3}']

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
pvc11_2 = h5py.File(pvc11_2_path, 'r')
pvc11_3 = h5py.File(pvc11_3_path, 'r')
ret2 = h5py.File(ret2_path, 'r')
ecog = h5py.File(ecog_path, 'r')
results = [[pvc11_1, pvc11_2, pvc11_3], ret2, ecog]
n_max_units = pvc11_1['units'].shape[2]
dims = 3 + np.arange(n_max_units - 2)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharex=True)
plt.subplots_adjust(wspace=0.3)
percentile_lower = 40
percentile_upper = 60
linestyles = ['--', ':', '-']


for idx, result in enumerate(results):
    ax = axes[idx]
    if isinstance(result, list):
        for jj, result_inner in enumerate(result):
            p_s_lfi = result_inner['p_s_lfi'][:]
            p_s_median = np.median(p_s_lfi, axis=1)
            p_s_lower = np.percentile(p_s_lfi, q=percentile_lower, axis=1)
            p_s_upper = np.percentile(p_s_lfi, q=percentile_upper, axis=1)

            p_r_lfi = result_inner['p_r_lfi'][:]
            p_r_median = np.median(p_r_lfi, axis=1)
            p_r_lower = np.percentile(p_r_lfi, q=percentile_lower, axis=1)
            p_r_upper = np.percentile(p_r_lfi, q=percentile_upper, axis=1)

            p_fa_lfi = result_inner['p_fa_lfi'][:]
            p_fa_median = np.median(p_fa_lfi, axis=1)
            p_fa_lower = np.percentile(p_fa_lfi, q=percentile_lower, axis=1)
            p_fa_upper = np.percentile(p_fa_lfi, q=percentile_upper, axis=1)

            ax.fill_between(
                x=dims,
                y1=p_s_lower,
                y2=p_s_upper,
                color='gray',
                alpha=0.1)
            ax.fill_between(
                x=dims,
                y1=p_r_lower,
                y2=p_r_upper,
                color='red',
                alpha=0.1)
            ax.fill_between(
                x=dims,
                y1=p_fa_lower,
                y2=p_fa_upper,
                color='fuchsia',
                alpha=0.1)
            ax.plot(
                dims,
                p_s_median,
                linewidth=4,
                color='gray',
                linestyle=linestyles[jj])
            ax.plot(
                dims,
                p_r_median,
                linewidth=4,
                color='red',
                linestyle=linestyles[jj])
            ax.plot(
                dims,
                p_fa_median,
                linewidth=4,
                color='fuchsia',
                linestyle=linestyles[jj])
    else:
        p_s_lfi = result['p_s_lfi'][:]
        p_s_median = np.median(p_s_lfi, axis=1)
        p_s_lower = np.percentile(p_s_lfi, q=percentile_lower, axis=1)
        p_s_upper = np.percentile(p_s_lfi, q=percentile_upper, axis=1)

        p_r_lfi = result['p_r_lfi'][:]
        p_r_median = np.median(p_r_lfi, axis=1)
        p_r_lower = np.percentile(p_r_lfi, q=percentile_lower, axis=1)
        p_r_upper = np.percentile(p_r_lfi, q=percentile_upper, axis=1)

        p_fa_lfi = result['p_fa_lfi'][:]
        p_fa_median = np.median(p_fa_lfi, axis=1)
        p_fa_lower = np.percentile(p_fa_lfi, q=percentile_lower, axis=1)
        p_fa_upper = np.percentile(p_fa_lfi, q=percentile_upper, axis=1)

        ax.fill_between(
            x=dims,
            y1=p_s_lower,
            y2=p_s_upper,
            color='gray',
            alpha=0.1)
        ax.fill_between(
            x=dims,
            y1=p_r_lower,
            y2=p_r_upper,
            color='red',
            alpha=0.1)
        ax.fill_between(
            x=dims,
            y1=p_fa_lower,
            y2=p_fa_upper,
            color='fuchsia',
            alpha=0.1)
        ax.plot(dims, p_s_median, linewidth=4, color='gray')
        ax.plot(dims, p_r_median, linewidth=4, color='red')
        ax.plot(dims, p_fa_median, linewidth=4, color='fuchsia')

axes[0].set_ylim(bottom=1e-3)
axes[1].set_ylim(bottom=1e-4)
axes[2].set_ylim(bottom=1)

axes[2].plot([], [], color='gray', linewidth=3, label=r'\textbf{Shuffle}')
axes[2].plot([], [], color='red', linewidth=3, label=r'\textbf{Rotation}')
axes[2].plot([], [], color='fuchsia', linewidth=3, label=r'\textbf{Factor Analysis}')

axes[2].legend(loc='center left',
               bbox_to_anchor=(1.0, 0.5),
               prop={'size': 15})

axes[0].set_title(r'\textbf{V1}', fontsize=20)
axes[1].set_title(r'\textbf{Retina}', fontsize=20)
axes[2].set_title(r'\textbf{AC}', fontsize=20)
for ax in axes:
    ax.set_xlim([3, 15])
    ax.set_ylim([0, 1])
    ax.set_xticks([3, 5, 10, 15])
    ax.tick_params(labelsize=15)
    
    ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=16)
    ax.set_ylabel(r'\textbf{Percentile}', fontsize=16)

for idx, linestyle in enumerate(linestyles):
    axes[0].plot(
        [], [],
        linewidth=2,
        linestyle=linestyle,
        color='black',
        label=r'\textbf{Dataset ' + f'{idx}' + '}')
axes[0].legend(loc='best', prop={'size': 12})
mplego.labels.apply_subplot_labels(axes, bold=True, x=-0.16, y=1.10, size=23)
plt.show()