In [None]:
import os

import numpy as np
from scipy.stats import ks_2samp
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['font.size'] = 7
from matplotlib.ticker import MaxNLocator

import pandas as pd


from context import src, utils
from src.analyzer import DataAnalyzer, plot_fill_between
from utils.rf_plot import show_fields
from utils.plot_utils import label_subplot


output_dir = '../output/natural_sparsity_van_hateren3'

In [None]:
pkl_fns = [os.path.join(output_dir, fn) 
           for fn in os.listdir(output_dir) 
           if fn.endswith('.pkl') or fn.endswith('.h5')]
pkl_fns.sort()
len(pkl_fns)

In [None]:
da_ = [DataAnalyzer.fromfilename(pkl_fn) for pkl_fn in pkl_fns]

In [None]:
records = []
for da in da_:
    record = {}
    record['D_name'] = da.data['D_name']
    record['ds'] = da.data['ds']
    record['lamb'] = da.data['lamb']
    record['dc'] = da.data['motion_gen']['dc']
    for t, snr_t in zip(da.time_list(), da.snr_list()):
        record[t] = snr_t

    records.append(record)

In [None]:
data = pd.DataFrame(records)
t = da.time_list()
grouping_columns = ['D_name', 'ds', 'lamb', 'dc']
grouped = pd.groupby(data, grouping_columns)
len(grouped)

In [None]:
for i, (name, group) in enumerate(grouped):
    tmp = ' | '.join([_ + ' {}' for _ in grouping_columns]).format(*name)
#     D_name, ds, lamb, dc = name
#     if ds != 0.4:
#         continue
    print 'Group: {} | Size: {:2d} | '.format(i, len(group)) + tmp
    print group[list(t)[-1]].mean()

In [None]:
da = da_[3]

In [None]:
plt.imshow(da.data['S_gen'], vmin=-0.5, vmax=0.5, cmap='bwr')
plt.colorbar()

In [None]:
da.s_range = 'sym'
fig, ax = da.plot_em_estimate(-1, figsize=(10, 10))

In [None]:
for i, da in enumerate(da_):
    print 'File {:3d} ds {:.2f}, dname: {:7s} img {:.5f}'.format(
        i, da.data['ds'], da.data['D_name'], da.snr_one_iteration(da.N_itr - 1))

In [None]:
INDEP = 'Indep'
SPAR = 'Sparse'
PCA = 'PCA'

In [None]:
c_ = plt.cm.rainbow(np.linspace(0, 1, 2 * len(grouped)))
np.random.shuffle(c_); c_ = list(c_)

label_ = {INDEP: 'Independent Pixel Prior', 
          SPAR: 'Sparse Prior', 
          PCA : 'PCA'}

In [None]:
title = 'SNR as a function of time'.format('')
fig, ax = plt.subplots(figsize=(7, 7))
ax.set_title(title)
alpha = 0.75
for c, (name, group) in zip(c_, grouped):
    D_name, ds, lamb, dc = name
    if dc != 20:
        continue
    #     if ds != 1.0:
#         continue
    label = 'D: {}, ds: {:.2f}, lamb: {:.4f}'.format(label_[D_name], ds, lamb)
#     label = label_[D_name]
    plot_fill_between(ax, t, group[list(t)], label=label, c=c, k=0.5)
#     ax.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     ax.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
    ax.set_xlabel('time (ms)')
    ax.set_ylabel('SNR')
ax.legend(loc='upper left')
# plt.savefig(os.path.join(output_dir, 'dict_compare.png'), dpi=200)

In [None]:
final_snrs = [group[list(t)[-1]].values for k, group in grouped]
names = [k for k, group in grouped]

In [None]:
q1, q2 = 1, 0
print names[q1], names[q2]
ks_2samp(final_snrs[q1], final_snrs[q2])

In [None]:
idx = [data[(data['D_name'] == key) & (data['ds'] == 0.75)].index.values for key in [INDEP, PCA, SPAR]]
indep_idx, spar_idx, pca_idx = idx

In [None]:
# _ = da_[indep_idx[0]].plot_em_estimate(-1)
# # plt.savefig(os.path.join(output_dir, 'sparse_example.png'), dpi=200)

In [None]:
def plot_fill_between(ax, t, data, label='', c=None, hatch=None, k=1., alpha=0.5):
    """
    Create a plot of the data +/- k standard deviations.

    Parameters
    ----------
    t : array, shape (timesteps, )
        Times for each data point
    data : array, shape (samples, timesteps)
        Data to plot mean and +/- one sdev as a function of time
    k : float
        Scaling factor for standard deviations
    """
    mm = data.mean(0)
    sd = data.std(0) * k
    p = ax.plot(t, mm, color=c, label=label)
    c = p[-1].get_color()
    p = ax.fill_between(t, mm - sd, mm + sd, alpha=alpha, color=c,
                    hatch=hatch)


