In [None]:
import os

import numpy as np
from scipy.stats import ks_2samp
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['font.size'] = 7

import pandas as pd


from context import src, utils
from src.analyzer import DataAnalyzer, plot_fill_between
from utils.rf_plot import show_fields

output_dir = '../output/natural_sparsity_van_hateren2'


In [None]:
pkl_fns = [os.path.join(output_dir, fn) 
           for fn in os.listdir(output_dir) 
           if fn.endswith('.h5')]
pkl_fns.sort()
len(pkl_fns)

In [None]:
da_ = [DataAnalyzer.fromfilename(pkl_fn) for pkl_fn in pkl_fns]

In [None]:
records = []
for da in da_:
    record = []
    record.append(da.data['D_name'])
    record.append(da.data['ds'])
    record.append(da.data['lamb'])
    record = record + da.snr_list()
    records.append(record)

In [None]:
t = da.time_list()
data = pd.DataFrame.from_records(records, columns=['D_name', 'ds', 'lamb'] + list(t))
grouped = pd.groupby(data, ['D_name', 'ds', 'lamb'])
len(grouped)

In [None]:
for i, (name, group) in enumerate(grouped):
    D_name, ds, lamb = name
#     if ds != 0.4:
#         continue
    print 'Group: {} | Size: {:2d} | Prior: {} | ds: {} | lamb: {}'.format(i, len(group), D_name, ds, lamb)
    print group[list(t)[-1]].mean()

In [None]:
da = da_[-1]

In [None]:
plt.imshow(da.data['S_gen'], vmin=-0.5, vmax=0.5, cmap='bwr')
plt.colorbar()

In [None]:
fig, ax = plt.subplots()
da.plot_base_image(fig, ax)

In [None]:
da.compute_spike_moving_average(0.01)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
da.plot_spikes(ax, 100, mode='OFF')

In [None]:
ax.scatter(da.data['XE'][0:n_n/2], da.data['YE'][0:n_n/2])
ax.set_aspect('equal')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
da.plot_image_estimate(fig, ax, -1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
da.plot_image_and_rfs(fig=fig, ax=ax)

In [None]:
fig, ax = da.plot_em_estimate(-1, figsize=(10, 10))

In [None]:
plt.plot(da.snr_list())

In [None]:
for i, da in enumerate(da_):
    print 'File {:3d} ds {:.2f}, dname: {:7s} img {:.5f}'.format(
        i, da.data['ds'], da.data['D_name'], da.snr_one_iteration(da.N_itr - 1))

In [None]:
q = 6

In [None]:
fig, ax = da_[0 + q].plot_em_estimate(-1, figsize=(20, 20))

In [None]:
fig, ax = da_[14 +  q].plot_em_estimate(-1, figsize=(10, 10))

In [None]:
fig, ax = da_[14 + 14 +  q].plot_em_estimate(-1, figsize=(10, 10))

In [None]:
INDEP = 'Indep'
SPAR = 'Sparse'
PCA = 'PCA'

In [None]:
c_ = plt.cm.rainbow(np.linspace(0, 1, 2 * len(grouped)))
np.random.shuffle(c_); c_ = list(c_)

label_ = {INDEP: 'Independent Pixel Prior', 
          SPAR: 'Sparse Prior', 
          PCA : 'PCA'}

In [None]:
title = 'SNR as a function of time'.format('')
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_title(title)
alpha = 0.75
for c, (name, group) in zip(c_, grouped):
    D_name, ds, lamb = name
#     if ds != 1.0:
#         continue
    label = 'D: {}, ds: {:.2f}, lamb: {:.4f}'.format(label_[D_name], ds, lamb)
#     label = label_[D_name]
    plot_fill_between(ax, t, group[list(t)], label=label, c=c, k=0.5)
#     ax.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     ax.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
    ax.set_xlabel('time (ms)')
    ax.set_ylabel('SNR')
ax.legend(loc='upper left')
# plt.savefig(os.path.join(output_dir, 'dict_compare.png'), dpi=200)

In [None]:
final_snrs = [group[list(t)[-1]].values for k, group in grouped]
names = [k for k, group in grouped]

In [None]:
q1, q2 = 1, 0
print names[q1], names[q2]
ks_2samp(final_snrs[q1], final_snrs[q2])

In [None]:
idx = [data[(data['D_name'] == key) & (data['ds'] == 0.75)].index.values for key in [INDEP, PCA, SPAR]]
indep_idx, spar_idx, pca_idx = idx

In [None]:
da_[indep_idx[0]].plot_em_estimate(-1)
# plt.savefig(os.path.join(output_dir, 'sparse_example.png'), dpi=200)

In [None]:
def show_fields(d, cmap=plt.cm.gray, m=None, pos_only=False,
                colorbar=True, fig=None, ax=None):
    """
    Plot a collection of images.

    Parameters
    ----------
    d : array, shape (n, n_pix)
        A collection of n images unrolled into n_pix length vectors
    cmap : plt.cm
        Color map for plot
    m : int
        Plot a m by m grid of receptive fields
    """
    if fig is None or ax is None:
        fig, ax = plt.subplots(1, 1)
    n, n_pix = d.shape
    if m is None:
        m = int(np.sqrt(n - 0.01)) + 1

    l = int(np.sqrt(n_pix))  # Linear dimension of the image

    mm = np.max(np.abs(d))

    out = np.zeros(((l + 1) * m - 1, (l + 1) * m - 1)) + mm

    for u in range(n):
        i = u / m
        j = u % m
        out[(i * (l + 1)):(i * (l + 1) + l),
            (j * (l + 1)):(j * (l + 1) + l)] = np.reshape(d[u], (l, l))

    if pos_only:
        m0 = 0
    else:
        m0 = -mm
    m1 = mm
    cax = ax.imshow(out, cmap=cmap, interpolation='nearest', vmin=m0, vmax=m1)
    if colorbar:
        fig.colorbar(cax, ax=ax)

#     plt.axis('off')

In [None]:
def plot_fill_between(ax, t, data, label='', c=None, hatch=None, k=1.):
    """
    Create a plot of the data +/- k standard deviations.

    Parameters
    ----------
    t : array, shape (timesteps, )
        Times for each data point
    data : array, shape (samples, timesteps)
        Data to plot mean and +/- one sdev as a function of time
    k : float
        Scaling factor for standard deviations
    """
    mm = data.mean(0)
    sd = data.std(0) * k
    ax.fill_between(t, mm - sd, mm + sd, alpha=0.5, color=c,
                    hatch=hatch)
    ax.plot(t, mm, color=c, label=label)


In [None]:
idx

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(7, 3.5))


