In [2]:
import importlib

cf = importlib.import_module("calc_funcs_v1i")
cf.calc_funcs_ver

'cfv1i'

In [None]:
import calc_funcs_v1i as cf
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmocean
import logging
import copy

In [None]:
plt.rcParams['text.usetex'] = True
da_dims_valid = ("latitude", "longitude")
da_names_cyclic_valid = ["hour_max", "hour_min"]
figwidth_default = 20

In [None]:
def create_pcolormesh(da, extents=None, ax=None):
    
    assert ((str(type(da)) == "<class 'xarray.core.dataarray.DataArray'>") & 
            (da.dims == da_dims_valid)), \
        f"da must be an xarray.DataArray with da.dims == {da_dims_valid}"
    
    if extents:
        assert (isinstance(extents, list) & (len(extents) == 4) & 
                (extents[1] > extents[0]) & (extents[3] > extents[2])), \
            "extents must a 4 element list in [W, E, S, N] format or None"
    else:
        extents = []
        extents.append(da.longitude.min())
        extents.append(da.longitude.max())
        extents.append(da.latitude.min())
        extents.append(da.latitude.max())
    
    if ax:
        assert str(type(ax)) == "<class 'cartopy.mpl.geoaxes.GeoAxesSubplot'>", \
            "ax must be a cartopy.GeoAxesSubplot or None"
    else:
        ax_input = None
        figwidth = figwidth_default / 2
        figheight = figwidth * (extents[3]-extents[2]) / (extents[1]-extents[0])
        fig, ax = plt.subplots(1, 1, figsize=(figwidth, figheight), 
                               subplot_kw = {"projection": ccrs.PlateCarree()}
                              )
    
    if da.attrs["full_name"].split(" ")[0] == "Difference":
        vmin = None
        vmax = None
        levels = None
        if da.name in da_names_cyclic_valid:
            cmap = "twilight_shifted"
            levels = np.arange(-12, 13)
        else:
            cmap = cmocean.cm.balance
    else:
        vmin = None
        vmax = None
        levels = None
        if da.name in da_names_cyclic_valid:
            cmap = cmocean.cm.phase
            levels = np.arange(0, 25)
        else:
            cmap = "viridis"
            vmin = (da
                    .sel(longitude=slice(extents[0], extents[1]), 
                         latitude=slice(extents[3], extents[2]))
                    .min()
                   )
            vmax = (da
                    .sel(longitude=slice(extents[0], extents[1]), 
                         latitude=slice(extents[3], extents[2]))
                    .max()
                   )
        
    ax.set_extent(extents=extents, crs=ccrs.PlateCarree())
    da.plot.pcolormesh(ax = ax, cmap = cmap, transform = ccrs.PlateCarree(),
                       vmin = vmin, vmax = vmax, levels = levels,
                       cbar_kwargs = {"label": "{} [{}]"
                                      .format(da.attrs["abbreviation"], 
                                              da.attrs["units"])
                                     }
                      )
    ax.set_title(da.attrs["full_name"])
    ax.add_feature(cfeature.COASTLINE)
    ax.gridlines(draw_labels=True, x_inline=False, y_inline=False)
    
    if ax_input == None:
        fig.tight_layout()
        plt.show()

