# Title of this first notebook

<img src="path or URL to some visual here"
     width="30%"
     alt="MPAS-JEDI visual"
     align="right"
/>

### In this section, you'll learn:

* Utilizing ...

### Related Documentation

* [URL title](URL)
* 

### Prerequisites

| Concepts | Importance | Notes |
| --- | --- | --- |
| [URL title](URL) | Necessary OR Helpful?  | |

**Time to learn**: 30 minutes?

-----

## Import packages

In [None]:
%%time 

# autoload external python modules if they changed
%load_ext autoreload
%autoreload 2

# add ../funcs to the current path
import sys
import os

sys.path.append(os.path.join(os.getcwd(), ".."))

# import modules
import math

import cartopy.crs as ccrs
import geoviews.feature as gf
import holoviews as hv
import matplotlib.pyplot as plt

import s3fs

import numpy as np
import uxarray as ux

## Configure visualization tools

In [None]:
# hv.extension("bokeh")
# hv.extension("matplotlib")

# common border lines
coast_lines = gf.coastline(projection=ccrs.PlateCarree(), line_width=1, scale="50m")
state_lines = gf.states(
    projection=ccrs.PlateCarree(), line_width=1, line_color="gray", scale="50m"
)

## Convenience functions

In [None]:
# contour horizontal cross sections
def hslice_contour(
    ux_hslice,
    title,
    cmin=None,
    cmax=None,
    width=800,
    height=500,
    clevs=20,
    cmap="coolwarm",
    symmetric_cmap=False,
):
    # Get min and max
    amin = ux_hslice.min().item()
    amax = ux_hslice.max().item()
    title += f" min={amin:.1f} max={amax:.1f}"
    if cmin is None:
        cmin = math.floor(amin)
    if cmax is None:
        cmax = math.ceil(amax)
    if symmetric_cmap:  # to get a symmetric color map when cmin < 0, cmax >0
        cmax = max(abs(cmin), cmax)
        cmin = -cmax

    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    # generate contour plot
    contour_plot = hv.operation.contours(
        ux_hslice.plot(),
        levels=np.linspace(cmin, cmax, num=clevs),  # levels=np.arange(cmin, cmax, 0.5)
        filled=True,
    ).opts(
        line_color=None,  # line_width=0.001
        width=width,
        height=height,
        cmap="coolwarm",
        clim=(cmin, cmax),
        colorbar=True,  # cmap="inferno"
        show_legend=False,
        tools=["hover"],
        title=title,
    )

    return contour_plot

    # OERO: See the vertical cross section part below re how to use UXarray's new function


# # contour vertical cross sections along a given latitude or logitude
# def vslice_contour(uxvar, lon=None, lat=None, cmin=None, cmax=None, width=600, height=530, clevels=20):
#     if lon is None and lat is None:
#         print("Need to specify either a const lon or a const lat")
#         return

#     if lon is not None:  # along constant lon
#         ux_vslice = uxvar.isel(Time=0).cross_section.constant_longitude(lon)
#     elif lat is not None:  # along constant lat
#         ux_vslice = uxvar.isel(Time=0).cross_section.constant_latitude(lat)

#     # temporary uxarray bug fix and it will be updated in a new UXarray fix soon
#     del ux_vslice.uxgrid._ds['edge_node_connectivity']
#     del ux_vslice.uxgrid._ds['edge_lon']
#     del ux_vslice.uxgrid._ds['face_lon']

#     # sort lats or lons
#     if lon is not None:  # along constant lon
#         sort_indices = ux_vslice.uxgrid.face_lat.argsort()
#     elif lat is not None:  # along constant lat
#         sort_indices = ux_vslice.uxgrid.face_lon.argsort()
#     sorted_lons = ux_vslice.uxgrid.face_lon[sort_indices]
#     sorted_lats = ux_vslice.uxgrid.face_lat[sort_indices]

#     # remap faces
#     face_indices = []
#     for mylon, mylat in zip(sorted_lons, sorted_lats):
#         face_idx = ux_vslice.uxgrid.get_faces_containing_point(points=np.array([mylon.item(), mylat.item()]))
#         face_indices.append(face_idx)

#     face_indices = np.array(face_indices).squeeze()
#     if lon is not None:  # along constant lon
#         face_DataArray = xr.DataArray(data=np.array(face_indices), dims=['lat'])
#     elif lat is not None:  # along constant lat
#         face_DataArray = xr.DataArray(data=np.array(face_indices), dims=['lon'])

#     ux_vslice_selected = ux_vslice.isel(n_face=face_DataArray, ignore_grid=True)
#     # Get min and max
#     amin = ux_vslice_selected.min().item()
#     amax = ux_vslice.max().item()
#     if cmin is None:
#         cmin = math.floor(amin)
#     if cmax is None:
#         cmax = math.ceil(amax)