In [None]:
grouping_columns = ['D_name', 'ds', 'lamb', 'dc']
grouped = pd.groupby(data[data['dc']==20], grouping_columns)
len(grouped)

In [None]:
plt.rcParams['font.sans-serif'] = [u'Arial']

plt.rcParams['lines.linewidth'] = 0.25

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.frameon'] = False


In [None]:
import brewer2mpl
bmap = brewer2mpl.get_map('Set1', 'qualitative', 5)
colors = bmap.mpl_colors

# bmap = brewer2mpl.get_map('Set2', 'qualitative', 7)
# colors = bmap.mpl_colors

matplotlib.rcParams['axes.color_cycle'] = colors


In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(7, 3.6))


label_ = {'Indep': 'IND', 
          'Sparse': 'SP', 
          'PCA' : 'PCA'}


for i, idx1 in enumerate(idx):
    da = da_[idx1[0]]
    ax = axes[1][i]
    da.plot_image_estimate(fig, ax, -1, colorbar=False)

    ax.set_title('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    ax.set_axis_off()

    ax = axes[0][i]
    D = da.data['D'].copy()
    D_name = da.data['D_name']

    ax.set_title('Dictionary: {}'.format(label_[D_name]))
    np.random.shuffle(D)
    show_fields(D[0:25], pos_only=False, fig=fig, ax=ax, colorbar=False, normed=True)

    ax.set_axis_off()
    
ax = axes[1][-1]
# da.plot_base_image(fig, ax, colorbar=False, cmap=plt.cm.gray_r)
ax.set_title('Original')

da.plot_image_and_rfs(fig=fig, ax=ax, alpha_rf=0.25)
a = da.data['ds'] * da.L_I / 2
ax.set_xlim([-a, a])
ax.set_ylim([-a, a])

ax.set_axis_off()
    
    
    
ax = axes[0][-1]
 
ax.set_title('SNR vs time (ms)')
alpha = 0.75
cix = -1
for c, (name, group) in zip(c_, grouped):
    D_name, ds, lamb, dc = name
    if lamb > 0.005:
        continue
    cix += 1
    label = label_[D_name]
    if D_name == 'Sparse':
        if lamb == 0:
            label += 'Z'
    plot_fill_between(ax, t, group[list(t)], label=label, c=None, k=0.5, alpha=0.75)
#     ax.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     ax.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);


# obtain the handles and labels from the figure
handles, labels = ax.get_legend_handles_labels()
# copy the handles
import copy
handles = [copy.copy(ha) for ha in handles ]
# set the linewidths to the copies
[ha.set_linewidth(3) for ha in handles ]
# put the copies into the legend
leg = ax.legend(handles=handles, labels=labels, loc='upper left')

ax.yaxis.set_major_locator(MaxNLocator(4))

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_visible(False)
TICK_WIDTH = 0.25
ax.xaxis.set_tick_params(width=TICK_WIDTH)
ax.yaxis.set_tick_params(width=TICK_WIDTH)



for i, ax in enumerate(axes.T.flat):
    label = chr(ord('a') + i)
    label_subplot(fig=fig, ax=ax, label=label, dy=0.03)

    ax.set_title(ax.get_title(), fontdict={'size': 7})

plt.subplots_adjust(hspace=0.3)
plt.savefig(os.path.join('../output', 'natural_dict_and_rec.pdf'), dpi=300)

In [None]:
plt.subplots_adjust?

In [None]:
fig, ax = plt.subplots(1, 1)
da.plot_image_and_rfs(fig=fig, ax=ax, alpha_rf=0.25)
a = 3
ax.set_xlim([-a, a])
ax.set_ylim([-a, a])

# ax.set_axis_off()


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10,10))
show_fields(D, fig=fig, ax=ax, normed=True)
plt.savefig('dict.pdf', dpi=200)

In [None]:
video_dir = os.path.join(output_dir, 'video')

if not os.path.exists(video_dir):
    os.makedirs(video_dir)

Create a video
`
avconv -framerate 20 -i img_%04d.png -c:v libx264 -r 30 rec.mp4
`

In [None]:
# for i in range(da.N_itr):
#     print 'Rendering image {:04d}'.format(i)
# #     da_[spar_idx[1]].plot_em_estimate(i)
# #     plt.savefig(os.path.join(video_dir, 'img_{:04d}.png'.format(i)), dpi=150)
# #     plt.close()

Dictionary with Reconstructions after 200 ms, DC = 100

In [None]:
plt.rcParams['font.size'] = 7

