In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import optimize as opt
from datetime import datetime

In [None]:
from movement import Akagi
from movement.data import load_csv
from movement.data import distances

In [None]:
import pandas as pd

Load data

In [None]:
centroid_data_fname = "data/area/statistical-area-2-2018-centroid-true.csv"
hierarchy_data_fname = "data/area/statistical-area-2-higher-geographies-2018-generalised.csv"
areas = distances.AreaSubset(centroid_data_fname, hierarchy_data_fname)
areas.data

Filter by distance from some SA2

Filter by territorial authority.

In [None]:
areas.filter_rc_name(['Auckland Region', 'Waikato Region'])
# areas.filter_ta_name(['Nelson City'])
areas.data

In [None]:
telco_data_fname = "data/telco/pop_data_2020-04-01.dat"

In [None]:
telco = load_csv.load_telco_data(telco_data_fname)

Filter to only data from times of interest

In [None]:
date_first = datetime(2020, 2, 18, 7)
date_last = datetime(2020, 2, 18, 9)

telco_subset_time = telco[
        (telco["time"] == date_first) | (telco["time"] == date_last)
]

Find regions with too few people and remove them

In [None]:
count_threshold = 10

current_codes = set(areas.sa2_codes())
# less than theshold at some time
disallowed_codes = set(telco_subset_time[telco_subset_time['count'] <= count_threshold].sa2_code.to_list())
allowed_codes = set(telco_subset_time.sa2_code.to_list()) - disallowed_codes

sa2_codes_to_remove = current_codes - allowed_codes

areas.remove_sa2(sa2_codes_to_remove)
areas.data

In [None]:
areas.sa2_names()

Filter to only telco data from the regions of interest and put SA2 code in columns

In [None]:
telco_subset = pd.pivot_table(
    telco_subset_time[
        telco_subset_time["sa2_code"].isin(areas.sa2_codes())
    ],
    index="time",
    columns="sa2_code",
    values="count",
)

telco_subset = telco_subset.reindex(areas.sa2_codes(), axis='columns')

N = telco_subset.to_numpy()

In [None]:
assert telco_subset.columns.to_list() == areas.sa2_codes()

In [None]:
d = areas.distance_table(units='km')
d.max()

In [None]:
N.min()

In [None]:
assert np.all(np.isfinite(N))

In [None]:
assert np.count_nonzero(N) == N.size

In [None]:
N.shape

In [None]:
K = 80

In [None]:
plt.plot(N.sum(axis=1))

# Estimate movement

In [None]:
scale = 1000

In [None]:
a = Akagi(N * scale, d, K)

In [None]:
a.lamda = 1e1 / scale

In [None]:
a.M += a.gamma * np.random.random(size=a.M.shape) * 10.

In [None]:
# Check that all regions have at least one neighbour
assert np.all(a.gamma_exc.sum(axis=1) > 1)

In [None]:
%time result = a.exact_inference(1e-3)

In [None]:
M_est = a.M / scale

In [None]:
with open(a.save_options.output_dir + "/sa2_codes", "w") as f:
    for line in areas.sa2_codes():
        f.write(line)
        f.write(", ")
    f.write("\n")

In [None]:
a.beta

In [None]:
a.beta_bounds

Check if beta bounds are saturated

In [None]:
np.isclose(a.beta, a.beta_bounds[0][0])

In [None]:
np.isclose(a.beta, a.beta_bounds[0][1])

Are there approximately the right number of people in the end?

In [None]:
N.sum(axis=1)

In [None]:
M_est.sum(axis=(1, 2)).astype(int)

In [None]:
np.rint((M_est.sum(axis=2) - N[:-1])).astype(int)

What does `M` look like?

In [None]:
np.rint(a.M[0]).astype(int)

In [None]:
plt.imshow(np.where(d < K, d, np.nan))
plt.colorbar()

In [None]:
plt.imshow(np.where(d < K, np.log(M_est[0]), np.nan))
plt.colorbar()