label_ = {'Indep': 'IND', 
          'Sparse': 'SP', 
          'PCA' : 'PCA'}


for i, idx1 in enumerate(idx):
    da = da_[idx1[0]]
    ax = axes[1][i]
    da.plot_image_estimate(fig, ax, -1, colorbar=False)

    ax.set_title('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    ax.set_axis_off()

    ax = axes[0][i]
    D = da.data['D']
    D_name = da.data['D_name']

    ax.set_title('Dictionary: {}'.format(label_[D_name]))
    Dc = D.copy()
    Dc = Dc / abs(Dc).max(axis=1, keepdims=True)
    np.random.shuffle(Dc)
    show_fields(Dc[0:25], pos_only=False, fig=fig, ax=ax, colorbar=False)

    ax.set_axis_off()
    
ax = axes[1][-1]
da.plot_base_image(fig, ax, colorbar=False, cmap=plt.cm.gray)
ax.set_title('Original')
ax.set_axis_off()


    
    
    
ax = axes[0][-1]
 
title = 'SNR vs time (ms)'
ax.set_title(title)
alpha = 0.75
for c, (name, group) in zip(c_, grouped):
    D_name, ds, _ = name
#     if ds != 1.0:
#         continue
#     label = 'D: {}, ds: {:.2f}'.format(label_[D_name], ds)
    label = label_[D_name]
    plot_fill_between(ax, t, group[list(t)], label=label, c=c, k=0.5)
#     ax.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     ax.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
#     ax.set_xlabel('time (ms)')
#     ax.set_ylabel('SNR')
ax.legend(loc='upper left')


plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'natural_dict_and_rec.pdf'), dpi=300)

In [None]:
D = da_[0].data['D']

In [None]:
D = D / abs(D).max(axis=1, keepdims=True)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10,10))
show_fields(D, fig=fig, ax=ax)
plt.savefig('dict.pdf', dpi=200)

In [None]:
video_dir = os.path.join(output_dir, 'video')

if not os.path.exists(video_dir):
    os.makedirs(video_dir)

Create a video
`
avconv -framerate 20 -i img_%04d.png -c:v libx264 -r 30 rec.mp4
`

In [None]:
# for i in range(da.N_itr):
#     print 'Rendering image {:04d}'.format(i)
# #     da_[spar_idx[1]].plot_em_estimate(i)
# #     plt.savefig(os.path.join(video_dir, 'img_{:04d}.png'.format(i)), dpi=150)
# #     plt.close()

Dictionary with Reconstructions after 200 ms, DC = 100

In [None]:
plt.rcParams['font.size'] = 7

