In [None]:
import os

import numpy as np
from scipy.stats import ks_2samp
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['font.size'] = 7
plt.rcParams['font.sans-serif'] = [u'Arial']

LINEWIDTH = 0.25

plt.rcParams['lines.linewidth'] = LINEWIDTH
plt.rcParams['axes.linewidth'] = LINEWIDTH
plt.rcParams['xtick.major.width'] = LINEWIDTH
plt.rcParams['ytick.major.width'] = LINEWIDTH

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.frameon'] = False

from matplotlib.ticker import MaxNLocator

import brewer2mpl
bmap = brewer2mpl.get_map('Set1', 'qualitative', 6)
colors = bmap.mpl_colors
matplotlib.rcParams['axes.color_cycle'] = colors

import pandas as pd


from context import src, utils
from src.analyzer import DataAnalyzer, plot_fill_between
from utils.rf_plot import show_fields
from utils.plot_utils import label_subplot, expand_legend_linewidths

INDEP = 'Indep'
SPAR = 'Sparse'
PCA = 'PCA'

dict_label_ = {
    INDEP: 'IND', 
    SPAR: 'SP', 
    PCA : 'PCA'
}

def get_label(D_name, lamb, quad_reg):
    """Get labels for different priors."""
    if D_name == PCA:
        s = 'PCA'
    elif D_name == INDEP:
        s = 'IND'
    elif D_name == SPAR:
        s = 'SP '
        if lamb == 0 and quad_reg == 0:
            s += 'Z'
        elif lamb ==0 and quad_reg > 0:
            s += 'L2'
        elif lamb > 0 and quad_reg == 0:
            s += 'L1'
        elif lamb > 0 and quad_reg > 0:
            s += 'L12'
    return s

In [None]:
data_fns = []
for output_dir in ['../output/natural_sparsity_van_hateren_sp_quad_reg']:
    data_fns.extend(
        [os.path.join(output_dir, fn) 
               for fn in os.listdir(output_dir) 
               if fn.endswith('.h5')]
        )
data_fns.sort()
len(data_fns)

In [None]:
da_ = map(DataAnalyzer.fromfilename, data_fns)

In [None]:
records = []
for da in da_:
    record = {}
    record['D_name'] = da.data['D_name']
    record['ds'] = da.data['ds']
    record['lamb'] = da.data['lamb']
    record['dc'] = da.data['motion_gen']['dc']
    record['qr'] = da.data['quad_reg'].mean()
    for t, snr_t in zip(da.time_list(), da.snr_list()):
        record[t] = snr_t
    record['s_gen'] = int(da.S_gen.sum() * 10000)
    records.append(record)

In [None]:
data = pd.DataFrame(records)
data = data[(data['lamb'] < 0.01) | (data['qr'] == 0)]  # removes extra trials
t = da.time_list()
tf = t[-1]
grouping_columns = ['D_name', 'ds', 'lamb', 'dc', 'qr']
grouped = pd.groupby(data, grouping_columns)
len(grouped)

In [None]:
for i, (name, group) in enumerate(grouped):
    tmp = ' | '.join([_ + ' {}' for _ in grouping_columns]).format(*name)
#     D_name, ds, lamb, dc = name
#     if ds != 0.4:
#         continue
    print 'Group: {} | Size: {:2d} | '.format(i, len(group)) + tmp
    print group[tf].mean()

In [None]:
group_ = []
for key, group in pd.groupby(data, ['s_gen']):
    group = group.copy()
    for tt in t:
        pca_snr = group[group['D_name'] == 'PCA'][tt].mean()
        if np.isnan(pca_snr):
            continue
        group[str(tt) + '_norm_snr'] = group[tt] / pca_snr
    group_.append(group)
t_n = map(lambda x: str(x) + '_norm_snr', t)  # column labels for snr rel PCA
data = pd.concat(group_)

In [None]:
for key, group in pd.groupby(data, ['D_name', 'lamb', 'qr']):
    print key, group[str(tf) + '_norm_snr'].mean(), group[str(tf) + '_norm_snr'].std()
#     print group['norm_snr']

In [None]:
grouped = pd.groupby(data, grouping_columns)

In [None]:
# da = da_[-3]

# plt.imshow(da.data['S_gen'], vmin=-0.5, vmax=0.5, cmap='bwr')
# plt.colorbar()

