In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline
# plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['font.size'] = 10
plt.rcParams['font.sans-serif'] = [u'Arial']
from matplotlib.patches import Rectangle

LINEWIDTH = 0.25

plt.rcParams['lines.linewidth'] = LINEWIDTH
plt.rcParams['axes.linewidth'] = LINEWIDTH
plt.rcParams['xtick.major.width'] = LINEWIDTH
plt.rcParams['ytick.major.width'] = LINEWIDTH

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.frameon'] = False


import brewer2mpl
bmap = brewer2mpl.get_map('Set1', 'qualitative', 5)
colors = bmap.mpl_colors

from cycler import cycler

plt.rcParams['axes.prop_cycle'] = cycler(color=colors)


from scipy.stats import ks_2samp
from sklearn.neighbors import KernelDensity

import pandas as pd


from context import src, utils
from src.analyzer import DataAnalyzer, plot_fill_between
from utils.plot_utils import label_subplot, equalize_y_axes

from collections import OrderedDict

from motion_plot_utils import (
    nice_print, plot_results, tuning_plot, 
    collect_analyzers, groupby, analyzers_to_dataframe,
)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
output_dir_list = ['motion_benefit_nom1', 'motion_benefit6'] # Main plots of motion benefit
output_dir_list = ['motion_benefit_drop']   # Plots for motion benefit in the presence of cone loss
# output_dir_list = ['motion_benefit_best_dc_sweep']  # Find best diffusion constant 
output_dir_list = ['extended_motion_benefit']  # Extended motion benefit with cone loss experiments
# output_dir_list = ['motion_gain']  # Experiments testing gain factor on eye motions
# output_dir_list = ['stimlus_size']  # Experiments testing gain factor on eye motions
data_fns, da_ = collect_analyzers(output_dir_list)

In [None]:
grouping_columns = ['inference_mode', 'gen_mode', 'gen_dc', 'prior_dc', 'drop_prob', 'scaling_factor', 'ds']
EXP = 'Experiment'
NOM = 'Diffusion'

In [None]:
da = da_[0]
t = da.time_list()
tf = t[-1]

data = analyzers_to_dataframe(da_)
grouped = groupby(data, grouping_columns)

In [None]:
# Only for drop prob experiments

tuning_key = 'drop_prob'
x_label = 'Cone Loss'
log_yscale = False
x_normalizer = 1.

fig, ax = plt.subplots(1, 1, figsize=(3, 3))

tuning_plot(
    ax=ax, grouped=grouped, tuning_key=tuning_key, 
    x_label=x_label, log_yscale=log_yscale, x_normalizer=x_normalizer,
    t=t,
)

tuning_key = 'scaling_factor'
x_label = 'Motion Gain'
log_yscale = False
x_normalizer = 1.

tuning_key = 'ds'
x_label = 'Relative Stimlus Size'
log_yscale = True
x_normalizer = 0.4




In [None]:
output_dir_list = ['motion_gain']  # Experiments testing gain factor on eye motions
_, da1_ = collect_analyzers(output_dir_list)
data1 = analyzers_to_dataframe(da1_)
grouped1 = groupby(data1, grouping_columns)


output_dir_list = ['stimlus_size', 'stimlus_size1']  # Experiments testing stimulus size
output_dir_list = ['stimlus_size2']

_, da2_ = collect_analyzers(output_dir_list)
data2 = analyzers_to_dataframe(da2_)
grouped2 = groupby(data2, grouping_columns)



In [None]:
# SI Figure:

nrows = 4
ncols = 2
a = 3.5 / 2
fig, axes = plt.subplots(nrows, ncols, figsize=(a * ncols, a * nrows))

for i in range(2):
    if i == 0:
        tuning_key = 'scaling_factor'
        x_label = 'Motion Gain'
        log_yscale = False
        x_normalizer = 1.
        grouped = grouped1
        loc = None
    if i == 1:
        tuning_key = 'ds'
        x_label = 'Stroke Width'
