In [None]:
# Global Imports
import os
import warnings

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import xagg as xa
import xarray as xr
from shapely import wkt
from shapely.geometry import mapping

warnings.filterwarnings("ignore", message=".*initial implementation of Parquet.*")

In [None]:
# Load Datasets

In [None]:
# Load CHIRPS
chirps_ds = xr.open_dataset("../data/processed/chirps.nc")

In [None]:
# Load Terra
terra_ds = xr.open_dataset("../data/processed/terra.nc")

In [None]:
region_file = "../data/raw/region/region.geojson"
gdf = gpd.read_file(region_file, crs="EPSG:4326")

In [None]:
# Loading GRACE dataset
grace = xr.open_dataset("../data/processed/grace.nc")["lwe_thickness"].to_dataset()

In [None]:
# Generate Target Grid
final_grace_df = grace.to_dataframe().reset_index().drop(columns=["WGS84"])
grace_coords_df = final_grace_df[["lat", "lon"]].drop_duplicates()

lat_min, lat_max = gdf.geometry.bounds.miny[0], gdf.geometry.bounds.maxy[0]
lon_min, lon_max = gdf.geometry.bounds.minx[0], gdf.geometry.bounds.maxx[0]

target_lats = np.arange(lat_min, lat_max, 0.008)

target_lons = np.arange(lon_min, lon_max, 0.008)

target_ds = xr.Dataset({"lat": (["lat"], target_lats), "lon": (["lon"], target_lons)})

In [None]:
target_ds = target_ds.assign_coords(time=grace.time)

In [None]:
target_ds

In [None]:
len(target_ds.lat)

In [None]:
final_target_ds = target_ds.assign(
    precipitation=(
        ("lat", "lon", "time"),
        np.random.rand(
            len(target_ds.lat) * len(target_ds.lon) * len(target_ds.time)
        ).reshape(len(target_ds.lat), len(target_ds.lon), len(target_ds.time)),
    )
)

In [None]:
final_target_ds["precipitation"] = final_target_ds["precipitation"].rio.write_crs(
    "epsg:4326"
)

In [None]:
final_target_ds.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True)
clipped_target_ds = final_target_ds.rio.clip(
    gdf.geometry.apply(mapping), gdf.crs, all_touched=True, drop=True
)

In [None]:
clipped_target_ds.to_netcdf("../data/processed/target.nc")

In [None]:
# Save target polygons. Takes forever to run. That's why saving the output.

In [None]:
target_polygons = xa.core.create_raster_polygons(clipped_target_ds)

target_polygons["gdf_pixels"].drop(columns=["pix_idx"]).to_csv(
    "../data/processed/target_polygons.csv", index=False
)

In [None]:
target_gdf = target_polygons["gdf_pixels"]

In [None]:
%%time
clip_target = target_gdf.overlay(gdf)

In [None]:
final_target_gdf = target_gdf[target_gdf.pix_idx.isin(clip_target.pix_idx)]

In [None]:
final_target_gdf.drop(columns=["pix_idx"]).to_csv(
    "../data/processed/clipped_target_polygons.csv", index=False
)

In [None]:
# Generate GRACE Polygons
grace_polygons = xa.core.create_raster_polygons(grace)

In [None]:
%%time
clip_grace = grace_polygons["gdf_pixels"].overlay(gdf)

In [None]:
final_grace_gdf = grace_polygons["gdf_pixels"][
    grace_polygons["gdf_pixels"].pix_idx.isin(clip_grace.pix_idx)
]

In [None]:
final_grace_gdf.drop(columns=["pix_idx"]).to_csv(
    "../data/processed/grace_polygons.csv", index=False
)

In [None]:
# Generate Terra Polygons
terra_polygons = xa.core.create_raster_polygons(terra_ds)

In [None]:
terra_poly_gdf = terra_polygons["gdf_pixels"]

In [None]:
%%time
clip_terra = terra_poly_gdf.overlay(gdf)

In [None]:
final_terra_gdf = terra_poly_gdf[terra_poly_gdf.pix_idx.isin(clip_terra.pix_idx)]

In [None]:
final_terra_gdf.drop(columns=["pix_idx"]).to_csv(
    "../data/processed/terra_polygons.csv", index=False
)

In [None]:
# Generate CHIRPS Polygons
chirps_polygons = xa.core.create_raster_polygons(chirps_ds)

In [None]:
chirps_poly_gdf = chirps_polygons["gdf_pixels"]

In [None]:
%%time
clip_chirps = chirps_poly_gdf.overlay(gdf)

In [None]:
final_chirps_gdf = chirps_poly_gdf[chirps_poly_gdf.pix_idx.isin(clip_chirps.pix_idx)]

In [None]:
final_chirps_gdf.drop(columns=["pix_idx"]).to_csv(
    "../data/processed/chirps_polygons.csv", index=False
)

In [None]:
len(final_grace_gdf), len(final_terra_gdf), len(final_chirps_gdf)

In [None]:
fig, ax = plt.subplots(
    1, 3, figsize=(20, 25), constrained_layout=True, sharex=True, sharey=True
)
# grace.lwe_thickness.isel(time=100).plot(ax=ax)
# clipped_target_ds.precipitation.isel(time=100).plot(ax=ax)
# final_target_gdf.drop(columns=['pix_idx']).plot(ax=ax, color="none", edgecolor="blue")
gdf.plot(ax=ax[0], color="none", edgecolor="red")
# final_chirps_gdf.plot(color="none", edgecolor="blue", ax=ax)
final_terra_gdf.plot(color="none", edgecolor="blue", ax=ax[2])
gdf.plot(ax=ax[1], color="none", edgecolor="red")
gdf.plot(ax=ax[2], color="none", edgecolor="red")
final_chirps_gdf.plot(color="none", edgecolor="blue", ax=ax[1])
final_grace_gdf.plot(color="none", edgecolor="blue", ax=ax[0])
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()
ax[0].set_title("GRACE Grid Cells")
ax[1].set_title("CHIRPS Grid Cells")
ax[2].set_title("TERRA Grid Cells")
fig.tight_layout()
# plt.savefig("../reports/grid_cells.png", bbox_inches="tight")

