# Visualization of JEDI analysis with UXarray in the model space

<img src="images/jedi-mpas.png"
     width="30%"
     alt="jedi-mpas"
     align="right"
/>

### In this section, you'll learn:

* Analyze the and visualize the performance of the data assimilation method used.

### Related Documentation

* [URL title](URL)
* 

### Prerequisites

| Concepts | Importance | Notes |
| --- | --- | --- |
| [URL title](URL) | Necessary OR Helpful?  | |

**Time to learn**: 30 minutes?

-----

## Import packages

In [None]:
%%time 

# autoload external python modules if they changed
%load_ext autoreload
%autoreload 2

# add ../funcs to the current path
import sys, os
sys.path.append(os.path.join(os.getcwd(), "..")) 

# import modules
import warnings
import math

import cartopy.crs as ccrs
import geoviews as gv
import geoviews.feature as gf
import holoviews as hv
import hvplot.xarray
from holoviews import opts
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import s3fs

import geopandas as gp
import numpy as np
import uxarray as ux
import xarray as xr

## Configure visualization tools

In [None]:
# hv.extension("bokeh")
# hv.extension("matplotlib")


# common border lines
coast_lines = gf.coastline(projection=ccrs.PlateCarree(), line_width=1, scale="50m")
state_lines = gf.states(projection=ccrs.PlateCarree(), line_width=1, line_color='gray', scale="50m")


## Load MPAS data
Depending on the network, the data loading process may take a few minutes.    

There are two ways to load MPAS data:
- 1. Download all example data from JetStream2 to local and them load them locally. This approach allows you to download the data once and reuse it in notebooks.
- 2. Access the JetStream2 S3 objects on demand. In this case, each notebook (incluidng restarting a notebook) will download the required data as needed, which may lead to repeated downloads.

In [None]:
data_load_method = 1  # or 2

### Download all example data to your local disk

In [None]:
%%time
# This cell only needs to run once in a machine and can be converted to a MarkDown cell before publishing the cookbook

if data_load_method == 1:
    jetstream_url = 'https://js2.jetstream-cloud.org:8001/'
    fs = s3fs.S3FileSystem(anon=True, asynchronous=False,client_kwargs=dict(endpoint_url=jetstream_url))
    conus12_path = 's3://pythia/mpas/conus12km'
    local_dir="/tmp"
    fs.get(conus12_path, local_dir, recursive=True)

In [None]:
# path to the MPAS data
if data_load_method == 1:
    grid_file = local_dir + "/conus12km/conus12km.invariant.nc_L60_GFS"
    ana_file = local_dir + "/conus12km/bkg/mpasout.2024-05-06_01.00.00.nc"
    bkg_file = local_dir + "/conus12km/ana/mpasout.2024-05-06_01.00.00.nc"

### Access JetStream2 and S3 objects on demand  

In [None]:
%%time
## **!! skip this section if data has been downloaded to local in the above !!**
if data_load_method == 2:
    jetstream_url = 'https://js2.jetstream-cloud.org:8001/'
    fs = s3fs.S3FileSystem(anon=True, asynchronous=False,client_kwargs=dict(endpoint_url=jetstream_url))
    conus12_path = 's3://pythia/mpas/conus12km'
    
    grid_url=f"{conus12_path}/conus12km.invariant.nc_L60_GFS"
    bkg_url=f"{conus12_path}/bkg/mpasout.2024-05-06_01.00.00.nc"
    ana_url=f"{conus12_path}/ana/mpasout.2024-05-06_01.00.00.nc"
    
    grid_file = fs.open(grid_url)
    ana_file = fs.open(ana_url)
    bkg_file = fs.open(bkg_url)
else:
    print("No action here, as example data has been downloaded to local in the above")

### Open UXarray datasets

In [None]:
%%time 
uxds_a = ux.open_dataset(grid_file, ana_file)
uxds_b = ux.open_dataset(grid_file, bkg_file)

In [None]:
uxds_a

In [None]:
i_lev = 0   # `nVertLevels` index
i_time = 0  # `Time` index

## plot temperature analysis increments at different levels

In [None]:
%%time 
# Select variable of interest
var_name = "theta"
uxdiff0 = uxds_a[var_name] - uxds_b[var_name]
uxvar = uxdiff0

In [None]:
# contour horizontal cross sections
def hslice_contour(ux_hslice, title, cmin=None, cmax=None, width=800, height=500, clevs=20, cmap="coolwarm", 
                   symmetric_cmap=False, colorbar=True):
    # Get min and max
    amin = ux_hslice.min().item()
    amax = ux_hslice.max().item()
    title += f" min={amin:.1f} max={amax:.1f}"
    if cmin is None:
        cmin = math.floor(amin)
    if cmax is None:
        cmax = math.ceil(amax)
    if symmetric_cmap:  # to get a symmetric color map when cmin < 0, cmax >0
        cmax = max(abs(cmin), cmax)
        cmin = -cmax

    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    # generate contour plot
    contour_plot = hv.operation.contours(
        ux_hslice.plot(),
        levels=np.linspace(cmin, cmax, num=clevs),  # levels=np.arange(cmin, cmax, 0.5)
        filled=True
    ).opts(
        line_color=None,  # line_width=0.001
        width=width, height=height,
        cmap=cmap,
        clim=(cmin, cmax),
        colorbar=colorbar,  # cmap="inferno"
        show_legend=False, tools=['hover'], title=title,
    )

    return contour_plot

