In [None]:
cd ..

In [None]:
import os
import numpy as np
import pandas as pd
from src.analyzer import DataAnalyzer, plot_fill_between
import matplotlib.pyplot as plt
from utils.rf_plot import show_fields
from scipy.stats import ks_2samp
%matplotlib inline
output_dir = 'output/sparsity'

In [None]:
pkl_fns = [os.path.join(output_dir, fn) 
           for fn in os.listdir(output_dir) 
           if fn.endswith('.pkl')]
pkl_fns.sort()
len(pkl_fns)

In [None]:
da_ = [DataAnalyzer.fromfilename(pkl_fn) for pkl_fn in pkl_fns]

In [None]:
records = []
for da in da_:
    record = []
    record.append(da.data['D_name'])
    record.append(da.data['ds'])
    record = record + da.snr_list()
    records.append(record)

In [None]:
t = da.time_list()
data = pd.DataFrame.from_records(records, columns=['D_name', 'ds'] + list(t))
grouped = pd.groupby(data, ['D_name', 'ds'])
len(grouped)

In [None]:
for i, (name, group) in enumerate(grouped):
    D_name, ds = name
#     if ds != 0.4:
#         continue
    print 'Group: {} | Prior: {} | ds: {}'.format(i, D_name, ds)
    print group[list(t)[-1]].mean()

In [None]:
label_ = ['Sparse Prior', 'Non-sparse Learned Prior', 'Independent Pixel Prior']
c_ = plt.cm.rainbow(np.linspace(0, 1, len(grouped)))
np.random.shuffle(c_); c_ = list(c_)

In [None]:
plt.figure(figsize=(10, 7))
plt.title('SNR as a function of time')
# , DC = {}'.format(100.)
alpha = 0.75
for c, (name, group) in zip(c_, grouped):
    D_name, ds = name
#     if mode != 'EM':
#         continue
    if ds != 0.70:
        continue
    label = 'D: {}, ds: {:.2f}'.format(*name)
    plot_fill_between(t, group[list(t)], label=label, c=c, k=0.5)
#     plt.plot(t, group[list(t)].T.iloc[:, 0], c=c, label=label, alpha=alpha)
#     plt.plot(t, group[list(t)].T.iloc[:, 1:], c=c, alpha=alpha);
    plt.xlabel('time (ms)')
    plt.ylabel('SNR')
plt.legend(loc='upper left')
# plt.ylim([0, 12]);
# plt.savefig(os.path.join(output_dir, 'sparse_compare.png'), dpi=200)

In [None]:
final_snrs = [group[list(t)[-1]].values for k, group in grouped]
names = [k for k, group in grouped]

In [None]:
q1, q2 = 4, 7
print names[q1], names[q2]
ks_2samp(final_snrs[q1], final_snrs[q2])

In [None]:
da_[42].plot_em_estimate(30)
# plt.savefig(os.path.join(output_dir, 'sparse_example.png'))

In [None]:
plt.figure(figsize=(10, 10))
for i, idx in enumerate([0, 20, 40]):
    da = da_[idx]
    plt.subplot(2, 2, i + 1)

    D = da.data['D']
    D_name = da.data['D_name']
    plt.title('D_name: {}'.format(D_name))
    show_fields(D, pos_only=True)

In [None]:
da_[20].plot_image_and_rfs()

In [None]:
plt.figure(figsize=(3, 3))
da.plot_tuning_curves()
plt.tight_layout()
# plt.savefig(os.path.join(output_dir, 'firing_rate.png'), dpi=200)