#         log_yscale = True
        x_normalizer = 0.5
        grouped = grouped2
        loc = 'upper left'

    ax = axes[0, i]
    tuning_plot(
        ax=ax, grouped=grouped, tuning_key=tuning_key, 
        x_label=x_label, log_yscale=log_yscale, x_normalizer=x_normalizer,
        t=t,
        loc=loc
    )


    if i == 0:
        ax.set_ylim([0, ax.get_ylim()[-1]])
        ax.axvline(x=1.0, ls='--', c='black')

    if i == 1:
        ax.set_ylim(0, 25.)
        ax.set_xlim(0, 1.5 * 0.8)
        ax.set_xticks([0, 0.8])

        

for key, group in grouped2:
    ds = round(key['ds'], 2)
    gen_mode = key['gen_mode']
    ds_map = {0.25: 0, 0.4: 1, 0.6: 2}
    row = ds_map.get(ds)
    if row is None:
        continue
    row += 1
    col = {EXP: 0, NOM: 1}[gen_mode]
    ax = axes[row, col]
    idx1 = np.argsort(group[t[-1]].values)[int(len(group) * 0.5)]
    idx = group.index[idx1]
    da = da2_[idx]
    da.plot_image_estimate(fig, ax, da.N_itr - 1,
                                  colorbar=False)
    snr = da.snr_list()[-1]
    stim_mode = {NOM: 'NM', EXP: 'M'}[gen_mode]
    ax.set_title('SNR={:.2f} w={:.1f} S:{}'.format(snr, 2 * ds, stim_mode))

    a = ax.get_xlim()[1]; ax.set_xlim([-a, a]); ax.set_ylim([-a, a])
    b = int(a * 0.8); ax.set_xticks([-b, 0, b]); ax.set_yticks([-b, 0, b])

        
for i, ax in enumerate(axes.flat):
    ax.set_title(ax.get_title(), fontdict={'size': 10})
    label=chr(ord('A') + i)
    label_subplot(fig=fig, ax=ax, label=label, dx=0.10, dy=0.03)
    
    if i >= 2: 
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_visible(False)


plt.tight_layout(pad=0.5)
plt.savefig('../output/si_extended_tuning.pdf')

In [None]:
print ['Group'] + grouping_columns + ['SNR', 'N_EXP']
for i, (name, group) in enumerate(grouped2):
#     if name['drop_prob'] > 0.1:
#         continue
    print nice_print(
        [i]
        + list(name.values())
        + [group[list(t)[-1]].mean()]
        + [len(group)]
    )

In [None]:
# # Select the best prior_dc for given gen_dc
# grouping_columns = ['inference_mode', 'gen_mode', 'gen_dc', 'ds', 'drop_prob',]
# grouped = groupby(data, grouping_columns)
# data1 = []
# for name, group in grouped:
# #     print group
#     subgrouped = groupby(group, ['prior_dc'])
#     snrs = [subgroup[tf].mean() for _, subgroup in subgrouped]
#     best_group_idx = np.argmax(snrs)
#     _, best_group = subgrouped[best_group_idx]
#     data1.append(best_group)

# data1 = pd.concat(data1)

# df1 = pd.groupby(
#     data1[['gen_dc', t[-1], 'drop_prob']], 
#     ['gen_dc', 'drop_prob']).aggregate([np.mean, np.std])
# df1.reset_index(inplace=True)


# fig, ax = plt.subplots(1, 1, figsize=(3, 3))
# for name, group in df1.groupby('drop_prob'):
#     ax.plot(group['gen_dc'], group[t[-1]]['mean'], label='Drop Prob: {:.2f}'.format(name))
# ax.set_ylabel('SNR')
# ax.set_xlabel('gen_dc')
# ax.set_xscale('log')
# ax.legend()
# plt.savefig('../output/drop_prob.pdf')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.set_title('SNR as a function of time')
plot_results(ax, grouped)

In [None]:
final_snrs = [group[tf].values for k, group in grouped]

res = ks_2samp(final_snrs[0], final_snrs[1])
print res # two sided / 2 = one sided
label_ = [k['gen_mode'] for k, group in grouped]



In [None]:
plt.figure(figsize=(4, 4))
X_plot = np.linspace(0, 20, 1000)[:, np.newaxis]
for final_snr, label in zip(final_snrs, label_):
    log_dens = KernelDensity().fit(final_snr[:, np.newaxis]).score_samples(X_plot)
    plt.plot(X_plot, np.exp(log_dens), c=None, label=label)