In [None]:
first_int = gpd.overlay(
    final_chirps_gdf.drop(columns="pix_idx"),
    final_terra_gdf.drop(columns="pix_idx"),
    how="intersection",
)

In [None]:
first_int

In [None]:
second_int = gpd.overlay(
    final_grace_gdf.drop(columns="pix_idx"), first_int, how="intersection"
)

In [None]:
second_int

In [None]:
second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)]

In [None]:
lat_2_unique, lon_2_unique = (
    second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].lat_2.unique(),
    second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].lon_2.unique(),
)

In [None]:
lat_1_unique, lon_1_unique = (
    second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].lat_1.unique(),
    second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].lon_1.unique(),
)

In [None]:
second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].lat_2.iloc[0]

In [None]:
final_terra_gdf[final_terra_gdf.lat == 32.520833333333336]

In [None]:
second_int

In [None]:
single_grace_pixel = final_grace_gdf[
    (final_grace_gdf.lat == 32.75) & (final_grace_gdf.lon == -89.75)
]

In [None]:
grace_terra_int = single_grace_pixel.overlay(final_terra_gdf)

In [None]:
grace_chirps_int = single_grace_pixel.overlay(final_chirps_gdf)

In [None]:
chirps_terra_int = final_chirps_gdf.overlay(final_terra_gdf)

In [None]:
chirps_terra_int

In [None]:
chirps_terra_int[chirps_terra_int.pix_idx_1 == 115]

In [None]:
chirps_terra_int[chirps_terra_int.pix_idx_2 == 136]

In [None]:
chirps_terra_int[chirps_terra_int.pix_idx_2 == 136].plot(color="none")

In [None]:
grace_terra_int

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 15))
second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].plot(
    edgecolor="red", color="none", ax=ax[0]
)
grace_terra_int.plot(edgecolor="black", color="none", ax=ax[0])
second_int[(second_int.lat == 32.75) & (second_int.lon == -89.75)].plot(
    edgecolor="red", color="none", ax=ax[1]
)
grace_chirps_int.plot(edgecolor="black", color="none", ax=ax[1])

# final_terra_gdf[(final_terra_gdf.lat.isin([32.520833333333336]))].plot(edgecolor="red", color="none", ax=ax[1])

In [None]:
# Generate grid for the training data
timestep_list = []
for time_step in final_grace_df.time.unique():
    ts_df = second_int.copy()
    ts_df["time"] = time_step
    timestep_list.append(ts_df)

In [None]:
grids_df = pd.concat(timestep_list)

In [None]:
chirps_df = chirps_ds.to_dataframe().reset_index()

In [None]:
terra_df = terra_ds.to_dataframe().reset_index()

In [None]:
chirps_df

In [None]:
final_grace_df

In [None]:
grids_df

In [None]:
final_gdf = (
    grids_df.merge(final_grace_df)
    .merge(terra_df.rename(columns={"lat": "lat_2", "lon": "lon_2"}))
    .merge(chirps_df.rename(columns={"lat": "lat_1", "lon": "lon_1"}))
)

In [None]:
final_gdf

In [None]:
final_columns = [
    "lat_1",
    "lon_1",
    "lat_2",
    "lon_2",
    "lat",
    "lon",
    "time",
    "lwe_thickness",
    "aet",
    "def",
    "pdsi",
    "pet",
    "pr",
    "srad",
    "ro",
    "soil",
    "swe",
    "precip",
]

In [None]:
model_df = final_gdf[final_columns]

In [None]:
model_df.to_csv("../data/processed/model_inputs.csv", index=False)

In [None]:
# Generate Target GRID

In [None]:
target_polygons = pd.read_csv("../data/processed/clipped_target_polygons.csv")

In [None]:
target_polygons = gpd.GeoDataFrame(
    target_polygons, geometry=target_polygons.geometry.apply(wkt.loads), crs="EPSG:4326"
)

In [None]:
target_int = gpd.overlay(first_int, target_polygons, how="intersection")

In [None]:
target_int.to_csv("../data/processed/target_grid.csv", index=False)

In [None]:
# Generating parquet files for the target grid
for time_step in final_grace_df.time.unique():
    # ts_df = second_int.copy()
    ts_df = target_int.copy()
    ts_df["time"] = time_step
    output_date_str = pd.to_datetime(time_step).strftime("%Y-%m-%d")
    terra_subset = terra_df[terra_df.time == time_step]
    chirps_subset = chirps_df[chirps_df.time == time_step]
    if len(terra_subset) > 0:
        output_df = ts_df.merge(
            terra_subset.rename(columns={"lat": "lat_2", "lon": "lon_2"})
        ).merge(chirps_subset.rename(columns={"lat": "lat_1", "lon": "lon_1"}))
        output_df.drop(
            columns=["geometry", "lat_1", "lon_1", "lat_2", "lon_2"]
        ).to_parquet(
            os.path.join("../data/processed/target_grid", f"{output_date_str}"),
            index=False,
            compression="gzip",
        )
    # .merge(chirps_df.rename(columns={"lat": "lat_1", "lon": "lon_1"}))