In [None]:
def create_quiver(da_u, da_v, extents=None, ax=None):
    
    assert ((str(type(da_u)) == "<class 'xarray.core.dataarray.DataArray'>") & 
            (da_u.dims == da_dims_valid)), \
        f"da_u must be an xarray.DataArray with da_u.dims == {da_dims_valid}"
    
    assert ((str(type(da_v)) == "<class 'xarray.core.dataarray.DataArray'>") & 
            (da_v.dims == da_dims_valid)), \
        f"da_v must be an xarray.DataArray with da_v.dims == {da_dims_valid}"
    
    attrs_u = copy.deepcopy(da_u.attrs)
    attrs_u["abbreviation"] = attrs_u["abbreviation"].replace("_u", "")
    attrs_u["full_name"] = (attrs_u["full_name"]
                            .replace("Zonal Component of ", ""))
    
    attrs_v = copy.deepcopy(da_v.attrs)
    attrs_v["full_name"] = (attrs_v["full_name"]
                            .replace("Meridional Component of ", ""))
    
    assert attrs_u["full_name"] == attrs_v["full_name"], \
        ("da_u and da_v must be the zonal and meridional components " +
         "of the same variable")
    
    vector_test = (attrs_u["abbreviation"]
                   .split("(")[-1]
                   .split(")")[0]
                   .split("^")[0]
                   .split("$")[0]
                   .replace("U", "WV")
                   .lower())
    
    assert vector_test in cf.vectors, \
        "da_u and da_v must be the components of a vector parameter"
    
    if extents:
        assert (isinstance(extents, list) & (len(extents) == 4) & 
                (extents[1] > extents[0]) & (extents[3] > extents[2])), \
            "extents must a 4 element list in [W, E, S, N] format or None"
    else:
        extents = []
        extents.append(da_u.longitude.min())
        extents.append(da_u.longitude.max())
        extents.append(da_u.latitude.min())
        extents.append(da_u.latitude.max())
    
    if ax:
        assert str(type(ax)) == "<class 'cartopy.mpl.geoaxes.GeoAxesSubplot'>", \
            "ax must be a cartopy.GeoAxesSubplot or None"
    else:
        ax_input = None
        figwidth = figwidth_default / 2
        figheight = figwidth * (extents[3]-extents[2]) / (extents[1]-extents[0])
        fig, ax = plt.subplots(1, 1, figsize=(figwidth, figheight), 
                               subplot_kw = {"projection": ccrs.PlateCarree()}
                              )
    
    if attrs_u["full_name"].split(" ")[0] == "Difference":
        cmap = cmocean.cm.tempo
    else:
        cmap = cmocean.cm.speed
        
    da_mag = xr.DataArray(cf.get_magnitude(da_u, da_v), name = "mag")
    da_u_unit = xr.DataArray(da_u / da_mag, name = "u_unit")
    da_v_unit = xr.DataArray(da_v / da_mag, name = "v_unit")
    ds = xr.merge([da_mag, da_u_unit, da_v_unit])
    ax.set_extent(extents=extents, crs=ccrs.PlateCarree())
    ds.plot.quiver(x = "longitude", y = "latitude", ax = ax, 
                   u = "u_unit", v = "v_unit", 
                   hue = "mag", cmap = cmap, transform = ccrs.PlateCarree(),
                   cbar_kwargs={"label": "{} [{}]"
                                .format(attrs_u["abbreviation"], 
                                        attrs_u["units"])
                               }
                  )
    ax.set_title(attrs_u["full_name"])
    ax.add_feature(cfeature.COASTLINE)
    ax.gridlines(draw_labels=True, x_inline=False, y_inline=False)
    
    if ax_input == None:
        fig.tight_layout()
        plt.show()

In [None]:
def create_individual_calc_plot(calc_func, region, period_start, period_end, months_subset, param, hour, ax):
    # set extents using region
    # set vmin an vmax here?
    # create output=True/False argument?
    # or rather, output if it doesn't exist, otherwise just continue with code (don't terminate)?

In [None]:
def create_individual_comp_plot(
    calc_func, region, period1_start, period1_end, period2_start, period2_end, 
    months_subset, param, hour, output, vmin, vmax, ax_period1, ax_period2, ax_diff
):

In [None]:
da_land_elev = xr.open_dataset(f"../data_processed/era5_land_elev/{cf.calc_funcs_ver}_calc_global_static_land-elev.nc")["lse"]
da_lai_mean = xr.open_dataset(f"../data_processed/glass_mean_clim/{cf.calc_funcs_ver}_calc_wa_Jan-1992_Dec-1996_all_glass-mean_avhrr.nc")["mlai"]
da_lai_mean_diff = xr.open_dataset(f"../data_processed/glass_mean_clim/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_glass-mean_avhrr.nc")["mlai"]
da_mslp_mean = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_calc_wa_Jan-1992_Dec-1996_all_era5-mdp_mslp_stats.nc")["mean"]
da_mslp_mean_diff = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-mdp_mslp_stats.nc")["mean"]
da_mslp_hour_max = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_calc_wa_Jan-1992_Dec-1996_all_era5-mdp_mslp_stats.nc")["hour_max"]
da_mslp_hour_max_diff = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-mdp_mslp_stats.nc")["hour_max"]
da_wv100_mean_u = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_calc_wa_Jan-1992_Dec-1996_all_era5-mdp_wv100_stats.nc")["mean_u"]
da_wv100_mean_v = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_calc_wa_Jan-1992_Dec-1996_all_era5-mdp_wv100_stats.nc")["mean_v"]
da_wv100_mean_u_diff = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-mdp_wv100_stats.nc")["mean_u"]
da_wv100_mean_v_diff = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-mdp_wv100_stats.nc")["mean_v"]