# da.s_range = 'sym'
# fig, ax = da.plot_em_estimate(-1, figsize=(10, 10))

In [None]:
# for i, da in enumerate(da_):
#     print 'File {:3d} ds {:.2f}, dname: {:7s} img {:7.3f}, lamb: {:.3f}'.format(
#         i, da.data['ds'], da.data['D_name'], da.snr_one_iteration(da.N_itr - 1), da.data['lamb'])

In [None]:
fig, ax = plt.subplots(figsize=(7, 7))
ax.set_title('SNR as a function of time (ms)')
for (name, group) in grouped:
    D_name, ds, lamb, dc, quad_reg = name
    label = 'D: {}, ds: {:.2f}, lamb: {:.4f} quadreg {:.2f}'.format(D_name, ds, lamb, 100 * quad_reg)
    label = get_label(D_name, lamb, quad_reg)
    plot_fill_between(ax, t, group[t_n], label=label, c=None, k=1.)
ax.legend(loc='upper left')
plt.savefig(os.path.join('../output/', 'dict_compare.png'), dpi=200)

In [None]:
final_snrs = [group[t_n[-1]].values for k, group in grouped]
names = [k for k, group in grouped]

In [None]:
names[0], 

In [None]:
get_label_ = lambda x: get_label(x['D_name'], x['lamb'], x['qr'])
get_label__ = lambda x: get_label(x[0], x[2], x[4])

In [None]:
pvals = {}
for name, final_snr in zip(names, final_snrs):
    print name
    pval = ks_2samp(final_snrs[1], final_snr).pvalue
    pvals[get_label__(name)] = pval

In [None]:
pvals

In [None]:
def pval_to_star(p):
    if p <= 0.0001:
        return '****'
    elif p <= 0.001:
        return '***'
    elif p <= 0.01:
        return '**'
    elif p <= 0.05:
        return '*'
    else:
        return 'ns'

In [None]:
stars = {k: pval_to_star(v) for k, v in pvals.items()}


In [None]:
stars

In [None]:
q1, q2 = 1, 4
print names[q1], names[q2]
ks_2samp(final_snrs[q1], final_snrs[q2])

In [None]:
final_snrs[q2]

In [None]:
idx__ = [data[data['D_name'] == key].index.values for key in [INDEP, PCA, SPAR]]
indep_idx_, spar_idx_, pca_idx_ = idx__

In [None]:
grouping_columns = ['D_name', 'ds', 'lamb', 'dc', 'qr']
grouped = pd.groupby(data, grouping_columns)
len(grouped)

In [None]:
data['label'] = data.apply(get_label_, axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(7/4., 3.6/2))


order = [
    'IND',
    'SP Z',
    'PCA',
    'SP L1',
    'SP L2',
    'SP L12',
]

sns.barplot(
    ax=ax,
    y='label', 
    x=t_n[-1], 
    data=data,
    order=order,
)
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('SNR rel PCA at 600 ms')

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_visible(False)

for label, patch in zip(order, ax.patches):
    w = patch.get_width()
    h = patch.get_height()
    y = patch.get_y()
    print w
    star = stars[label]
    if star is not 'ns':
        ax.text(w + 0.1, y+ 0.5 * h, star)

# ax.set_xlim([-1, 2])
ax.yaxis.tick_right()
# plt.tight_layout()
# plt.savefig('../output/snr_barchart.pdf')

In [None]:
aspect = np.diff(ax.get_xlim())
-np.diff(ax.get_ylim())        


In [None]:
ax.text?

In [None]:
patch.get_y()

In [None]:
ax.patches

In [None]:
patch.get_height()

In [None]:
x_label_ = []
yy_ = []
yys_ = []
for key, group in grouped:
    x_label_.append(key)
    snrs = group[t_n[-1]]
    yy_.append(snrs.mean())
    yys_.append(snrs.std() / np.sqrt(len(snrs)))

In [None]:
final_snrs = [group[t_n[-1]].values for k, group in grouped]

res = ks_2samp(final_snrs[0], final_snrs[2])
print res

In [None]:
xx = np.arange(len(x_label_))

In [None]:
ax = plt.axes()
ax.bar(xx, yy_, 0.25, yerr=yys_)
ax.set_xticks(xx)
ax.set_xticklabels(x_label_, rotation=90)
# _ = ax.set_ylim(0.3, 1.5)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(7, 3.6))

