# OSNAP data extraction

In [1]:
%matplotlib inline
import cosima_cookbook as cc
import numpy as np
import pandas as pd
import xarray as xr
import flox  # for faster groupby in xarray with dask
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.patches as patches
from dask.distributed import Client
from datetime import timedelta, date
import calendar
import os
from collections import OrderedDict
import cartopy.crs as ccrs
import cmocean as cm
import logging
logging.captureWarnings(True)
logging.getLogger('py.warnings').setLevel(logging.ERROR)
logging.getLogger('distributed.utils_perf').setLevel(logging.ERROR)

In [2]:
import climtas.nci
climtas.nci.GadiClient(malloc_trim_threshold='64kib')

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/44595/status,

0,1
Dashboard: /proxy/44595/status,Workers: 7
Total threads: 7,Total memory: 32.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:41403,Workers: 7
Dashboard: /proxy/44595/status,Total threads: 7
Started: Just now,Total memory: 32.00 GiB

0,1
Comm: tcp://127.0.0.1:36497,Total threads: 1
Dashboard: /proxy/44083/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:43999,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-cd86g_30,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-cd86g_30

0,1
Comm: tcp://127.0.0.1:43881,Total threads: 1
Dashboard: /proxy/45017/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:37669,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-q08v7nj4,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-q08v7nj4

0,1
Comm: tcp://127.0.0.1:39961,Total threads: 1
Dashboard: /proxy/41589/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:45985,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-5snm_uhe,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-5snm_uhe

0,1
Comm: tcp://127.0.0.1:35095,Total threads: 1
Dashboard: /proxy/44759/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:38935,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-ovw9qwft,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-ovw9qwft

0,1
Comm: tcp://127.0.0.1:36285,Total threads: 1
Dashboard: /proxy/34429/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:43769,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-r1fdju0x,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-r1fdju0x

0,1
Comm: tcp://127.0.0.1:34163,Total threads: 1
Dashboard: /proxy/38705/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:44913,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-gjrzntt3,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-gjrzntt3

0,1
Comm: tcp://127.0.0.1:33361,Total threads: 1
Dashboard: /proxy/41885/status,Memory: 4.57 GiB
Nanny: tcp://127.0.0.1:35519,
Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-4yh11ocg,Local directory: /jobfs/112699668.gadi-pbs/tmpzkmktn4pdask-worker-space/dask-scratch-space/worker-4yh11ocg


In [3]:
session = cc.database.create_session()


## Initialise data structure and define helper functions

In [4]:
# WARNING! FORGETS ALL LOADED DATA!
data = OrderedDict() # init nested dict of experiments and their analyses

In [5]:
def addexpt(k, d):
    if k in data:
        print('skipped {}: already exists'.format(k))
    else:
        data[k] = d

In [6]:
def dictget(d, l):
    """
    Get item in nested dict using a list of keys

    d: nested dict
    l: list of keys
    """
    if len(l) == 1:
        return d[l[0]]
    return dictget(d[l[0]], l[1:])

In [7]:
def dictknown(d, l):
    """
    Return true if list of keys is valid in nested dict

    d: nested dict
    l: list of keys
    """    
    while len(l)>0 and l[0] in d:
        d = d[l[0]]
        l = l[1:]
    return len(l) == 0

In [8]:
def dictput(d, l, item):
    """
    Put item in nested dict using a list of keys

    d: nested dict
    l: list of keys
    item: item to be put
    """
    while l[0] in d and len(l)>1:  # transerse existing keys
        d = d[l[0]]
        l = l[1:]
    while len(l)>1:  # add new keys as needed
        d[l[0]] = dict()
        d = d[l[0]]
        l = l[1:]
    d[l[0]] = item
    return

In [9]:
# convenience functions
def dget(l):
    return dictget(data, l)
def dknown(l):
    return dictknown(data, l)
def dput(l, item):
    return dictput(data, l, item)