plt.xlabel('SNR')
plt.ylabel('Normalized Density')
plt.title('Density of SNR, p = {:.1E}'.format(res.pvalue))
plt.legend()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'motion_benefit_density.png'), dpi=200)

In [None]:
idx = [data[
        (data['gen_mode'] == key) & 
        (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'EM')
    ].index.values for key in [EXP, NOM]]
exp_idx, nom_idx = idx

In [None]:
idx = [data[
        (data['gen_mode'] == key) & 
        (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'NoMotion') 
    ].index.values for key in [EXP, NOM]]
nom_exp_idx, nom_nom_idx = idx
# np.random.shuffle(exp_idx)
# np.random.shuffle(nom_idx)

In [None]:
# da = da_[nom_exp_idx[0]]

# fig, ax = plt.subplots(1, 1)
# da.plot_image_estimate(fig=fig, ax=ax, q=100)

In [None]:
# da = da_[exp_idx[2]]

# da.plot_em_estimate(-1)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(3.5, 3.5))

idx = [data[
        (data['gen_mode'] == key) & 
        (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'EM') &
        (data['drop_prob'] == 0.3)
    ].index.values for key in [EXP, NOM]]

idx_ = {}
for key, group in grouped:
    if key['drop_prob'] != 0.3:
        continue
    idx1 = np.argsort(group[t[-1]].values)[int(len(group) * 0.5)]
    idx = group.index[idx1]
    print key
    idx_[key['gen_mode']] = idx

exp_idx, nom_idx = idx_[EXP], idx_[NOM]

da_exp, da_nom = [da_[idx_] for idx_ in [exp_idx, nom_idx]]

ax = axes[0, 0]
da_exp.plot_image_and_rfs(fig, ax, -1)
a = 4; ax.set_xlim([-a, a]); ax.set_ylim([-a, a])
b = 3; ax.set_xticks([-b, 0, b]); ax.set_yticks([-b, 0, b])

ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
# ax.xaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.yaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.axis('off')


ax = axes[0, 1]

tuning_key = 'drop_prob'
x_label = 'Cone Loss'
log_yscale = False
x_normalizer = 1.

tuning_plot(
    ax=ax, grouped=grouped, tuning_key=tuning_key, 
    x_label=x_label, log_yscale=log_yscale, x_normalizer=x_normalizer,
    y_label='SNR (t=700 ms)',
    t=t,
)
# ax.axis('equal')
# ax.set_aspect('equal')

# plot_results(ax, grouped)
# ax.set_xlim([0., 700.])
ax.set_xticks([0.0, 0.3, 0.5, 0.8])
# ax.set_title('SNR vs time (ms)')
# ax.set_xlabel('time (ms)')
# ax.set_ylabel('SNR')
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')


# ax.xaxis.set_major_locator(MaxNLocator(4))
ax.yaxis.set_major_locator(MaxNLocator(4))

x0,x1 = ax.get_xlim()
y0,y1 = ax.get_ylim()
ax.set_aspect(abs(x1-x0)/abs(y1-y0))

for v, da, label in zip(
    [0, 1],
    [da_exp, da_nom],
    ['Motion', 'No Motion']):
    ax = axes[1, v]
    da.plot_image_estimate(fig, ax, da.N_itr - 1,
                           colorbar=False)
    ax.set_title('{}: SNR={:.2f}'.format(label, da.snr_one_iteration(da.N_itr - 1)))
    ax.set_xticks([])
    ax.set_yticks([])

#     ax.xaxis.set_major_locator(MaxNLocator(3))
#     ax.yaxis.set_major_locator(MaxNLocator(3))


for i, ax in enumerate(axes.flat):
    ax.set_title(ax.get_title(), fontdict={'size': 10})
    label=chr(ord('A') + i)
    label_subplot(fig=fig, ax=ax, label=label, dx=0.10, dy=0.03)

    
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)
    
plt.tight_layout(pad=0.7)
plt.savefig(os.path.join('../output', 'motion_benefit_drop.pdf'), dpi=200)

