#### IMPORTS

In [None]:
import numpy as np
import pandas as pd
import glob
import os

import geopandas as gpd
import xarray as xr
import rioxarray as rxr
from pyproj import CRS
from shapely.geometry import mapping
from rasterio.enums import Resampling


import climate_indices
from climate_indices import indices

#### SETTING PARAMETERS

In [None]:
############################################################################ SETTING PARAMETERS  for the SPI #################################################################

scale = 3
distribution = climate_indices.indices.Distribution.gamma   #Fixed
data_start_year = 1980
calibration_year_initial = 1980
calibration_year_final = 2023
periodicity = climate_indices.compute.Periodicity.monthly   #Fixed

######################################################################################### CRS ############################################################################

crs_project = CRS.from_epsg(4326) #WGS84

######################################################################################### INPUTS ############################################################################

ERA5_input_path = 'Insert path to monthly ERA5_land data'
ERA5_daily_input_folder =  'Insert path to daily ERA5_land data'

######################################################################################### OUPUTS ############################################################################

path_out = 'Insert path to the output directory'
SPEI_ouput_file = 'Insert output file name'

#### MASK FUNCTIONS

In [None]:
# Load the shapefile
def load_shape_file(filepath):
    """Loads the shape file desired to mask a grid.
    Args:
        filepath: Path to *.shp file
    """
    shpfile = gpd.read_file(filepath)
    print("""Shapefile loaded. To prepare for masking, run the function
        `select_shape`.""")
    return shpfile

#Create the mask
def select_shape(shpfile):
    """Select the submask of interest from the shapefile.
    Args:
        shpfile: (*.shp) loaded through `load_shape_file`
        category: (str) header of shape file from which to filter shape.
            (Run print(shpfile) to see options)
        name: (str) name of shape relative to category.
           Returns:
        shapely polygon
    """

    col_code = 'ISO3_CODE'
    country_codes = ['ZAF', 'LSO', 'SWZ']

    # Extract the rows that have 'ZAF', 'LSO', or 'SWZ' in the 'SOV_A3' column
    selected_rows = shpfile[shpfile[col_code].isin(country_codes)]

    # Combine the selected polygons into a single polygon
    unioned_polygon = selected_rows.geometry.unary_union

    # Convert the unioned polygon to a geopandas dataframe with a single row
    mask_polygon = gpd.GeoDataFrame(geometry=[unioned_polygon])
    
    print("""Mask created.""")

    return mask_polygon

#### MASK LAYER

In [None]:
#Load de shp
shpfile = load_shape_file('Insert path to the shp CNTR_RG_01M_2020_4326.shp') #Boundaries

#Create the mask layer
mask_layer = select_shape(shpfile)

#Giving a CRS
mask_layer.crs = crs_project

#### VARIABLE PROCESSING

#### ERA5 TEMPERATURE DATA

In [None]:
daily_temp_files = glob.glob(os.path.join(ERA5_daily_input_folder, '*.nc'))

##### Temperatures

In [None]:
def getTemperaturesData(temp_data):
    """Process the data to get daily max, mean and min temperatures as input to the SPEI function
    Args:
        data: netcdf file with hourly temperature data

        Returns
        Three dataarray for Tmax, Tmin and Tmean with monthly values
    """  
    scale_factor =  temp_data.attrs['scale_factor']
    offset = temp_data.attrs['add_offset']
    temp_data = (temp_data * scale_factor) + offset #Rescaling the values

    #Create 3 variables for daily tmean
    temp_tmean = temp_data.resample(time ='D').mean()-273.15

    # Resample data from daily into monthly. 

    temp_tmean = temp_tmean.resample(time ='M').mean()

    # Resample original data from hourly into monthly to have the structure
    temperatures = temp_data.resample(time='M').mean()

    #Putting all together
    temperatures['tmean'] = temp_tmean

    #Separate in variables
    tmean = temperatures['tmean']
    tmean = tmean.reindex(y=list(reversed(tmean['y']))) # Reverse the Y dimension values to increasing values (This is an issue of ERA5 datasets and other climatic datasets)
    tmean.rio.write_crs(crs_project, inplace=True)

    #Mask the country
    tmean_masked = tmean.rio.clip(mask_layer.geometry.apply(mapping), crs=mask_layer.crs, all_touched=True, from_disk=True).squeeze()

    return tmean_masked

In [None]:
tmean_list = []

#Applying the function to each one of the files with hourly data. Appending the result in separate variable lists
for file in daily_temp_files:
    data = rxr.open_rasterio(file, masked=True)
    tmean_masked = getTemperaturesData(data)
    tmean_list.append(tmean_masked)

#Creating an xarray for each temp variable
Tmean = xr.concat(tmean_list, dim='time')

