In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks as packs
import numpy as np
import os

from mpl_lego.ellipse import plot_cov_ellipse
from noise_correlations import analysis, utils 
from settings import colors, titles

In [None]:
mplego.style.use_latex_style()

In [None]:
base_path = '/storage/fits/neurocorr/exp09'

In [None]:
ret2_path = os.path.join(base_path, 'exp09_ret2_15_1000_1000.h5')
pvc11_1_path = os.path.join(base_path, 'exp09_1_pvc11_15_1000_1000.h5')
ecog_path = os.path.join(base_path, 'exp09_ecog_15_3000_1000.h5')

# Extra datasets
# pvc11_2_path = os.path.join(base_path, 'exp09_2_pvc11_15_1000_1000.h5')
# pvc11_3_path = os.path.join(base_path, 'exp09_3_pvc11_15_1000_1000.h5')

In [None]:
ret2 = h5py.File(ret2_path, 'r')
pvc11_1 = h5py.File(pvc11_1_path, 'r')
ecog = h5py.File(ecog_path, 'r')
results = [ret2, pvc11_1, ecog]

# Extra datasets
# pvc11_2 = h5py.File(pvc11_2_path, 'r')
# pvc11_3 = h5py.File(pvc11_3_path, 'r')

In [None]:
n_max_units = ret2['units'].shape[2]
dims = 3 + np.arange(n_max_units - 2)

In [None]:
titles = mplego.labels.bold_text(titles)

In [None]:
groups = [
    'v_lfi',
    'v_s_lfi',
    'v_r_lfi',
    'v_fa_lfi'
]
labels = mplego.labels.bold_text([
    'Observed',
    'Shuffle',
    'Rotation',
    'FA'
])

In [None]:
mu_scatter_size = 100
ax_label_size = 18

fig, ax = plt.subplots(1, 1, figsize=(5, 5))

# RET2 path
data_path = '/storage/data/ret2/200114_fov1_data.mat'
# Create neuropack
pack = packs.RET2(data_path=data_path)
X_ret2 = pack.get_response_matrix(cells='tuned', response='max')
# Extract stimuli
stimuli_ret2 = pack.angles
unique_stimuli_ret2 = pack.unique_angles
stim1 = 0
stim2 = 60
neuron1 = 32
neuron2 = 25
X1 = X_ret2[stimuli_ret2 == stim1][:, [neuron1, neuron2]]
mu1 = np.mean(X1, axis=0)
X2 = X_ret2[stimuli_ret2 == stim2][:, [neuron1, neuron2]]
mu2 = np.mean(X2, axis=0)

ax.scatter(
    X1[:, 0],
    X1[:, 1],
    s=70,
    color='C0',
    edgecolor='white',
    alpha=0.50)
ax.scatter(
    X2[:, 0],
    X2[:, 1],
    s=70,
    color='C1',
    edgecolor='white',
    alpha=0.60)
ax.scatter(
    mu1[0], mu1[1],
    color='C0',
    s=mu_scatter_size,
    edgecolor='black')
ax.scatter(
    mu2[0], mu2[1],
    color='C1',
    s=mu_scatter_size,
    edgecolor='black')

ax.text(
    x=0.45, y=1.00, s=r'$\mathbf{X}_{s_1}^d$',
    va='center',
    ha='center',
    color='C0',
    size=26)
ax.text(
    x=1.75, y=0.23, s=r'$\mathbf{X}_{s_2}^d$',
    va='center',
    ha='center',
    color='C1',
    size=26)

ax.set_xlim([0., 2])
ax.set_ylim(ax.get_xlim())
ax.set_xticks([0, 0.5, 1, 1.5, 2])
ax.set_yticks(ax.get_xticks())
ax.set_xlabel(r'\textbf{Unit 1 ($\Delta$F/F)}', fontsize=ax_label_size)
ax.set_ylabel(r'\textbf{Unit 2 ($\Delta$F/F)}', fontsize=ax_label_size)
ax.tick_params(labelsize=15)
plt.savefig('figure3a.pdf', bbox_inches='tight')

In [None]:
"""
Figure Settings
"""
# Subplot adjustments
wspace = 0.4
hspace = 0.5
# Label adjustments
subplot_label_size = 22
subplot_x = -0.16
subplot_y = 1.10
axis_label_size = 16
title_size = 18
# Line settings
linewidth = 2
line_alpha = 0.8
# Fill settings
fill_alpha = 0.1
# Percentile bounds for curves
percentile_lower = 40
percentile_upper = 60

"""
Figure 3
"""
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
plt.subplots_adjust(wspace=wspace, hspace=hspace)

# Enumerate over results
for idx, (result, ax) in enumerate(zip(results, axes)):
    # Plot observed LFI
    for group, color in zip(groups, colors.values()):
        if group == 'v_lfi':
            values = result[group][:]
        else:
            values = np.median(result[group], axis=2)
            # Alternative: take statistics across all dim-stims and repeats
            # values = np.reshape(result[group], (dims.size, -1))
        median = np.median(values, axis=1)
        lower = np.percentile(values, q=percentile_lower, axis=1)
        upper = np.percentile(values, q=percentile_upper, axis=1)
        # Fill region between percentile bounds
        ax.fill_between(
            x=dims,
            y1=lower,
            y2=upper,
            color=color,
            alpha=fill_alpha)
        ax.plot(
            dims,
            median,
            linewidth=linewidth,
            color=color,
            alpha=line_alpha)

# Set bounds
axes[0].set_ylim(bottom=5e-4, top=1e-2)
axes[1].set_ylim(bottom=1e-3)
axes[2].set_ylim(bottom=5)

# Set axis limits, scales, and labels
for (ax, title) in zip(axes, titles):
    ax.set_xlim([3, 15])
    ax.set_yscale('log')
    ax.set_xticks([3, 5, 10, 15])
    ax.tick_params(labelsize=15)
    
    ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=axis_label_size)
    ax.set_ylabel(r'\textbf{LFI}', fontsize=axis_label_size)
    ax.set_title(title, fontsize=title_size)

# Create legend in last axis spot
for color, label in zip(colors.values(), labels):
    axes[-1].plot([], [], color=color, label=label, linewidth=linewidth)
axes[-1].legend(
    loc='center left',
    bbox_to_anchor=(1.05, 0.5),
    prop={'size': 15})

# Apply subplot labels
mplego.labels.apply_subplot_labels(
    axes.ravel(),
    labels=['b', 'c', 'd'],
    bold=True,
    x=subplot_x,
    y=subplot_y,
    size=subplot_label_size)

plt.savefig('figure3_base.pdf', bbox_inches='tight')
plt.show()