
## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import hvplot.pandas
import hvplot.xarray
from utils import *
import pandas as pd
import geoviews as gv
import geoviews.feature as gf
from sklearn.metrics import accuracy_score, f1_score

## Set environment variables

In [None]:
os.environ['AA_DATA_DIR'] = "/home/daniele/Documents/CHD/Python_projects/pa-aa-toolbox-folder/"
os.environ['IRI_AUTH'] = '957b9ba29d14f52928d863d854278df8c749aaaca7f746d4127033ac4dfd5e8d6c3531433b2498daa03de77e925e7c09c55d0ef2'
tchad_config_file = "config/countries/tcd_adm0.yaml"

## Create country configuration

In [None]:
country_config, geobb = create_area(tchad_config_file)

## Create IRI terciles

In [None]:
ds_iri_terciles = iri_terciles_create(
    country_config=country_config, 
    geobb=geobb,
    only_dominant=False
)
ds_iri_tercile_dominant = calculate_dominant_tercile(ds_iri_terciles)

## Create CHIRPS terciles

In [None]:
ds_chirps_terciles = chirps_terciles_create(
    country_config=country_config, 
    geobb=geobb,
    adapt_coordinates=True,
    ds_ref=ds_iri_terciles
)
ds_chirps_tercile_dominant = calculate_dominant_tercile(ds_chirps_terciles)

## Create ECMWF terciles

In [None]:
ds_ecmwf_terciles = ecmwf_terciles_create(
    country_config=country_config, 
    geobb=geobb,
    adapt_coordinates=True,
    ds_ref=ds_iri_terciles,
)
ds_ecmwf_tercile_dominant = calculate_dominant_tercile(ds_ecmwf_terciles)

## Adapt coordinates for IRI

In [None]:
ds_iri_tercile_dominant = ds_iri_tercile_dominant.sel(latitude=ds_ecmwf_tercile_dominant.latitude)
ds_iri_tercile_dominant = ds_iri_tercile_dominant.sel(longitude=ds_ecmwf_tercile_dominant.longitude)

## Create datasets for visualisation terciles

In [None]:
ds_ecmwf_vis = xr.where(
    ds_ecmwf_tercile_dominant.dominant_terc=='upper',
    ds_ecmwf_tercile_dominant.terc_prob,
    -ds_ecmwf_tercile_dominant.terc_prob
)

ds_ecmwf_vis = xr.where(
    ds_ecmwf_tercile_dominant.dominant_terc=='middle',
    0,
    ds_ecmwf_vis
)

In [None]:
ds_iri_vis = xr.where(
    ds_iri_tercile_dominant.dominant_terc=='upper',
    ds_iri_tercile_dominant.terc_prob,
    -ds_iri_tercile_dominant.terc_prob
)

ds_iri_vis = xr.where(
    ds_iri_tercile_dominant.dominant_terc=='middle',
    0,
    ds_iri_vis
)

In [None]:
ds_chirps_vis = xr.where(
    ds_chirps_tercile_dominant.dominant_terc=='upper',
    ds_chirps_tercile_dominant.terc_prob,
    -ds_chirps_tercile_dominant.terc_prob
)

ds_chirps_vis = xr.where(
    ds_chirps_tercile_dominant.dominant_terc=='middle',
    0,
    ds_chirps_vis
)

## Visualize terciles

In [None]:
map_plot = ds_ecmwf_vis.hvplot(
    groupby='time',
    widget_type='scrubber', 
    x='longitude', 
    y='latitude',
    coastline=True,
    features=['borders'], 
    clim=(-1, 1),
    geo=True,
    widget_location='bottom',
    cmap='BrBG',
    width=600, height=500,
    label='ECMWF,'
)

# Show the plot
map_plot

In [None]:
map_plot = ds_ecmwf_vis.hvplot(
    groupby='time',
    widget_type='scrubber', 
    x='longitude', 
    y='latitude',
    coastline=True,
    features=['borders'], 
    clim=(-1, 1),
    geo=True,
    widget_location='bottom',
    cmap='BrBG',
    width=600, height=500,
    label='ECMWF,'
)

