In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import optimize as opt
from typing  import Tuple

In [None]:
import akagi
import forward

In [None]:
num_cells = 9

In [None]:
n = int(num_cells**0.5)
x = np.linspace(-1, 1, n)
y = np.linspace(-1, 1, n)
xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')

loc = []
for i in range(n):
    for j in range(n):
        loc.append([x[i], y[j]])

loc = np.array(loc)

In [None]:
loc

In [None]:
rad = 0.5
N_init = np.exp(-((loc**2).sum(axis=1)**0.5 - rad)**2)
N_init = (N_init * 1000 * N_init.max()).astype(int)
N_init.reshape((n, n))

In [None]:
s = np.exp((-4 * loc**2).sum(axis=1))
s.reshape((n, n))

In [None]:
pi = N_init / N_init.max() / 10
pi.reshape((n, n))

In [None]:
# Array of distances between cells
d = np.zeros((num_cells, num_cells))
for i in range(num_cells):
    for j in range(num_cells):
        d[i, j] = ((loc[i] - loc[j])**2).sum()**0.5
d

In [None]:
beta = 1
K = 1.5

## Generate fake data

In [None]:
simulator = forward.ForwardSimulator(pi, s, beta, d, K)

In [None]:
num_steps = 10
np.random.seed(0)
N, M_true = simulator.simulate(N_init, num_steps)

In [None]:
N

In [None]:
plt.imshow(N)
plt.colorbar()

## Estimate movement from fake data

In [None]:
a = akagi.Akagi(N, d, K)

In [None]:
M_init = a.M
pi_init = a.pi
s_init = a.s
beta_init = a.beta

In [None]:
%%prun
result = a.exact_inference()

Does `M` look similar?

In [None]:
M_true[num_steps - 2]

In [None]:
a.M[num_steps - 2].astype(int)

In [None]:
np.abs((M_true - a.M)).sum()/M_true.sum()

Are there approximately the right number of people in the end?

In [None]:
(a.M.sum(axis=2) - N[:-1])

In [None]:
a.likelihood(M_init, pi_init, s_init, beta_init)

In [None]:
a.likelihood(a.M, pi_init, s_init, beta_init)

Have `pi` and `s` converged well?

In [None]:
pi

In [None]:
a.pi

In [None]:
s / s.max()

In [None]:
a.s / a.s.max()

In [None]:
beta

In [None]:
a.beta

Save some vars

In [None]:
np.save("N", N)
np.save("d", d)
np.save("K", K)
np.save("beta", beta)
np.save("s", s)
np.save("pi", pi)
np.save("M_true", M_true)