In [None]:
create_pcolormesh(da_land_elev, cf.regions["wa"]["extent"])

In [None]:
create_pcolormesh(da_lai_mean)

In [None]:
create_pcolormesh(da_lai_mean_diff)

In [None]:
create_pcolormesh(da_mslp_mean)

In [None]:
create_pcolormesh(da_mslp_mean_diff)

In [None]:
create_pcolormesh(da_mslp_hour_max)

In [None]:
create_pcolormesh(da_mslp_hour_max_diff)

In [None]:
create_quiver(da_wv100_mean_u, da_wv100_mean_v)

In [None]:
create_quiver(da_wv100_mean_u_diff, da_wv100_mean_v_diff)

In [None]:
# from dask.distributed import Client
# client = Client()
# client

In [None]:
# client.close()

In [None]:
import ipywidgets as ipw
import hvplot.xarray # noqa
import hvplot.pandas # noqa
import panel as pn
import pandas as pd
import panel.widgets as pnw
import xarray as xr

In [None]:
%%time
cf.calc_glass_mean_clim("wa", "Jun-2000", "Aug-2005", [8, 6,7])

In [None]:
%%time
cf.calc_era5_mdp_clim_given_var_or_dvar("sa", "Dec-1994", "Feb-2000", "djf", "wv10")

In [None]:
%%time
cf.calc_era5_mdp_clim_stats_given_var_or_dvar("sa", "Dec-1994", "Feb-2000", "jja", "dwv100")

In [None]:
%%time
cf.calc_era5_wsd_clim("wa", "Jun-2000", "Aug-2005", [1,3,5,7,9,11])

In [None]:
%%time
cf.calc_diff(cf.calc_era5_mdp_clim_given_var_or_dvar, "ca", "Jan-1985", "Dec-1994", "Jan-1995", "Jan-2004", "all", "dnac")

In [None]:
test1 = xr.open_dataset(f"../data_processed/glass_mean_clim/{cf.calc_funcs_ver}_calc_wa_Jun-2000_Aug-2005_6-7-8_glass-mean_avhrr.nc")
test1

In [None]:
test1["mlai"].plot()

In [None]:
test1["mfapar"].plot()

In [None]:
test2 = xr.open_dataset(f"../data_processed/era5_mdp_clim_given_var_or_dvar/{cf.calc_funcs_ver}_calc_sa_Dec-1994_Feb-2000_djf_era5-mdp_wv10.nc")
test2

In [None]:
test2["u10"].sel(hour=21).plot()

In [None]:
test2["v10"].sel(hour=21).plot()

In [None]:
test3 = xr.open_dataset(f"../data_processed/era5_mdp_clim_given_var_or_dvar/{cf.calc_funcs_ver}_diff_ca_Jan-1985_Dec-1994_Jan-1995_Jan-2004_all_era5-mdp_dnac.nc")
test3

In [None]:
test3["dnse"].sel(hour=21).plot()

In [None]:
test3["dvidmf"].sel(hour=21).plot()

In [None]:
test3["dvidcfwf"].sel(hour=21).plot()

In [None]:
test3["dvidclwf"].sel(hour=21).plot()

In [None]:
test3["dtcwv"].sel(hour=21).plot()

In [None]:
test3["dnac"].sel(hour=21).plot()

In [None]:
test4 = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_calc_sa_Dec-1994_Feb-2000_jja_era5-mdp_dwv100_stats.nc")
test4

In [None]:
test4["hour_max"].plot()

In [None]:
test4["hour_min"].plot()

In [None]:
test4["max_u"].plot()

In [None]:
test4["max_v"].plot()

In [None]:
test4["min_u"].plot()

In [None]:
test4["min_v"].plot()

In [None]:
test4["mean_u"].plot()

In [None]:
test4["mean_v"].plot()

In [None]:
test4["range"].plot()

In [None]:
# test4["max"].plot()

In [None]:
# test4["min"].plot()

In [None]:
# test4["mean"].plot()

In [None]:
test5 = xr.open_dataset(f"../data_processed/era5_wsd_clim/{cf.calc_funcs_ver}_calc_wa_Jun-2000_Aug-2005_1-3-5-7-9-11_era5-wsd.nc")
test5

In [None]:
test5["ws100_mean"].plot()

In [None]:
test5["ws100_std"].plot()

In [None]:
test5["c100"].plot()

In [None]:
test5["k100"].plot()

In [None]:
test5["eroe100"].plot()

In [None]:
test5["tgcf100"].plot()