#Changing time format
Tmean['time'] = Tmean['time'].astype('datetime64[ns]')

#Cleaning
Tmean = Tmean.drop_vars('tmean')

#### ERA5 DATA

In [None]:
#Loading the data
data = rxr.open_rasterio(ERA5_input_path, masked=True)
#Giving a CRS
data.rio.write_crs(crs_project, inplace=True)
data['time'] = data['time'].astype('datetime64[ns]')  # Change to datetime format to fix with te temp data ahead

data = data.assign_coords(time=pd.to_datetime(data.time.values) + pd.offsets.MonthEnd(1))   #Change the first day of the month for the last day of the month

##### Precipitation

In [None]:
#Extract and give shape to the precipitation data
def getPrecipData(precip_data):
    """Process the data to get precipitation as input to the SPEI function
    Args:
        data: netcdf file with precip data

        Returns
        DataArrayGroupBy grouped over point (y and x coordinates)
    """
    num_days_month = precip_data.time.dt.days_in_month #Necessary to multiply the daily values of the mean to the "size" of the month

    scale_factor =  precip_data.attrs['scale_factor']
    offset = precip_data.attrs['add_offset']
    precip = (precip_data * scale_factor) + offset #Rescaling the values
    precip = precip*1000*num_days_month  # The original units are meters, we change them to milimeters, and multiply by the days of the month
    
# Reverse the Y dimension values to increasing values (This is an issue of ERA5 datasets and other climatic datasets)
    precip = precip.rename({'y': 'lat', 'x':'lon'})       #Necessary step
    precip = precip.reindex(lat=list(reversed(precip['lat'])))
    precip = precip.rename({'lat': 'y', 'lon':'x'})

#Mask the country
    precip_masked = precip.rio.clip(mask_layer.geometry.apply(mapping), crs=mask_layer.crs, all_touched=True, from_disk=True).squeeze()

#Giving the appropriate shape to da data
    precip_grouped = precip_masked.stack(point=('y', 'x')).groupby('point')
    print("""Precipitation data is prepared to serve
        as input for the SPEI index.""")

    return precip_grouped


In [None]:
#Get precipitation data
precip_data = data['tp']
precips_mm = getPrecipData(precip_data)


##### Mean daylight hours (N) for the 15th of the month. -30º latitud kind of average
###### https://www.fao.org/3/x0490e/x0490e0j.htm#annex%202.%20meteorological%20tables


In [None]:
mdh = {
    "01": 13.7,
    "02": 13,
    "03": 12.2,
    "04": 11.3,
    "05": 10.5,
    "06": 10.1,
    "07": 10.2,
    "08": 10.9,
    "09": 11.8,
    "10": 12.7,
    "11": 13.5,
    "12": 13.9 
}

In [None]:
# Sumatory of the mean values of daylight
tota_hours_year = sum(mdh.values())
#Getting percentages for each month
mdh_perc = {key: (value / tota_hours_year) * 100 for key, value in mdh.items()}


In [None]:
mdh_perc

In [None]:
meses = Tmean['time'].dt.strftime('%m').values
meses

In [None]:
array_resultado = np.array([mdh_perc[mes] for mes in meses])
array_resultado

#### Calculating PET

In [None]:
# ET0 function BLANEY-CRIDDLE EQUATION
def get_pet_mm(Tmean, mdh_perc):
    month = Tmean['time'].dt.strftime('%m').values
    ro = xr.DataArray([mdh_perc[m] for m in month], dims=('time',))
    pet_mm = ro * ( ( 0.46 * Tmean ) + 8.13 )
    print(pet_mm)
    
    return pet_mm

In [None]:
#Calculation pet
pet_mm = get_pet_mm(Tmean, mdh_perc)

#Giving the appropriate shape to da data. Grouping
pet_mm_grouped = pet_mm.stack(point=('y', 'x')).groupby('point')

#### APPLY SPEI FUNCTION

In [None]:
#####https://github.com/monocongo/climate_indices
spei_values = xr.apply_ufunc(indices.spei,
                            precips_mm,
                            pet_mm_grouped, 
                            scale,
                            distribution,
                            periodicity,
                            data_start_year,
                            calibration_year_initial,
                            calibration_year_final
                            )                 

# Unstack the array back into original dimensions
spei_results = spei_values.unstack('point')         


In [None]:
# Give CRS and reprojecto to match the data source
spei_results = spei_results.rio.write_crs("EPSG:4326")
spei_results = spei_results.rio.reproject_match(data, resampling = Resampling.bilinear, nodata=np.nan)

#### EXPORTING

In [None]:
spei_results.to_netcdf(f'{path_out}{SPEI_ouput_file}')