In [None]:
cd ..

In [None]:
import os
import numpy as np
import pandas as pd
from src.analyzer import DataAnalyzer, plot_fill_between
import matplotlib.pyplot as plt
from utils.rf_plot import show_fields
from scipy.stats import ks_2samp
%matplotlib inline
output_dir = 'output/sparsity'

In [None]:
pkl_fns = [os.path.join(output_dir, fn) 
           for fn in os.listdir(output_dir) 
           if fn.endswith('.pkl')]
pkl_fns.sort()
len(pkl_fns)

In [None]:
da_ = [DataAnalyzer.fromfilename(pkl_fn) for pkl_fn in pkl_fns]

In [None]:
records = []
for da in da_:
    record = []
    record.append(da.data['D_name'])
    record.append(da.data['ds'])
    record = record + da.snr_list()
    records.append(record)

In [None]:
t = da.time_list()
data = pd.DataFrame.from_records(records, columns=['D_name', 'ds'] + list(t))
grouped = pd.groupby(data, ['D_name', 'ds'])
len(grouped)

In [None]:
for i, (name, group) in enumerate(grouped):
    D_name, ds = name
#     if ds != 0.4:
#         continue
    print 'Group: {} | Prior: {} | ds: {}'.format(i, D_name, ds)
    print group[list(t)[-1]].mean()

In [None]:
INDEP = 'Indep'
SPAR = 'Sparse'
NSPAR = 'Non-sparse'

In [None]:
c_ = plt.cm.rainbow(np.linspace(0, 1, len(grouped)))
np.random.shuffle(c_); c_ = list(c_)

label_ = {'Indep': 'Independent Pixel Prior', 
          'Sparse': 'Sparse Prior', 
          'Non-sparse' : 'Non-sparse Prior'}

In [None]:
title = 'SNR as a function of time'.format('')
plt.figure(figsize=(5, 5))
plt.title(title)
# , DC = {}'.format(100.)
alpha = 0.75
for c, (name, group) in zip(c_, grouped):
    D_name, ds = name
#     if mode != 'EM':
#         continue
    if ds != 0.70:
        continue
#     label = 'D: {}, ds: {:.2f}'.format(label_[D_name])
    label = label_[D_name]
    plot_fill_between(t, group[list(t)], label=label, c=c, k=0.5)
#     plt.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     plt.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
    plt.xlabel('time (ms)')
    plt.ylabel('SNR')
plt.legend(loc='upper left')
# plt.ylim([0, 12]);
plt.savefig(os.path.join(output_dir, 'dict_compare.png'), dpi=200)

In [None]:
final_snrs = [group[list(t)[-1]].values for k, group in grouped]
names = [k for k, group in grouped]

In [None]:
q1, q2 = 4, 7
print names[q1], names[q2]
ks_2samp(final_snrs[q1], final_snrs[q2])

In [None]:
idx = [data[(data['D_name'] == key) & (data['ds'] == 0.7)].index.values for key in [INDEP, SPAR, NSPAR]]
indep_idx, spar_idx, nspar_idx = idx

In [None]:
da_[spar_idx[1]].plot_em_estimate(-1)
plt.savefig(os.path.join(output_dir, 'sparse_example.png'), dpi=200)

Dictionary with Reconstructions after 200 ms, DC = 100

In [None]:
plt.figure(figsize=(15, 8))

for i, idx1 in enumerate([idx[u] for u in [0, 2, 1]]):
    da = da_[idx1[1]]

    plt.subplot(2, 3, i + 4)
    da.plot_image_estimate(-1)

    plt.subplot(2, 3, i + 1)
    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)
    
plt.savefig(os.path.join(output_dir, 'dict_and_rec.png'), dpi=250)

In [None]:
plt.figure(figsize=(7, 4))
# plt.suptitle('Reconstruction as a function of time for Sparse Image Prior')
da = da_[spar_idx[1]]
for i, ii in enumerate([None, 0, 14, 24, 59, 99]):
    plt.subplot(2, 3, i + 1)
    if i == 0:
        da.plot_base_image()
        plt.title('True Image')
    else: 
        da.plot_image_estimate(ii)
        plt.title('t = {} ms'.format(ii * 2 + 2))
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'sparse_rec_time.png'), dpi=200)

Plot of Dictionaries

In [None]:
plt.figure(figsize=(12, 3))
for i, idx in enumerate([0, 20, 40]):
    da = da_[idx]
    plt.subplot(1, 3, i + 1)

    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('Dictionary: {}'.format(D_name))
    show_fields(D, pos_only=True)
# plt.savefig(os.path.join(output_dir, 'dictionaries.png'), dpi=250)

In [None]:
plt.figure(figsize=(3, 3))
da.plot_tuning_curves()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'firing_rate.png'), dpi=200)