In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import optimize as opt
from typing  import Tuple

In [None]:
import akagi
import forward

# Set parameter values

In [None]:
num_cells = 16

In [None]:
n = int(num_cells**0.5)
x = np.linspace(-1, 1, n)
y = np.linspace(-1, 1, n)
xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')

loc = []
for i in range(n):
    for j in range(n):
        loc.append([x[i], y[j]])

loc = np.array(loc)

In [None]:
rad = 0.8
N_init = np.exp(-((loc**2).sum(axis=1)**0.5 - rad)**2)
N_init = (N_init * 1000 * N_init.max()).astype(int)
N_init.reshape((n, n))

In [None]:
plt.imshow(N_init.reshape((n, n)))
plt.colorbar()

In [None]:
s = np.exp((-4 * loc**2).sum(axis=1))
s.reshape((n, n))

In [None]:
plt.imshow(s.reshape((n, n)))
plt.colorbar()

In [None]:
pi = N_init / N_init.max() / 10
pi.reshape((n, n))

In [None]:
plt.imshow(pi.reshape((n, n)))
plt.colorbar()

In [None]:
# Array of distances between cells
d = np.zeros((num_cells, num_cells))
for i in range(num_cells):
    for j in range(num_cells):
        d[i, j] = ((loc[i] - loc[j])**2).sum()**0.5

In [None]:
beta = 1
K = 1.5

## Generate fake data

In [None]:
simulator = forward.ForwardSimulator(pi, s, beta, d, K)
simulator.noise_amplitude = 0.0

In [None]:
num_steps = 11
np.random.seed(0)
N, M_true = simulator.simulate(N_init, num_steps)

In [None]:
fig, ax = plt.subplots()

cs = ax.imshow(N)
cb = fig.colorbar(cs)

ax.set_ylabel('t')
ax.set_xlabel('cell number')

In [None]:
fig, ax = plt.subplots()

ax.plot(N.sum(axis=1), label='total')

for i in range(N.shape[1]):
    ax.plot(N[:, i])
    
ax.set_xlabel('time')
ax.set_ylabel('population')

ax.set_yscale('log')

fig.legend()

## Estimate movement from fake data

In [None]:
a = akagi.Akagi(N, d, K)
a.lamda = 10

In [None]:
M_init = a.M.copy()
pi_init = a.pi.copy()
s_init = a.s.copy()
beta_init = a.beta

Pre-compile some numba functions

In [None]:
%time akagi._cost(M_init, N)
%time akagi._cost(M_init, N)

In [None]:
term_0_log = a.term_0_log(a.pi)
braces = a.term_1_braces(a.pi, a.s, a.beta, a.d)

%time a.likelihood(a.M, a.pi, a.s, a.beta, term_0_log=term_0_log, term_1_braces=braces)
%time a.likelihood(a.M, a.pi, a.s, a.beta, term_0_log=term_0_log, term_1_braces=braces)

In [None]:
%%time
# %%prun
result = a.exact_inference(1e-4)

# Compare output

Does `M` look similar?

In [None]:
np.abs((M_true - a.M)).sum()/M_true.sum()

In [None]:
M_true[1][:10, :10]

In [None]:
np.rint(a.M[0]).astype(int)[:10, :10]

Are there approximately the right number of people in the end?

In [None]:
np.rint((a.M.sum(axis=2) - N[:-1])).astype(int)

In [None]:
print(a.likelihood(M_init, pi_init, s_init, beta_init) )
print(a.likelihood(a.M, a.pi, a.s, a.beta))

Have $\pi$,  $s$ and $\beta$ converged well?

In [None]:
def comp_plot(exact, est, title):
    fig, ax = plt.subplots()
    
    ident = np.linspace(
        min(exact.min(), est.min()),
        max(exact.max(), est.max()),
    )
    ax.plot(ident, ident, alpha=0.5, color='gray')
    
    ax.scatter(
        exact,
        est,
        marker="+"
    )
    
    ax.set_xlabel(r"exact")
    ax.set_ylabel(r"estimated")
    
    ax.set_title(title)
    
    ax.set_aspect('equal')

In [None]:
comp_plot(pi, a.pi, r"$\pi$")

In [None]:
comp_plot(s/s.max(), a.s/a.s.max(), r"$s$")

In [None]:
comp_plot(M_true.flatten(), a.M.flatten(), r"$M$")

In [None]:
beta

In [None]:
a.beta

Save some vars

In [None]:
np.save("N", N)
np.save("d", d)
np.save("K", K)
np.save("beta", beta)
np.save("s", s)
np.save("pi", pi)
np.save("M_true", M_true)