# Kronecker-structured GPs for spatiotemporal modelling

This notebook demostrates how to use GPJax on a real-world example from epidemiology. By exploiting a separable structure in the covariance matrix, here separating the spatial and temporal dimensions, we can yield significant gains in efficiency.

## Data

We'll be using the Chicken Pox Cases in Hungary, a benchmark dataset in the spatiotemporal graph neural networks field ([Rozemberczki 2021a](https://arxiv.org/abs/2102.08100), [Rozemberczki 2021b](https://dl.acm.org/doi/10.1145/3459637.3482014)). The data consists of weekly counts of chickenpox cases for the 20 counties of Hungary over the period 2005-2015.

There's evidence to suggest that graph GPs outperform GNNs on this task ([Nikitin 2022](https://arxiv.org/pdf/2111.08524v1.pdf)).

In [None]:
# Load in the data

## Building a model

### Likelihood

$$ \mathbf{y} \sim \text{Poisson Process}(0, k(\cdot, \cdot)) $$

OR (maybe they're equivalent, but you get the picture - we need a Poisson likelihood and the log transform keeps the rates positive):

$$ y_{st} \sim \text{Poisson}(\lambda_{st}) $$

$$ \log(\lambda_{st}) \sim \mathcal{GP}(0, k_{st}(\cdot, \cdot)) $$

We assume the spatial and temporal processes are separable and use [Kronecker inference](https://sethrf.com/files/fast-hierarchical-GPs.pdf)

$$ K_{st} = K_s \otimes K_t $$


### Spatial effects

Spatial problems are naturally correlated. Near things are more similar to far things. This is why they lend themselves so easily to Gaussian processes.

For events which can be given an exact location, such as point processes, it is common to define a Gaussian process using the spatial distance between the events. The data here, however, is areal level data collected at the county level. Rather than define spatial distances between counties, we will use the adjacency structure between counties. We can define a graph describing the spatial connectivity betwen counties and use a graph kernel to model the correlations.

In [None]:
# Plot the counties
# Load in the adjacency matrix
# Plot the county adjacency matrix as a network

In [None]:
# define the spatial kernel
L = nx.laplacian_matrix(G).toarray()
k_s = gpx.GraphKernel(laplacian=L)

### Temporal effects

In [None]:
# Plot time series facet for each county (see paper)

The time series present strong seasonality due to weather patterns and the periodocity of the school year.

In [None]:
# define time kernel
k_t = gpx.Matern52() + gpx.Periodic()

In [None]:
# k_st = k_s kron k_t

The problem involves a Poisson likelihood, which means we can't use conjugate inference. We will use Blackjax for inference.

In [None]:
# Pick some reasonable priors for lengthscales and variances