In [None]:
test5["eroe100"].where(test5["eroe100"]==test5["eroe100"].max(), drop = True).squeeze()

In [None]:
%%time
cf.calc_diff(cf.calc_glass_mean_clim, "wa", "Jan-1985", "Dec-1990", "Jan-2005", "Dec-2010", "all")

In [None]:
%%time
cf.calc_diff(cf.calc_era5_mdp_clim_given_var_or_dvar, "wa", "Jan-1985", "Dec-1990", "Jan-2005", "Dec-2010", "all", "nac")

In [None]:
%%time
cf.calc_diff(cf.calc_era5_mdp_clim_stats_given_var_or_dvar, "wa", "Jan-1985", "Dec-1990", "Jan-2005", "Dec-2010", "all", "wv10")

In [None]:
%%time
cf.calc_diff(cf.calc_era5_wsd_clim, "wa", "Jan-1985", "Dec-1990", "Jan-2005", "Dec-2010", "all")

In [None]:
test6 = xr.open_dataset(f"../data_processed/glass_mean_clim/{cf.calc_funcs_ver}_diff_wa_Jan-1985_Dec-1990_Jan-2005_Dec-2010_all_glass-mean_avhrr.nc")
test6

In [None]:
test6["mlai"].plot()

In [None]:
test6["mfapar"].plot()

In [None]:
test7 = xr.open_dataset(f"../data_processed/era5_mdp_clim_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1985_Dec-1990_Jan-2005_Dec-2010_all_era5-mdp_nac.nc")
test7

In [None]:
test7["nse"].sel(hour=21).plot()

In [None]:
test7["vidmf"].sel(hour=21).plot()

In [None]:
test7["vidcfwf"].sel(hour=21).plot()

In [None]:
test7["vidclwf"].sel(hour=21).plot()

In [None]:
test7["tcwv"].sel(hour=21).plot()

In [None]:
test7["nac"].sel(hour=21).plot()

In [None]:
test8 = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1985_Dec-1990_Jan-2005_Dec-2010_all_era5-mdp_wv10_stats.nc")
test8

In [None]:
test8["hour_max"].plot()

In [None]:
test8["hour_min"].plot()

In [None]:
test8["max_u"].plot()

In [None]:
test8["max_v"].plot()

In [None]:
test8["min_u"].plot()

In [None]:
test8["min_v"].plot()

In [None]:
test8["mean_u"].plot()

In [None]:
test8["mean_v"].plot()

In [None]:
test8["range"].plot()

In [None]:
test9 = xr.open_dataset(f"../data_processed/era5_wsd_clim/{cf.calc_funcs_ver}_diff_wa_Jan-1985_Dec-1990_Jan-2005_Dec-2010_all_era5-wsd.nc")
test9

In [None]:
test9["ws100_mean"].plot()

In [None]:
test9["ws100_std"].plot()

In [None]:
test9["c100"].plot()

In [None]:
test9["k100"].plot()

In [None]:
test9["eroe100"].plot()

In [None]:
test9["tgcf100"].plot()

In [None]:
test9["eroe100"].where(test9["eroe100"]==test9["eroe100"].max(), drop = True).squeeze()

In [None]:
%%time
cf.calc_era5_land_elev()

In [None]:
test10 = xr.open_dataset(f"../data_processed/era5_land_elev/{cf.calc_funcs_ver}_calc_global_static_land-elev.nc")
test10

In [None]:
test10["lse"].plot()

In [None]:
%%time
cf.calc_glass_rolling_avg_of_annual_diff("wa", 1984, 1987, 7)

In [None]:
test11 = xr.open_dataset(f"../data_processed/glass_rolling_avg_of_annual_diff/{cf.calc_funcs_ver}_calc_wa_1984_1987_7-year_glass-rolling-diff_pref-avhrr.nc")
test11

In [None]:
test11["mlai"].isel(year=0).plot()

In [None]:
test11["mfapar"].isel(year=0).plot()

In [None]:
%%time
cf.calc_glass_rolling_avg_of_annual_diff("wa", 1983, 2019, 5)

In [None]:
%%time
cf.calc_glass_rolling_avg_of_annual_diff("wa", 1984, 2018, 7)

In [None]:
test12 = xr.open_dataset(f"../data_processed/glass_rolling_avg_of_annual_diff/{cf.calc_funcs_ver}_calc_wa_1983_2019_5-year_glass-rolling-diff_pref-avhrr.nc")
test12

In [None]:
test12["mlai"].interactive.sel(year=pnw.DiscreteSlider).plot(cmap = "RdBu", vmin = -1, vmax = 1)

