In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline
# plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['font.size'] = 10
plt.rcParams['font.sans-serif'] = [u'Arial']
from matplotlib.patches import Rectangle



LINEWIDTH = 0.25

plt.rcParams['lines.linewidth'] = LINEWIDTH
plt.rcParams['axes.linewidth'] = LINEWIDTH
plt.rcParams['xtick.major.width'] = LINEWIDTH
plt.rcParams['ytick.major.width'] = LINEWIDTH

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.frameon'] = False


import brewer2mpl
bmap = brewer2mpl.get_map('Set1', 'qualitative', 5)
colors = bmap.mpl_colors

from cycler import cycler

plt.rcParams['axes.prop_cycle'] = cycler(color=colors)



from scipy.stats import ks_2samp
from sklearn.neighbors import KernelDensity

import pandas as pd


from context import src, utils
from src.analyzer import DataAnalyzer, plot_fill_between
from utils.plot_utils import label_subplot, equalize_y_axes


def nice_print(lst):
    l1 = []
    for l in lst:
        if isinstance(l, str):
            l1.append(l)
        elif isinstance(l, float):
            l1.append('{:6.2f}'.format(l))
        elif isinstance(l, int):
            l1.append('{:02}'.format(l))
        else:
            l1.append(str(l))
    return ' | '.join(l1)

from collections import OrderedDict

def groupby(data, grouping_columns):
    """
    Groupby that has group names that are a dictionary.
    """
    grouped = data.groupby(grouping_columns)

    lst = []
    for name, group in grouped:
        if len(grouping_columns) == 1:
            name = (name,)
        name1 = OrderedDict(zip(grouping_columns, name))
        lst.append((name1, group))
    return lst

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
def plot_results(ax, grouped):
    """Plot SNR vs. Time for the different inference modes."""
    label_ = []
    rect_ = []
    for name, group in grouped:
        gen_mode = name['gen_mode']
        inf_mode = name['inference_mode']
        label = ({EXP: 'S:M', NOM: 'S:NM'}[gen_mode] + ' | ' + 
                 {'EM': 'D:EM ', 'NoMotion': 'D:NM'}[inf_mode])
        plot_fill_between(ax, t, group[list(t)], label=label, confidence_interval=True)
        c = ax.get_lines()[-1].get_color()
        rect_.append(Rectangle((0, 0), 1, 1, fc=c, hatch=None, linewidth=0))
        label_.append(label)
    ax.legend(rect_, label_, loc='upper left', prop={'size': 7}, labelspacing=0.2)

In [None]:
data_fns = []
output_dir_list = ['motion_benefit_nom1', 'motion_benefit6'] # Main plots of motion benefit
output_dir_list = ['motion_benefit_drop']   # Plots for motion benefit in the presence of cone loss
# output_dir_list = ['motion_benefit_best_dc_sweep']  # Find best diffusion constant 
output_dir_list = ['extended_motion_benefit']  # Extended motion benefit with cone loss experiments
output_dir_list = ['motion_gain']  # Experiments testing gain factor on eye motions
output_dir_list = ['stimlus_size']  # Experiments testing gain factor on eye motions
output_dir_list = [os.path.join('../output', folder) for folder in output_dir_list]
for output_dir in output_dir_list: 
    data_fns.extend(
        [os.path.join(output_dir, fn) 
               for fn in os.listdir(output_dir) 
               if fn.endswith('.h5')]
        )
data_fns.sort()
len(data_fns)

In [None]:
da_ = map(DataAnalyzer.fromfilename, data_fns)

for da in da_:
    da.s_range = 'pos'

In [None]:
out_ = []
for da in da_:
    out = {}
    out['inference_mode'] = da.data['EM_data/mode']
    out['gen_mode'] = da.data['motion_gen']['mode']   # True eye movements or no eye movements
    try:
        out['gen_dc'] = da.data['motion_gen']['dc']
    except KeyError:
        out['gen_dc'] = 0.0
    try:
        out['drop_prob'] = da.data['drop_prob']
    except KeyError:
        out['drop_prob'] = 0.0
    out['scaling_factor'] = da.data['motion_gen'].get('scaling_factor', 1.0)
    out['prior_dc'] = da.data['motion_prior']['dc']
    out['ds'] = da.data['ds'] # Image size
    for t, snr_t in zip(da.time_list(), da.snr_list()):
        out[t] = snr_t

    out_.append(out)