#     levels = np.linspace(cmin, cmax, num=clevels)
#     if lon is not None:  # along constant lon
#         title = f"constant_lon={lon} min={amin:.1f} max={amax:.1f}"
#     if lat is not None:  # along constant lat
#         title = f"constant_lat={lat} min={amin:.1f} max={amax:.1f}"

#     # ux_vslice_selected.to_xarray().transpose().plot.contourf()  # plot using matplotlib
#     # plt.title(f"lon = {lon}")  # add a title to matplotlib

#     # return the slice array with a lat dim
#     # ux_vslice_selected = ux_vslice_selected.assign_coords(lats=xr.DataArray(data=sorted_lats, dims=['lat']))
#     # return ux_vslice_selected.to_xarray().transpose() # return the slice array

#     return ux_vslice_selected.to_xarray().transpose().hvplot.contourf(levels=levels, width=width, height=height, title=title)  # aspect=1

## Load MPAS data
Depending on the network, the data loading process may take a few minutes.    

There are two ways to load MPAS data:
- 1. Download all example data from JetStream2 to local and them load them locally. This approach allows you to download the data once and reuse it in notebooks.
- 2. Access the JetStream2 S3 objects on demand. In this case, each notebook (incluidng restarting a notebook) will download the required data as needed, which may lead to repeated downloads.

In [None]:
data_load_method = 1  # or 2

### Download all example data to your local disk

In [None]:
%%time
# This cell only needs to run once in a machine and can be converted to a MarkDown cell before publishing the cookbook

if data_load_method == 1:
    jetstream_url = "https://js2.jetstream-cloud.org:8001/"
    fs = s3fs.S3FileSystem(
        anon=True, asynchronous=False, client_kwargs=dict(endpoint_url=jetstream_url)
    )
    conus12_path = "s3://pythia/mpas/conus12km"
    local_dir = "/tmp"
    fs.get(conus12_path, local_dir, recursive=True)

In [None]:
# path to the MPAS data
if data_load_method == 1:
    grid_file = local_dir + "/conus12km/conus12km.invariant.nc_L60_GFS"
    ana_file = local_dir + "/conus12km/bkg/mpasout.2024-05-06_01.00.00.nc"
    bkg_file = local_dir + "/conus12km/ana/mpasout.2024-05-06_01.00.00.nc"

### access JetStream2 and S3 objects on demand  

In [None]:
%%time
## **!! skip this section if data has been downloaded to local in the above !!**
if data_load_method == 2:
    jetstream_url = "https://js2.jetstream-cloud.org:8001/"
    fs = s3fs.S3FileSystem(
        anon=True, asynchronous=False, client_kwargs=dict(endpoint_url=jetstream_url)
    )
    conus12_path = "s3://pythia/mpas/conus12km"

    grid_url = f"{conus12_path}/conus12km.invariant.nc_L60_GFS"
    bkg_url = f"{conus12_path}/bkg/mpasout.2024-05-06_01.00.00.nc"
    ana_url = f"{conus12_path}/ana/mpasout.2024-05-06_01.00.00.nc"

    grid_file = fs.open(grid_url)
    ana_file = fs.open(ana_url)
    bkg_file = fs.open(bkg_url)
else:
    print("No action here, as example data has been downloaded to local in the above")

### Open UXarray datasets

In [None]:
%%time 
uxds_a = ux.open_dataset(grid_file, ana_file)
uxds_b = ux.open_dataset(grid_file, bkg_file)

