## Augment Missing Data in LitPop with Geg-15

There are some missing countries in LitPop. This notebook fills in those areas with Geg-15 and saves the grid as a single parquet file.

In [None]:
%load_ext autoreload

%autoreload 2

import dask.dataframe as ddf
import geopandas as gpd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy_groupies as npg
import pandas as pd
import regionmask
import rhg_compute_tools.kubernetes as rhgk
import xarray as xr
import rioxarray
import xesmf as xe
from cartopy import crs as ccrs
from cartopy import feature as cfeature

from sliiders.geography import get_iso_geometry
from sliiders.spatial import grid_ix_to_val, grid_val_to_ix
from sliiders import settings as sset

import zipfile
from sliiders import __file__
from pathlib import Path
import os

In [None]:
client, cluster = rhgk.get_micro_cluster()

In [None]:
nworkers = 16
cluster.scale(nworkers)
cluster

In [None]:
sliiders_dir = Path(__file__).parent
zipf = zipfile.ZipFile("sliiders.zip", "w", zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(sliiders_dir):
    for file in files:
        zipf.write(
            os.path.join(root, file),
            os.path.relpath(os.path.join(root, file), os.path.join(sliiders_dir, "..")),
        )
zipf.close()
client.upload_file("sliiders.zip")

## Load Datasets

In [None]:
# # convert litpop from csvs to a parquet file
litpop = ddf.read_csv(str(sset.PATH_LITPOP_RAW))

In [None]:
litpop = litpop.rename(columns={"latitude": "lat", "longitude": "lon"})

In [None]:
litpop["value"] = litpop["value"].astype(np.float32)
litpop["lat"] = litpop["lat"].astype(np.float32)
litpop["lon"] = litpop["lon"].astype(np.float32)

In [None]:
litpop = litpop.repartition(npartitions=nworkers).persist()

In [None]:
litpop_meta = pd.read_csv(sset.DIR_LITPOP_RAW / "_metadata_countries_v1_2.csv")

## Create Geodataframe for Countries with Missing LitPop Data

In [None]:
missing_countries = litpop_meta[litpop_meta["included"] == 0].copy()

In [None]:
missing_countries = gpd.GeoDataFrame(
    missing_countries,
    geometry=get_iso_geometry(missing_countries["iso3"].to_numpy()),
)

missing_countries["iso3"].unique()

In [None]:
region_id_to_iso = litpop_meta.set_index("region_id")[["iso3"]]

litpop = litpop.join(region_id_to_iso, on="region_id").persist()

In [None]:
geg15 = pd.read_parquet(sset.PATH_GEG15_INT, columns=["lon", "lat", "iso3", "tot_val"])

In [None]:
geg15["tot_val"] = geg15["tot_val"] * 1e6

In [None]:
lp_iso3 = litpop["iso3"].unique().compute()
geg_iso3 = geg15["iso3"].unique()

## Helper Functions

In [None]:
# retrieve geg data to regrid
def subset_relevant_geg_data(poly, geg15, buffer=1 / 48):
    # subset geg for buffered country poly bounds
    geg15_sub = (
        geg15[
            (geg15.lon >= poly.bounds[0] - buffer)
            & (geg15.lon <= poly.bounds[2] + buffer)
            & (geg15.lat >= poly.bounds[1] - buffer)
            & (geg15.lat <= poly.bounds[3] + buffer)
        ][["lon", "lat", "tot_val"]].reset_index(drop=True)
        #         .compute()
    )

    if geg15_sub.shape[0] == 0:
        return None

    subset = geg15_sub.set_index(["lat", "lon"]).to_xarray()

    subset["mask"] = poly_mask(poly, subset)

    if subset.tot_val.where(subset.mask == 1).sum() <= 0:
        return None

    return subset


def create_grid(subset, resolution, add_cell_corners=False):

    masked_lon = subset.lon.where((subset.mask > 0) & (subset.tot_val.notnull()))
    masked_lat = subset.lat.where((subset.mask > 0) & (subset.tot_val.notnull()))

    # construct destination grid with mask holder variable
    x1, y1 = np.floor((masked_lon.lon.min().item(), masked_lat.lat.min().item()))
    x2, y2 = np.ceil((masked_lon.lon.max().item(), masked_lat.lat.max().item()))

    lat = np.arange(y1 + resolution / 2, y2, resolution)
    lon = np.arange(x1 + resolution / 2, x2, resolution)

    ds_out = xr.Dataset(
        coords={
            "lat": lat,
            "lon": lon,
        }
    )

    if add_cell_corners:
        ds_out.coords["lat_b"] = (ds_out.lat.min().item() - resolution / 2) + np.arange(
            len(ds_out.lat) + 1
        ) * resolution
        ds_out.coords["lon_b"] = (ds_out.lon.min().item() - resolution / 2) + np.arange(
            len(ds_out.lon) + 1
        ) * resolution

    return ds_out


def poly_mask(poly, grid):
    mask_grid = grid.copy()
    mask_grid["mask"] = (
        ["lat", "lon"],
        np.full((len(mask_grid.lat), len(mask_grid.lon)), 1, np.int32),
    )

    mask_grid = mask_grid.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True)
    mask_grid = mask_grid.rio.write_crs("epsg:4326", inplace=True)

    clipped = mask_grid.rio.clip([poly], drop=False, all_touched=True)
    clipped = (clipped == 1).astype(np.int32)

    return clipped.mask.dims, clipped.mask.values