In [None]:
t = da.time_list()
tf = t[-1]
data = pd.DataFrame(out_)

grouping_columns = ['inference_mode', 'gen_mode', 'gen_dc', 'prior_dc', 'drop_prob', 'scaling_factor', 'ds']
grouped = groupby(data, grouping_columns)

EXP = 'Experiment'
NOM = 'Diffusion'  # no motion
# data

In [None]:
# Only for drop prob experiments

tuning_key = 'drop_prob'
x_label = 'Cone Loss'
log_snr = False
x_normalizer = 1.

tuning_key = 'scaling_factor'
x_label = 'Motion Gain'
log_snr = False
x_normalizer = 1.

tuning_key = 'ds'
x_label = 'Relative Stimlus Size'
log_snr = True
x_normalizer = 0.4

y_label = 'SNR'
# if log_snr:
#     y_label = '$log_{10}$ SNR'

item_ = []
for name, group in grouped:
    snrs = group[list(t)[-1]].values
#     if log_snr:
#         snrs = np.log(snrs) / np.log(10)
#     snrs = np.minimum(snrs, 20.)
    mean = snrs.mean()
    std = snrs.std() / np.sqrt(len(snrs))
    item = {
        tuning_key: name[tuning_key], 
        'gen_mode': name['gen_mode'], 
        'snr': snrs,
        'snr_mean': mean,
        'snr_std': std,
    }
    if tuning_key == 'ds' and name[tuning_key] > 0.7:
        continue
    item_ += [item]
df = pd.DataFrame(item_)

fig, ax = plt.subplots(1, 1, figsize=(3, 3))
for name, group in df.groupby('gen_mode'):
    label = {EXP:'Motion', NOM:'No Motion'}[name]
    color = {EXP:'b', NOM: 'r'}[name]
    ax.errorbar(
        group[tuning_key] / x_normalizer, 
        group['snr_mean'], yerr=group['snr_std'], 
        label=label, color=color
    )
ax.set_ylabel(y_label)
ax.set_xlabel(x_label)
if log_snr:
    ax.set_yscale('log')

ax.legend()

In [None]:
# FIXME: unclear how to show relative performance

x_label = 'Normalized Stimulus Size'
fig, ax = plt.subplots(1, 1, figsize=(3, 3))
yy_ = {}
snrs_ = {}
for name, group in df.groupby('gen_mode'):
    label = {EXP:'Motion', NOM:'No Motion'}[name]
    color = {EXP:'b', NOM: 'r'}[name]
    xx = group[tuning_key]
    yy_[name] = [group['snr_mean'].values]
    snrs_[name] = np.array([_ for _ in group['snr'].values])
        

# n_val = 5
# snr_ratio = snrs_[1].reshape(n_val, 1, -1) / snrs_[0].reshape(n_val, -1, 1)
snr_ratio = snrs_[EXP] - snrs_[NOM][:, ::-1]

# snr_ratio = snr_ratio.reshape(n_val, -1)
yy = snr_ratio.mean(axis=1)
yerr = snr_ratio.std(axis=1) / np.sqrt(snr_ratio.shape[1])
xx = xx / 0.4
# yy = yy_[1] / yy_[0]
ax.errorbar(xx, yy,  yerr=yerr, ls='-', label=label, color='black')

ax.set_ylabel('SNR (Motion) - SNR (No Motion)')
ax.set_xlabel(x_label)
# ax.set_ylim([0, ax.get_ylim()[-1]])
ax.axhline(y=0, ls='--', c='black')
# ax.set_ylim([0, 32])

In [None]:
print ['Group'] + grouping_columns + ['SNR', 'N_EXP']
for i, (name, group) in enumerate(grouped):
#     if name['drop_prob'] > 0.1:
#         continue
    print nice_print(
        [i]
        + list(name.values())
        + [group[list(t)[-1]].mean()]
        + [len(group)]
    )

In [None]:
# # Select the best prior_dc for given gen_dc
# grouping_columns = ['inference_mode', 'gen_mode', 'gen_dc', 'ds', 'drop_prob',]
# grouped = groupby(data, grouping_columns)
# data1 = []
# for name, group in grouped:
# #     print group
#     subgrouped = groupby(group, ['prior_dc'])
#     snrs = [subgroup[tf].mean() for _, subgroup in subgrouped]
#     best_group_idx = np.argmax(snrs)
#     _, best_group = subgrouped[best_group_idx]
#     data1.append(best_group)