#### (Experimental) UXarray adataset could be opened with chunks for better performance 
Applies the suggestion [here](https://uxarray.readthedocs.io/en/latest/user-guide/parallel-load-ux-with-dask.html#opening-a-single-data-file). Currently, it doesn not seem to be significantly better performance though. Uxarray team needs to look into this.

In [None]:
# %%time

# uxds_a_chunked = ux.open_dataset(grid_file, ana_file, chunks=-1)
# uxds_b_chunked = ux.open_dataset(grid_file, bkg_file, chunks=-1)

## Regional temperature contour

Let us use the `theta` (potential temperature) variable from this dataset, which has a `(Time, n_face, nVertLevels)`, i.e. `Time * Number of grid faces * Number of vertical levels` dimensionality to have a look at a regional, horizontal plot.

In [None]:
uxvar = uxds_a["theta"] - 273.15  ## Kelvin to Celsius

Also, for simplicity, let's focus at the vertical level and time indices of 0. 

In [None]:
i_lev = 0  # `nVertLevels` index
i_time = 0  # `Time` index

In [None]:
%%time

plot = hslice_contour(
    uxvar.isel(Time=0, nVertLevels=i_lev),
    title=f"Contour plot for potential temperature over a region: lev={i_lev}",
)  # , symmetric_cmap=True)
plot * coast_lines * state_lines

## Vertical cross section

In [None]:
# %%time

# tmp = vslice_contour(uxvar, lat=42.63, clevels=10)
# display(tmp)

Let us use UXarray's vertical cross-section function to get a cross-section over a great circle arc:

In [None]:
%%time

start_point = (-120, 20)
end_point = (-70, 50)

cross_section_gca = uxvar.cross_section(start=start_point, end=end_point, steps=100)

UXarray's cross-section returns an `xarray.DataArray` that can then be plotted:

In [None]:
cross_section_gca.isel(Time=0).transpose().plot.contourf()

## compute the analysis increments from the JEDI data assimilation
JEDI updates the background atmospheric state (`uxds_b`) with observation innovations and gets a new atmospheric state called analysis (`uxds_a`).  
The difference of `uxds_a` - `uxds_b` is called "analysis increments"

In [None]:
%%time 

var_name = "theta"
uxdiff0 = uxds_a[var_name] - uxds_b[var_name]
uxvar = uxdiff0

## plot temperature analysis increments at different levels

In [None]:
%%time

nt = 0  # time dimension
plot_levels = [0, 29, 42]  # [0, 29, 42]  # [0, 19, 29, 39, 49, 58]

# Create the colormap
# colors = ['blue', 'white', 'red']
# cmap = LinearSegmentedColormap.from_list('blue_white_red', colors)
# zero_shift = 0.02

cmap = plt.get_cmap("coolwarm")
zero_shift = 0.0

plots = []
for lev in plot_levels:
    #  use hslice_contour0() which uses 0 divide the cool and warm colors in the plot by default
    #
    # tmp = hslice_contour0(
    #     uxvar.isel(Time=nt, nVertLevels=lev),
    #     title=f'lev={lev}',
    #     cmap=cmap,
    #     zero_shift=zero_shift,
    #     clevs_multiplier=1,
    # )  # for the whole domain

    # hslice_contour() does not dvide the cool and warm colors at 0 by default
    # But it can be achieved by setting symmetric_cmap=True which will set symmetric cmax/cmin automatically,
    #       or mannualy setting symmetric cmax/cmin
    #
    tmp = hslice_contour(
        uxvar.isel(Time=nt, nVertLevels=lev),
        title=f"lev={lev}",
        symmetric_cmap=True,
        # clevs=20,
    )  # for the whole domain

    plots.append(tmp * coast_lines * state_lines)

# plots share one toolbar, which facilitates doing sync'ed zoom-in/out
# hv.Layout(plots).cols(1)

# each plot has its own toolbar, which facilitates controlling each plot individually
for p in plots:
    display(p)

## Zoomed into Colorado using the subset capability

In [None]:
%%time

lon_center = -105.03
lat_center = 39.0
lon_incr = 5  # degree
lat_incr = 3  # degree
lon_bounds = (lon_center - lon_incr, lon_center + lon_incr)
lat_bounds = (lat_center - lat_incr, lat_center + lat_incr)

### subset to a small domain
uxdiff1 = uxdiff0.subset.bounding_box(
    lon_bounds,
    lat_bounds,
)
uxvar = uxdiff1


nt = 0  # time dimension
plot_levels = [42]  # [0, 29, 42]  # [0, 19, 29, 39, 49, 58]

plots = []
for lev in plot_levels:
    tmp = hslice_contour(
        uxvar.isel(Time=nt, nVertLevels=lev), title=f"lev={lev}", width=700, height=500
    )  # for the subdomain

    # overlay state_lines
    # plots.append(tmp * coast_lines * state_lines)

    # overlay county lines, this takes longer time to render
    plots.append(
        tmp
        * coast_lines.opts(
            xlim=(lon_bounds[0], lon_bounds[1]), ylim=(lat_bounds[0], lat_bounds[1])
        )
    )

# plots share one toolbar, which facilitates doing sync'ed zoom-in/out
# hv.Layout(plots).cols(1)

# each plot has its own toolbar, which facilitates controlling each plot individually
for p in plots:
    display(p)

## vertical cross section of temperature increments

In [None]:
%%time

# tmp = vslice_contour(uxvar, lon=-85.77, clevels=10)
# display(tmp)
# tmp = vslice_contour(uxvar, lat=42.63, clevels=10)
# display(tmp)

## save plots to files

In [None]:
hv.save(tmp, "vslice.png")