def make_land_weights(subset, poly, out_resolution, in_resolution):

    print("Creating grids...")
    # create grid at out_resolution with grid cell edges at a whole lat and lon values
    out_grid = create_grid(subset, resolution=out_resolution, add_cell_corners=True)

    # create grid at in_resolution with grid cell edges at a whole lat and lon values
    in_grid = create_grid(subset, resolution=in_resolution, add_cell_corners=True)

    # create grid cell id for in_grid
    in_grid["id5x"] = (
        ["lat", "lon"],
        np.arange(in_grid.lat.shape[0] * in_grid.lon.shape[0]).reshape(
            (in_grid.lat.shape[0], in_grid.lon.shape[0])
        ),
    )

    # apply in_grid grid cell id to out_grid cells
    out_grid["idx5"] = in_grid.reindex_like(
        out_grid, method="nearest", tolerance=in_resolution / 2
    ).id5x

    print("Creating land mask...")
    out_grid["mask"] = (
        regionmask.Regions([poly], numbers=[1])
        .mask(out_grid.lon.values, out_grid.lat.values)
        .fillna(0)
    )

    print("Constructing land weights...")
    in_grid["land_weights"] = (
        ["lat", "lon"],
        npg.aggregate(
            group_idx=out_grid.idx5.values.flatten(),
            a=out_grid.mask.values.flatten(),
            fill_value=0,
            func="sum",
        ).reshape(in_grid.id5x.shape)
        / ((in_resolution / out_resolution) ** 2),
    )

    return in_grid


def prep_geg_for_regrid(
    poly, geg15, geg_res=sset.GEG_GRID_WIDTH, litpop_res=sset.LITPOP_GRID_WIDTH
):

    # get relevant geg data given poly of interest
    subset = subset_relevant_geg_data(poly, geg15, geg_res / 2)
    if subset is None:
        return None

    # construct land weights
    weights = make_land_weights(subset, poly, litpop_res, geg_res)

    # add corners for conservative regrid
    subset.coords["lat_b"] = (subset.lat.min().item() - geg_res / 2) + np.arange(
        len(subset.lat) + 1
    ) * geg_res
    subset.coords["lon_b"] = (subset.lon.min().item() - geg_res / 2) + np.arange(
        len(subset.lon) + 1
    ) * geg_res

    # regrid landweights onto geg grid
    regridder = xe.Regridder(weights, subset, "conservative")
    land_weights_regrid = regridder(weights)

    # normalize using amount of land per cell
    weights = geg_res**2
    subset["tot_val_norm"] = (
        subset.tot_val.where(land_weights_regrid.land_weights > 0) / weights
    )

    # drop out if all null data --> no asset value on relevant land
    if (
        subset.tot_val_norm.where((subset.mask > 0) & subset.tot_val_norm.notnull())
        .notnull()
        .sum()
        == 0
    ):
        return None

    return subset