# Show the plot
map_plot

In [None]:
map_plot = ds_iri_vis.hvplot(
    groupby='time',
    widget_type='scrubber', 
    x='longitude', 
    y='latitude',
    coastline=True,
    features=['borders'], 
    clim=(-1, 1),
    geo=True,
    widget_location='bottom',
    cmap='BrBG',
    width=600, height=500,
    label='IRI,'
)

# Show the plot
map_plot

In [None]:
map_plot = ds_chirps_vis.hvplot(
    groupby='time',
    widget_type='scrubber', 
    x='longitude', 
    y='latitude',
    coastline=True,
    features=['borders'], 
    clim=(-1, 1),
    geo=True,
    widget_location='bottom',
    cmap='BrBG',
    width=600, height=500,
    label='CHIRPS,'
)

# Show the plot
map_plot

In [None]:
## Calculate metrics: accuracy and F1 score

In [None]:
# Number of points under which predicted and true dataset will not be compared and accuracy and F1 score 
# will be set to zero
tolerance = 4 

# Initiate score array
score_array = {}

# Loop over metrics
for metrics in ['accuracy', 'f1 score']:
    
    print(metrics)
    
    # Loop over forecast products
    for d, data in enumerate(['iri', 'ecmwf']):
        
        print(data)

        # Loop over relevant seasons
        for s, season in enumerate(['MJJ', 'JJA', 'JAS', 'ASO']):

            print(season)

            # Choose dataset
            if data=='iri':
                forecast_dataset = ds_iri_tercile_dominant.copy()
            else:
                forecast_dataset = ds_ecmwf_tercile_dominant.copy()

            # Restrict dataset to season
            forecast_dataset = forecast_dataset.sel(time=forecast_dataset['time'].str.contains(season))

            # Restrict chirps dataset to time points (season-year) present in the forecast dataset
            chirps_dataset = chirps_dataset.sel(
                time=ds_chirps_tercile_dominant['time'].isin(forecast_dataset['time'])
            )

            # Create masks for chirps and forecast (NaN values are replaced by -999)
            mask_chirps = np.asarray(chirps_dataset.terc_prob.fillna(-999))
            mask_forecast = forecast_dataset.terc_prob.fillna(-999)

            # Threshold for tercile probabilities
            threshold_num_list = [np.nan, np.nan]+[t for t in np.arange(0.3, 0.9, 0.1)]
            threshold_str_list = [
                'All terciles considered', 
                'Only below considered: no threshold'
            ]+[f'Only below considered: threshold {t:.1f}' for t in threshold_num_list[2:]]

            # Loop over thresholds
            for t, (threshold, threshold_label) in enumerate(zip(threshold_num_list, threshold_str_list)):

                # 'All terciles considered': here the metrics will be calculated based on
                # - success: same dominant tercile
                # - failure: different dominant tercile
                if t==0:
                    forecast = forecast_dataset.dominant_terc
                    chirps = chirps_dataset.dominant_terc
                # 'Only below considered: no threshold': here the metrics will be calculated based on
                # - success: both observation and forecast have dominant tercile 'lower'
                # - failure: both observation and forecast don't have dominant tercile 'lower' 
                elif t==1:
                    forecast = forecast_dataset.dominant_terc.where(
                        forecast_dataset.dominant_terc=='lower', 
                        'not_lower'
                    )
                    chirps = chirps_dataset.dominant_terc.where(
                        chirps_dataset.dominant_terc=='lower', 
                        'not_lower'
                    )
                # 'Only below considered: threshold ...': here the metrics will be calculated based on
                # - success: both observation and forecast have dominant tercile 'lower', and the terc.
                #            probability is higher than the threshold
                # - failure: both observation and forecast don't have dominant tercile 'lower', or they have
                #            dominant tercile 'lower' but with probability lower than the threshold
                else:
                    forecast = xr.where(
                        (forecast_dataset.terc_prob>threshold) & (forecast_dataset.dominant_terc=='lower'), 
                        1, 
                        0
                    )
                    chirps = xr.where(
                        chirps_dataset.dominant_terc=='lower', 
                        1, 
                        0
                    )

                # Reapply the mask previously calculated
                forecast = forecast.where((mask_chirps!=-999) & (mask_forecast!=-999))
                chirps = chirps.where((mask_chirps!=-999) & (mask_forecast!=-999))

                # Create score array
                score = np.empty((np.shape(chirps)[1:]))
                score[:] = np.nan

                # Loop over latitude and longitude
                for i in range(np.shape(chirps)[1]):
                    for j in range(np.shape(chirps)[2]):
                        
                        # Read values of lon and lat
                        latitude = forecast.latitude.values[i]
                        longitude = forecast.longitude.values[j]
                        
                        # Create time series for predicted (forecast) and true (observations)
                        y_pred = forecast.sel(latitude=latitude, longitude=longitude).values
                        y_true = chirps.sel(latitude=latitude, longitude=longitude).values
                        
                        # Pair time series, and exclude nan values (this is done to keep only non-nan values
                        # and return nan in case there are not many comparisons)
                        y_pair = [(x, y) for (x,y) in zip(list(y_true), list(y_pred)) \
                                  if not pd.isna(x) and not pd.isna(y)]
                        
                        # If the number of pairs is lower than the tolerance, assign nan to the metric
                        if len(y_pair)<tolerance:
                            score[i, j] = np.nan
                        else:
                            # Re-extract time series
                            y_true = [p[0] for p in y_pair]
                            y_pred = [p[1] for p in y_pair]
                            
                            # Calculate metrics
                            if metrics == 'accuracy':
                                score[i, j] = accuracy_score(y_true, y_pred)
                            else:
                                score[i, j] = f1_score(y_true, y_pred, average='weighted')

                # Initialise score dataArray
                score_thre = chirps.isel(time=0).drop('time').copy()
                
                # Assign score array to variable of the dataArray
                score_thre.data = score.copy()
                
                # Expand dimensions of the dataArray, to include other coordinates. This is done to be able
                # to create the final dataArray
                score_thre = score_thre.expand_dims(['season', 'threshold', 'data'])\
                                       .assign_coords({
                    'season': [season], 
                    'threshold': [threshold_label],
                    'data': [data],
                })

                # At the first iteration, create a new dataArray, later concatenate the dataArray with the
                # one created in this iteration
                if t==0:
                    score_season = score_thre.copy()
                else:
                    score_season = xr.concat([score_season, score_thre], dim='threshold')
            
            # At the first iteration, create a new dataArray, later concatenate the dataArray with the
            # one created in this iteration
            if s==0:
                score_data = score_season.copy()
            else:
                score_data = xr.concat([score_data, score_season], dim='season')
        
        # At the first iteration, create a new dataArray, later concatenate the dataArray with the
        # one created in this iteration
        if d==0:
            score_array[metrics] = score_data.copy()
        else:
            score_array[metrics] = xr.concat([score_array[metrics], score_data], dim='data')    

# Transform dict of dataArrays in dataset, with two variables
score_array_total = xr.Dataset(score_array)

## Plot dataset with metrics

In [None]:
# Plot 
for data in ['ECMWF', 'IRI']:
    for score in ['Accuracy', 'F1 score']:
        for season in score_array_total.season.values.tolist():
            dataset = gv.Dataset(
                data=score_array_total.sel(season=season, data=data.lower()).drop(['season', 'data']), 
                kdims=['longitude', 'latitude', 'threshold'],
                vdims=score.lower(),
                label=f'Score: {score}, Dataset: {data},\n'+f'Season: {season},',
                crs=ccrs.PlateCarree()
            )
            images = dataset.to(gv.Image)
            display(images.opts(cmap='viridis', colorbar=True, width=600, height=500) * gf.borders)

## TO-DO

https://docs.google.com/drawings/d/1CB504-oanD6T2KRVrC6-pwj8cHT_ygc300PKrawQYCw/edit

- Check code for accuracy and F1 score
- Create polygon based on aggregation (Sahel belt)
- Aggregate xarrays based on polygon
- Produce plot
- Add other lead times