In [None]:
cd ..

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['image.interpolation'] = 'nearest'

from src.gauss import EMGauss
from utils.image_gen import ImageGenerator
from sparse_coder.prep_field_dataset import get_data_matrix

In [None]:
l_i = 20
ds = 0.3
de = 1.09
rf_ratio = 0.203
n_t = 150

In [None]:
l_n = (l_i * ds / np.sqrt(2))

In [None]:
mat = get_data_matrix(path='sparse_coder/data/final/IMAGES.npy', l_patch=l_i, n_patches=10)

In [None]:
emg = EMGauss(
    l_i=l_i, 
    motion_gen={'mode': 'Diffusion', 'dc': 4.},
    motion_prior={'dc': 10.},
    n_t=n_t,
    ds=ds,
    de=de,
    n_p=50,
    print_mode=True,
    l_n=l_n,
    rf_ratio=rf_ratio,
    sig_obs=0.1
)

In [None]:
s_gen = mat[3].astype('float32')
s_gen = np.clip(s_gen, -1, 1)
plt.imshow(s_gen.reshape(l_i, l_i))
plt.colorbar()

In [None]:
m, xr, yr = emg.gen_data(s_gen)

In [None]:
plt.hist(m.ravel(), bins=50);
plt.yscale('log')

In [None]:
plt.plot(xr)
plt.plot(yr)

In [None]:
s, data = emg.run_em(m, n_passes=1, n_itr=n_t, reg=1.)

In [None]:
emg.pf.calculate_means_sdevs()

In [None]:
means = emg.pf.means
sdevs = emg.pf.sdevs

In [None]:
def plot_path_estimate(est_mean, est_sdev, xyr, d, q, dt=0.001): 
    """ 
    Plot the actual and estimated path generated.

    Parameters
    ----------
    q : int
        EM iteration number
    d : int
        Dimension to plot (either 0 or 1)
    """
    n_t = est_mean.shape[0]
    path = xyr[d]
    if (d == 0): 
        label = 'Hor.'
#         dxy = self.dx
    elif (d == 1): 
        label = 'Ver.'
#         dxy = self.dy
    else:
        raise ValueError('d must be either 0 or 1')

    tt = dt * np.arange(n_t)
    plt.fill_between(tt,
                     est_mean[:, d] - est_sdev[:, d], 
                     est_mean[:, d] + est_sdev[:, d], 
                     alpha=0.5, linewidth=1.)
    plt.plot(tt,
             est_mean[:, d], label='estimate')
    plt.plot(tt,
             path, label='actual')
    plt.xlabel('Time (s)')
    plt.ylabel('Relative position (arcmin)')
#     plt.title(label + ' Pos., shift = %.2f' % dxy)

In [None]:
plt.figure(figsize=(10, 4))
for d in range(2):
    plt.subplot(1, 2, d + 1)
    plot_path_estimate(means, sdevs, (xr, yr), d=d, q=100)

In [None]:
with emg.tb.sess.as_default():
    xe, ye = emg.tb.sess.run([emg.tb.t_xe, emg.tb.t_ye])
    xs, ys = emg.tb.sess.run([emg.tb.t_xs, emg.tb.t_ys])

In [None]:
from src.analyzer import snr

In [None]:
q = len(data) - 1

In [None]:
def snr_one_iteration(s_gen, s_est, xyr, xyr_est, xs, ys, t, var):
    """
    Calculate the SNR of the estimated image and the true image.

    Parameters
    ----------
    q : int
        Iteration of the EM to pull estimated image.

    Note that we shift the image estimate by the average
        amount that the path estimate was off the true path
        (There is a degeneracy in the representation that this
        fixes. )
    """

    try:
        xr_est = xyr_est[:, 0]
        yr_est = xyr_est[:, 1]

        dx = np.mean(xr[0:t] - xr_est[0:t])
        dy = np.mean(yr[0:t] - yr_est[0:t])
    except KeyError:
        dx = 0.
        dy = 0.
#     self.dx = dx
#     self.dy = dy
#     if img is None:
#         img = self.S_gen
    i1 = s_gen.ravel()
    i2 = s_est.ravel()
    i1 = i1 / i1.max()
    i2 = i2 / i2.max()
    return snr(i1, xs, ys,
               i2, xs + dx, ys + dy, 
               var)

In [None]:
snr_list = [snr_one_iteration(
        s_gen, data[q], (xr, yr), means, xs, ys, q, 
        var=(0.5 * ds) ** 2) for q in range(n_t)]

In [None]:
plt.plot(snr_list)

In [None]:
from src.gauss_plots import plot_image, plot_rfs, compare_fourier

In [None]:
compare_fourier(s_gen.reshape(l_i, l_i), data[-1].reshape(l_i, l_i), l_i, ds, de, xe, ye)
plt.title('SNR as a function of time')
plt.plot(snr_list);

In [None]:
plt.hist(s_gen - data[-1])
plt.imshow((s_gen - data[-1]).reshape(l_i, l_i))
plt.colorbar()