In [None]:
fig = plt.figure(figsize=(7.5, 3.75))
size = (6, 12)
ax1 = plt.subplot2grid(size, (0, 0), colspan=3, rowspan=3)
ax7 = plt.subplot2grid(size, (0, 3), colspan=3, rowspan=3)

ax2 = plt.subplot2grid(size, (3, 0), colspan=3, rowspan=3)
ax3 = plt.subplot2grid(size, (3, 3), rowspan=3, colspan=3)


ax4_ = []
for i in range(3):
    ax4_.append(plt.subplot2grid(size, (0, 6 + 2 * i), rowspan=2, colspan=2))


# ax7 = plt.subplot2grid(size, (2, 6), rowspan=4, colspan=2)
    


ax6_ = []
for i in range(3):
    ax6_.append(plt.subplot2grid(size, (2, 6 + 2 * i), rowspan=2, colspan=2))

ax5_ = []
for i in range(2):
    ax5_.append(plt.subplot2grid(size, (4, 7 + 2 * i), rowspan=2, colspan=2))


da_exp, da_nom = [da_[idx_[4]] for idx_ in [exp_idx, nom_idx]]

da_nom_exp = da_[nom_exp_idx[0]]


ax = ax1
da.plot_base_image(fig=fig, ax=ax, colorbar=False)
ax.set_title('Image on the Retina')
ax.set_xticks([-3, 0, 3])
ax.set_yticks([-3, 0, 3])

# ax.yaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.xaxis.set_major_locator(MaxNLocator(3, symmetric=True))

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_visible(False)


    
ax = ax7

plot_results(ax, grouped)
ax.set_xlim([0., 700.])
ax.set_xticks([0, 350, 700])
ax.set_title('SNR vs time (ms)')
for axis in ['top', 'right']:
    ax.spines[axis].set_visible(False)
ax.yaxis.set_major_locator(MaxNLocator(3))

# da.plot_image_and_rfs(fig, ax)

# # ax.set_title('E Projected onto the Retina \n E thickness: {:.2f} arcmin \n Neuron Spacing {:.2f} arcmin'.format(
# #         2 * da.data['ds'], da.data['de']))
# a = 6; ax.set_xlim([-a, a]); ax.set_ylim([-a, a])

# # ax.set_yticks([-5, 0, 5])
# ax.yaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.xaxis.set_major_locator(MaxNLocator(3, symmetric=True))

# ax.title.set_fontsize(7)

# for axis in ['top','bottom','left','right']:
#     ax.spines[axis].set_visible(False)


    
for v, da, label, ax in zip(
    [0, 1],
    [da_exp, da_nom],
    ['S:M | D:EM', 'S:NM | D:EM'],
    [ax2, ax3]):
#     ax = axes[1, v]
    da.plot_image_estimate(fig, ax, da.N_itr - 1,
                           colorbar=False)
    ax.set_title('{}'.format(label))
    ax.set_xlabel('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    ax.set_xticks([])
    ax.set_yticks([])

    
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)

    ax.set_aspect('equal')

# for ax in [ax7]:
#     ax.axis('off')
#     for axis in ['top','bottom','left','right']:
#         ax.spines[axis].set_visible(False)





for i, q in enumerate([9, 39, 79]):
    ax = ax4_[i]
    da = da_exp

    colorbar = False
    da.plot_image_estimate(fig, ax, q=q,
                           colorbar=colorbar)
    t0 = da.data['EM_data/{}/time_steps'.format(q)]
    ax.set_title('T = {} ms'.format(t0), fontsize=7)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)

    ax = ax6_[i]
    da_nom_exp.plot_image_estimate(fig, ax, q=q, colorbar=colorbar)

    t0 = da.data['EM_data/{}/time_steps'.format(q)]
    ax.set_title('T = {} ms'.format(t0), fontsize=7)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.set_ylabel('Motion | NoM')


    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)
        
q = 139
for i, title in enumerate(['x', 'y']):
    ax = ax5_[i]
    da.plot_path_estimate(ax, q, i)
    ax.set_title(title)
    ax.set_ylabel('')
    start, end = ax.get_xlim()
    ax.set_xticks([.7])
    ax.set_xticklabels([700])
    ax.set_xlabel('')

    ax.set_yticks([0, 5])
    
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    
    for axis in ['top', 'right']:
        ax.spines[axis].set_visible(False)
        