# data1 = pd.concat(data1)

# df1 = pd.groupby(
#     data1[['gen_dc', t[-1], 'drop_prob']], 
#     ['gen_dc', 'drop_prob']).aggregate([np.mean, np.std])
# df1.reset_index(inplace=True)


# fig, ax = plt.subplots(1, 1, figsize=(3, 3))
# for name, group in df1.groupby('drop_prob'):
#     ax.plot(group['gen_dc'], group[t[-1]]['mean'], label='Drop Prob: {:.2f}'.format(name))
# ax.set_ylabel('SNR')
# ax.set_xlabel('gen_dc')
# ax.set_xscale('log')
# ax.legend()
# plt.savefig('../output/drop_prob.pdf')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.set_title('SNR as a function of time')
plot_results(ax, grouped)

In [None]:
final_snrs = [group[tf].values for k, group in grouped]

res = ks_2samp(final_snrs[0], final_snrs[1])
print res # two sided / 2 = one sided
label_ = [k['gen_mode'] for k, group in grouped]



In [None]:
plt.figure(figsize=(4, 4))
X_plot = np.linspace(0, 20, 1000)[:, np.newaxis]
for final_snr, label in zip(final_snrs, label_):
    log_dens = KernelDensity().fit(final_snr[:, np.newaxis]).score_samples(X_plot)
    plt.plot(X_plot, np.exp(log_dens), c=None, label=label)
plt.xlabel('SNR')
plt.ylabel('Normalized Density')
plt.title('Density of SNR, p = {:.1E}'.format(res.pvalue))
plt.legend()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'motion_benefit_density.png'), dpi=200)

In [None]:
idx = [data[
        (data['gen_mode'] == key) & 
        (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'EM')
    ].index.values for key in [EXP, NOM]]
exp_idx, nom_idx = idx

In [None]:
idx = [data[
        (data['gen_mode'] == key) & 
        (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'NoMotion')
    ].index.values for key in [EXP, NOM]]
nom_exp_idx, nom_nom_idx = idx
# np.random.shuffle(exp_idx)
# np.random.shuffle(nom_idx)

In [None]:
# da = da_[nom_exp_idx[0]]

# fig, ax = plt.subplots(1, 1)
# da.plot_image_estimate(fig=fig, ax=ax, q=100)

In [None]:
# da = da_[exp_idx[2]]

# da.plot_em_estimate(-1)

In [None]:
da = da_[-1]

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(10, 10))
ax = axes
# da.plot_image_estimate(fig, ax, da.N_itr - 1)
da.plot_image_and_rfs(fig, ax, -1)

In [None]:
da.xr

In [None]:
da = da_[-1]

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(3.5, 3.5))

da_exp, da_nom = [da_[idx_[-1]] for idx_ in [exp_idx, nom_idx]]

ax = axes[0, 0]
da_exp.plot_image_and_rfs(fig, ax, -1)
a = 4; ax.set_xlim([-a, a]); ax.set_ylim([-a, a])
b = 3; ax.set_xticks([-b, 0, b]); ax.set_yticks([-b, 0, b])

ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
# ax.xaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.yaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.axis('off')


ax = axes[0, 1]
plot_results(ax, grouped)
ax.set_xlim([0., 700.])
ax.set_xticks([0, 350, 700])
ax.set_title('SNR vs time (ms)')
# ax.set_xlabel('time (ms)')
# ax.set_ylabel('SNR')
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')


# ax.xaxis.set_major_locator(MaxNLocator(3))
ax.yaxis.set_major_locator(MaxNLocator(4))

for v, da, label in zip(
    [0, 1],
    [da_exp, da_nom],
    ['Motion', 'No Motion']):
    ax = axes[1, v]
    da.plot_image_estimate(fig, ax, da.N_itr - 1,
                           colorbar=False)
    ax.set_title('{}: SNR = {:.2f}'.format(label, da.snr_one_iteration(da.N_itr - 1)))
    ax.set_xticks([])
    ax.set_yticks([])

#     ax.xaxis.set_major_locator(MaxNLocator(3))
#     ax.yaxis.set_major_locator(MaxNLocator(3))


for i, ax in enumerate(axes.flat):
    ax.set_title(ax.get_title(), fontdict={'size': 10})
    label=chr(ord('A') + i)
    label_subplot(fig=fig, ax=ax, label=label, dx=0.07, dy=0.03)

    
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)
    
