In [None]:
from itertools import product
import os, sys
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from movement import Akagi, ForwardSimulator

In [None]:
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [None]:
def generate_params(num_cells, seed=0):

    np.random.seed(seed)
    
    n = int(num_cells**0.5)
    x = np.linspace(-1, 1, n)
    y = np.linspace(-1, 1, n)
    xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')

    loc = []
    for i in range(n):
        for j in range(n):
            loc.append([x[i], y[j]])

    loc = np.array(loc)

    rad = 0.8
    N_init = np.exp(-((loc**2).sum(axis=1)**0.5 - rad)**2)
    N_init = (N_init * 1000 * N_init.max()).astype(int)
    N_init.reshape((n, n))

    s = np.exp((-4 * loc**2).sum(axis=1))
    s.reshape((n, n))

    pi = N_init / N_init.max() / 10
    pi.reshape((n, n))

    # Array of distances between cells
    d = np.zeros((num_cells, num_cells))
    for i in range(num_cells):
        for j in range(num_cells):
            d[i, j] = ((loc[i] - loc[j])**2).sum()**0.5

    beta = 1

    K = 1.5
    
    return N_init, d, K, pi, s, beta

In [None]:
def generate_fake_data(N_init, d, K, pi, s, beta, seed=0):

    simulator = ForwardSimulator(pi, s, beta, d, K)
    simulator.noise_amplitude = 0.0

    # num_steps = 11
    num_steps = 2
    np.random.seed(seed)
    N, M_true = simulator.simulate(N_init, num_steps)

    return N, M_true

In [None]:
def estimate(N, M, d, K, eps, lamda):
    a = Akagi(N, d, K)
    a.lamda = lamda
    
    result = a.exact_inference(eps)
    
    return result

In [None]:
# Precompile numba expressions
def precompile():
    num_cells = 121
    params = generate_params(num_cells)
    _, d, K, pi_true, s_true, beta_true = params

    N_true, M_true = generate_fake_data(*params)

    with HiddenPrints():
        M_est, pi_est, s_est, beta_est = estimate(N_true*1.0, M_true, d, K, 0.1, 10.0)

%time precompile()
%time precompile()
%time precompile()

In [None]:
def nae(tr, es):
    return np.abs((tr - es)).sum() / np.abs(tr).sum()

In [None]:
%%time

columns = ['num regions', 'eps', 'lamda', 'scale', 'M_nae', 'pi_nae', 's_nae', 'beta_nae', 'time (s)']
df = pd.DataFrame(columns=columns)

for (num_cells, seed) in product(np.arange(6, 14)**2, range(3)):

    print("Generate params, ", num_cells, " cells")
    params = generate_params(num_cells, seed=seed)
    _, d, K, pi_true, s_true, beta_true = params

    print("Generate fake data")
    N_true, M_true = generate_fake_data(*params, seed=seed)

    for (eps, lamda, scale) in product(np.logspace(-1, -4, 4), np.logspace(1, 2, 2), np.logspace(0, 3, 1)):
        
        lamda = float(lamda / scale)
        
        print("Estimate M from fake N, eps={eps}, lamda={lamda}".format(eps=eps, lamda=lamda))
        with HiddenPrints():
            t_0 = time.time()
            M_est, pi_est, s_est, beta_est = estimate(N_true*scale, M_true, d, K, eps, lamda)
            t_1 = time.time()

        M_est /= scale
        
        df.loc[len(df)] = [
            num_cells,
            eps, lamda, scale,
            nae(M_true, M_est),
            nae(pi_true, pi_est), nae(s_true/s_true.max(), s_est), nae(beta_true, beta_est), 
            t_1 - t_0,
        ]

In [None]:
df['eps'] = df['eps'].astype('category')
df['lamda'] = df['lamda'].astype('category')
df

In [None]:
df['time (s)'].min()

In [None]:
df['time (s)'].max()

In [None]:
fig, (ax1, ax2) = plt.subplots(2,1, figsize = (6, 10))

# NAE
sns.lineplot(
    data=df,
    x='num regions',
    y='M_nae',
    hue='eps',
    style='eps',
    ax=ax1,
)

ax1.set_ylim(0,)

# Time
ln = sns.lineplot(
    data=df,
    x='num regions',
    y='time (s)',
    hue='eps',
    style='eps',
    ax=ax2,
)

# ax2.set_ylim(0,)
ax2.set_yscale('log')

In [None]:
sns.distplot(
    df['time (s)'],
    kde=False,
)

In [None]:
fig, (ax_eps, ax_lamda) = plt.subplots(2, 1, figsize=(6, 10))

sns.lineplot(
    data=df,
    x='eps',
    y='M_nae',
    ax=ax_eps,
)
ax_eps.set_xscale('log')
ax_eps.set_ylim(0,)

plot = sns.lineplot(
    data=df,
    x='lamda',
    y='M_nae',
    ax=ax_lamda,
)
ax_lamda.set_xscale('log')
ax_lamda.set_ylim(0,)

In [None]:
plot = sns.lineplot(
    data=df,
    x='lamda',
    y='time (s)',
)
plot.axes.set_xscale('log')
plot.axes.set_yscale('log')

In [None]:
plot = sns.lineplot(
    data=df,
    x='eps',
    y='time (s)',
)
plot.axes.set_xscale('log')
plot.axes.set_yscale('log')

In [None]:
plt.imshow(
    np.abs((pi_est-pi_true)).reshape(13, 13))
plt.colorbar()

In [None]:
plt.imshow(
    np.abs(s_est-s_true/s_true.max()).reshape(13, 13)
)
plt.colorbar()