In [None]:
test13 = xr.open_dataset(f"../data_processed/glass_rolling_avg_of_annual_diff/{cf.calc_funcs_ver}_calc_wa_1984_2018_7-year_glass-rolling-diff_pref-avhrr.nc")
test13

In [None]:
test13["mlai"].interactive.sel(year=pnw.DiscreteSlider).plot(cmap = "RdBu", vmin = -1, vmax = 1)

In [None]:
%%time
cf.create_all_possible_calc_data_files("wa", "Jan-1992", "Dec-1996", "all")

In [None]:
%%time
cf.create_all_possible_calc_data_files("wa", "Jan-2002", "Dec-2006", "all")

In [None]:
%%time
cf.create_all_possible_diff_data_files("wa", "Jan-1992", "Dec-1996", "Jan-2002", "Dec-2006", "all")

In [None]:
test14 = xr.open_dataset(f"../data_processed/glass_mean_clim/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_glass-mean_avhrr.nc")
test14["mlai"].plot()

In [None]:
test15 = xr.open_dataset(f"../data_processed/era5_wsd_clim/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-wsd.nc")
test15["ws10_mean"].plot()

In [None]:
%%time
cf.calc_diff(cf.calc_era5_mdp_clim_stats_given_var_or_dvar, "wa", "Jan-1992", "Dec-1996", "Jan-2002", "Dec-2006", "all", "wv100")

In [None]:
%%time
cf.calc_diff(cf.calc_era5_mdp_clim_stats_given_var_or_dvar, "wa", "Jan-1992", "Dec-1996", "Jan-2002", "Dec-2006", "all", "ws100")

In [None]:
test16 = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-mdp_wv100_stats.nc")
test16["range"].plot()

In [None]:
test17 = xr.open_dataset(f"../data_processed/era5_mdp_clim_stats_given_var_or_dvar/{cf.calc_funcs_ver}_diff_wa_Jan-1992_Dec-1996_Jan-2002_Dec-2006_all_era5-mdp_ws100_stats.nc")
test17["range"].plot()

In [None]:
test16

In [None]:
test17

In [None]:
def create_pcolormesh(da, extents=None, ax=None):
    
    assert ((str(type(da)) == "<class 'xarray.core.dataarray.DataArray'>") & 
            (da.dims == da_dims_valid)), \
        f"da must be an xarray.DataArray with da.dims == {da_dims_valid}"
    assert da.name not in da_names_cyclic_valid, \
        f"da.name must not be one of: {da_names_cyclic_valid}"
    
    if extents:
        assert (isinstance(extents, list) & (len(extents) == 4) & 
                (extents[1] > extents[0]) & (extents[3] > extents[2])), \
            "extents must a 4 element list in [W, E, S, N] format or None"
    else:
        extents = []
        extents.append(da.longitude.min())
        extents.append(da.longitude.max())
        extents.append(da.latitude.min())
        extents.append(da.latitude.max())
    
    if ax:
        assert str(type(ax)) == "<class 'cartopy.mpl.geoaxes.GeoAxesSubplot'>", \
            "ax must be a cartopy.GeoAxesSubplot or None"
    else:
        ax_input = None
        figwidth = figwidth_default / 2
        figheight = figwidth * (extents[3]-extents[2]) / (extents[1]-extents[0])
        fig, ax = plt.subplots(1, 1, figsize=(figwidth, figheight), 
                               subplot_kw = {"projection": ccrs.PlateCarree()}
                              )
    
    if da.attrs["full_name"].split(" ")[0] == "Difference":
        cmap = cmocean.cm.balance
        vmin = None
        vmax = None
    else:
        cmap = "viridis"
        vmin = (da
                .sel(longitude=slice(extents[0], extents[1]), 
                     latitude=slice(extents[3], extents[2]))
                .min()
               )
        vmax = (da
                .sel(longitude=slice(extents[0], extents[1]), 
                     latitude=slice(extents[3], extents[2]))
                .max()
               )
        
    ax.set_extent(extents=extents, crs=ccrs.PlateCarree())
    da.plot.pcolormesh(ax = ax, cmap = cmap, transform = ccrs.PlateCarree(),
                       vmin = vmin, vmax = vmax, 
                       cbar_kwargs = {"label": "{} [{}]"
                                      .format(da.attrs["abbreviation"], 
                                              da.attrs["units"])
                                     }
                      )
    ax.set_title(da.attrs["full_name"])
    ax.add_feature(cfeature.COASTLINE)
    ax.gridlines(draw_labels=True, x_inline=False, y_inline=False)
    
    if ax_input == None:
        fig.tight_layout()
        plt.show()

