# Stochastic compute GCM fluxes
Clone of the GCM (MiMA) compute fluxes, however now using a stochastic implementation.

Care needs to be taken in this case to isolate blocks.

In [2]:
### Master Variables
dc = 0.4 # Spectral resolution
exclude_unbroken = False # Exclude waves that are unbroken at top level from the momentum flux calculations? 
use_intrinsic_c = 'always' # always center distributon on zero intrinsic phase speed vs never (center on zero ground relative speed) or switch at latitude ? 
Fs0 = 4.3e-3 # Pa, initial momentum flux at source level
cw = 35 # m/s, gravity wave phase speed source distribution half width at half maximum 
expname=f'{use_intrinsic_c}_{"exclude" if exclude_unbroken else "include"}_unbroken_1year_nice'
expname

'always_include_unbroken_1year_nice'

## dask/SLURM setup

In [3]:
from dask_jobqueue import SLURMCluster
NCORES = 8
NPROCESS = 8
NCORESPERPROCESS = NCORES//NPROCESS
constraints = ['-C \"CLASS:SH4_CBASE|CLASS:SH4_CPERF\"'] # SH4 nodes are the fastest, and mixing node gens seems to cause ib0 issues.
cluster = SLURMCluster(queue='serc',memory='48GiB',cores=NCORES,processes=NPROCESS,walltime='02:00:00',job_extra_directives=constraints,log_directory='/scratch/users/robcking/dask_worker_logs')
cluster.scale(jobs=25) # roughly but tune to scheduler 
cluster

0,1
Dashboard: http://10.20.10.27:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.20.10.27:35020,Workers: 0
Dashboard: http://10.20.10.27:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [4]:
from dask.distributed import Client 
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.20.10.27:8787/status,

0,1
Dashboard: http://10.20.10.27:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.20.10.27:35020,Workers: 0
Dashboard: http://10.20.10.27:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


## Load Data

In [5]:
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from ad99py.ad99 import AlexanderDunkerton1999
from ad99py.masks import mask_dataset,load_mask
from ad99py.constants import GRAV,R_DRY,C_P


DATA_PATH = 'data' # Load in raw data from GCM, here we use MiMA that has been interpolated to the ERA5 pressure levels. 
ds = xr.open_dataset(os.path.join(DATA_PATH,'atmos_4xdaily_interp_1year.nc'),chunks={'time':16})
ds = ds[['ucomp','temp','vcomp','height']]
ds = ds.isel(level=slice(None,None,-1))
ds = ds.rename(lat='latitude',lon='longitude')
ds = ds.transpose('time','longitude','latitude','level')
ds = ds.chunk({'time':16})
ds

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Load Data

In [6]:
us = ds.ucomp.data
vs = ds.vcomp.data
zs = ds.height.data
temps = ds.temp.data
level = ds.level.data

In [7]:
import dask.array as da
from ad99py.variables import bouyancy_freq_squared,density

Ns = bouyancy_freq_squared(temps,zs)**0.5
rho = density(temps,level)
lat=da.broadcast_to(ds.latitude.data[None,None,:].data,us.shape[:-1])[...,None].rechunk((16,-1,-1,-1))


## Define Blocks

In [8]:
SEED = 42 # for reproducibility
master = np.random.SeedSequence(SEED)
n_blocks = us.shape[0]
seeds = master.spawn(n_blocks)
seeds