def regrid_geg(
    poly, geg15, geg_res=sset.GEG_GRID_WIDTH, litpop_res=sset.LITPOP_GRID_WIDTH
):

    geg_sub = prep_geg_for_regrid(poly, geg15, geg_res, litpop_res)

    if geg_sub is None:
        return None

    out_grid = create_grid(geg_sub, resolution=litpop_res)

    regridder = xe.Regridder(geg_sub, out_grid, "nearest_s2d")

    geg_regridded = regridder(geg_sub)

    mask_dims, mask = poly_mask(poly, geg_regridded[["lat", "lon"]])
    geg_regridded["tot_val"] = (geg_regridded.tot_val_norm * (litpop_res**2)).where(
        mask == 1
    )

    return geg_regridded

## Regrid GEG for Missing Countries in LitPop

In [None]:
out_dict = {}
for territory in sset.ISOS_IN_GEG_NOT_LITPOP:
    print(territory)
    territory_shape = (
        missing_countries[missing_countries["iso3"] == territory].iloc[0].geometry
    )
    out_dict[territory] = regrid_geg(territory_shape, geg15)

## Check Regridding Looks Good

In [None]:
def plot_exposure(ax, title, data, poly, vmin=None, vmax=None):
    ax.set_extent(
        [poly.bounds[0] - 1, poly.bounds[2] + 1, poly.bounds[1] - 1, poly.bounds[3] + 1]
    )
    ax.coastlines("10m", linewidth=0.5, edgecolor="tab:orange")

    adm0 = cfeature.NaturalEarthFeature(
        category="cultural",
        name="admin_0_boundary_lines_land",
        scale="10m",
        facecolor="none",
    )

    ax.add_feature(adm0, edgecolor="tab:orange", linewidth=0.1)

    data.where(data > 0.0000001).plot(
        cmap="YlGnBu",
        norm=matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax),
        ax=ax,
        cbar_kwargs={"shrink": 0.5, "label": ""},
    )

    ax.add_geometries(
        [poly], ccrs.PlateCarree(), facecolor="none", edgecolor="r", linewidth=0.3
    )
    ax.set_title(title)

In [None]:
# How does the regridding look?
%matplotlib inline

plot_dict = out_dict

pc_transform = ccrs.PlateCarree()
fig, axs = plt.subplots(
    figsize=((3 * 3), (3 * 4)),
    dpi=500,
    ncols=3,
    nrows=3,
    subplot_kw={"projection": pc_transform},
)

axs = axs.flatten()
for ax, tup in zip(axs, plot_dict.items()):
    iso = tup[0]
    out = tup[1]
    row = missing_countries[missing_countries.iso3 == iso].iloc[0]
    poly = row.geometry
    plot_exposure(ax, iso, out["tot_val"], poly)

## Add Regridded Data into LitPop

In [None]:
# swap from value to integer indexing
litpop["lat"] = litpop.lat.map_partitions(
    grid_val_to_ix, cell_size=sset.LITPOP_GRID_WIDTH
)
litpop["lon"] = litpop.lon.map_partitions(
    grid_val_to_ix, cell_size=sset.LITPOP_GRID_WIDTH
)
litpop = litpop.persist()

litpop

