# Modeling Spatial Data with Gaussian Processes in PyMC

In this notebook I will recreat the code from [Luciano Paz](https://www.pymc-labs.io/blog-posts/spatial-gaussian-process-01/) with the goal of predicting expected concentrations of radioactive in households depending on the county of the house.

Usually we would picture some kind of continiuos lateng geographical feature that makes observations taken from nearby places be similar to each other. This is diferent from the notion of treating geographical information as a categorical value that grouped observations together. 

One way to get the continium, and avoiding the notion of neighbours to have absolutly nothing in common, is to use Gaussina Processes to model it. GP provides a very nice and flexible way of setting a prior, that essentially says: "nearby observations should be similar to each other, and as the observations go further away, they become uncorrelated. 

## The dataset

In this exercise, the dataset is [the radon dataset from Gelman and Hill 2006](https://www.cambridge.org/highereducation/books/data-analysis-using-regression-and-multilevel-hierarchical-models/32A29531C7FD730C3A68951A17C9D983#overview) where they studied a dataset of radon meassurements that were performed in 919 households from 85 counties of the state of Minnesota. They used a grouping approach, which doesn't have any sense, why should the radon concentration in earth crust follows some county border line? 

In [None]:
import arviz as az
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
import numpy as np
import pandas as pd
import pymc as pm
from aesara import tensor as at
from matplotlib import pyplot as plt

Cartopy can be used to get the shape files from some public sources, like the Natural Earth. The county shapes have a lot of meta information. One important fields is called FIPS, that stands for [Federal Information Processing Systems](https://en.wikipedia.org/wiki/Federal_Information_Processing_Standards). At the time the radon measurements were performed, the counties were identified using their FIPS codes, and we will use these to align our observations to the corresponding shape files.

In [None]:
# load the dataset:

df = pd.read_csv('data/radon.csv', index_col=0, dtype={"fips":int})
df['fips'] = "US" + df['fips'].astype(str)
county_idx, counties = df.county.factorize(sort=True)
unique_fips = df.fips.unique()
df

In [None]:
# get the state of minnesota shape file:

reader = shpreader.Reader(
    shpreader.natural_earth(
        resolution="10m", category='cultural',name='admin_1_states_provinces'
    )
)
minnesota = [s for s in reader.records() if s.attributes['admin'] == 'United States of America' and s.attributes['name']=="Minnesota"][0]

In [None]:
# Get Minnesota counties and neighboring counties shape files:

reader = shpreader.Reader(
    shpreader.natural_earth(
        resolution='10m', category='cultural', name='admin_2_counties'
    )
)

minnesota_counties = [county for county in reader.records() if county.geometry.intersects(minnesota.geometry) \
                       and county.geometry.difference(minnesota.geometry).area / county.geometry.area < 0.01]

minnesota_neighbor_counties = [county for county in reader.records() if county.geometry.intersects(minnesota.geometry) \
                       and county.geometry.difference(minnesota.geometry).area / county.geometry.area > 0.5]

counties_with_measurements = [c for c in minnesota_counties if c.attributes['FIPS'] in unique_fips]

counties_without_measurements = [c for c in minnesota_counties if c.attributes['FIPS'] not in unique_fips]

len(counties_with_measurements), len(counties_without_measurements), len(minnesota_neighbor_counties)

In [None]:
counties_without_measurements

In [None]:
# Getting a map from the county names to latitude and longitude:

county_fips = {counties[idx]: df.fips.iloc[i] for i, idx in enumerate(county_idx)}
fips_to_records = {
    record.attributes["FIPS"]: record for record in counties_with_measurements
}
county_to_records = {c: fips_to_records[county_fips[c]] for c in counties}
county_lonlat = {
    c: np.array(
        [
            county_to_records[c].attributes["longitude"],
            county_to_records[c].attributes["latitude"],
        ]
    )
    for c in counties
}
cond_counties = [
    c.attributes["NAME"].upper() for c in counties_without_measurements
] + [
    f"{c.attributes['NAME']} - {c.attributes['REGION']}".upper()
    for c in minnesota_neighbor_counties
]
county_to_records.update(
    {
        name: record
        for name, record in zip(
            cond_counties, counties_without_measurements + minnesota_neighbor_counties
        )
    }
)
cond_county_lonlat = {
    c: np.array(
        [
            county_to_records[c].attributes["longitude"],
            county_to_records[c].attributes["latitude"],
        ]
    )
    for c in cond_counties
}

In [None]:
expected_radon = df.groupby("county")["log_radon"].mean()

In [None]:
expected_radon

In [None]:
minnesota.geometry

In [None]:
county_to_records['ANOKA'].geometry

In [None]:
fig = plt.figure(figsize=(12, 7))
projection = ccrs.PlateCarree()
ax = plt.axes(projection=projection)
ax.add_feature(
    cfeature.ShapelyFeature([minnesota.geometry], projection),
    edgecolor="k",
    facecolor="w",
)
vmin = expected_radon.min()
vmax = expected_radon.max()
color_getter = lambda x: plt.get_cmap("viridis")(np.interp(x, [vmin, vmax], [0, 1]))
for county in counties:
    county_record = county_to_records[county]
    val = expected_radon[county]
    ax.add_feature(
        cfeature.ShapelyFeature([county_record.geometry], projection),
        edgecolor="gray",
        facecolor=color_getter(val.data),
    )

cbar = fig.colorbar(plt.matplotlib.cm.ScalarMappable(norm=None, cmap="viridis"))
cbar.set_ticks(np.linspace(0, 1, 6))
cbar.set_ticklabels(
    [f"{round(np.interp(x, [0, 1], [vmin, vmax]), 2)}" for x in cbar.get_ticks()]
)
cbar.set_label("Observed mean Log Radon")
ax.add_feature(cfeature.LAKES, alpha=0.5)
ax.add_feature(cfeature.RIVERS)
ax.set_xlim([-99, -87])
ax.set_ylim([42, 50])
ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)