# Stochastic compute ERA5 fluxes
Clone of the 2014 compute fluxes, however now using a stochastic implementation.

Care needs to be taken in this case to isolate blocks.

In [16]:
### Master Variables
dc = 0.4 # Spectral resolution
exclude_unbroken = False # Exclude waves that are unbroken at top level from the momentum flux calculations? 
use_intrinsic_c = 'always' # always center distributon on zero intrinsic phase speed vs never (center on zero ground relative speed) or switch at latitude ? 
Fs0 = 4.3e-3 # Pa, initial momentum flux at source level
cw = 35 # m/s, gravity wave phase speed source distribution half width at half maximum 
expname=f'{use_intrinsic_c}_{"exclude" if exclude_unbroken else "include"}_unbroken_1year'
expname

'always_include_unbroken_1year'

## dask/SLURM setup

In [2]:
from dask_jobqueue import SLURMCluster
NCORES = 8
NPROCESS = 8
NCORESPERPROCESS = NCORES//NPROCESS
constraints = ['-C \"CLASS:SH4_CBASE|CLASS:SH4_CPERF\"'] # SH4 nodes are the fastest, and mixing node gens seems to cause ib0 issues.
cluster = SLURMCluster(queue='serc',memory='96GiB',cores=NCORES,processes=NPROCESS,walltime='06:00:00',job_extra_directives=constraints,log_directory='/scratch/users/robcking/dask_worker_logs')
cluster.scale(jobs=25) # roughly but tune to scheduler 
cluster


0,1
Dashboard: http://10.20.3.15:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.20.3.15:44659,Workers: 0
Dashboard: http://10.20.3.15:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [3]:
from dask.distributed import Client 
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.20.3.15:8787/status,

0,1
Dashboard: http://10.20.3.15:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.20.3.15:44659,Workers: 0
Dashboard: http://10.20.3.15:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


## Load Data

In [4]:
import xarray as xr 
import matplotlib.pyplot as plt 
import numpy as np 
import dask.array as da
import cartopy.crs as ccrs 
## Load in from ERA5 GCP public dataset
PATH_COARSE = 'gs://gcp-public-data-arco-era5/ar/1959-2022-1h-240x121_equiangular_with_poles_conservative.zarr'
variables = ['u_component_of_wind','v_component_of_wind','temperature','geopotential']
ds = xr.open_zarr(PATH_COARSE,chunks={})
ds_4xdaily = ds.sel(time=slice('2014-01-01','2015-01-01',6))
ds_4xdaily = ds_4xdaily.chunk({'time':1})
ds_4xdaily_reduce = ds_4xdaily[variables]


### coarsen ERA5

In [5]:
import xesmf as xe 

longitudes = np.linspace(0,360,128,endpoint=False)

N = 64  # Number of latitude points
x, w = np.polynomial.legendre.leggauss(N)  # x are roots (in [-1, 1]), w are weights
latitudes_radians = np.arcsin(x)  # Convert from Gauss x-space to latitude in radians
latitudes= np.degrees(latitudes_radians)

ds_4xdaily_regrid = xr.Dataset({
    'longitude':(['longitude'],longitudes,{"units": "degrees_east"}),
    'latitude':(['latitude'],latitudes,{"units": "degrees_north"})
})
regridder = xe.Regridder(ds_4xdaily_reduce,ds_4xdaily_regrid,'bilinear',periodic=True)
ds_4xdaily_regrid = regridder(ds_4xdaily_reduce).persist()

ds_4xdaily_regrid = ds_4xdaily_regrid.isel(level=slice(None,None,-1))
ds_4xdaily_regrid= ds_4xdaily_regrid.transpose("time","longitude","latitude","level")

### setup AD99 input variables

In [6]:
from ad99py import GRAV,C_P,BFLIM,R_DRY
from ad99py.variables import bouyancy_freq_squared,density

us = ds_4xdaily_regrid.u_component_of_wind.data
vs = ds_4xdaily_regrid.v_component_of_wind.data
temps = ds_4xdaily_regrid.temperature.data
height = ds_4xdaily_regrid.geopotential.data / GRAV

Ns = bouyancy_freq_squared(temps,height)**0.5
rho = density(temps,ds_4xdaily_regrid.level.data)

lat = da.broadcast_to(ds_4xdaily_regrid.latitude.data[None,None,:],us.shape[:-1]).astype(np.float32)
lat = lat.rechunk((1,-1,-1,))
lat_4d = lat[..., None]         # now shape = (11688, 240, 121, 1)
lat_4d = lat_4d.rechunk((1,-1,-1,-1)).persist() 

In [7]:
import dask
from dask.distributed import progress
us,vs,height,Ns,rho = dask.persist(us,vs,height,Ns,rho)
progress(Ns)

VBox()

## Define blocks

In [8]:
SEED = 42 # for reproducibility
master = np.random.SeedSequence(SEED)
n_blocks = us.shape[0]
seeds = master.spawn(n_blocks)
seeds