In [None]:
def plot_snr_fcn_time(fig, ax, grouped, label_):

    c_ = plt.cm.rainbow(np.linspace(0, 1, len(grouped)))
    np.random.shuffle(c_); c_ = list(c_)

    c_ = ['r', 'r', 'r', 'g' 'g', 'g', 'b', 'b', 'b']
#     label_ = {'Indep': 'Independent Pixel Prior', 
#               'Sparse': 'Sparse Prior', 
#               'Non-sparse' : 'Non-sparse Prior'}
    
    
    title = 'SNR as a function of time'.format('')
    ax.set_title(title)
    alpha = 0.75
    for c, (name, group) in zip(c_, grouped):
        D_name, ds = name
        if ds != 0.70:
            continue
    #     label = 'D: {}, ds: {:.2f}'.format(label_[D_name])
        label = label_[D_name]
#         label=D_name
        plot_fill_between(ax, t, group[list(t)], label=label, c=c, k=0.5)
        ax.set_xlabel('time (ms)')
        ax.set_ylabel('SNR')
    ax.legend(loc='upper left', prop={'size': '6'})

In [None]:
from src.analyzer import _get_sum_gaussian_image

In [None]:
def plot_image_estimate(self, fig, ax, q, cmap=plt.cm.gray,
                        colorbar=True, vmax=None):

    """Plot the estimated image after iteration q."""
    if q == -1:
        q = self.N_itr - 1

    res = _get_sum_gaussian_image(
        self.data['EM_data'][q]['image_est'].ravel(),
        self.xs, self.ys,
    self.data['ds'] / np.sqrt(2), n=100)
    ax.set_title('Estimated Image, S = DA:\n SNR = %.2f'
            % self.snr_one_iteration(q))
    # FIXME: extent calculation could break in future
    a = self.data['ds'] * self.L_I / 2
    cax = ax.imshow(res, cmap=cmap, interpolation='nearest',
                         extent=[-a, a, -a, a],
                         vmax=vmax)
    if colorbar:
        fig.colorbar(cax, ax=ax)

In [None]:
tmp = [idx[u] for u in [0, 2, 1]]
tmp = [tp[1] for tp in tmp]

label_ = {'Indep': 'IND', 
          'Sparse': 'SP', 
          'Non-sparse' : 'N-SP'}



fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(3.5, 4))

for (u, v), idx_  in zip([[0, 1], [1, 1], [2, 1]], tmp):
    da = da_[idx_]
    plot_image_estimate(da, fig, axes[u][v], -1, colorbar=False,
                          vmax=2.8)
    axes[u][v].set_title('{}: SNR = {:.2f}'.format(
            label_[da.data['D_name']], da.snr_one_iteration(da.N_itr - 1)))

plot_snr_fcn_time(fig, axes[2][0], grouped, label_)

da.plot_base_image(fig, axes[0][0])
axes[0][0].set_title('Original Pattern')


da.plot_image_and_rfs(fig, axes[1][0], legend=False)
for u in [0, 1]:
    axes[u][0].set_xlabel('x (arcmin)')
    axes[u][0].set_ylabel('y (arcmin)')
axes[1][0].set_title('Pattern and RFs')

for ax in axes.flat:
    ax.set_title(ax.get_title(), fontdict={'size': 7})


plt.tight_layout(pad=0.2)
# plt.savefig(os.path.join(output_dir, 'sparsity.pdf'), dpi=300)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(7, 4.5))

for i, idx1 in enumerate([idx[u] for u in [0, 1]]):
    da = da_[idx1[0]]
    da.plot_image_estimate(fig, axes[1][i], -1)
    axes[1][i].set_title('SNR = {:.2f}'.format(da.snr_one_iteration(da.N_itr - 1)))
    plt.subplot(2, 3, i + 1)
    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)

# plt.savefig(os.path.join(output_dir, 'dict_and_rec.pdf'), dpi=300)

In [None]:
plt.figure(figsize=(7, 4))
# plt.suptitle('Reconstruction as a function of time for Sparse Image Prior')
da = da_[spar_idx[1]]
for i, ii in enumerate([None, 0, 14, 24, 59, 99]):
    plt.subplot(2, 3, i + 1)
    if i == 0:
        da.plot_base_image()
        plt.title('True Image')
    else: 
        da.plot_image_estimate(ii)
        plt.title('t = {} ms'.format(ii * 2 + 2))
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'sparse_rec_time.png'), dpi=200)

Plot of Dictionaries

In [None]:
plt.figure(figsize=(12, 3))
for i, q in enumerate([0, 20, 40]):
    da = da_[q]
    plt.subplot(1, 3, i + 1)

    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)
# plt.savefig(os.path.join(output_dir, 'dictionaries.png'), dpi=250)

In [None]:
plt.figure(figsize=(3, 3))
da.plot_tuning_curves()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'firing_rate.png'), dpi=200)