In [None]:
cd ..

In [None]:
import os
import pandas as pd
from src.analyzer import *
%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
output_dir = 'output/motion_benefit3'

In [None]:
def plot_fill_between(t, data, label='', c = None, k=1.):
    """
    Create a plot of the data +/- k standard deviations.
    
    Parameters
    ----------
    t : array, shape (timesteps, )
        Times for each data point
    data : array, shape (samples, timesteps)
        Data to plot mean and +/- one sdev as a function of time
    k : float
        Scaling factor for standard deviations
    """
    mm = data.mean(0)
    sd = data.std(0) * k
    plt.fill_between(t, mm-sd, mm + sd, alpha=0.5, color=c)
    plt.plot(t, mm, color=c, label=label)

In [None]:
pkl_fns = [os.path.join(output_dir, fn) 
           for fn in os.listdir(output_dir) 
           if fn.endswith('.pkl')]
pkl_fns.sort()
# pkl_fns = pkl_fns[120:]
len(pkl_fns)

In [None]:
da_ = [DataAnalyzer.fromfilename(pkl_fn) for pkl_fn in pkl_fns]
t = da_[0].time_list()

In [None]:
out_ = []
# for da in da_:
for pkl_fn in pkl_fns:
    da = DataAnalyzer.fromfilename(pkl_fn)
    out = []
    out.append(da.data['EM_data']['mode']) # Eye path given or not
#     out.append(da.data['motion_gen']['mode']) # True eye movements or no eye movements
    out.append(da.data['motion_gen']['dc'])
    out.append(da.data['ds']) # Image size
    out = out + da.snr_list()
    out_.append(out)

In [None]:
t = da.time_list()

In [None]:
data = pd.DataFrame.from_records(out_, columns=['inference_mode', 'dc', 'ds'] + list(t))
# data = data[data['ds'] == 0.75]
grouped = pd.groupby(data, ['inference_mode', 'dc', 'ds'])
c_ = ['g', 'b', 'r', 'y', 'm', 'c']
c_ = c_ + c_

In [None]:
plt.figure(figsize=(10, 7))
plt.title('SNR as a function of time')
for c, (name, group) in zip(c_, grouped):
    label = 'Mode: {} dc: {:.2f} ds: {:.2f}'.format(*name)
    plot_fill_between(t, group[list(t)], label=label, c=c, k=1)
    plt.xlabel('time (ms)')
plt.legend(loc='upper left')
# plt.savefig(os.path.join(output_dir, 'motion_benefit.png'), dpi=200)

In [None]:
da_[0].plot_image_and_rfs(s=50)

In [None]:
plt.figure(figsize=(10, 4))
plt.subplot(121)
da_[0].plot_base_image()
plt.subplot(122)
da_[0].plot_image_estimate(20)

# Group SNR plots by image size and diffusion constant

In [None]:
out_ = []
for pkl_fn in pkl_fns:
    da = DataAnalyzer.fromfilename(pkl_fn)
    out = [da.data['ds'], da.data['motion_gen']['dc']]
    out = out + da.SNR_list()
    out_.append(out)

In [None]:
t = da_[0].time_list()

In [None]:
data = pd.DataFrame.from_records(out_, columns=['ds', 'dc_gen'] + list(t))
# data = data[data['ds'] == 0.5]

In [None]:
grouped = pd.groupby(data, ['ds', 'dc_gen'])
c_ = ['g', 'b', 'r', 'y', 'm', 'c']

In [None]:
plt.figure(figsize=(10, 7))
plt.title('SNR as a function of time')
for c, (name, group) in zip(c_, grouped):
    label = 'ds={}, dc={}'.format(*name)
    plot_fill_between(t, group[list(t)], label=label, c=c)
plt.legend(loc='upper left')
# plt.savefig(os.path.join(output_dir, 'motion_benefit.png'), dpi=200)