plt.tight_layout(pad=0.5)
plt.savefig(os.path.join('../output', 'motion_benefit_drop.pdf'), dpi=200)

In [None]:
fig = plt.figure(figsize=(7.5, 3.75))
size = (6, 12)
ax1 = plt.subplot2grid(size, (0, 0), colspan=3, rowspan=3)
ax7 = plt.subplot2grid(size, (0, 3), colspan=3, rowspan=3)

ax2 = plt.subplot2grid(size, (3, 0), colspan=3, rowspan=3)
ax3 = plt.subplot2grid(size, (3, 3), rowspan=3, colspan=3)


ax4_ = []
for i in range(3):
    ax4_.append(plt.subplot2grid(size, (0, 6 + 2 * i), rowspan=2, colspan=2))


# ax7 = plt.subplot2grid(size, (2, 6), rowspan=4, colspan=2)
    


ax6_ = []
for i in range(3):
    ax6_.append(plt.subplot2grid(size, (2, 6 + 2 * i), rowspan=2, colspan=2))

ax5_ = []
for i in range(2):
    ax5_.append(plt.subplot2grid(size, (4, 7 + 2 * i), rowspan=2, colspan=2))


da_exp, da_nom = [da_[idx_[4]] for idx_ in [exp_idx, nom_idx]]

da_nom_exp = da_[nom_exp_idx[0]]


ax = ax1
da.plot_base_image(fig=fig, ax=ax, colorbar=False)
ax.set_title('Image on the Retina')
ax.set_xticks([-3, 0, 3])
ax.set_yticks([-3, 0, 3])

# ax.yaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.xaxis.set_major_locator(MaxNLocator(3, symmetric=True))

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_visible(False)


    
ax = ax7

plot_results(ax, grouped)
ax.set_xlim([0., 700.])
ax.set_xticks([0, 350, 700])
ax.set_title('SNR vs time (ms)')
for axis in ['top', 'right']:
    ax.spines[axis].set_visible(False)
ax.yaxis.set_major_locator(MaxNLocator(3))

# da.plot_image_and_rfs(fig, ax)

# # ax.set_title('E Projected onto the Retina \n E thickness: {:.2f} arcmin \n Neuron Spacing {:.2f} arcmin'.format(
# #         2 * da.data['ds'], da.data['de']))
# a = 6; ax.set_xlim([-a, a]); ax.set_ylim([-a, a])

# # ax.set_yticks([-5, 0, 5])
# ax.yaxis.set_major_locator(MaxNLocator(3, symmetric=True))
# ax.xaxis.set_major_locator(MaxNLocator(3, symmetric=True))

# ax.title.set_fontsize(7)

# for axis in ['top','bottom','left','right']:
#     ax.spines[axis].set_visible(False)


    
for v, da, label, ax in zip(
    [0, 1],
    [da_exp, da_nom],
    ['S:M | D:EM', 'S:NM | D:EM'],
    [ax2, ax3]):
#     ax = axes[1, v]
    da.plot_image_estimate(fig, ax, da.N_itr - 1,
                           colorbar=False)
    ax.set_title('{}'.format(label))
    ax.set_xlabel('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    ax.set_xticks([])
    ax.set_yticks([])

    
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)

    ax.set_aspect('equal')

# for ax in [ax7]:
#     ax.axis('off')
#     for axis in ['top','bottom','left','right']:
#         ax.spines[axis].set_visible(False)





for i, q in enumerate([9, 39, 79]):
    ax = ax4_[i]
    da = da_exp

    colorbar = False
    da.plot_image_estimate(fig, ax, q=q,
                           colorbar=colorbar)
    t0 = da.data['EM_data/{}/time_steps'.format(q)]
    ax.set_title('T = {} ms'.format(t0), fontsize=7)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)

    ax = ax6_[i]
    da_nom_exp.plot_image_estimate(fig, ax, q=q, colorbar=colorbar)

    t0 = da.data['EM_data/{}/time_steps'.format(q)]
    ax.set_title('T = {} ms'.format(t0), fontsize=7)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.set_ylabel('Motion | NoM')


    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(False)
        
