In [None]:
import numpy as np
import xarray as xr
import os
import argparse
import glob
import metpy
import pathlib


def get_mse(data):
    """compute equivalent temperature"""

    ## specific humidity
    q = metpy.calc.specific_humidity_from_dewpoint(
        dewpoint=data["d2m"] * metpy.units.units.K,
        pressure=data["sp"] * metpy.units.units.Pa,
    )

    # then, MSE
    mse = metpy.calc.moist_static_energy(
        height=0.0 * metpy.units.units.meters,
        temperature=data["t2m"] * metpy.units.units.kelvin,
        specific_humidity=q,
    )

    return mse


def trim_to_PNW(data):
    """Trim data to Pacific Northwest region
    as defined in Bartusek et al. (2021)"""

    ## lon/lat range
    lat_range = [60, 40]
    lon_range = [230, 250]

    return data.sel(latitude=slice(*lat_range), longitude=slice(*lon_range))


def trim(data):

    ## only get 0:00 UTC (4 PM local time)
    is_0utc = data.time.dt.hour == 0

    # return trim_to_PNW(data).isel(time=is_0utc)
    return trim_to_PNW(data)


def landarea_weighted_mean(data, lsm=None):
    """Get landarea-weighted mean on regular lon-lat grid.
    Specifically, weight by cosine of latitude"""

    ## get cos(lat)
    cos_lat = np.cos(np.deg2rad(data.latitude))

    ## multiply by fraction of land, if land-sea mask is provided
    if lsm is None:
        weights = cos_lat

    else:
        weights = cos_lat * lsm

    return data.weighted(weights=weights).mean(["latitude", "longitude"])


def load_lsm_from_server():
    """Load land-sea mask (constant in time) from WHOI server"""

    ## filepath
    lsm_path = "/vortexfs1/share/cmip6/data/era5/reanalysis/single-levels/monthly-means/land_sea_mask/2022_land_sea_mask.nc"

    ## load the data
    lsm = xr.open_dataarray(lsm_path).isel(time=6, drop=True)

    return lsm

    # trim to Pac. NW
    return trim_to_PNW(lsm)


def load_data(save_fp=None):
    """Load ERA5 data for given variable. Loads locally-saved
    data if available; otherwise downloads from WHOI's server"""

    ## set default save_fp if not specified
    if save_fp is None:
        save_fp = pathlib.Path(os.environ["DATA_FP"]) / "whoi_data.nc"

    ## check if file exists
    if save_fp.is_file():
        data = xr.open_dataset(save_fp)

    else:

        ## get filepaths
        era5_path = pathlib.Path(
            "/vortexfs1/share/cmip6/data/era5/reanalysis/single-levels/6hr"
        )
        get_paths_month = lambda varname, month: list(
            era5_path.glob(f"{varname}/*-{month:02d}*.nc")
        )
        get_paths = (
            lambda n: get_paths_month(n, 6)
            + get_paths_month(n, 7)
            + get_paths_month(n, 8)
        )
        paths = (
            get_paths("2m_temperature")
            + get_paths("2m_dewpoint_temperature")
            + get_paths("surface_pressure")
        )

        ## Open data
        data = xr.open_mfdataset(paths, preprocess=trim)

        ## Load to memory and compute MSE
        data.load()
        data["mse"] = get_mse(data)

        ## save to file
        data.to_netcdf(save_fp)

    return data


def load_data_spatial_avg(data, save_fp=None):

    ## set default save_fp if not specified
    if save_fp is None:
        save_fp = pathlib.Path(os.environ["DATA_FP"]) / "whoi_data_avg.nc"

    ## check if file exists
    if save_fp.is_file():
        data_avg = xr.open_dataset(save_fp)

    else:

        ## load lsm
        lsm = trim_to_PNW(load_lsm_from_server())

        ## compute spatial avg
        data_avg = landarea_weighted_mean(data, lsm=lsm)

        ## save to file
        data_avg.to_netcdf(save_fp)

    return data_avg

## Load the data

In [None]:
data = load_data()
data_avg = load_data_spatial_avg(data)