In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import optimize as opt
from datetime import datetime

In [None]:
from movement import Akagi
from movement.data import load_csv
from movement.data import distances

In [None]:
import pandas as pd

Load data

In [None]:
centroid_data_fname = "data/area/statistical-area-2-2018-centroid-true.csv"
hierarchy_data_fname = "data/area/statistical-area-2-higher-geographies-2018-generalised.csv"
areas = distances.AreaSubset(centroid_data_fname, hierarchy_data_fname)
areas.data

In [None]:
areas.filter_ta_name(['Christchurch City', 'Waimakariri District'])
areas.data

In [None]:
d = areas.distance_table(units='km')

sa2_codes = areas.sa2_codes()

In [None]:
areas.sa2_names()

In [None]:
d.max()

In [None]:
telco_data_fname = "data/telco/pop_data_2020-04-01.dat"

In [None]:
date_first = datetime(2020, 2, 18, 7)
date_last = datetime(2020, 2, 18, 9)

In [None]:
telco = load_csv.load_telco_data(telco_data_fname)

telco_subset = pd.pivot_table(
    telco[
        (telco["sa2_code"].isin(sa2_codes)) &
        ((telco["time"] == date_first) | (telco["time"] == date_last))
    ],
    index="time",
    columns="sa2_code",
    values="count",
)

N = telco_subset.to_numpy()

In [None]:
N.shape

In [None]:
N.dtype

In [None]:
K = 80

In [None]:
plt.plot(N.sum(axis=1))

# Estimate movement

In [None]:
a = Akagi(N, d, K)

In [None]:
a.lamda = 1e1

In [None]:
%time result = a.exact_inference(1e-3)

In [None]:
np.rint(a.M[0]).astype(int)

Are there approximately the right number of people in the end?

In [None]:
N.sum(axis=1)

In [None]:
a.M.sum(axis=(2)).astype(int)

In [None]:
N

In [None]:
np.rint((a.M.sum(axis=2) - N[:-1])).astype(int)

In [None]:
N.sum(axis=1)

In [None]:
plt.plot(
    np.arange(3),
    N.sum(axis=1)
)
plt.plot(
    np.arange(0.5, 2.5),
    a.M.sum(axis=(1,2))
)

In [None]:
from cycler import cycler
import matplotlib.pyplot as plt
import itertools

color_list = ['b', 'orange', 'r', 'green', 'k', 'gray']
color_cycle = []
for i in range(len(color_list)):
    color_cycle.append(color_list[i])
    color_cycle.append(color_list[i])

num_regions = 6
plt.rc('axes', prop_cycle=(cycler('color', color_cycle)))

for i in range(num_regions):
    plt.plot(
        np.arange(3),
        N[:, i],
        linestyle='-',
    )
    plt.plot(
        np.arange(0.5, 2.5),
        a.M[:, i].sum(axis=1),
        linestyle=':',
    )