In [None]:
import os; os.environ['PROJ_LIB'] = '/home/users/jdha/mambaforge/envs/gp/share/proj' # avoid basemap import error

## Extracting timeseries from model data using Shapefiles

An example notebook to show the extraction and plotting of a timeseries using polygons extracted from a shapefile.

In [None]:
import geopandas as gpd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.io.shapereader as shpreader
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dask 
import xarray as xr
import datetime

from xarray import DataArray, Dataset
from shapely.geometry import Point
from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar, Union, cast
from matplotlib import cm
from matplotlib.axes import Axes
from mpl_toolkits.axes_grid1 import make_axes_locatable

get_ipython().run_line_magic("matplotlib", "inline")
get_ipython().run_line_magic("config", "InlineBackend.figure_format = 'retina'")

F = TypeVar('F', bound=Callable[..., Any])

### Functions

In [None]:
def shp_extract(
    shp_gdf: gpd.GeoDataFrame,
    ds: Dataset,
    shp_name: str='LME_NAME',
    lat_name: str='nav_lat',
    lon_name: str='nav_lon',
) -> Dataset:
    
    """
    Generalised function providing to extract the polygons from a shp file
    and append the ds with corresponding array.
    Parameters
    ----------
    shp_gdf: GeoDataFrame
        GeoDataFrame created from a .shp file
    ds: DataSet
        xarray dataset containing coordinate information
    lat_name, lon_name: str
        name of longitude and latitude coordinates
    shp_name: str
        name of shapefile polygons    
    Returns
    -------
    Dataset
        xarray dataset with shape indices
    Notes
    -----
    Tested only with LME data.
    """    
    
    # flatten the 2D longitude-latitude coordinates
    x_rav = ds[lon_name].values.ravel()
    y_rav = ds[lat_name].values.ravel()
    
    # pass to dataframe
    df    = pd.DataFrame({'lon':x_rav, 'lat':y_rav})
    df['coords'] = list(zip(df['lon'],df['lat']))
    df['coords'] = df['coords'].apply(Point)
    
    # create GeoDataFrame and match points to the ploygons in the shp file
    pnts = gpd.GeoDataFrame(df, geometry='coords', crs='epsg:4326')
    rois = gpd.tools.sjoin(pnts, shp_gdf, predicate='within', how='left')
    
    # create a new coordinate containing an indexed grid of polygons
    ds[shp_name] = (('y', 'x'), rois[shp_name].values.reshape(ds[lat_name].shape))
   
    return ds

# --------------------------------------------------------------------------

def shp_plot(
    ax: Axes, 
    shp_gdf: gpd.GeoDataFrame, 
    shp_name: str, 
    roi_list: list, 
    projection: ccrs.Projection, 
) -> Axes:
    
    """
    Generalised function to plot the extent of the polygons 
    in a given shp file.
    Parameters
    ----------
    ax: plt.ax
        figure axes handle
    shp_gdf: GeoDataFrame
        GeoDataFrame created from a .shp file
    shp_name: str
        name of shapefile polygons
    roi_list: list
        list of str containing the names of the 
        polygons required for plotting
    projection: ccrs.Projection
        for example: ccrs.PlateCarree(central_longitude=0)
    Returns
    -------
    plt.ax
        updated axes handle
    Notes
    -----
    Tested only with LME data.
    """     
    
    # adding land features
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.COASTLINE)
   
    # extract names from GeoDataFrame
    shp_names = shp_gdf[shp_name].values
 
    # set colormap for number of entries
    color = iter(cm.tab20b(np.linspace(0,1,len(shp_gdf))))
    colors_rgb = [next(color) for c in range(len(shp_gdf))]

    # loop over polygons entries
    counter = 0
    for name in shp_names:
        
        # select polygon
        shp_gdf_sel = shp_gdf[ shp_names==name ]
        
        # is it a roi?
        if (name in roi_list):

            # add the geometry and fill it with color
            color = colors_rgb[counter]
            ax.add_geometries(shp_gdf_sel['geometry'], 
                              projection,
                              facecolor=color, 
                              edgecolor='k')
            ax.annotate(text=counter, 
                        xy=(shp_gdf_sel.to_crs('+proj=cea').centroid.to_crs(shp_gdf_sel.crs).x,
                            shp_gdf_sel.to_crs('+proj=cea').centroid.to_crs(shp_gdf_sel.crs).y), 
                            color='white',
                            fontsize=10)
        else:
            ax.add_geometries(shp_gdf_sel['geometry'], 
                              projection,
                              facecolor='LightGray', 
                              edgecolor='k')
        counter+=1

    return ax

