## ERA5 Compute Fluxes
This notebook loads in (via GCP) the ERA5 dataset at $1.5\degree$ resolution and coarsens it in time to 4x daily (6 hourly). 

It then runs the AD99 offline parameterization over the data to predict both the zonal and meridional momentum fluxes before saving it to the drive

## ERA5 Compute Fluxes (2014 only)
This notebook loads in (via GCP) the ERA5 dataset at $1.5\degree$ resolution and coarsens it in time to 4x daily (6 hourly). 

It then runs the AD99 offline parameterization over the data to predict both the zonal and meridional momentum fluxes before saving it to the drive.

This resembles the other ERA5 notebook however it only uses a single year (2014) which overlaps with the greatest density of Loon datasets. As there is very little interannual variability in our results, the dataset is pretty similar to the full 2014-2021 dataset however substantially smaller in size. The netCDF generated by this dataset should not need more than approximately 20GB of data vs >200GB for the other dataset.



## dask/SLURM Setup
This is optional and configured for the Sherlock HPC at Stanford.

In [1]:
from dask_jobqueue import SLURMCluster
NCORES = 8
NPROCESS = 8
NCORESPERPROCESS = NCORES//NPROCESS
constraints = ['-C \"CLASS:SH3_CBASE|CLASS:SH3_CPERF\"'] # SH4 nodes are the fastest, and mixing node gens seems to cause ib0 issues.
cluster = SLURMCluster(queue='serc',memory='96GiB',cores=NCORES,processes=NPROCESS,walltime='06:00:00',job_extra_directives=constraints,log_directory='/scratch/users/robcking/dask_worker_logs')
cluster.scale(jobs=25) # roughly but tune to scheduler 
cluster




0,1
Dashboard: http://10.19.14.21:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.19.14.21:34018,Workers: 0
Dashboard: http://10.19.14.21:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [2]:
from dask.distributed import Client 
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.19.14.21:8787/status,

0,1
Dashboard: http://10.19.14.21:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.19.14.21:34018,Workers: 0
Dashboard: http://10.19.14.21:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


## Load Data

In [3]:
import xarray as xr 
import matplotlib.pyplot as plt 
import numpy as np 
import dask.array as da
import cartopy.crs as ccrs 

PATH_COARSE = 'gs://gcp-public-data-arco-era5/ar/1959-2022-1h-240x121_equiangular_with_poles_conservative.zarr'
variables = ['u_component_of_wind','v_component_of_wind','temperature','geopotential']
ds = xr.open_zarr(PATH_COARSE,chunks={})
ds_4xdaily = ds.sel(time=slice('2014-01-01','2015-01-01',6))
ds_4xdaily = ds_4xdaily.chunk({'time':1})
ds_4xdaily_reduce = ds_4xdaily[variables]


In [4]:
import xesmf as xe 

longitudes = np.linspace(0,360,128,endpoint=False)

N = 64  # Number of latitude points
x, w = np.polynomial.legendre.leggauss(N)  # x are roots (in [-1, 1]), w are weights
latitudes_radians = np.arcsin(x)  # Convert from Gauss x-space to latitude in radians
latitudes= np.degrees(latitudes_radians)

ds_4xdaily_regrid = xr.Dataset({
    'longitude':(['longitude'],longitudes,{"units": "degrees_east"}),
    'latitude':(['latitude'],latitudes,{"units": "degrees_north"})
})
regridder = xe.Regridder(ds_4xdaily_reduce,ds_4xdaily_regrid,'bilinear',periodic=True)
ds_4xdaily_regrid = regridder(ds_4xdaily_reduce).persist()

In [5]:
ds_4xdaily_regrid = ds_4xdaily_regrid.isel(level=slice(None,None,-1))
ds_4xdaily_regrid= ds_4xdaily_regrid.transpose("time","longitude","latitude","level")

In [6]:
from ad99py import GRAV,C_P,BFLIM,R_DRY
from ad99py.variables import bouyancy_freq_squared,density

us = ds_4xdaily_regrid.u_component_of_wind.data
vs = ds_4xdaily_regrid.v_component_of_wind.data
temps = ds_4xdaily_regrid.temperature.data
height = ds_4xdaily_regrid.geopotential.data / GRAV

