# Making a diagnostic
This notebook aims to help you make a new diagnostic function.

In [20]:
from pathlib import Path
from datatree import DataTree
import xarray as xr
import valenspy as vp
import pandas as pd

demo_data = vp.demo_data_CF
demo_ds = xr.open_dataset(demo_data)

## Demo data
The demo data set is already in CF convention so no input conversion and preprocessing is needed.

In [4]:
vp.cf_checks.is_cf_compliant(demo_data)

True

In [5]:
demo_ds

## Let's make a new diagnostic

To make a new diagnostic, you need to define a 2 functions:
1. `diagnostic` function that takes a dataset and returns a diagnostic value.
2. `diagnostic_plot` function that takes output from the `diagnostic` function and returns a plot.

There are several types of diagnostics, depending on the type different input/outputs are expected. 

### Model2Ref

This compares 1 single model to a reference. Therefore, a Model2Ref diagnostic expects the following inputs:
- `data`: **xarray dataset** of the model data
- `ref`: **xarray dataset** of the reference data

the diagnostic function returns the results (this can be any type of data) and the plot function returns a plot.

#### Let's take a look at the example below

In [50]:

#Diagnostic functions to calculate the area average warming (tas) compared to a reference period and get the time of crossing a certain warming level
def warming_levels(ds, ref, levels=[1.5, 2.0], rol_years=21, freq=None):
    """
    Calculate the crossing times for different warming levels - the time when the area average warming crosses a certain level compared to the reference period

    Parameters
    ----------
    ds : xarray.Dataset
        Dataset containing the model data
    ref : xarray.Dataset
        Dataset containing the reference period
    levels : list
        List of warming levels to get crossing times for
    rol_years : int
        Number of years to use for the rolling mean
    freq : str, optional
        Frequency of the data (following pandas conventions), default is None and will be inferred from the data
    
    Returns
    -------
    crossingtimes : dict
        Dictionary containing the crossing times for each warming level
    temp_warming : xarray.DataArray
        DataArray containing the area average warming compared to the reference period
    """
    if not freq:
        freq = ds.time.to_index().inferred_freq
    ref_temp = ref.tas.mean()
    if not freq:
        #Throw an error
        return None
    else:
        freq = freq_to_times_per_year(freq)
        warming_ds = temp_warming(ds, ref, rol_amount=rol_years*freq)
        crossingtimes = {level : warming_ds.where(warming_ds>level, drop=True).idxmin('time').astype('datetime64[ns]').values for level in levels}
        return crossingtimes, temp_warming

#Small helper function (See diagnostic functions for other already implemented helper functions)
def freq_to_times_per_year(inferred_freq):
    freq_mapping = {'D': 365, 'B': 260, 'W': 52, 'M': 12, 'Q': 4, 'A': 1}
    return freq_mapping.get(inferred_freq, None)


def temp_warming(ds, ref, rol_amount):
    """
    Calculate the area average warming compared to the reference period

    Parameters
    ----------
    ds : xarray.Dataset
        Dataset containing the model data
    ref : xarray.Dataset
        Dataset containing the reference period
    rol_years : int
        Number of years to use for the rolling mean
    
    Returns
    -------
    temp_warming : xarray.DataArray
        DataArray containing the area average warming compared to the reference period
    """
    ref_temp = ref.tas.mean()
    rol = ds.tas.rolling(time=rol_amount).mean().mean(dim=['lat','lon'])
    temp_warming = rol - ref_temp

    return temp_warming

In [51]:
ct, temp_warming = warming_levels(demo_ds, demo_ds, levels=[1.5, 2.0], rol_years=1, freq='M')

ValueError: All-NaN slice encountered

In [None]:
demo_ds.tas.rolling(time=12).mean().mean(dim=['lat','lon'])

NameError: name 'rol_amount' is not defined

In [28]:
time_variable = ds['time']
window_size = pd.Timedelta(days=60)
ds.tas.rolling(time=12, center=True).mean()

### Write your own Model2Ref diagnostic and plotting function

In [None]:
#Note that ds and rf are expected to be xarray datasets, add aditional arguments to your function if needed
def your_diagnostic_function(ds: xr.Dataset, ref: xr.Dataset) -> DataTree: #Replace DataTree with your expected output type!
    pass #Replace this with your code

def your_diagnostic_plotting_function():
    pass #

Test the functions perform as expected with the demo data

In [None]:
your_diagnostic_function(demo_ds, demo_ds)

NameError: name 'your_diagnostic_function' is not defined

## Finally make the diagnostic

In [None]:
#your_diag = vp.Diagnostic(your_diagnostic_function, your_diagnostic_plotting_function, 'Your diagnostic name', 'Your diagnostic description')
example_diag = vp.Diagnostic(warming_levels, None, 'Warming levels', 'Calculate the crossing times for different warming levels')

Apply it

In [None]:
with ProgressBar():
    #result = your_diag.apply(demo_ds, demo_ds)
    result = example_diag.apply(demo_ds, demo_ds)

NameError: name 'your_diag' is not defined

Plot it

In [None]:
#your_diag.plot(result)
example_diag.plot(result)

## Congratulations! You have made a new diagnostic function!
You can now add it to the diagnostics_functions.py file so everyone can use it.

TODO: short guideline on how to add a new diagnostic function to the diagnostics_functions.py file.