In [None]:
from matplotlib.colors import ListedColormap, BoundaryNorm, to_rgba
edges = [-4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1.0, 1.5, 2, 2.5, 3, 3.5, 4]
colors = [
    "#313695",  # [-4,-3.5]
    "#3f72b4",  # [-3.5,-3]
    "#5a9bd5",  # [-3,-2.5]
    "#81bfe0",  # [-2.5,-2]
    "#a6d8e7",  # [-2,-1.5]
    "#cae6ef",  # [-1.5,-1]
    "#e4f1f5",  # [-1,-0.5]
    "#f2f9fc",  # [-0.5,-0.1]  ← slightly pale blue
    "#fcf2f2",  # [0.1,0.5]     ← slightly pale pink
    "#f9d6d4",  # [0.5,1.0]
    "#f5b5b1",  # [1.0,1.5]
    "#ee8a85",  # [1.5,2.0]
    "#e75e5a",  # [2.0,2.5]
    "#d73027",  # [2.5,3.0]
    "#a50026",  # [3.0,3.5]
    "#67001f",   # [3.5,4.0]
]
cmap = ListedColormap(colors)


In [None]:
%%time

nt=0  # time dimension
plot_levels = [0, 19, 29, 39, 42, 49, 58]  # [0, 29, 42]  # [0, 19, 29, 39, 49, 58]

zero_shift = 0.0

plots = []
for lev in plot_levels:
    dat = uxvar.isel(Time=nt, nVertLevels=lev)
    tmp = hslice_contour(
        dat.where((dat > 0.1) | (dat < -0.1)),
        title=f'lev={lev}',
        cmap=cmap,
        colorbar=True,
        cmax=4,
        cmin=-4
    ) 
    
    plots.append(tmp * coast_lines * state_lines)
from IPython.display import display, Markdown

display(Markdown(
    r"**Small increments** $\left[ -0.1 \ \text{to} \ 0.1 \right]$ **neglected**  <br>"
    r"Indicated in white spaces"
))
for p in plots:
   display(p)

## Zoomed into Colorado using the subset capability

In [None]:
uxvar = uxds_a['theta'] - 273.15   ## Kelvin to Celsius

In [None]:
%%time

lon_center = -105.03
lat_center = 39.0
lon_incr = 5 # degree
lat_incr = 3 # degree
lon_bounds = (lon_center - lon_incr, lon_center + lon_incr)
lat_bounds = (lat_center - lat_incr, lat_center + lat_incr)

### subset to a small domain
uxdiff1 = uxdiff0.subset.bounding_box(lon_bounds, lat_bounds,)
uxvar = uxdiff1


nt=0  # time dimension
plot_levels = [0, 29, 42]  # [0, 19, 29, 39, 49, 58]

plots = []
for lev in plot_levels:
    tmp = hslice_contour(uxvar.isel(Time=nt, nVertLevels=lev), title=f'lev={lev}', width=700, height=500)  # for the subdomain
    
    # overlay state_lines
    plots.append(tmp * coast_lines * state_lines)  
    
    # overlay county lines, this takes longer time to render
    # plots.append(tmp * coast_lines.opts(xlim=(lon_bounds[0], lon_bounds[1]), ylim=(lat_bounds[0], lat_bounds[1])))

# plots share one toolbar, which facilitates doing sync'ed zoom-in/out
# hv.Layout(plots).cols(1)

# each plot has its own toolbar, which facilitates controlling each plot individually
for p in plots:
   display(p)

In [None]:
# Random Great Circle Arc (GCA)

In [None]:
%%time

start_point = (-110, 20)
end_point = (-70, 50)
var_name = "theta"
uxdiff0 = uxds_a[var_name].isel(Time=0) - uxds_b[var_name].isel(Time=0)
uxvar = uxdiff0
cross_section_gca = uxvar.cross_section(start=start_point, end=end_point, steps=100)

In [None]:
hlabelticks = [
    f"{abs(lat):.1f}°{'N' if lat >= 0 else 'S'}\n{abs(lon):.1f}°{'E' if lon >= 0 else 'W'}"
    for lat, lon in zip(cross_section_gca['lat'], cross_section_gca['lon'])
]

In [None]:
# cross_section_gca.isel(Time=0).transpose().plot.contourf()
%matplotlib inline


fig= plt.figure(figsize=(8,3))
gs= fig.add_gridspec(1,1)
ax = fig.add_subplot(gs[0,0])
cf=ax.contourf(cross_section_gca.transpose(),cmap='Reds',extend='both')
tick_stride = 10
ax.set_xticks(cross_section_gca['steps'][::tick_stride])
ax.set_xticklabels(hlabelticks[::tick_stride])
plt.savefig('cross_section_increments.png')

In [None]:
lon=-83.3
cross_section_lon = uxvar.cross_section(lon=lon, steps=100)

hlabelticks = [
    f"{abs(lat):.1f}°{'N' if lat >= 0 else 'S'}" for lat in cross_section_lon['lat']
]

%matplotlib inline
fig= plt.figure(figsize=(8,3))
gs= fig.add_gridspec(1,1)
ax = fig.add_subplot(gs[0,0])
cf=ax.contourf(cross_section_lon.transpose(),cmap='Reds',extend='both')

ax.set_xticks(cross_section_lon['steps'][::tick_stride])
ax.set_xticklabels(hlabelticks[::tick_stride])
plt.savefig("cross_section_increments2.png")

In [None]:
cross_section_lon = uxvar.cross_section(lon=-60., steps=100)
cross_section_lon

## Save plots to files

In [None]:
hv.save(tmp, 'vslice.png')