In [None]:
def plot_snr_fcn_time(fig, ax, grouped, label_):

    c_ = plt.cm.rainbow(np.linspace(0, 1, len(grouped)))
    np.random.shuffle(c_); c_ = list(c_)

    c_ = ['r', 'r', 'r', 'g' 'g', 'g', 'b', 'b', 'b']
#     label_ = {'Indep': 'Independent Pixel Prior', 
#               'Sparse': 'Sparse Prior', 
#               'Non-sparse' : 'Non-sparse Prior'}
    
    
    title = 'SNR as a function of time'.format('')
    ax.set_title(title)
    alpha = 0.75
    for c, (name, group) in zip(c_, grouped):
        D_name, ds = name
        if ds != 0.70:
            continue
    #     label = 'D: {}, ds: {:.2f}'.format(label_[D_name])
        label = label_[D_name]
#         label=D_name
        plot_fill_between(ax, t, group[list(t)], label=label, c=c, k=0.5)
        ax.set_xlabel('time (ms)')
        ax.set_ylabel('SNR')
    ax.legend(loc='upper left', prop={'size': '6'})

In [None]:
from src.analyzer import _get_sum_gaussian_image

In [None]:
def plot_image_estimate(self, fig, ax, q, cmap=plt.cm.gray,
                        colorbar=True, vmax=None):

    """Plot the estimated image after iteration q."""
    if q == -1:
        q = self.N_itr - 1

    res = _get_sum_gaussian_image(
        self.data['EM_data'][q]['image_est'].ravel(),
        self.xs, self.ys,
    self.data['ds'] / np.sqrt(2), n=100)
    ax.set_title('Estimated Image, S = DA:\n SNR = %.2f'
            % self.snr_one_iteration(q))
    # FIXME: extent calculation could break in future
    a = self.data['ds'] * self.L_I / 2
    cax = ax.imshow(res, cmap=cmap, interpolation='nearest',
                         extent=[-a, a, -a, a],
                         vmax=vmax)
    if colorbar:
        fig.colorbar(cax, ax=ax)

In [None]:
tmp = [idx[u] for u in [0, 2, 1]]
tmp = [tp[1] for tp in tmp]

label_ = {'Indep': 'IND', 
          'Sparse': 'SP', 
          'Non-sparse' : 'N-SP'}



fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(3.5, 4))

for (u, v), idx_  in zip([[0, 1], [1, 1], [2, 1]], tmp):
    da = da_[idx_]
    plot_image_estimate(da, fig, axes[u][v], -1, colorbar=False,
                          vmax=2.8)
    axes[u][v].set_title('{}: SNR = {:.2f}'.format(
            label_[da.data['D_name']], da.snr_one_iteration(da.N_itr - 1)))

plot_snr_fcn_time(fig, axes[2][0], grouped, label_)

da.plot_base_image(fig, axes[0][0])
axes[0][0].set_title('Original Pattern')


da.plot_image_and_rfs(fig, axes[1][0], legend=False)
for u in [0, 1]:
    axes[u][0].set_xlabel('x (arcmin)')
    axes[u][0].set_ylabel('y (arcmin)')
axes[1][0].set_title('Pattern and RFs')

for ax in axes.flat:
    ax.set_title(ax.get_title(), fontdict={'size': 7})


plt.tight_layout(pad=0.2)
# plt.savefig(os.path.join(output_dir, 'sparsity.pdf'), dpi=300)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(7, 4.5))

for i, idx1 in enumerate([idx[u] for u in [0, 1]]):
    da = da_[idx1[0]]
    da.plot_image_estimate(fig, axes[1][i], -1)
    axes[1][i].set_title('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    plt.subplot(2, 3, i + 1)
    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)

# plt.savefig(os.path.join(output_dir, 'dict_and_rec.pdf'), dpi=300)

In [None]:
plt.figure(figsize=(7, 4))
# plt.suptitle('Reconstruction as a function of time for Sparse Image Prior')
da = da_[spar_idx[1]]
for i, ii in enumerate([None, 0, 14, 24, 59, 99]):
    plt.subplot(2, 3, i + 1)
    if i == 0:
        da.plot_base_image()
        plt.title('True Image')
    else: 
        da.plot_image_estimate(ii)
        plt.title('t = {} ms'.format(ii * 2 + 2))
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'sparse_rec_time.png'), dpi=200)

Plot of Dictionaries

In [None]:
plt.figure(figsize=(12, 3))
for i, q in enumerate([0, 20, 40]):
    da = da_[q]
    plt.subplot(1, 3, i + 1)

    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)
# plt.savefig(os.path.join(output_dir, 'dictionaries.png'), dpi=250)

In [None]:
plt.figure(figsize=(3, 3))
da.plot_tuning_curves()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'firing_rate.png'), dpi=200)