[SeedSequence(
     entropy=42,
     spawn_key=(0,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(1,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(2,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(3,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(4,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(5,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(6,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(7,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(8,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(9,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(10,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(11,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(12,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(13,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(14,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(15,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(16,),
 ),
 SeedSequence(
     entr

In [10]:
from ad99py.ad99stochastic import AlexanderDunkerton1999Stochastic
from functools import partial 

def dask_ad99_map_block(ublock,vblock,Nblock,zblock,rhoblock,latblock,seeds=None,block_id=None):
    """
    Map indivudal blocks by linearly running them through the parameterization. 
    Not very fast or efficient right now but could be optimized further in future
    """
    ## Use same seed for both u and v to ensure consistent sampling?
    if block_id is None:
        seed = 0 # for meta pass

    else:
        seed = seeds[block_id[0]]

    rng_u = np.random.default_rng(seed)
    rng_v = np.random.default_rng(seed)
    ad99_u = AlexanderDunkerton1999Stochastic(
        rng=rng_u,
        Fs0 = Fs0,
        Fs0_sigma = 1, # NEW CONSTRAINED PARAMETER! 
        cw=cw,
        exclude_unbroken=exclude_unbroken,
        use_intrinsic_c=use_intrinsic_c,
        Fs0_meaning='mean'
    )
    ad99_v = AlexanderDunkerton1999Stochastic(
        rng=rng_v,
        Fs0 = Fs0,
        Fs0_sigma = 1, # NEW CONSTRAINED PARAMETER! 
        cw=cw,
        exclude_unbroken=exclude_unbroken,
        use_intrinsic_c=use_intrinsic_c,
        Fs0_meaning='mean'
    )
    batch_shape = ublock.shape[:-1]
    nlevels = ublock.shape[-1]
    
    ublock_flt = ublock.reshape((-1,nlevels))
    vblock_flt = vblock.reshape((-1,nlevels))
    Nblock_flt = Nblock.reshape((-1,nlevels))
    zblock_flt = zblock.reshape((-1,nlevels))
    rhoblock_flt = rhoblock.reshape((-1,nlevels))
    latblock_flt = latblock.ravel()
    results_u = np.array([
        ad99_u.momentum_flux_neg_ptv(u,N,z,rho,lat) for u,N,z,rho,lat in zip(ublock_flt,Nblock_flt,zblock_flt,rhoblock_flt,latblock_flt)
    ])
    results_v = np.array([
        ad99_v.momentum_flux_neg_ptv(v,N,z,rho,lat) for v,N,z,rho,lat in zip(vblock_flt,Nblock_flt,zblock_flt,rhoblock_flt,latblock_flt)
    ])
    
    result_u_shp = results_u.reshape((*batch_shape,2,nlevels))
    result_u_shp = np.moveaxis(result_u_shp,-2,0)

    result_v_shp = results_v.reshape((*batch_shape,2,nlevels))
    result_v_shp = np.moveaxis(result_v_shp,-2,0)

    result_rtn = np.concatenate([result_u_shp,result_v_shp],axis=0)
    return result_rtn 

run_ad99_stochastic = partial( dask_ad99_map_block, seeds=seeds)

In [11]:
results = da.map_blocks(
    run_ad99_stochastic,
    us,
    vs,
    Ns,
    height,
    rho,
    lat_4d,
    new_axis=[0],
    dtype=us.dtype,
    chunks= (4,) + tuple(c[0] for c in us.chunks) ).persist()

ntv_u_flux = results[0]
ptv_u_flux = results[1]
ntv_v_flux = results[2]
ptv_v_flux = results[3]


  warn(f"`source` is not set, using default Gaussian source spectrum, with `cw={cw}` and `Bm={Bm}`.")


In [19]:
import os
dims = ('time','longitude','latitude','level')
coords={
    'time':('time',ds_4xdaily_regrid.time.values),
    'longitude':('longitude',ds_4xdaily_regrid.longitude.values),
    'latitude':('latitude',ds_4xdaily_regrid.latitude.values),
    'level':('level',ds_4xdaily_regrid.level.values)
}
ds_data = xr.Dataset({
    'u':(dims,us),
    'v':(dims,vs),
    'Ns':(dims,Ns),
    'z':(dims,height),
    'rho':(dims,rho),
    'gw_flux_westward':(dims,ntv_u_flux),
    'gw_flux_eastward':(dims,ptv_u_flux),
    'gw_flux_northward':(dims,ptv_v_flux),
    'gw_flux_southward':(dims,ntv_v_flux)
    },coords=coords)
DATA_PATH = os.path.expandvars('$SCRATCH/data_ad99')

OUTPUT_PATH = f'{DATA_PATH}/era5_{expname}_stochastic.nc'
ds_data.to_netcdf(OUTPUT_PATH)
print(OUTPUT_PATH)

/scratch/users/robcking/data_ad99/era5_always_include_unbroken_1year_stochastic.nc