In [None]:
def create_pcolormesh_cyclic(da, extents=None, ax=None):
    
    assert ((str(type(da)) == "<class 'xarray.core.dataarray.DataArray'>") & 
            (da.dims == da_dims_valid)), \
        f"da must be an xarray.DataArray with da.dims == {da_dims_valid}"
    assert da.name in da_names_cyclic_valid, \
        f"da.name must be one of: {da_names_cyclic_valid}"
    
    if extents:
        assert (isinstance(extents, list) & (len(extents) == 4) & 
                (extents[1] > extents[0]) & (extents[3] > extents[2])), \
            "extents must a 4 element list in [W, E, S, N] format or None"
    else:
        extents = []
        extents.append(da.longitude.min())
        extents.append(da.longitude.max())
        extents.append(da.latitude.min())
        extents.append(da.latitude.max())
    
    if ax:
        assert str(type(ax)) == "<class 'cartopy.mpl.geoaxes.GeoAxesSubplot'>", \
            "ax must be a cartopy.GeoAxesSubplot or None"
    else:
        ax_input = None
        figwidth = figwidth_default / 2
        figheight = figwidth * (extents[3]-extents[2]) / (extents[1]-extents[0])
        fig, ax = plt.subplots(1, 1, figsize=(figwidth, figheight), 
                               subplot_kw = {"projection": ccrs.PlateCarree()}
                              )
    
    if da.attrs["full_name"].split(" ")[0] == "Difference":
        cmap = "twilight_shifted"
        levels = np.arange(-12, 13)
    else:
        cmap = cmocean.cm.phase
        levels = np.arange(0, 25)
        
    ax.set_extent(extents=extents, crs=ccrs.PlateCarree())
    da.plot.pcolormesh(ax = ax, cmap = cmap, transform = ccrs.PlateCarree(),
                       levels = levels,
                       cbar_kwargs = {"label": "{} [{}]"
                                      .format(da.attrs["abbreviation"], 
                                              da.attrs["units"])
                                     }
                      )
    ax.set_title(da.attrs["full_name"])
    ax.add_feature(cfeature.COASTLINE)
    ax.gridlines(draw_labels=True, x_inline=False, y_inline=False)
    
    if ax_input == None:
        fig.tight_layout()
        plt.show()

In [None]:
def create_quiver(da_u, da_v, extents):
    plt.rcParams['text.usetex'] = True
    fig, ax = plt.subplots(1, 1, figsize=(10, 5), 
                           subplot_kw = {"projection": ccrs.PlateCarree()}
                          )
    add_quiver(da_u=da_u, da_v=da_v, ax=ax, extents=cf.regions["wa"]["extent"])
    fig.tight_layout()
    plt.show()

In [None]:
def add_quiver(da_u, da_v, ax, extents):
    # var_or_dvar = (da_u
    #                .attrs["abbreviation"]
    #                .split("(")[-1]
    #                .split(")")[0]
    #                .split("^")[0]
    #                .replace("U", "wv")
    #               )
    # attrs = cf.da_attrs[var_or_dvar] # This won't work for dvars
    attrs = copy.deepcopy(da_u.attrs)
    attrs["abbreviation"] = attrs["abbreviation"].replace("_u", "")
    attrs["full_name"] = attrs["full_name"].replace("Zonal Component of ", "")
    if attrs["full_name"].split(" ")[0] == "Difference":
        cmap = cmocean.cm.tempo
    else:
        cmap = cmocean.cm.speed
    da_mag = xr.DataArray(cf.get_magnitude(da_u, da_v), name = "mag")
    da_u_unit = xr.DataArray(da_u / da_mag, name = "u_unit")
    da_v_unit = xr.DataArray(da_v / da_mag, name = "v_unit")
    ds = xr.merge([da_mag, da_u_unit, da_v_unit])
    ds.plot.quiver(x = "longitude", y = "latitude", ax = ax, 
                   u = "u_unit", v = "v_unit", 
                   hue = "mag", cmap = cmap, transform = ccrs.PlateCarree(),
                   cbar_kwargs={"label": "{} [{}]"
                                .format(attrs["abbreviation"], 
                                        attrs["units"])
                               }
                  )
    ax.set_title(attrs["full_name"])
    ax.add_feature(cfeature.COASTLINE)
    ax.gridlines(draw_labels=True, x_inline=False, y_inline=False)