Ns = bouyancy_freq_squared(temps,height)**0.5
rho = density(temps,ds_4xdaily_regrid.level.data)

lat = da.broadcast_to(ds_4xdaily_regrid.latitude.data[None,None,:],us.shape[:-1]).astype(np.float32)
lat = lat.rechunk((1,-1,-1,))
lat_4d = lat[..., None]         # now shape = (11688, 240, 121, 1)
lat_4d = lat_4d.rechunk((1,-1,-1,-1)).persist() 

In [7]:
import dask
from dask.distributed import progress
us,vs,height,Ns,rho = dask.persist(us,vs,height,Ns,rho)
progress(Ns)


VBox()

In [12]:
from ad99py import AlexanderDunkerton1999
ad99 = AlexanderDunkerton1999(Fs0=4.3e-3,cw=35,exclude_unbroken=False,use_intrinsic_c='always',dc=0.4) # this is following the MiMA spec

def dask_ad99_map_block(ublock,Nblock,zblock,rhoblock,latblock,ad99=None):
    """
    Map indivudal blocks by linearly running them through the parameterization. 
    Not very fast or efficient right now but could be optimized further in future
    """
    batch_shape = ublock.shape[:-1]
    nlevels = ublock.shape[-1]
    
    ublock_flt = ublock.reshape((-1,nlevels))
    Nblock_flt = Nblock.reshape((-1,nlevels))
    zblock_flt = zblock.reshape((-1,nlevels))
    rhoblock_flt = rhoblock.reshape((-1,nlevels))
    latblock_flt = latblock.ravel()
    results = np.array([
        ad99.momentum_flux_neg_ptv(u,N,z,rho,lat) for u,N,z,rho,lat in zip(ublock_flt,Nblock_flt,zblock_flt,rhoblock_flt,latblock_flt)
    ])
    
    result_shp = results.reshape((*batch_shape,2,nlevels))
    result_rtn = np.moveaxis(result_shp,-2,-1)
    return result_rtn 


  warn(f"`source` is not set, using default Gaussian source spectrum, with `cw={cw}` and `Bm={Bm}`.")


In [13]:
from functools import partial 
result_u = da.map_blocks(
    partial(dask_ad99_map_block,ad99=ad99),
    us,Ns,height,
    rho,
    lat_4d,
    new_axis=len(us.shape),
    dtype=us.dtype,
    chunks=tuple(c[0] for c in us.chunks) + (2,)).persist()

ntv_u_flux = result_u[...,0]
ptv_u_flux = result_u[...,1]

result_v = da.map_blocks(
    partial(dask_ad99_map_block,ad99=ad99),
    vs,Ns,height,
    rho,
    lat_4d,
    new_axis=len(vs.shape),
    dtype=vs.dtype,
    chunks=tuple(c[0] for c in vs.chunks) + (2,)).persist()
    
ntv_v_flux = result_v[...,0]
ptv_v_flux = result_v[...,1]

In [14]:
dims = ('time','longitude','latitude','level')
coords={
    'time':('time',ds_4xdaily_regrid.time.values),
    'longitude':('longitude',ds_4xdaily_regrid.longitude.values),
    'latitude':('latitude',ds_4xdaily_regrid.latitude.values),
    'level':('level',ds_4xdaily_regrid.level.values)
}
ds_data = xr.Dataset({
    'u':(dims,us),
    'v':(dims,vs),
    'Ns':(dims,Ns),
    'z':(dims,height),
    'rho':(dims,rho),
    'gw_flux_westward':(dims,ntv_u_flux),
    'gw_flux_eastward':(dims,ptv_u_flux),
    'gw_flux_northward':(dims,ptv_v_flux),
    'gw_flux_southward':(dims,ntv_v_flux)
    },coords=coords)

expname='always_include_unbroken_1year'
OUTPUT_PATH = f'/scratch/users/robcking/era5_{expname}.nc'
ds_data.to_netcdf(OUTPUT_PATH)
print(OUTPUT_PATH)

/scratch/users/robcking/era5_always_include_unbroken_1year.nc


In [9]:
client.cancel(result_u)
client.cancel(result_v)