In [10]:
def showdata():
    """
    Display structure of data
    """
    for k, d in data.items():
        print(k)
        for k2, d2 in d.items():
            print('  ', k2)
            try:
                for k3, d3 in d2.items():
                    print('    ', k3)
                    try:
                        for k4, d4 in d3.items():
                            print('      ', k4)
                            try:
                                for k5, d5 in d4.items():
                                    print('        ', k5)
                                    try:
                                        for k6, d6 in d5.items():
                                            print('          ', k6)
                                    except:
                                        pass
                            except:
                                pass
                    except:
                        pass
            except:
                pass

## Set experiments, regions, date ranges, variables, frequencies etc
1deg_jra55_iaf_omip2_cycle6

1deg_jra55_iaf_omip2_cycle6_jra55v150_extension

025deg_jra55_iaf_omip2_cycle6

025deg_jra55_iaf_omip2_cycle6_jra55v150_extension

01deg_jra55v140_iaf_cycle4

01deg_jra55v140_iaf_cycle4_jra55v150_extension

In [14]:
addexpt('1', {'model':'access-om2-025',
              'expts': ['1deg_jra55_iaf_omip2_cycle6',
                        '1deg_jra55_iaf_omip2_cycle6_jra55v150_extension'],
              'gridpaths': ['/g/data/ik11/grids/ocean_grid_10.nc']})

In [15]:
addexpt('025', {'model':'access-om2-025',
                'expts': ['025deg_jra55_iaf_omip2_cycle6',
                          '025deg_jra55_iaf_omip2_cycle6_jra55v150_extension'],
                'gridpaths': ['/g/data/ik11/grids/ocean_grid_025.nc']})

In [16]:
addexpt('01', {'model':'access-om2-01',
               'expts': ['01deg_jra55v140_iaf_cycle4_cycle4',
                         '01deg_jra55v140_iaf_cycle4_jra55v150_extension'],
               'gridpaths': ['/g/data/ik11/grids/ocean_grid_01.nc', 
                             '/g/data/cj50/access-om2/raw-output/access-om2-01/01deg_jra55v140_iaf/output000/ocean/ocean-2d-area_t.nc',
                             '/g/data/cj50/access-om2/raw-output/access-om2-01/01deg_jra55v140_iaf/output000/ocean/ocean-2d-area_u.nc']
              })

skipped 01: already exists


In [17]:
showdata()

01
   model
   desc
   gridpaths
1
   model
   gridpaths
025
   model
   gridpaths


In [24]:
# set climatology date range

tstart = pd.to_datetime('1958', format='%Y')
tend = pd.to_datetime('2023', format='%Y')
# tend = tstart + pd.DateOffset(years=30)
timerange=slice(tstart, tend)
firstyear = pd.to_datetime(tstart).year  # assumes tstart is 1 January!
lastyear = pd.to_datetime(tend).year-1  # assumes tend is 1 January!
yearrange = str(firstyear)+'-'+str(lastyear)
print('yearrange =', yearrange, 'complete years')
print('tstart =', tstart)
print('tend =', tend)

yearrange = 1958-2022 complete years
tstart = 1958-01-01 00:00:00
tend = 2023-01-01 00:00:00


In [25]:
varnames = [ # must be 2d fields
            'u', 'v',
            'pot_temp',
            'salt',
            'pot_rho_0', 'pot_rho_2',
            'sea_level',
            'net_sfc_heating', 'frazil_3d_int_z',  # heat: https://forum.access-hive.org.au/t/net-surface-heat-and-freshwater-flux-variables/993/2
            'pme_river',  # water
            'sfc_salt_flux_ice', 'sfc_salt_flux_restore',  # salt
            # 'mh_flux',  # sea ice melt
            # 'sfc_hflux_coupler',
            # 'sfc_hflux_from_runoff',
            # 'sfc_hflux_pme',
            # 'net_sfc_heating', 'frazil_3d_int_z',  # Net surface heat flux into ocean is net_sfc_heating + frazil_3d_int_z: https://github.com/COSIMA/access-om2/issues/139#issuecomment-639278547
            # 'swflx',
            # 'lw_heat',
            # 'sens_heat',
            # 'evap_heat',
            # 'fprec_melt_heat',
           ]

