In [None]:
cd ..

In [None]:
import os
from utils.rf_plot import show_fields

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
"""Basic test of the code."""
import numpy as np
# import matplotlib.pyplot as plt
from scipy.io import loadmat

from src.model import EMBurak
from src.analyzer import DataAnalyzer
from utils.image_gen import ImageGenerator

# Tests the algorithm using a '0' from mnist and a sparse coding dictionary

data = loadmat('sparse_coder/output/mnist_dictionary.mat')
D = data['D']


_, N_pix = D.shape

L_I = int(np.sqrt(N_pix))  # Linear dimension of image

ig = ImageGenerator(L_I)
ig.make_digit()
ig.normalize()

s_gen = ig.img
s_gen_name = ig.img_name

motion_gen = {'mode': 'Diffusion', 'dc': 100.}
motion_prior = {'mode': 'PositionDiffusion', 'dc': 100.}

output_dir_base = 'routing_vis'


emb = EMBurak(s_gen - 0.5, D, motion_gen, motion_prior, n_t=10, save_mode=True,
              s_gen_name=s_gen_name, n_itr=10, lamb=0.0, s_range='sym',
              output_dir_base=output_dir_base, save_pix_rf_coupling=True)

In [None]:
XR, YR, R = emb.gen_data(s_gen)
emb.run_em(R)
emb.save()

In [None]:
output_dir = os.path.join('output', output_dir_base)

In [None]:
pkl_fns = [os.path.join(output_dir, fn) 
           for fn in os.listdir(output_dir)
           if fn.endswith('.pkl')]
pkl_fns.sort()
len(pkl_fns)

In [None]:
pkl_fn = pkl_fns[-1]

In [None]:
da = DataAnalyzer.fromfilename(pkl_fn)

In [None]:
R = da.R
t = 0
j = np.argmax(R.mean(axis=1))

In [None]:
frame = np.zeros((L_I, L_I))

frame[0] = 1
frame[-1] = 1
frame[:, 0] = 1
frame[:, -1] = 1
frame = frame.ravel()

In [None]:
from matplotlib.patches import Ellipse

In [None]:
# for t in range(10):
coup = da.data['EM_data'][t]['pix_rf_coupling'][:, j]
coup = coup / np.max(abs(coup))

A = da.data['EM_data'][t]['coeff_est']

xe = da.data['XE']
ye = da.data['YE']
de = da.data['de']


fig, axes = plt.subplots(2, 2, figsize=(15, 15))

ax = axes[0, 0]
da.plot_image_estimate(fig, ax, t, colorbar=False)

ax = axes[0, 1]
da.plot_spikes(ax, t)
if R[j, t] == 1:
    ax.add_patch(plt.Circle((xe[j], -ye[j]), de * 0.4, alpha=0.5, color='red', fill=False))

    
ax = axes[1, 0]
AD = A[:, np.newaxis] * D
show_fields(AD, fig=fig, ax=ax, colorbar=True, pos_only=True)
if R[j, t] == 1:
    show_fields(
        np.outer(coup, frame), 
        alpha=0.5, 
        cmap=rvb, 
        fig=fig, 
        ax=ax, 
        colorbar=False, pos_only=True)

    
ax = axes[1, 1]
mu  = da.data['EM_data'][t]['path_means'][t]
sig = da.data['EM_data'][t]['path_sdevs'][t]
for i, alpha in enumerate([1.0, 0.5, 0.25]):
    j = i + 1
    e = Ellipse(mu, width=sig[0] * j, height=sig[1] * j, alpha=alpha)
    ax.add_artist(e)


_ = ax.set_xlim(axes[0, 0].get_xlim())
_ = ax.set_ylim(axes[0, 0].get_ylim())

In [None]:
fig, ax = plt.subplots(1, 1)
show_fields(
    np.outer(coup, frame), 
    alpha=1, 
    cmap=rvb, 
    fig=fig, 
    ax=ax, 
    colorbar=False, pos_only=True)


In [None]:
plt.hist(np.outer(coup, frame).ravel(), bins=200)

In [None]:
ax = plt.axes()
a = np.zeros((5, 5))
b = np.zeros((5, 5))

for i in range(5):
    a[i, i] = 1
    b[i, 4-i] = 1
    
# ax.imshow(a, alpha=0.5, cmap=plt.cm.gray_r)
cax = ax.imshow(b, alpha=0.5, cmap=rvb, vmin=0)
fig.colorbar(cax, ax=ax)

In [None]:
import numpy as np
import matplotlib.colors as mcolors


def make_colormap(seq):
    """Return a LinearSegmentedColormap
    seq: a sequence of floats and RGB-tuples. The floats should be increasing
    and in the interval (0,1).
    """
    seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]
    cdict = {'red': [], 'green': [], 'blue': []}
    for i, item in enumerate(seq):
        if isinstance(item, float):
            r1, g1, b1 = seq[i - 1]
            r2, g2, b2 = seq[i + 1]
            cdict['red'].append([item, r1, r2])
            cdict['green'].append([item, g1, g2])
            cdict['blue'].append([item, b1, b2])
    return mcolors.LinearSegmentedColormap('CustomMap', cdict)


c = mcolors.ColorConverter().to_rgb
rvb = make_colormap(
    [c('red'), c('violet'), 0.5, c('violet'), c('blue'), 0.75, c('blue')])

rvb = make_colormap(
    [c('white'), 0.5, c('white'), c('red')])


N = 1000
array_dg = np.random.uniform(0, 10, size=(N, 2))
colors = np.random.uniform(-2, 2, size=(N,))
plt.scatter(array_dg[:, 0], array_dg[:, 1], c=colors, cmap=rvb)
plt.colorbar()
plt.show()