# --------------------------------------------------------------------------

### Setup

Define a few constants, the target Dataset and extract the shape information

In [None]:
# open Dataset
data_dir = '/gws/nopw/j04/class_vol2/senemo/jdha/FINAL_TESTING/'
exp      = 'EXP_MES_WAV_DJC_NTM_TDISSx2'
ds       = {}
year_st  = 1980
year_en  = 1989
shp_name = 'LME_NAME'
var_name = 'thetao_con'
roi_list = ['Indonesian Sea',]
ds       = xr.open_mfdataset(data_dir+exp+'/SENEMO_1m_*_grid_T_*.nc')

# open the shapefile
shp_file = '/home/users/jdha/shapefiles/LME66/LMEs66.shp'
shp_gdf  = gpd.GeoDataFrame.from_file(shp_file)

# extract shape polygons and add to Dataset
ds = shp_extract(shp_gdf, ds)

### Extract the data over the given time period

Not too sure whether the following will improve the efficiency of the extract:
```
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    roi_var = ...
```
Depending on time slice this is done over, can take several minutes - optimised use of Dask?

**NB the following has 'Indonesian Sea' hard coded - but it would be easy enough to loop over `roi_list`**

In [None]:
# extract data for each polygon
roi_var = ds[var_name][:,0,:,:].where(ds['LME_NAME'][:,:]=='Indonesian Sea', drop=True).sel(time_counter=slice(str(year_st)+"-01-01", str(year_en)+"-12-31"))

In [None]:
# to speed future operations load the data into memmory now - could create a new Dataset at this point
roi_var_mean = roi_var.mean(dim=["x","y"]).load() # area mean
roi_var_std  = roi_var.std(dim=["x","y"]).load() # area std
roi_var_clim = roi_var_mean.groupby('time_counter.month').mean('time_counter',keep_attrs=True) # monthly climatology
roi_var_anom = (roi_var_mean.groupby('time_counter.month') - roi_var_clim) # monthly anomaly

In [None]:
# create timeseries from data
average_period = "Y" # annual averaging

roi_rs      = roi_var_mean.resample(time_counter=average_period, loffset="-1Y") # offset by a year for plot labelling
roi_ts      = roi_rs.mean()
roi_rs_std  = roi_var_std.resample(time_counter=average_period, loffset="-1Y")
roi_ts_std  = roi_rs_std.mean()

roi_rs_anom = roi_var_anom.resample(time_counter=average_period, loffset="-1Y")
roi_ts_anom = roi_rs_anom.mean()

### Plot up the timeseries along with a map of the ROI

In [None]:
# setup the figure panels
fig = plt.figure(figsize=(12, 8))
projection=ccrs.PlateCarree(central_longitude=0)

ax1 = fig.add_subplot(211, projection=projection)
ax1 = shp_plot(ax1, shp_gdf, 'LME_NAME', ['Indonesian Sea'], projection)
ax1.set_title('Indonesian Sea')

ax2 = fig.add_subplot(212)
divider = make_axes_locatable(ax2)
ax3 = divider.new_horizontal(size="100%", pad=0.6, axes_class=plt.Axes)
fig.add_axes(ax3)

time = roi_ts.indexes['time_counter'].to_pydatetime()

ax2.fill_between(time, roi_ts-roi_ts_std, roi_ts+roi_ts_std, color='lightgrey', alpha=0.2)
ax2.plot(time,roi_ts,marker='o',linewidth=1.0)
ax2.title.set_text("SST [C]")
                     
ax3.fill_between(time, roi_ts_anom-roi_ts_std, roi_ts_anom+roi_ts_std, color='lightgrey', alpha=0.2,label=None)
ax3.plot(time,roi_ts_anom,marker='o',linewidth=1.0, label='NEMO')
ax3.title.set_text("SST anomaly [C]")