# Stochastic compute GCM fluxes
Clone of the GCM (MiMA) compute fluxes, however now using a stochastic implementation.

Care needs to be taken in this case to isolate blocks.

In [1]:
### Master Variables
dc = 0.4 # Spectral resolution
exclude_unbroken = False # Exclude waves that are unbroken at top level from the momentum flux calculations? 
use_intrinsic_c = 'always' # always center distributon on zero intrinsic phase speed vs never (center on zero ground relative speed) or switch at latitude ? 
Fs0 = 4.3e-3 # Pa, initial momentum flux at source level
cw = 35 # m/s, gravity wave phase speed source distribution half width at half maximum 
expname=f'{use_intrinsic_c}_{"exclude" if exclude_unbroken else "include"}_unbroken_1year'
expname

'always_include_unbroken_1year'

## dask/SLURM setup

In [2]:
from dask_jobqueue import SLURMCluster
NCORES = 8
NPROCESS = 8
NCORESPERPROCESS = NCORES//NPROCESS
constraints = ['-C \"CLASS:SH4_CBASE|CLASS:SH4_CPERF\"'] # SH4 nodes are the fastest, and mixing node gens seems to cause ib0 issues.
cluster = SLURMCluster(queue='serc',memory='96GiB',cores=NCORES,processes=NPROCESS,walltime='06:00:00',job_extra_directives=constraints,log_directory='/scratch/users/robcking/dask_worker_logs')
cluster.scale(jobs=25) # roughly but tune to scheduler 
cluster

0,1
Dashboard: http://10.20.3.15:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.20.3.15:38031,Workers: 0
Dashboard: http://10.20.3.15:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [3]:
from dask.distributed import Client 
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.20.3.15:8787/status,

0,1
Dashboard: http://10.20.3.15:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.20.3.15:38031,Workers: 0
Dashboard: http://10.20.3.15:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


## Load Data

In [6]:
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from ad99py.ad99 import AlexanderDunkerton1999
from ad99py.masks import mask_dataset,load_mask
from ad99py.constants import GRAV,R_DRY,C_P


DATA_PATH = 'data' # Load in raw data from GCM, here we use MiMA that has been interpolated to the ERA5 pressure levels. 
ds = xr.open_dataset(os.path.join(DATA_PATH,'atmos_4xdaily_interp_1year.nc'),chunks={'time':16})
ds = ds[['ucomp','temp','vcomp','height']]
ds = ds.isel(level=slice(None,None,-1))
ds = ds.rename(lat='latitude',lon='longitude')
ds = ds.transpose('time','longitude','latitude','level')
ds = ds.chunk({'time':16})
ds

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.63 GiB 18.50 MiB Shape (1440, 128, 64, 37) (16, 128, 64, 37) Dask graph 90 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  1  37  64  128,

Unnamed: 0,Array,Chunk
Bytes,1.63 GiB,18.50 MiB
Shape,"(1440, 128, 64, 37)","(16, 128, 64, 37)"
Dask graph,90 chunks in 4 graph layers,90 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Load Data

In [None]:
us = ds.ucomp.data
vs = ds.vcomp.data
zs = ds.height.data
temps = ds.temp.data
level = ds.level.data

In [None]:
import dask.array as da
from ad99py.variables import bouyancy_freq_squared,density

Ns = bouyancy_freq_squared(temps,zs)**0.5
rho = density(temps,level)
lat=da.broadcast_to(ds.latitude.data[None,None,:].data,us.shape[:-1])[...,None].rechunk((16,-1,-1,-1))


## Define Blocks

In [None]:
SEED = 42 # for reproducibility
master = np.random.SeedSequence(SEED)
n_blocks = us.shape[0]
seeds = master.spawn(n_blocks)
seeds