In [None]:
# add geg data into litpop dask dataframe
for iso, _add in out_dict.items():
    print(iso)
    add = _add.copy()
    add.coords["lat"] = grid_val_to_ix(add.lat.values, sset.LITPOP_GRID_WIDTH)
    add.coords["lon"] = grid_val_to_ix(add.lon.values, sset.LITPOP_GRID_WIDTH)

    litpop_sub = litpop[
        (litpop.lon >= add.lon.min().item())
        & (litpop.lon <= add.lon.max().item())
        & (litpop.lat >= add.lat.min().item())
        & (litpop.lat <= add.lat.max().item())
    ].compute()

    # Mask out all MAR values below the MAR-ESH border (this border is defined by its latitude)
    if iso == "ESH":
        litpop_sub = litpop_sub.loc[
            ~(litpop_sub["iso3"] == "MAR")
            | ~(litpop_sub["lat"] <= get_iso_geometry("ESH").bounds[3])
        ].copy()

    litpop_sub = litpop_sub.set_index(["lat", "lon"]).to_xarray()

    add = add.rename({"tot_val": "value"})

    add["iso3"] = (["lat", "lon"], np.where((~np.isnan(add["value"])), iso, None))

    litpop_sub["new_iso3"] = add["iso3"]
    litpop_sub["iso3"] = xr.where(
        litpop_sub["new_iso3"].isnull(), litpop_sub["iso3"], litpop_sub["new_iso3"]
    )
    litpop_sub["new_value"] = add["value"]
    litpop_sub["value"] = xr.where(
        litpop_sub["new_value"].isnull(), litpop_sub["value"], litpop_sub["new_value"]
    )

    mmed = xr.merge([litpop_sub[["value", "iso3"]], add[["value", "iso3"]]])

    litpop_m_sub = litpop[
        ~(
            (litpop.lon >= add.lon.min().item())
            & (litpop.lon <= add.lon.max().item())
            & (litpop.lat >= add.lat.min().item())
            & (litpop.lat <= add.lat.max().item())
        )
    ]

    to_append = mmed[["value", "iso3"]].to_dataframe().dropna().reset_index()

    # TODO figure out what's going on here--sometimes index isn't automatically named by `to_dataframe()`
    to_append = to_append.rename(columns={"level_0": "lat", "level_1": "lon"})
    litpop = litpop_m_sub.append(to_append).persist()

In [None]:
# prep vars for saving
litpop["y_ix"] = litpop["lat"].astype(np.int16)
litpop["x_ix"] = litpop["lon"].astype(np.int16)
litpop["value"] = litpop["value"].astype(np.float32)

In [None]:
litpop = litpop.persist()

In [None]:
out_iso3 = litpop["iso3"].unique().compute()

In [None]:
litpop = litpop[["y_ix", "x_ix", "value"]]
litpop = litpop[litpop["value"] > 0]

In [None]:
df_litpop = litpop.compute()

df_litpop = df_litpop.reset_index(drop=True)

In [None]:
df_litpop["value"] = df_litpop["value"].astype(np.float32)

In [None]:
df_litpop

In [None]:
sset.PATH_EXPOSURE_BLENDED.parent.mkdir(exist_ok=True)

In [None]:
df_litpop.to_parquet(
    sset.PATH_EXPOSURE_BLENDED,
    index=False,
    compression=None,
    engine="fastparquet",
)

In [None]:
client.close()

In [None]:
cluster.close()

## Check To Make Sure GEG Additions Look Good

In [None]:
litpop_int = pd.read_parquet(sset.PATH_EXPOSURE_BLENDED)

In [None]:
litpop_int["lat"] = grid_ix_to_val(litpop_int.y_ix, cell_size=sset.LITPOP_GRID_WIDTH)
litpop_int["lon"] = grid_ix_to_val(litpop_int.x_ix, cell_size=sset.LITPOP_GRID_WIDTH)

In [None]:
# How does the regridding look?
%matplotlib inline

plot_dict = out_dict

pc_transform = ccrs.PlateCarree()
fig, axs = plt.subplots(
    figsize=((3 * 3), (3 * 4)),
    dpi=500,
    ncols=3,
    nrows=4,
    subplot_kw={"projection": pc_transform},
)

axs = axs.flatten()
for ax, tup in zip(axs, plot_dict.items()):
    iso = tup[0]
    add = tup[1]
    row = missing_countries[missing_countries.iso3 == iso].iloc[0]
    poly = row.geometry

    litpop_sub = (
        litpop_int[
            (litpop_int.lon >= add.lon.min().item() - 1)
            & (litpop_int.lon <= add.lon.max().item() + 1)
            & (litpop_int.lat >= add.lat.min().item() - 1)
            & (litpop_int.lat <= add.lat.max().item() + 1)
        ]
        .set_index(["lat", "lon"])
        .to_xarray()
    )

    plot_exposure(ax, row.country_name, litpop_sub.value, poly)