In [20]:
frequencies = ['1 monthly']

In [21]:
# for the North Atlantic: 70W-0E,40N-70N
regions = OrderedDict([
    ('NA', {'lon': slice(-70, 0), 'lat': slice(40, 70)}),
])

## Calculations

### Load data

In [26]:
def loadalldata(data, regions, freqs, varnames, timerange=timerange, ncfiles=None):
    region = 'global'
    reduction = 'unreduced'

    varnames = varnames.copy()
    tau = 'tau' in varnames
    if tau:
        varnames.append('tau_x')
        varnames.append('tau_y')
        varnames.remove('tau')
    varnames = list(set(varnames))

    if not isinstance(ncfiles, list):
        ncfiles = [ncfiles]*len(varnames)  # use the same ncfile for all variables

    for expt in data.keys():
        print(expt)
        for freq in freqs:
            kkey = [expt, region, freq, reduction]
            for varname, ncfile in zip(varnames, ncfiles):
                if not dknown(kkey+[varname]):
                    if ncfile is None:
                        print('loading', varname)
                    else:
                        print('loading', varname, 'from', ncfile)
                    if expt == cycle4:
                        dput(kkey+[varname],
                                xr.concat([
                                        cc.querying.getvar(cycle4    , varname, session, frequency=freq, ncfile=ncfile, decode_coords=False, start_time=str(timerange.start)),
                                        cc.querying.getvar(cycle4_ext, varname, session, frequency=freq, ncfile=ncfile, decode_coords=False),
                                                            ], 'time'))
                    else:
                        dput(kkey+[varname],
                                cc.querying.getvar(expt, varname, session, frequency=freq, ncfile=ncfile, decode_coords=False, start_time=str(timerange.start)))

            # if tau:  # calculate stress magnitude tau (tau_x and tau_y already loaded above)
            #     varname = 'tau'
            #     if not dknown(kkey+[varname]):
            #         print('calculating', varname)
            #         tau_da = np.sqrt(dget(kkey+['tau_x'])**2
            #                         +dget(kkey+['tau_y'])**2)
            #         tau_da.attrs = dget(kkey+['tau_x']).attrs
            #         tau_da.attrs['long_name'] = 'wind stress magnitude'
            #         dput(kkey+[varname], tau_da)

        freq = 'static'
    
        grids = [(p, xr.open_dataset(p, chunks='auto')) for p in dget([expt, 'gridpaths'])]
        for k in ['xt_ocean', 'yt_ocean', 'geolon_t', 'geolat_t', 'area_t',
                  'xu_ocean', 'yu_ocean', 'geolon_c', 'geolat_c', 'area_u']:
            kkey = [expt, region, freq, k]
            if not dknown(kkey):
                for (p, g) in grids:
                    try:
                        dput(kkey, g[k])
                        da = g[k]
                        print(k, 'loaded from', p)
                        break
                    except:
                        continue
                try:
                    da = da.rename({'grid_x_T': 'xt_ocean', 'grid_y_T': 'yt_ocean'}) # fix for 01deg
                    da.coords['xt_ocean'] = dget(kkey[0:-1]+['xt_ocean']).values
                    da.coords['yt_ocean'] = dget(kkey[0:-1]+['yt_ocean']).values
                    dput(kkey, da)
                except:
                    pass
                try:
                    da = da.rename({'grid_x_C': 'xu_ocean', 'grid_y_C': 'yu_ocean'}) # fix for 01deg
                    da.coords['xu_ocean'] = dget(kkey[0:-1]+['xu_ocean']).values
                    da.coords['yu_ocean'] = dget(kkey[0:-1]+['yu_ocean']).values
                    dput(kkey, da)
                except:
                    pass

In [None]:
%%time
loadalldata(data, regions, frequencies, varnames, timerange=timerange, ncfiles=None)

In [None]:
showdata()