q = 139
for i, title in enumerate(['x', 'y']):
    ax = ax5_[i]
    da.plot_path_estimate(ax, q, i)
    ax.set_title(title)
    ax.set_ylabel('')
    start, end = ax.get_xlim()
    ax.set_xticks([.7])
    ax.set_xticklabels([700])
    ax.set_xlabel('')

    ax.set_yticks([0, 5])
    
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    
    for axis in ['top', 'right']:
        ax.spines[axis].set_visible(False)
        
equalize_y_axes(ax5_[0], ax5_[1])

# for i, mode in enumerate(['ON', 'OFF']):
#     ax = ax6_[i]
#     da.plot_spikes(ax, 600, mode=mode)
#     ax.set_title('')

# ax = ax5_[-1]
# da.plot_spikes(ax, 600, mode='ON')
# ax.set_title('ON RGCs')

# for axis in ['top','bottom','left','right']:
#     ax.spines[axis].set_visible(False)



for i, ax in enumerate(
    [
        ax1, ax7, ax2, ax4_[0], ax6_[0], ax5_[0]
    ]):
    label=chr(ord('A') + i)

    label_subplot(fig=fig, ax=ax, label=label, dx=0.03, dy=0.01)

    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    
for ax in fig.axes:
    ax.title.set_fontsize(10)
    
plt.subplots_adjust(wspace=1, hspace=3, top=0.9, left=0.05, right=0.95, bottom=0.1)
# plt.tight_layout()
plt.savefig('../output/motion_benefit.pdf')

In [None]:
# plt.figure(figsize=(4, 4))
# plt.title('SNR as a function of time')
# alpha = 0.5
# for c, (name, group) in zip(c_, grouped):
#     inf_mode, gen_mode, prior_dc, ds = name
#     label = {EXP: 'Motion', NOM: 'No Motion'}[gen_mode]
#     plt.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     plt.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
#     plt.xlabel('time (ms)')
# plt.legend(loc='upper left')
# # plt.savefig(os.path.join(output_dir, 'snr_time_no_motion_vs_motion_0.png'), dpi=200)

In [None]:
idx = data[
        (data['gen_mode'] == NOM) & 
#         (data['ds'] == 0.4) & 
        (data['inference_mode'] == 'NoMotion')
    ].index.values
print idx

In [None]:
da = da_[idx[0]]

In [None]:
fig, ax = plt.subplots(1, 1)

da.plot_image_estimate(fig=fig, ax=ax, q=-1)

In [None]:
da.plot_em_estimate(-1)
# plt.savefig(os.path.join(output_dir, 'ds04_e_reconstruction.png'), dpi=200)

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(4 * 2, 3.5 * 5))
for i, idx in enumerate(nom_idx[0:10]):
#     plt.subplot(5, 2, i + 1)
    da_[idx].plot_image_estimate(fig, axes.flat[i], -1)
plt.savefig(os.path.join(output_dir, 'e_reconstructions_no_motion.png'), dpi=200)

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(4 * 2, 3.5 * 5))
for i, idx in enumerate(exp_idx[0:10]):
    da_[idx].plot_image_estimate(fig, axes.flat[i], -1)
plt.savefig(os.path.join(output_dir, 'e_reconstructions_motion.png'), dpi=200)

In [None]:
plt.figure(figsize=(20, 20))
for i, t in enumerate(range(0, 140, 10)):
    plt.subplot(5, 5, i + 1)
    da.plot_image_estimate(t)

# Group SNR plots by image size and diffusion constant

In [None]:
out_ = []
for pkl_fn in pkl_fns:
    da = DataAnalyzer.fromfilename(pkl_fn)
    out = [da.data['ds'], da.data['motion_gen']['dc']]
    out = out + da.SNR_list()
    out_.append(out)

In [None]:
t = da_[0].time_list()

In [None]:
data = pd.DataFrame.from_records(out_, columns=['ds', 'dc_gen'] + list(t))
# data = data[data['ds'] == 0.5]

In [None]:
grouped = pd.groupby(data, ['ds', 'dc_gen'])
c_ = ['g', 'b', 'r', 'y', 'm', 'c']

In [None]:
plt.figure(figsize=(10, 7))
plt.title('SNR as a function of time')
for c, (name, group) in zip(c_, grouped):
    label = 'ds={}, dc={}'.format(*name)
    plot_fill_between(t, group[list(t)], label=label, c=c)
plt.legend(loc='upper left')
# plt.savefig(os.path.join(output_dir, 'motion_benefit.png'), dpi=200)

# Uncertainty

In [None]:
da = da_[0]

In [None]:
da.data['EM_data/0']