In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

from importlib import reload
from noise_correlations import plot, utils
from scipy.stats import iqr

%matplotlib inline

# Load data

In [None]:
# dataset keys
datasets = ['pvc11-2', 'pvc11-3', 'ret2']
# names for plots
names = ['Monkey V1 (2)', 'Monkey V1 (3)', 'Mouse Retina']
# filenames
files = ['exp1_2_pvc11_50_1000_10000.npz',
         'exp1_3_pvc11_50_1000_10000.npz',
         'exp1_ret2_50_1000_10000.npz', ]

In [None]:
results_paths = {dataset: os.path.join(os.environ['HOME'], 'fits/neurocorr', file)
                 for dataset, file in zip(datasets, files)}

In [None]:
# read in p-values and values for all results
p_s_lfis = {}
p_s_sdkls = {}
p_r_lfis = {}
p_r_sdkls = {}
v_lfis = {}
v_sdkls = {}

for dataset, results_path in results_paths.items():
    with np.load(results_path) as results:
        p_s_lfis[dataset] = results['p_s_lfi']
        p_s_sdkls[dataset] = results['p_s_sdkl']
        p_r_lfis[dataset] = results['p_r_lfi']
        p_r_sdkls[dataset] = results['p_r_sdkl']
        v_lfis[dataset] = results['v_lfi']
        v_sdkls[dataset] = results['v_sdkl']

In [None]:
# each experiment assumed to have same dimlet dimensions
n_dims, _ = p_s_lfis['pvc11-2'].shape
dims = 2 + np.arange(n_dims)
n_dimlets = 1000
n_repeas = 10000

# Behavior of metrics as dimensionality increases