# Get indices of trials for different inference methods for the same image
key, group = list(pd.groupby(data, ['s_gen']))[5]

idx0 = group[
    (group['D_name'] == INDEP)
].index.values[0]


idx1 = group[
    (group['D_name'] == PCA)
].index.values[0]

idx2 = group[
    (group['D_name'] == SPAR) & 
    (group['lamb'] > 0) & 
    (group['qr'] > 0)
].index.values[0]

idx_ = [idx0, idx1, idx2]

# Plot Random Subset of Dictionary and a Reconstruction
for i, idx in enumerate(idx_):
    da = da_[idx]
    ax = axes[1][i]
    da.plot_image_estimate(fig, ax, -1, colorbar=False)

    ax.set_title('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    ax.set_axis_off()

    ax = axes[0][i]
    D = da.data['D']

    ax.set_title('Dictionary: {}'.format(dict_label_[da.data['D_name']]))
    np.random.shuffle(D)
    show_fields(D[0:25], pos_only=False, fig=fig, ax=ax, colorbar=False, normed=True)
    ax.set_axis_off()

# Plot the original image and neurons for reference
ax = axes[1][-1]
da.plot_image_and_rfs(fig=fig, ax=ax, alpha_rf=0.25)
a = da.data['ds'] * da.L_I / 2
ax.set_xlim([-a, a]); ax.set_ylim([-a, a])
ax.set_title('Original')
ax.set_axis_off()
    
 
    
ax = axes[0][-1]

order = [
    'IND',
    'SP Z',
    'PCA',
    'SP L1',
    'SP L2',
    'SP L12',
]

sns.barplot(
    ax=ax,
    y='label', 
    x=t_n[-1], 
    data=data,
    order=order,
)
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('SNR rel PCA at 600 ms')

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_visible(False)

for label, patch in zip(order, ax.patches):
    w = patch.get_width()
    h = patch.get_height()
    y = patch.get_y()
    star = stars[label]
    if star is not 'ns':
        ax.text(w + 0.1, y+ 0.5 * h, star)



# ax.set_title('SNR rel PCA vs time (ms)')
# for (name, group) in grouped:
#     D_name, ds, lamb, dc, quad_reg = name
#     label = get_label(D_name, lamb, quad_reg)
#     plot_fill_between(
#         ax, t, group[t_n], label=label, 
#         k=1./np.sqrt(len(group)), alpha=0.75
#     )
# ax.set_yticks([0.5, 1.0, 1.5])
# ax.legend()

# expand_legend_linewidths(ax, loc='lower right', labelspacing=0.1)


for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_visible(False)

aspect = -np.diff(ax.get_xlim()) / np.diff(ax.get_ylim())        

# ax.set_aspect(aspect * 1.5)
ax.yaxis.tick_right()


    
    
for i, ax in enumerate(axes.T.flat):
    label = chr(ord('a') + i)
    label_subplot(fig=fig, ax=ax, label=label, dy=0.03)

    ax.set_title(ax.get_title(), fontdict={'size': 7})

plt.subplots_adjust(hspace=0.4)
# plt.tight_layout()
plt.savefig(os.path.join('../output', 'natural_dict_and_rec.pdf'), dpi=300)

In [None]:
fig, axes = plt.subplots(1, 2)
ax = axes[0]
ax.set_aspect(2)

# Create a video

In [None]:
video_dir = os.path.join(output_dir, 'video')

if not os.path.exists(video_dir):
    os.makedirs(video_dir)

Create a video
`
avconv -framerate 20 -i img_%04d.png -c:v libx264 -r 30 rec.mp4
`

In [None]:
# for i in range(da.N_itr):
#     print 'Rendering image {:04d}'.format(i)
# #     da_[spar_idx[1]].plot_em_estimate(i)
# #     plt.savefig(os.path.join(video_dir, 'img_{:04d}.png'.format(i)), dpi=150)
# #     plt.close()

Dictionary with Reconstructions after 200 ms, DC = 100

In [None]:
def plot_snr_fcn_time(fig, ax, grouped, label_):

    c_ = plt.cm.rainbow(np.linspace(0, 1, len(grouped)))
    np.random.shuffle(c_); c_ = list(c_)

    c_ = ['r', 'r', 'r', 'g' 'g', 'g', 'b', 'b', 'b']
#     label_ = {'Indep': 'Independent Pixel Prior', 
#               'Sparse': 'Sparse Prior', 
#               'Non-sparse' : 'Non-sparse Prior'}
    
    
    title = 'SNR as a function of time'.format('')
    ax.set_title(title)
    alpha = 0.75
    for c, (name, group) in zip(c_, grouped):
        D_name, ds = name
        if ds != 0.70:
            continue
    #     label = 'D: {}, ds: {:.2f}'.format(label_[D_name])
        label = label_[D_name]
#         label=D_name
        plot_fill_between(ax, t, group[list(t)], label=label, c=c, k=0.5)
        ax.set_xlabel('time (ms)')
        ax.set_ylabel('SNR')
    ax.legend(loc='upper left', prop={'size': '6'})

In [None]:
from src.analyzer import _get_sum_gaussian_image

In [None]:
def plot_image_estimate(self, fig, ax, q, cmap=plt.cm.gray,
                        colorbar=True, vmax=None):

    """Plot the estimated image after iteration q."""
    if q == -1:
        q = self.N_itr - 1

    res = _get_sum_gaussian_image(
        self.data['EM_data'][q]['image_est'].ravel(),
        self.xs, self.ys,
    self.data['ds'] / np.sqrt(2), n=100)
    ax.set_title('Estimated Image, S = DA:\n SNR = %.2f'
            % self.snr_one_iteration(q))
    # FIXME: extent calculation could break in future
    a = self.data['ds'] * self.L_I / 2
    cax = ax.imshow(res, cmap=cmap, interpolation='nearest',
                         extent=[-a, a, -a, a],
                         vmax=vmax)
    if colorbar:
        fig.colorbar(cax, ax=ax)

In [None]:
tmp = [idx[u] for u in [0, 2, 1]]
tmp = [tp[1] for tp in tmp]

label_ = {'Indep': 'IND', 
          'Sparse': 'SP', 
          'Non-sparse' : 'N-SP'}



fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(3.5, 4))

