#  MDTF Example Diagnostic POD Notebook for Multiple Cases / Experiments

================================================================================ <br>
This notebook does a simple diagnostic calculation to illustrate how to adapt code
for use in the MDTF-diagnostics framework. The main change is to set input/output
paths, variable names etc..., from shell environment variables the framework 
provides, instead of hard-coding them.

Below, this notebook consists of 3 parts: (1) a header template POD
developers must include in their POD's main driver script, (2) actual code, and 
(3) extensive in-line comments.<br>
================================================================================ 

This file is part of the Example Diagnostic POD of the MDTF code package (see mdtf/MDTF-diagnostics/LICENSE.txt)
 
## Example Diagnostic POD

   Last update: 8/23/2024
 
   This is an example POD that you can use as a template for your diagnostics.
   If this were a real POD, you'd place a one-paragraph synopsis of your 
   diagnostic here (like an abstract). 

   ### Version & Contact info
 
   Here you should describe who contributed to the diagnostic, and who should be
   contacted for further information:
 
   - Version/revision information: version 1 (5/06/2020)
   - PI (name, affiliation, email)
   - Developer/point of contact (name, affiliation, email)
   - Other contributors
 
   ### Open source copyright agreement
 
   The MDTF framework is distributed under the LGPLv3 license (see LICENSE.txt). 
   Unless you've distributed your script elsewhere, you don't need to change this.
 
   ### Functionality
 
   In this section you should summarize the stages of the calculations your 
   diagnostic performs, and how they translate to the individual source code files 
   provided in your submission. This help maintainers to fix bugs or 
   people with questions about how your code works know where to look.
 
   ### Required programming language and libraries
 
   In this section you should summarize the programming languages and third-party 
   libraries used by your diagnostic. You also provide this information in the 
   ``settings.jsonc`` file, but here you can give helpful comments to human 
   maintainers (eg, "We need at least version 1.5 of this library because we call
   this function.")
   
   * Python >= 3.12
   * xarray
   * matplotlib
   * intake
   * yaml
   * sys
   * os
   * numpy
 
   ### Required model output variables

   In this section you should describe each variable in the input data your 
   diagnostic uses. You also need to provide this in the ``settings.jsonc`` file, 
   but here you should go into detail on the assumptions your diagnostic makes 
   about the structure of the data.
   
   * tas - Surface (2-m) air temperature (CF: air_temperature)
 
   ### References
 
   Here you should cite the journal articles providing the scientific basis for 
   your diagnostic.
 
      Maloney, E. D, and Co-authors, 2019: Process-oriented evaluation of climate
         and wether forcasting models. BAMS, 100(9), 1665-1686,
         doi:10.1175/BAMS-D-18-0042.1.


In [1]:
# Import modules used in the POD
import os
import matplotlib

matplotlib.use('Agg')  # non-X windows backend

import matplotlib.pyplot as plt
import numpy as np
import intake
import sys
import yaml

# Part 1: Read in the model data

In [None]:
# Receive a dictionary of case information from the framework
print("reading case_info")
case_env_file = os.environ["case_env_file"]
assert os.path.isfile(case_env_file), f"case environment file not found"
with open(case_env_file, 'r') as stream:
    try:
        case_info = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

cat_def_file = case_info['CATALOG_FILE']
case_list = case_info['CASE_LIST']
# all cases share variable names and dimension coords in this example, so just get first result for each
tas_var = [case['tas_var'] for case in case_list.values()][0]
time_coord = [case['time_coord'] for case in case_list.values()][0]
lat_coord = [case['lat_coord'] for case in case_list.values()][0]
lon_coord = [case['lon_coord'] for case in case_list.values()][0]

## What is in the data catalog?

In [None]:
# open the csv file using information provided by the catalog definition file
cat = intake.open_esm_datastore(cat_def_file)
cat

In [None]:
cat.df

## Searching for TAS DAILY output for my POD

In [None]:
tas_subset = cat.search(variable_id=tas_var, frequency="day")
tas_subset 

In [None]:
tas_subset.df

In [None]:
# convert tas_subset catalog to an xarray dataset dict
tas_dict = tas_subset.to_dataset_dict(
    progressbar=False,
    aggregate=False,
    xarray_open_kwargs={"decode_times": True, "use_cftime": True}
)

# renaming keys in tas_dict to that found in case_list
tas_keys = list(tas_dict) 
case_keys = list(case_list)
for i in range(len(tas_keys)):
    tas_dict[case_keys[i]] = tas_dict.pop(tas_keys[i])

## Let us do some calculations

In [None]:
# Part 2: Do some calculations (time and zonal means)
# ---------------------------------------------------

tas_arrays = {}

# Loop over cases
for k, v in tas_dict.items():
    # load the tas data for case k
    print("case:",k)
    arr = tas_dict[k][tas_var]

    # take the time mean
    arr = arr.mean(dim=tas_dict[k][time_coord].name)

    # this block shuffles the data to make this single case look more
    # interesting.  ** DELETE THIS ** once we test with real data

    arr.load()
    values = arr.to_masked_array().flatten()
    np.random.shuffle(values)
    values = values.reshape(arr.shape)
    arr.values = values

    # convert to anomalies
    arr = arr - arr.mean()

    # take the zonal mean
    arr = arr.mean(dim=tas_dict[k][lon_coord].name)

    tas_arrays[k] = arr

### We are comparing the above cases


In [None]:
# Part 3: Make a plot that contains results from each case
# --------------------------------------------------------
print("Let's plot!")
print("--------------------------------------")

# set up the figure
fig = plt.figure(figsize=(12, 4))
ax = plt.subplot(1, 1, 1)

# loop over cases
for k, v in tas_arrays.items():
    v.plot(ax=ax, label=k)

# add legend
plt.legend()

# add title
plt.title("Zonal Mean Surface Air Temperature Anomaly")

In [None]:
%matplotlib inline
# save the plot in the right location
work_dir = os.environ["WORK_DIR"]
assert os.path.isdir(f"{work_dir}/model/PS"), f'Assertion error: {work_dir}/model/PS not found'

plt.savefig(f"{work_dir}/model/PS/example_multicase_plot.eps", bbox_inches="tight")
plt.show()

In [None]:
# Part 4: Close the catalog files and
# release variable dict reference for garbage collection
# ------------------------------------------------------
cat.close()
tas_dict = None

# Part 5: Confirm POD executed successfully
# ----------------------------------------
print("Last log message by example_multicase POD: finished successfully!")