In [None]:
# print plot of metrics for each dataset
for name, dataset in zip(names, datasets):
    # calculate means and stds
    lfi_mean = v_lfis[dataset].mean(axis=1)
    lfi_std = v_lfis[dataset].std(axis=1)
    sdkl_mean = v_sdkls[dataset].mean(axis=1)
    sdkl_std = v_sdkls[dataset].std(axis=1)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    # plot mean and std for lfi
    axes[0].plot(
        dims,
        lfi_mean,
        color='black',
        marker='o')
    axes[0].fill_between(
        x=dims,
        y1=lfi_mean - 0.5 * lfi_std,
        y2=lfi_mean + 0.5 * lfi_std,
        color='gray',
        alpha=0.5)
    # plot mean and std for sdkl
    axes[1].plot(
        dims,
        sdkl_mean,
        color='black',
        marker='o')
    axes[1].fill_between(
        x=dims,
        y1=sdkl_mean - 0.5 * sdkl_std,
        y2=sdkl_mean + 0.5 * sdkl_std,
        color='gray',
        alpha=0.5)

    for ax in axes:
        ax.set_xlabel(r'\textbf{Dimlet size}', fontsize=20)
        ax.set_xlim(left=1.5, right=n_dims + 1.5)
        ax.set_ylim(bottom=0)
    # labels and titles
    axes[0].set_ylabel(r'\textbf{Linear Fisher information}', fontsize=20)
    axes[1].set_ylabel(r'\textbf{sD}\textsubscript{\textbf{KL}}', fontsize=20)
    fig.text(x=0.53, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=25, va='center', ha='center')
    plt.tight_layout()
    plt.savefig('metric_curve_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()

# p-value distribution at different dimensionalities

In [None]:
dimlets = [2, 5, 25]
# print p-value comparison for each dataset
for name, dataset in zip(names, datasets):
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    # iterate over chosen dimlets
    for idx, dim in enumerate(dimlets):
        # lfi p-value comparison
        plot.plot_pvalue_comparison(
            p0s=p_s_lfis[dataset][dim - 2],
            p1s=p_r_lfis[dataset][dim - 2],
            labels=[r'\textbf{Shuffle p-value}', r'\textbf{Rotation p-value}'],
            show_inset=False,
            heatmap=True,
            color_regions=True,
            fax=(fig, axes[0, idx])
        )
        # sdkl p-value comparison
        plot.plot_pvalue_comparison(
            p0s=p_s_sdkls[dataset][dim - 2],
            p1s=p_r_sdkls[dataset][dim - 2],
            labels=[r'\textbf{Shuffle p-value}', r'\textbf{Rotation p-value}'],
            heatmap=True,
            show_inset=False,
            color_regions=True,
            fax=(fig, axes[1, idx])
        )

    for ax in axes.ravel():
        ax.set_ylim(top=1.25)
        ax.set_xlim(left=0.75 * 10e-5)
        ax.set_aspect('equal')            
    for idx, ax in enumerate(axes[0]):
        ax.set_title(r'\textbf{Dimlet Size: %s}' %dimlets[idx], fontsize=25)

    # axes labels
    fig.text(x=0., y=0.77, s=r'\textbf{LFI}',
             ha='center', va='center', fontsize=30,
             rotation=90)
    fig.text(x=0., y=0.30, s=r'\textbf{sD}\textsubscript{\textbf{KL}}',
             ha='center', va='center', fontsize=30,
             rotation=90)
    fig.text(x=0.53, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=30, va='center', ha='center')
    plt.tight_layout()

    plt.savefig('rotation_vs_shuffle_p_values_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()

# Mean and median p-value

In [None]:
# plot mean p-value for each dataset
for name, dataset in zip(names, datasets):
    # calculate mean/std lfi pvalue
    mean_s_lfi = np.mean(p_s_lfis[dataset], axis=1)
    std_s_lfi = np.std(p_s_lfis[dataset], axis=1)
    mean_r_lfi = np.mean(p_r_lfis[dataset], axis=1)
    std_r_lfi = np.mean(p_r_lfis[dataset], axis=1)
    # calculate mean/std sdkl pvalue
    mean_s_sdkl = np.mean(p_s_sdkls[dataset], axis=1)
    std_s_sdkl = np.std(p_s_sdkls[dataset], axis=1)
    mean_r_sdkl = np.mean(p_r_sdkls[dataset], axis=1)
    std_r_sdkl = np.mean(p_r_sdkls[dataset], axis=1)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    # mean lfi with spread (shuffle)
    axes[0].plot(
        dims,
        mean_s_lfi,
        color='gray',
        marker='o')
    axes[0].fill_between(
        x=dims,
        y1=mean_s_lfi - 0.5 * std_s_lfi,
        y2=mean_s_lfi + 0.5 * std_s_lfi,
        color='gray',
        alpha=0.25)
    # mean lfi with spread (rotation)
    axes[0].plot(
        dims,
        mean_r_lfi,
        color='red',
        marker='o')
    axes[0].fill_between(
        x=dims,
        y1=mean_r_lfi - 0.5 * std_r_lfi,
        y2=mean_r_lfi + 0.5 * std_r_lfi,
        color='red',
        alpha=0.25)
    # mean sdkl with spread (shuffle)
    axes[1].plot(
        dims,
        mean_s_sdkl,
        color='gray',
        marker='o',
        label=r'Shuffle')
    axes[1].fill_between(
        x=dims,
        y1=mean_s_sdkl - 0.5 * std_s_sdkl,
        y2=mean_s_sdkl + 0.5 * std_s_sdkl,
        color='gray',
        alpha=0.25)
    # mean sdkl with spread (Rotation)
    axes[1].plot(
        dims,
        mean_r_sdkl,
        color='red',
        marker='o',
        label=r'Rotation')
    axes[1].fill_between(
        x=dims,
        y1=mean_r_sdkl - 0.5 * std_r_sdkl,
        y2=mean_r_sdkl + 0.5 * std_r_sdkl,
        color='red',
        alpha=0.25)

    for ax in axes:
        ax.set_xlabel(r'\textbf{Dimlet size}', fontsize=20)
        ax.set_xlim(left=1.5, right=n_dims + 1.5)
        ax.set_ylim([0, 1])
        ax.set_ylabel(r'\textbf{Mean p-value}', fontsize=20)

    axes[0].set_title(r'\textbf{LFI}', fontsize=25)
    axes[1].set_title(r'\textbf{sD}\textsubscript{\textbf{KL}}', fontsize=25)
    fig.text(x=0.54, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=25, va='center', ha='center')
    plt.tight_layout()

    plt.savefig('mean_p_value_vs_dimlet_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()

In [None]:
# plot median p-value for each dataset
for name, dataset in zip(names, datasets):
    # calculate median/iqr lfi pvalue
    median_s_lfi = np.median(p_s_lfis[dataset], axis=1)
    iqr_s_lfi = iqr(p_s_lfis[dataset], axis=1)
    median_r_lfi = np.median(p_r_lfis[dataset], axis=1)
    iqr_r_lfi = iqr(p_r_lfis[dataset], axis=1) 
    # calculate median/iqr sdkl pvalue
    median_s_sdkl = np.median(p_s_sdkls[dataset], axis=1)
    iqr_s_sdkl = iqr(p_s_sdkls[dataset], axis=1)
    median_r_sdkl = np.median(p_r_sdkls[dataset], axis=1)
    iqr_r_sdkl = iqr(p_r_sdkls[dataset], axis=1)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    # median lfi with iqr (shuffle)
    axes[0].plot(
        dims,
        median_s_lfi,
        color='gray',
        marker='o')
    axes[0].fill_between(
        x=dims,
        y1=median_s_lfi - 0.5 * iqr_s_lfi,
        y2=median_s_lfi + 0.5 * iqr_s_lfi,
        color='gray',
        alpha=0.25)
    # median lfi with iqr (rotation)
    axes[0].plot(
        dims,
        median_r_lfi,
        color='red',
        marker='o')
    axes[0].fill_between(
        x=dims,
        y1=median_r_lfi - 0.5 * iqr_r_lfi,
        y2=median_r_lfi + 0.5 * iqr_r_lfi,
        color='red',
        alpha=0.25)
    # median sdkl with iqr (shuffle)
    axes[1].plot(
        dims,
        median_s_sdkl,
        color='gray',
        marker='o',
        label=r'Shuffle')
    axes[1].fill_between(
        x=dims,
        y1=median_s_sdkl - 0.5 * iqr_s_sdkl,
        y2=median_s_sdkl + 0.5 * iqr_s_sdkl,
        color='gray',
        alpha=0.25)
    # median sdkl with iqr (rotation)
    axes[1].plot(
        dims,
        median_r_sdkl,
        color='red',
        marker='o',
        label=r'Rotation')
    axes[1].fill_between(
        x=dims,
        y1=median_r_sdkl - 0.5 * iqr_r_sdkl,
        y2=median_r_sdkl + 0.5 * iqr_r_sdkl,
        color='red',
        alpha=0.25)

    lgd = axes[1].legend(loc='best',
                         prop={'size': 20})

    for ax in axes:
        ax.set_xlabel(r'\textbf{Dimlet size}', fontsize=20)
        ax.set_xlim(left=1.5, right=n_dims + 1.5)
        ax.set_ylim([0, 1])
        ax.set_ylabel(r'\textbf{Median p-value}', fontsize=20)

    axes[0].set_title(r'\textbf{LFI}', fontsize=25)
    axes[1].set_title(r'\textbf{sD}\textsubscript{\textbf{KL}}', fontsize=25)
    fig.text(x=0.54, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=25, va='center', ha='center')

    plt.tight_layout()
    plt.savefig('median_p_value_vs_dimlet_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()

# Example p-value distributions

In [None]:
dimlets = [2, 5, 10, 25, 50]
bins = np.linspace(0, 1, 31)

# plot p-value distributions for different datasets and dimensions
for name, dataset in zip(names, datasets):
    fig, axes = plt.subplots(2, 5, figsize=(12, 4), sharex=True)
    # iterate over dimlet dimensions
    for idx, dim in enumerate(dimlets):
        # get p-value distributions
        p_s_lfi = p_s_lfis[dataset][dim - 2]
        p_r_lfi = p_r_lfis[dataset][dim - 2]
        p_s_sdkl = p_s_sdkls[dataset][dim - 2]
        p_r_sdkl = p_r_sdkls[dataset][dim - 2]
        # plot histograms of p-value distributions
        axes[0, idx].hist(p_r_lfi,
                          bins=bins,
                          color='red',
                          linewidth=2,
                          histtype='step')
        axes[0, idx].hist(p_s_lfi,
                          bins=bins,
                          color='gray',
                          linewidth=2,
                          histtype='step')
        axes[1, idx].hist(p_r_sdkl,
                          bins=bins,
                          color='red',
                          linewidth=2,
                          histtype='step')
        axes[1, idx].hist(p_s_sdkl,
                          bins=bins,
                          color='gray',
                          linewidth=2,
                          histtype='step')
        # title for top plots
        for ax_idx, ax in enumerate(axes[0]):
            ax.set_title(r'\textbf{Dimlet Size: %s}' % dimlets[ax_idx],
                         fontsize=15)
        # x label for bottom plots
        for ax in axes[1]:
            ax.set_xlabel(r'$p$\textbf{ value}', fontsize=18)

        for ax in axes.ravel():
            ax.set_yscale('log', nonposy='clip')
            ax.tick_params(labelsize=13) 
            ax.set_ylim([1, 15000])
        # adjust ylim for ret2
        if dataset == 'ret2':
            ax.set_ylim([1, 10000])

    axes[0, 0].set_ylabel(r'\textbf{LFI}', fontsize=18)
    axes[1, 0].set_ylabel(r'\textbf{sD}\textsubscript{\textbf{KL}}', fontsize=18)
    fig.text(x=0.53, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=23, va='center', ha='center')
    
    plt.tight_layout()
    plt.savefig('p_values_dist_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()

# Look at plots separated by stimulus

In [None]:
unique_stims = {'pvc11-2': 12, 'pvc11-3': 12, 'ret2': 6}

for name, dataset in zip(names, datasets):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    n_stim = unique_stims[dataset]
    # iterate over stim pairings
    for idx in range(n_stim):
        # split up p-values by stimulus pairing
        idxs = idx + n_stim * np.arange(1000)
        # calculate medians within stimulus pairing
        median_s_lfi = np.median(p_s_lfis[dataset][:, idxs], axis=1)
        median_r_lfi = np.median(p_r_lfis[dataset][:, idxs], axis=1)
        median_s_sdkl = np.median(p_s_sdkls[dataset][:, idxs], axis=1)
        median_r_sdkl = np.median(p_r_sdkls[dataset][:, idxs], axis=1)
        # color scales by stimulus value
        color = plot.get_cmap_color('viridis', val=idx, max_val=n_stim)
        # plot median p-values vs dimlet dimension
        axes[0].plot(
            dims,
            median_s_lfi,
            color=color,
            marker='o')
        axes[0].plot(
            dims,
            median_r_lfi,
            color=color,
            linestyle='--',
            marker='x')
        axes[1].plot(
            dims,
            median_s_sdkl,
            color=color,
            marker='o')
        axes[1].plot(
            dims,
            median_r_sdkl,
            color=color,
            marker='x',
            linestyle='--')

    for ax in axes:
        ax.set_xlabel(r'\textbf{Dimlet size}', fontsize=20)
        ax.set_xlim(left=1.5, right=n_dims + 1.5)
        ax.set_ylim([-0.02, 1.02])
        ax.set_ylabel(r'\textbf{Median p-value}', fontsize=20)

    axes[0].set_title(r'\textbf{LFI}', fontsize=25)
    axes[1].set_title(r'\textbf{sD}\textsubscript{\textbf{KL}}', fontsize=25)
    fig.text(x=0.54, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=25, va='center', ha='center')
    plt.tight_layout()

    plt.savefig('median_p_value_by_stim_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()

In [None]:
unique_stims = {'pvc11-2': 12, 'pvc11-3': 12, 'ret2': 6}

for name, dataset in zip(names, datasets):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    n_stim = unique_stims[dataset]
    # iterate over stim pairings
    for idx in range(n_stim):
        # split up p-values by stimulus pairing
        idxs = idx + n_stim * np.arange(n_dimlets)
        # calculate medians within stimulus pairing
        median_s_lfi = np.mean(p_s_lfis[dataset][:, idxs], axis=1)
        median_r_lfi = np.mean(p_r_lfis[dataset][:, idxs], axis=1)
        median_s_sdkl = np.mean(p_s_sdkls[dataset][:, idxs], axis=1)
        median_r_sdkl = np.mean(p_r_sdkls[dataset][:, idxs], axis=1)
        # color scales by stimulus value
        color = plot.get_cmap_color('viridis', val=idx, max_val=n_stim)
        axes[0].plot(
            dims,
            median_s_lfi,
            color=color,
            marker='o')
        axes[0].plot(
            dims,
            median_r_lfi,
            color=color,
            linestyle='--',
            marker='x')
        axes[1].plot(
            dims,
            median_s_sdkl,
            color=color,
            marker='o')
        axes[1].plot(
            dims,
            median_r_sdkl,
            color=color,
            marker='x',
            linestyle='--')

    for ax in axes:
        ax.set_xlabel(r'\textbf{Dimlet size}', fontsize=20)
        ax.set_xlim(left=1.5, right=n_dims + 1.5)
        ax.set_ylim([-0.02, 1.02])
        ax.set_ylabel(r'\textbf{Median p-value}', fontsize=20)

    axes[0].set_title(r'\textbf{LFI}', fontsize=25)
    axes[1].set_title(r'\textbf{sD}\textsubscript{\textbf{KL}}', fontsize=25)
    fig.text(x=0.54, y=1.03, s=r'\textbf{' + name + '}',
             fontsize=25, va='center', ha='center')

    plt.tight_layout()
    plt.savefig('mean_p_value_by_stim_' + dataset + '.pdf', bbox_inches='tight')
    plt.show()