for (u, v), idx_  in zip([[0, 1], [1, 1], [2, 1]], tmp):
    da = da_[idx_]
    plot_image_estimate(da, fig, axes[u][v], -1, colorbar=False,
                          vmax=2.8)
    axes[u][v].set_title('{}: SNR = {:.2f}'.format(
            label_[da.data['D_name']], da.snr_one_iteration(da.N_itr - 1)))

plot_snr_fcn_time(fig, axes[2][0], grouped, label_)

da.plot_base_image(fig, axes[0][0])
axes[0][0].set_title('Original Pattern')


da.plot_image_and_rfs(fig, axes[1][0], legend=False)
for u in [0, 1]:
    axes[u][0].set_xlabel('x (arcmin)')
    axes[u][0].set_ylabel('y (arcmin)')
axes[1][0].set_title('Pattern and RFs')

for ax in axes.flat:
    ax.set_title(ax.get_title(), fontdict={'size': 7})


plt.tight_layout(pad=0.2)
# plt.savefig(os.path.join(output_dir, 'sparsity.pdf'), dpi=300)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(7, 4.5))

for i, idx1 in enumerate([idx[u] for u in [0, 1]]):
    da = da_[idx1[0]]
    da.plot_image_estimate(fig, axes[1][i], -1)
    axes[1][i].set_title('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    plt.subplot(2, 3, i + 1)
    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)

# plt.savefig(os.path.join(output_dir, 'dict_and_rec.pdf'), dpi=300)

In [None]:
plt.figure(figsize=(7, 4))
# plt.suptitle('Reconstruction as a function of time for Sparse Image Prior')
da = da_[spar_idx[1]]
for i, ii in enumerate([None, 0, 14, 24, 59, 99]):
    plt.subplot(2, 3, i + 1)
    if i == 0:
        da.plot_base_image()
        plt.title('True Image')
    else: 
        da.plot_image_estimate(ii)
        plt.title('t = {} ms'.format(ii * 2 + 2))
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'sparse_rec_time.png'), dpi=200)

Plot of Dictionaries

In [None]:
plt.figure(figsize=(12, 3))
for i, q in enumerate([0, 20, 40]):
    da = da_[q]
    plt.subplot(1, 3, i + 1)

    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)
# plt.savefig(os.path.join(output_dir, 'dictionaries.png'), dpi=250)

In [None]:
plt.figure(figsize=(3, 3))
da.plot_tuning_curves()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'firing_rate.png'), dpi=200)