In [None]:
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns

%matplotlib inline

In [None]:
base_path = '/Volumes/pss/fits/neurocorr'

In [None]:
pvc11_1_path = os.path.join(base_path, 'exp4_1_resp_values_pvc11_75_1000_1000.h5')
pvc11_2_path = os.path.join(base_path, 'exp4_2_resp_values_pvc11_50_1000_1000.h5')
pvc11_3_path = os.path.join(base_path, 'exp4_3_resp_values_pvc11_75_1000_1000.h5')
ac1_path = os.path.join(base_path, 'exp4_resp_values_ac1_50_1000_1000.h5')
ret2_path = os.path.join(base_path, 'exp4_tuned_values_ret2_50_1000_1000.h5')

In [None]:
labels = [r'\textbf{V1, 1}', r'\textbf{V1, 2}', r'\textbf{V1, 3}', r'\textbf{AC}', r'\textbf{Retina}']

In [None]:
pvc11_1 = h5py.File(pvc11_1_path, 'r')
pvc11_2 = h5py.File(pvc11_2_path, 'r')
pvc11_3 = h5py.File(pvc11_3_path, 'r')
ac1 = h5py.File(ac1_path, 'r')
ret2 = h5py.File(ret2_path, 'r')
# Store results and LFIs in lists
results = [pvc11_1, pvc11_2, pvc11_3, ac1, ret2]
v_lfis = [result['v_lfi'][:] for result in results]

# Observed LFI Values

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

for idx, (result, v_lfi) in enumerate(zip(results, v_lfis)):
    n_max_units = result['units'].shape[2]
    dims = 2 + np.arange(n_max_units - 1)
    lfi_mean = np.mean(v_lfi, axis=1)
    lfi_std = np.std(v_lfi, axis=1)

    ax.fill_between(
        x=dims,
        y1=np.percentile(v_lfi, q=25, axis=1),
        y2=np.percentile(v_lfi, q=75, axis=1),
        color=f'C{idx}',
        alpha=0.1
    )
    ax.plot(
        dims, np.mean(result['v_lfi'], axis=1),
        linewidth=4,
        color=f'C{idx}',
        label=labels[idx]
    )

    
ax.legend(
    loc='center left',
    bbox_to_anchor=(1.0, 0.5),
    prop={'size': 15}
)
ax.set_xlim([2, 75])
ax.set_ylim([0, 0.1])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{LFI}', fontsize=15)
plt.savefig('exp4_lfi_vals.pdf', bbox_inches='tight')

# Examine percentiles

In [None]:
p_s_vals = [
    np.mean(v_lfi[..., np.newaxis] < result['v_s_lfi'][:], axis=-1)
    for v_lfi, result in zip(v_lfis, results)
]

In [None]:
p_r_vals = [
    np.mean(v_lfi[..., np.newaxis] < result['v_r_lfi'][:], axis=-1)
    for v_lfi, result in zip(v_lfis, results)
]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

for idx, (result, p_s_val, p_r_val) in enumerate(zip(results, p_s_vals, p_r_vals)):
    n_max_units = result['units'].shape[2]
    dims = 2 + np.arange(n_max_units - 1)

    axes[0].plot(
        dims, np.median(p_s_val, axis=1),
        linewidth=3,
        color=f'C{idx}',
    )
    axes[0].fill_between(
        x=dims,
        y1=np.percentile(p_s_val, q=40, axis=1),
        y2=np.percentile(p_s_val, q=60, axis=1),
        color=f'C{idx}',
        alpha=0.1
    )
    axes[1].plot(
        dims, np.median(p_r_val, axis=1),
        linewidth=3,
        color=f'C{idx}',
        label=labels[idx]
    )
    axes[1].fill_between(
        x=dims,
        y1=np.percentile(p_r_val, q=40, axis=1),
        y2=np.percentile(p_r_val, q=60, axis=1),
        color=f'C{idx}',
        alpha=0.1
    )
    
axes[1].legend(
    loc='center left',
    bbox_to_anchor=(1.0, 0.5),
    prop={'size': 15}
)

for ax in axes:
    ax.set_xlim([2, 75])
    ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
    ax.set_ylim([0, 1.05])
    ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0])
    ax.tick_params(labelsize=15)

axes[0].set_ylabel(r'\textbf{Shuffle Percentile}', fontsize=15)
axes[1].set_ylabel(r'\textbf{Rotation Percentile}', fontsize=15)
plt.savefig('exp4_percentile_vs_dim.pdf', bbox_inches='tight')

# Comparison of Values

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

for result_idx, (result, v_lfi) in enumerate(zip(results, v_lfis)):
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    v_s_lfi = np.median(result['v_s_lfi'], axis=-1)
    v_r_lfi = np.median(result['v_r_lfi'], axis=-1)
    plot_dims = np.arange(0, result['units'].shape[2], 10)

    for idx, dim_idx in enumerate(plot_dims):
        color = plt.get_cmap('plasma')(int(idx / (plot_dims.size - 1) * 255))

        sns.kdeplot(
            x=v_lfi[idx],
            y=v_s_lfi[idx],
            fill=True,
            levels=5,
            log_scale=True,
            alpha=0.5,
            color=color,
            ax=axes[0]
        )
        sns.kdeplot(
            x=v_lfi[idx],
            y=v_r_lfi[idx],
            fill=True,
            levels=5,
            log_scale=True,
            alpha=0.5,
            color=color,
            ax=axes[1]
        )

    # Colorbars
    [[x0, y0], [x1, y1]] = axes[1].get_position().get_points()
    cax1 = fig.add_axes([x1 + 0.05 * (x1 - x0), y0, 0.05 * (x1 - x0), (y1 - y0)])
    cb1 = fig.colorbar(
        mpl.cm.ScalarMappable(
            norm=mpl.colors.Normalize(vmin=2, vmax=plot_dims.max()),
            cmap="plasma"
        ),
        cax=cax1)
    cb1.ax.tick_params(labelsize=18)

    # Labels
    axes[0].set_ylabel(r'\textbf{Median LFI under Shuffle Null Model}', fontsize=18)
    axes[1].set_ylabel(r'\textbf{Median LFI under Rotation Null Model}', fontsize=18)
    cb1.ax.set_ylabel(r'\textbf{Dimension for Rotation Null Model}', labelpad=23, fontsize=18, rotation=270)

    for ax in axes:
        ax.tick_params(labelsize=15)
        ax.set_xlabel(r'\textbf{Observed LFI}', fontsize=18)
        ax.set_xlim([1e-6, 1e0])
        ax.set_ylim([1e-6, 1e0])
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.plot(ax.get_xlim(), ax.get_ylim(), color='k')