equalize_y_axes(ax5_[0], ax5_[1])

# for i, mode in enumerate(['ON', 'OFF']):
#     ax = ax6_[i]
#     da.plot_spikes(ax, 600, mode=mode)
#     ax.set_title('')

# ax = ax5_[-1]
# da.plot_spikes(ax, 600, mode='ON')
# ax.set_title('ON RGCs')

# for axis in ['top','bottom','left','right']:
#     ax.spines[axis].set_visible(False)



for i, ax in enumerate(
    [
        ax1, ax7, ax2, ax4_[0], ax6_[0], ax5_[0]
    ]):
    label=chr(ord('A') + i)

    label_subplot(fig=fig, ax=ax, label=label, dx=0.03, dy=0.01)

    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    
for ax in fig.axes:
    ax.title.set_fontsize(10)
    
plt.subplots_adjust(wspace=1, hspace=3, top=0.9, left=0.05, right=0.95, bottom=0.1)
# plt.tight_layout()
plt.savefig('../output/motion_benefit.pdf')

In [None]:
# plt.figure(figsize=(4, 4))
# plt.title('SNR as a function of time')
# alpha = 0.5
# for c, (name, group) in zip(c_, grouped):
#     inf_mode, gen_mode, prior_dc, ds = name
#     label = {EXP: 'Motion', NOM: 'No Motion'}[gen_mode]
#     plt.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     plt.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
#     plt.xlabel('time (ms)')
# plt.legend(loc='upper left')
# # plt.savefig(os.path.join(output_dir, 'snr_time_no_motion_vs_motion_0.png'), dpi=200)

In [None]:
idx = data[
        (data['gen_mode'] == NOM) & 
#         (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'NoMotion')
    ].index.values
print idx

In [None]:
da = da_[idx[0]]

In [None]:
fig, ax = plt.subplots(1, 1)

da.plot_image_estimate(fig=fig, ax=ax, q=-1)

In [None]:
da.plot_em_estimate(-1)
# plt.savefig(os.path.join(output_dir, 'ds04_e_reconstruction.png'), dpi=200)

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(4 * 2, 3.5 * 5))
for i, idx in enumerate(nom_idx[0:10]):
#     plt.subplot(5, 2, i + 1)
    da_[idx].plot_image_estimate(fig, axes.flat[i], -1)
plt.savefig(os.path.join(output_dir, 'e_reconstructions_no_motion.png'), dpi=200)

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(4 * 2, 3.5 * 5))
for i, idx in enumerate(exp_idx[0:10]):
    da_[idx].plot_image_estimate(fig, axes.flat[i], -1)
plt.savefig(os.path.join(output_dir, 'e_reconstructions_motion.png'), dpi=200)

In [None]:
plt.figure(figsize=(20, 20))
for i, t in enumerate(range(0, 140, 10)):
    plt.subplot(5, 5, i + 1)
    da.plot_image_estimate(t)

# Group SNR plots by image size and diffusion constant

In [None]:
out_ = []
for pkl_fn in pkl_fns:
    da = DataAnalyzer.fromfilename(pkl_fn)
    out = [da.data['ds'], da.data['motion_gen']['dc']]
    out = out + da.SNR_list()
    out_.append(out)

In [None]:
t = da_[0].time_list()

In [None]:
data = pd.DataFrame.from_records(out_, columns=['ds', 'dc_gen'] + list(t))
# data = data[data['ds'] == 0.5]

In [None]:
grouped = pd.groupby(data, ['ds', 'dc_gen'])
c_ = ['g', 'b', 'r', 'y', 'm', 'c']

In [None]:
plt.figure(figsize=(10, 7))
plt.title('SNR as a function of time')
for c, (name, group) in zip(c_, grouped):
    label = 'ds={}, dc={}'.format(*name)
    plot_fill_between(t, group[list(t)], label=label, c=c)
plt.legend(loc='upper left')
# plt.savefig(os.path.join(output_dir, 'motion_benefit.png'), dpi=200)

# Uncertainty

In [None]:
da = da_[0]

In [None]:
da.data['EM_data/0']