[SeedSequence(
     entropy=42,
     spawn_key=(0,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(1,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(2,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(3,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(4,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(5,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(6,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(7,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(8,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(9,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(10,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(11,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(12,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(13,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(14,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(15,),
 ),
 SeedSequence(
     entropy=42,
     spawn_key=(16,),
 ),
 SeedSequence(
     entr

In [10]:
from ad99py.ad99stochastic import AlexanderDunkerton1999Stochastic
from functools import partial 

def dask_ad99_map_block(ublock,vblock,Nblock,zblock,rhoblock,latblock,seeds=None,block_id=None,distribution='lognormal'):
    """
    Map indivudal blocks by linearly running them through the parameterization. 
    Not very fast or efficient right now but could be optimized further in future
    """
    ## Use same seed for both u and v to ensure consistent sampling?
    if block_id is None:
        seed = 0 # for meta pass

    else:
        seed = seeds[block_id[0]]

    rng_u = np.random.default_rng(seed)
    rng_v = np.random.default_rng(seed)
    ad99_u = AlexanderDunkerton1999Stochastic(
        rng=rng_u,
        Fs0 = Fs0,
        Fs0_sigma = 2.7*1e-3, # NEW CONSTRAINED PARAMETER! 
        cw=cw,
        exclude_unbroken=exclude_unbroken,
        use_intrinsic_c=use_intrinsic_c,
        distribution=distribution
    )
    ad99_v = AlexanderDunkerton1999Stochastic(
        rng=rng_v,
        Fs0 = Fs0,
        Fs0_sigma = 2.7*1e-3, # NEW CONSTRAINED PARAMETER! 
        cw=cw,
        exclude_unbroken=exclude_unbroken,
        use_intrinsic_c=use_intrinsic_c,
        distribution=distribution
    )
    batch_shape = ublock.shape[:-1]
    nlevels = ublock.shape[-1]
    
    ublock_flt = ublock.reshape((-1,nlevels))
    vblock_flt = vblock.reshape((-1,nlevels))
    Nblock_flt = Nblock.reshape((-1,nlevels))
    zblock_flt = zblock.reshape((-1,nlevels))
    rhoblock_flt = rhoblock.reshape((-1,nlevels))
    latblock_flt = latblock.ravel()
    results_u = np.array([
        ad99_u.momentum_flux_neg_ptv(u,N,z,rho,lat) for u,N,z,rho,lat in zip(ublock_flt,Nblock_flt,zblock_flt,rhoblock_flt,latblock_flt)
    ])
    results_v = np.array([
        ad99_v.momentum_flux_neg_ptv(v,N,z,rho,lat) for v,N,z,rho,lat in zip(vblock_flt,Nblock_flt,zblock_flt,rhoblock_flt,latblock_flt)
    ])
    
    result_u_shp = results_u.reshape((*batch_shape,2,nlevels))
    result_u_shp = np.moveaxis(result_u_shp,-2,0)

    result_v_shp = results_v.reshape((*batch_shape,2,nlevels))
    result_v_shp = np.moveaxis(result_v_shp,-2,0)

    result_rtn = np.concatenate([result_u_shp,result_v_shp],axis=0)
    return result_rtn 

def run_ad99_stochastic_lognormal(ublock,vblock,Nblock,zblock,rhoblock,latblock,block_id=None):
    return dask_ad99_map_block(ublock,vblock,Nblock,zblock,rhoblock,latblock,seeds=seeds,block_id=block_id,distribution='lognormal')

def run_ad99_stochastic_normal(ublock,vblock,Nblock,zblock,rhoblock,latblock,block_id=None):
    return dask_ad99_map_block(ublock,vblock,Nblock,zblock,rhoblock,latblock,seeds=seeds,block_id=block_id,distribution='normal')

def run_ad99_stochastic_uniform(ublock,vblock,Nblock,zblock,rhoblock,latblock,block_id=None):
    return dask_ad99_map_block(ublock,vblock,Nblock,zblock,rhoblock,latblock,seeds=seeds,block_id=block_id,distribution='uniform')



In [11]:
results_lognormal = da.map_blocks(
    run_ad99_stochastic_lognormal,
    us,
    vs,
    Ns,
    zs,
    rho,
    lat,
    new_axis=[0],
    dtype=us.dtype,
    chunks= (4,) + tuple(c[0] for c in us.chunks) )

ntv_u_flux_lognormal = results_lognormal[0]
ptv_u_flux_lognormal = results_lognormal[1]
ntv_v_flux_lognormal = results_lognormal[2]
ptv_v_flux_lognormal  = results_lognormal[3]

results_normal = da.map_blocks(
    run_ad99_stochastic_normal,
    us,
    vs,
    Ns,
    zs,
    rho,
    lat,
    new_axis=[0],
    dtype=us.dtype,
    chunks= (4,) + tuple(c[0] for c in us.chunks) )

ntv_u_flux_normal = results_normal[0]
ptv_u_flux_normal= results_normal[1]
ntv_v_flux_normal = results_normal[2]
ptv_v_flux_normal  = results_normal[3]

results_uniform = da.map_blocks(
    run_ad99_stochastic_uniform,
    us,
    vs,
    Ns,
    zs,
    rho,
    lat,
    new_axis=[0],
    dtype=us.dtype,
    chunks= (4,) + tuple(c[0] for c in us.chunks) )

ntv_u_flux_uniform = results_uniform[0]
ptv_u_flux_uniform= results_uniform[1]
ntv_v_flux_uniform = results_uniform[2]
ptv_v_flux_uniform  = results_uniform[3]




  warn(f"`source` is not set, using default Gaussian source spectrum, with `cw={cw}` and `Bm={Bm}`.")
  warn(f"`source` is not set, using default Gaussian source spectrum, with `cw={cw}` and `Bm={Bm}`.")
  warn(f"`source` is not set, using default Gaussian source spectrum, with `cw={cw}` and `Bm={Bm}`.")


In [12]:
import os
dims = ('time','longitude','latitude','level')
coords={
    'time':('time',ds.time.values),
    'longitude':('longitude',ds.longitude.values),
    'latitude':('latitude',ds.latitude.values),
    'level':('level',ds.level.values)
}
ds_data_lognormal = xr.Dataset({
    'u':(dims,us),
    'v':(dims,vs),
    'Ns':(dims,Ns),
    'z':(dims,zs),
    'rho':(dims,rho),
    'gw_flux_westward':(dims,ntv_u_flux_lognormal),
    'gw_flux_eastward':(dims,ptv_u_flux_lognormal),
    'gw_flux_northward':(dims,ptv_v_flux_lognormal),
    'gw_flux_southward':(dims,ntv_v_flux_lognormal)
    },coords=coords)

ds_data_normal = xr.Dataset({
    'u':(dims,us),
    'v':(dims,vs),
    'Ns':(dims,Ns),
    'z':(dims,zs),
    'rho':(dims,rho),
    'gw_flux_westward':(dims,ntv_u_flux_normal),
    'gw_flux_eastward':(dims,ptv_u_flux_normal),
    'gw_flux_northward':(dims,ptv_v_flux_normal),
    'gw_flux_southward':(dims,ntv_v_flux_normal)
    },coords=coords)

ds_data_uniform = xr.Dataset({
    'u':(dims,us),
    'v':(dims,vs),
    'Ns':(dims,Ns),
    'z':(dims,zs),
    'rho':(dims,rho),
    'gw_flux_westward':(dims,ntv_u_flux_uniform),
    'gw_flux_eastward':(dims,ptv_u_flux_uniform),
    'gw_flux_northward':(dims,ptv_v_flux_uniform),
    'gw_flux_southward':(dims,ntv_v_flux_uniform)
    },coords=coords)


DATA_PATH = os.path.expandvars('$SCRATCH/data_ad99')

OUTPUT_PATH_LOGNORMAL = f'{DATA_PATH}/mima_gwf_{expname}_stochastic.nc'
OUTPUT_PATH_NORMAL = f'{DATA_PATH}/mima_gwf_{expname}_normal_stochastic.nc'
OUTPUT_PATH_UNIFORM = f'{DATA_PATH}/mima_gwf_{expname}_uniform_stochastic.nc'
delayed_lognormal = ds_data_lognormal.to_netcdf(OUTPUT_PATH_LOGNORMAL,compute=False)
delayed_normal = ds_data_normal.to_netcdf(OUTPUT_PATH_NORMAL,compute=False)
delayed_uniform = ds_data_uniform.to_netcdf(OUTPUT_PATH_UNIFORM,compute=False)
OUTPUT_PATH_LOGNORMAL,OUTPUT_PATH_NORMAL,OUTPUT_PATH_UNIFORM

('/scratch/users/robcking/data_ad99/mima_gwf_always_include_unbroken_1year_nice_stochastic.nc',
 '/scratch/users/robcking/data_ad99/mima_gwf_always_include_unbroken_1year_nice_normal_stochastic.nc',
 '/scratch/users/robcking/data_ad99/mima_gwf_always_include_unbroken_1year_nice_uniform_stochastic.nc')

In [13]:
import dask
dask.compute(delayed_lognormal,delayed_normal,delayed_uniform)