# Plot frames for ACCESS-OM3 surface speed movie

https://github.com/COSIMA/cosima-recipes/blob/main/Tutorials/Model_Agnostic_Analysis.ipynb

In [None]:
%matplotlib inline
import pandas as pd
import intake
import dask
import xarray as xr
import numpy as np
import cf_xarray as cfxr
import pint_xarray
from pint import application_registry as ureg
import cf_xarray.units
import cftime
import xgcm
import os

import matplotlib.pyplot as plt
import cmocean as cm
import cartopy.crs as ccrs
import cartopy.feature as cft

from dask.distributed import Client

In [None]:
xr.set_options(keep_attrs=True); # cf_xarray works best when xarray keeps attributes by default

In [None]:
# from https://github.com/COSIMA/cosima-recipes/blob/main/Tutorials/ACCESS-NRI_Intake_Catalog.ipynb
# Try passing the following argument to your to_dask or to_dataset_dict call:
# See the xarray documentation on Reading multi-file datasets for more details about these arguments.
# https://docs.xarray.dev/en/stable/user-guide/io.html#reading-multi-file-datasets

xarray_combine_by_coords_kwargs=dict(
    compat="override",
    data_vars="minimal",
    coords="minimal"
)

In [None]:
client = Client(threads_per_worker=1)
client

In [None]:
thisdir = '/g/data/v45/aek156/notebooks/github/aekiss/access-eval-recipes/ocean/'

In [None]:
# PATH='/scratch/v45/aek156/access-om3/archive/MOM6-CICE6-1deg_jra55do_ryf.iss138/'
# PATH='/scratch/v45/aek156/access-om3/archive/MOM6-CICE6-1deg_jra55do_ryf.testAug2024/'
# datastore = intake.open_esm_datastore(PATH+'intake_datastore.json', columns_with_iterables=['variable'])

In [None]:
#PATH='/g/data/tm70/ml0072/COMMON/git_repos/test-Reichl-2025-04-continue/archive/'
#datastore = intake.open_esm_datastore(PATH+'intake_esm_ds.json', columns_with_iterables=['variable'])

https://access-nri.zulipchat.com/#narrow/dm/784080,784272-dm/near/523926733

Everything is here
/g/data/tm70/ml0072/COMMON/git_repos/control_runs/ctrl_25km_1st_version

where

dev-MC_025deg_jra_ryf_alpha_rel covers year 00-19

test-Reichl-2025-04-continue covers 20-48

ctrl_run_25km_0.5 currently covers 49-62. Ongoing data will be temporarily transferred here as well.

dev-MC_025deg_jra_ryf_alpha_rel and test-Reichl-2025-04-continue share the same topo, but different MOM_parameters and different timesteps
the topos are different from test-Reichl-2025-04-continue and ctrl_run_25km_0.5 but they share the same MOM_parameters


https://access-nri.zulipchat.com/#narrow/dm/784080,784272-dm/near/524167154

Minghang Li: Hi Andrew, I've finally managed to move the data to ik11 and just created an esm datastore for this. /g/data/ik11/outputs/access-om3-025/MC_25km_jra_ryf_0.5_prerelease I haven’t tested it yet, so please let me know if you notice anything unusual.

Minghang Li: The most up-to-date year currently available is 1988. I’ll continue transferring data and updating the esm-datastore as new data become available.

Minghang Li: It seems the newly generated ESM datastore might not be correctly configured. I’m reaching out to Charles to see if he can help diagnose the issue.

Minghang Li: Hi Andrew, **We’re using different horizontal grids for the runs - years 0–48 use one version, while years 49 onward use a double precision version.** As a result, when generating the ESM datastore, I’m running into an error due to the differing grid definitions ValueError: Resulting object does not have monotonic global indexes along dimension yh. I can apply some postprocessing to standardise the grid across variables

In [None]:
PATH = '/g/data/tm70/ml0072/COMMON/git_repos/control_runs/ctrl_25km_1st_version/'
catalogs = [
    # 'MC_25km_jra_ryf_0.5_prerelease_0-19.json',
    # 'MC_25km_jra_ryf_0.5_prerelease_20-48.json',
    'MC_25km_jra_ryf_0.5_prerelease_49-62.json', # different grid from here onwards
    # PATH+'test-Reichl-2025-04-continue/archive/intake_esm_ds.json',
    # PATH+'ctrl_run_25km_0.5/archive/intake_esm_ds.json', # faulty
    # 'ctrl_run_25km_0.5_intake_esm_ds.json', # replacement for faulty one above
    '/g/data/ik11/outputs/access-om3-025/MC_25km_jra_ryf_0.5_prerelease/63-ongoing/63-ongoing.json'
]

In [None]:
datastores = [ intake.open_esm_datastore(c, columns_with_iterables=['variable']) for c in catalogs ]
# datastore = pd.concat(datastores)

In [None]:
datastores

In [None]:
# get coords from short run without processor masking
# https://github.com/aekiss/MOM6-CICE6/commit/59ed8ffc6ae1d4a79821a951924e7c853d9b788a
# https://xgcm.readthedocs.io/en/latest/xgcm-examples/03_MOM6.html#A-note-on-geographical-coordinates
static = xr.open_dataset('/g/data/ik11/outputs/access-om3-025/grid/access-om3.mom6.static.nc')

In [None]:
static

see https://xgcm.readthedocs.io/en/latest/xgcm-examples/03_MOM6.html#xgcm-grid-definition

ACCESS-OM3 uses a non-symmetric memory layout, i.e. all fields have the same i and j sizes. See
https://mom6.readthedocs.io/en/main/api/generated/pages/Horizontal_Indexing.html?highlight=symmetric#declaration-of-variables

`MOM_parameter_doc.layout`:
```
!SYMMETRIC_MEMORY_ = False      !   [Boolean]
                                ! If defined, the velocity point data domain includes every face of the
                                ! thickness points. In other words, some arrays are larger than others,
                                ! depending on where they are on the staggered grid.  Also, the starting index
                                ! of the velocity-point arrays is usually 0, not 1. This can only be set at
                                ! compile time.```


In [None]:
# for non-symmetric 
# SYMMETRIC_MEMORY_ = False
# see https://xgcm.readthedocs.io/en/latest/xgcm-examples/03_MOM6.html#xgcm-grid-definition
# and https://xgcm.readthedocs.io/en/latest/grid_metrics.html#Using-metrics-with-xgcm
grid = xgcm.Grid(static,
                 coords={'X': {'center': 'xh', 'right': 'xq'},
                         'Y': {'center': 'yh', 'right': 'yq'},},
                         # 'Z': { 'inner': 'zl', 'outer': 'zi'}},
                 metrics = {
                        ('X',): ['dxt', 'dxCu', 'dxCv'], # X distances
                        ('Y',): ['dyt', 'dyCu', 'dyCv'], # Y distances
                        ('X', 'Y'): ['areacello', 'areacello_cu', 'areacello_cv', 'areacello_bu'] # Areas
                        },
                 periodic=['X'])

In [None]:
blue_marble = plt.imread('/g/data/ik11/grids/BlueMarble.tiff')
blue_marble_extent = (-180, 180, -90, 90)

In [None]:
access-om3.mom6.2d.mlotst.1mon.mean.1966.nc
access-om3.mom6.2d.mlotst.1mon.max.1966.nc

In [None]:
d

In [None]:
datastores[1].search(variable='mlotst',
                 frequency='1mon').df.head()['variable_cell_methods'][0]

In [None]:
data = dict()
d = [ ds.search(variable='mlotst', frequency='1day', variable_cell_methods='.*time: mean.*') for ds in datastores ]
d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
data['mlotst'] = xr.concat(d, 'time')

In [None]:
# set date range
trange = slice(cftime.DatetimeNoLeap(1970, 1, 1, 0, 0, 0, 0),
               cftime.DatetimeNoLeap(1980, 1, 1, 0, 0, 0, 0))

In [None]:
MLD_monthly_mean = data['mlotst']['mlotst'].sel(time=trange).groupby('time.month').mean('time')

In [None]:
# small BUG: mean of monthly means is not mean of days in that month (eg Feb gets slightly more heavily weighted)
MLD_JFM_mean = MLD_monthly_mean.sel(month=slice(1,3)).mean('month')
MLD_JAS_mean = MLD_monthly_mean.sel(month=slice(7,9)).mean('month')

In [None]:
MLD_JFM_mean.load()

In [None]:
MLDmax = data['mlotst']['mlotst'].sel(time=trange).max('time')

In [None]:
MLDmax.load()

In [None]:
# to match Treguier et al 2023 fig 1 https://doi.org/10.5194/gmd-16-3849-2023
dat = MLDmax.cf.assign_coords( { "longitude": static['geolon'],
                                 "latitude": static['geolat'] })
fname = thisdir+k+'_'+dat.attrs['long_name'].replace(' ', '_').replace('/', '_')+'.png'
if os.path.isfile(fname):
    print(f'   ---- skipping existing file {fname}')
else:
    fig = plt.figure(figsize=(12, 6))
    ax = plt.axes(projection=ccrs.Robinson(central_longitude=-100))
    dat.plot.contourf(
        ax=ax,
        levels=51,
        vmin=0,
        vmax=500,
        extend="max",
        # cmap=cm.cm.thermal,
        cmap='viridis',
        transform=ccrs.PlateCarree(),
        cbar_kwargs={"label": dat.attrs['units'], "fraction": 0.03, "aspect": 15, "shrink": 0.7},
    )
    
    # Add blue marble land:
    ax.imshow(
        blue_marble, extent=blue_marble_extent, transform=ccrs.PlateCarree(), origin="upper"
    )
    
    plt.title(f"Maximum Daily Mean {dat.attrs['long_name']}, {trange.start.strftime('%Y-%m-%d')} - {trange.stop.strftime('%Y-%m-%d')}");
    
    # plt.savefig(fname, dpi=150)
    # print(f'   saved {fname}')


In [None]:
fields = [
    'speed'
]
data = dict()
for k in fields:
    print(k)
    # d = [ ds.search(variable=k).to_dataset_dict(xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs) for ds in datastores ]
    try:
        d = [ ds.search(variable=k, frequency='1day', variable_cell_methods='.*time: mean.*') for ds in datastores ]
        d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
    except ValueError:
        # try:
        #     d = [ ds.search(variable=k, frequency='1mon', variable_cell_methods='.*time: mean.*') for ds in datastores ]
        #     d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
        # except:
        #     try:
        #         d = [ ds.search(variable=k, variable_cell_methods='.*time: mean.*') for ds in datastores ]
        #         d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
        #     except:
        #         try:
        #             d = [ ds.search(variable=k, variable_cell_methods='.*time: point.*') for ds in datastores ]
        #             d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
        #         except:
        #             try:
        #                 d = [ ds.search(variable=k, variable_cell_methods='.*time: min.*') for ds in datastores ]
        #                 d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
        #             except:
        #                 try:
        #                     d = [ ds.search(variable=k, variable_cell_methods='.*time: max.*') for ds in datastores ]
        #                     d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
        #                 except:
                            print(f'{k} failed')
                            continue
    if d:
        data[k] = xr.concat(d, 'time')
    else:
        print(f'no data for {k}')
data

### Make movie frames

In [None]:
# set date range

tstart = cftime.DatetimeNoLeap(1980, 1, 1, 0, 0, 0, 0)
tend = cftime.DatetimeNoLeap(1981, 1, 1, 0, 0, 0, 0)
timerange = slice(tstart, tend)
print('tstart =', tstart)
print('tend =', tend)

In [None]:
for k, d in data.items():
    print(k)
    datall = d[k].sel(time=timerange)
    # datall.load()
    for t in datall.time.values:
        dat = datall.sel(time=t).cf.assign_coords( { "longitude": static['geolon'],
                                                    "latitude": static['geolat'] })
        fname = thisdir+k+'_'+dat.attrs['long_name'].replace(' ', '_').replace('/', '_')+'_'+t.strftime('%Y-%m-%d')+'.png'
        if os.path.isfile(fname):
            print(f'   ---- skipping existing file {fname}')
        else:
            fig = plt.figure(figsize=(12, 6))
            ax = plt.axes(projection=ccrs.Robinson(central_longitude=-100))
            dat.plot.contourf(
                ax=ax,
                levels=33,
                vmin=0,
                vmax=1,
                extend="max",
                cmap=cm.cm.thermal,
                transform=ccrs.PlateCarree(),
                cbar_kwargs={"label": dat.attrs['units'], "fraction": 0.03, "aspect": 15, "shrink": 0.7},
            )
            
            # Add blue marble land:
            ax.imshow(
                blue_marble, extent=blue_marble_extent, transform=ccrs.PlateCarree(), origin="upper"
            )
            
            plt.title(dat.attrs['long_name']+' '+t.strftime('%Y-%m-%d'));

            try:
                plt.savefig(fname, dpi=150)
                print(f'   saved {fname}')
            except FileNotFoundError:
                print(f'*** FileNotFoundError when saving {fname}')
            # break
    # break

In [None]:
! module load ffmpeg
! ffmpeg -r 30 -pattern_type glob -i 'speed_Sea_Surface_Speed_*.png' -c:v libx264 -vf "pad=trunc((iw+1)/2)*2:trunc((ih+1)/2)*2:0:0:white,crop=w=1506:h=692:x=219:y=92" -preset veryslow -tune animation -crf 25 -pix_fmt yuv420p -r 30 Sea_Surface_Speed5.mp4

In [None]:
datall

In [None]:
static['geolon']

In [None]:
datall.sel(time=t)

In [None]:
speed['speed'].isel(time=-1)

In [None]:
dask.array.rechunk(datall, chunks='auto')

In [None]:
datall

In [None]:
datall.cf

In [None]:
datall.cf.assign_coords(
        { "longitude": static['geolon'],
         "latitude": static['geolat'] }
    )

In [None]:
speed.speed.isel(time=-1).cf

In [None]:
fig = plt.figure(figsize=(12, 6))
ax = plt.axes(projection=ccrs.Robinson(central_longitude=-100))

speed.speed.isel(time=-1).cf.assign_coords({ "longitude": static['geolon'],
                                             "latitude": static['geolat'] }
                                        ).plot.contourf(
    ax=ax,
    # x="longitude",
    # y="latitude",
    levels=33,
    vmin=0,
    vmax=1,
    extend="max",
    cmap=cm.cm.thermal,
    transform=ccrs.PlateCarree(),
    cbar_kwargs={"fraction": 0.03, "aspect": 15, "shrink": 0.7},
)

# Add blue marble land:
ax.imshow(
    blue_marble, extent=blue_marble_extent, transform=ccrs.PlateCarree(), origin="upper"
)

plt.title(speed.speed.attrs['long_name']);

# OLD BELOW

In [None]:
# fields = ['tosga', 'thetaoga', 'tos', 'sosga', 'soga', 'sos', 'SSH']#, 'KE', 'sss_global', 'volo', 'masso',]
fields = [ # from ncdump -h ctrl_run_25km_0.5/archive/output009/access-om3.mom6.scalar.1day.snap.1962.nc | grep double
    'soga',
    'thetaoga',
    'tosga',
    'sosga',
    'total_salt_Flux_Added',
    'total_salt_Flux_In',
    'total_salt_flux',
    'net_fresh_water_global_adjustment',
    'salt_flux_global_restoring_adjustment',
    'total_wfo',
    'total_evs',
    'total_fsitherm',
    'total_precip',
    'total_prsn',
    'total_lprec',
    'total_ficeberg',
    'total_friver',
    'total_net_massout',
    'total_net_massin',
]
data = dict()
for k in fields:
    print(k)
    # d = [ ds.search(variable=k).to_dataset_dict(xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs) for ds in datastores ]
    try:
        d = [ ds.search(variable=k) for ds in datastores ]
        d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
    except ValueError:
        try:
            d = [ ds.search(variable=k, frequency='1mon', variable_cell_methods='.*time: mean.*') for ds in datastores ]
            d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
        except:
            try:
                d = [ ds.search(variable=k, variable_cell_methods='.*time: mean.*') for ds in datastores ]
                d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
            except:
                try:
                    d = [ ds.search(variable=k, variable_cell_methods='.*time: point.*') for ds in datastores ]
                    d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
                except:
                    try:
                        d = [ ds.search(variable=k, variable_cell_methods='.*time: min.*') for ds in datastores ]
                        d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
                    except:
                        try:
                            d = [ ds.search(variable=k, variable_cell_methods='.*time: max.*') for ds in datastores ]
                            d = [ ds.to_dask() for ds in d if ds ] # avoid .to_dask() for empty datasets
                        except:
                            print(f'{k} failed')
                            continue
    if d:
        data[k] = xr.concat(d, 'time')
    else:
        print(f'no data for {k}')
data

In [None]:
data.keys()

In [None]:
for k, d in data.items():
    print(k)
    dat = d[k]
    fname = thisdir+k+'_'+dat.attrs['long_name'].replace(' ', '_')+'.png'
    if os.path.isfile(fname):
        print(f'---- skipping existing file {fname}')
    else:
        if 'depth' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
            dat = grid.average(dat, ['X', 'Y', 'Z'])
        elif 'longitude' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
            dat = grid.average(dat, ['X', 'Y'])
        dat.load()
        if int((dat.time[1]-dat.time[0]).values/1e9/60/60/24) == 1:
            label = 'daily'
        else:
            label = 'monthly mean' # possible BUG: plausible guess
        plt.figure(figsize=(10,5))
        dat.plot(label=label)
        dat.cf.resample(time='1YE').mean('time').plot(label='annual mean')
        # dat.cf.rolling(time=12, center=True).mean('time').plot()
        # dat.cf.rolling(time='1YE', center=True).mean('time').plot()
        plt.title(k+': '+dat.attrs['long_name'])#+' ['+dat.attrs['units']+']')
        plt.legend()
        try:
            plt.savefig(fname, dpi=150)
            print(f'saved {fname}')
        except FileNotFoundError:
            print(f'*** FileNotFoundError when saving {fname}')
    # break

In [None]:
dat.attrs['long_name'].replace(' ', '_')

# OLDER BELOW

In [None]:
fields_mean = ['tosga', 'thetaoga', 'tos', 'sosga', 'soga', 'sos', 'sss_global', 'SSH', 'volo', 'masso',]# 'KE']
data_mean = { k: datastore.search(variable=k, 
                                  # frequency='1mon', 
                                  # variable_cell_methods='.*time: mean.*'
                                 ).to_dataset_dict(
    xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs) for k in fields_mean }
data_mean = { k: list(v.values())[0] for k, v in data_mean.items() if v } # drop any empty datasets
data_mean

In [None]:
fields_mean = ['thetaoga', 'tos', 'soga', 'sos', 'sss_global', 'SSH', 'volo', 'masso',]# 'KE']
data_mean = { k: datastore.search(variable=k, 
                                  frequency='1mon', 
                                  # variable_cell_methods='.*time: mean.*'
                                 ).to_dataset_dict(
    xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs) for k in fields_mean }
data_mean = { k: list(v.values())[0] for k, v in data_mean.items() if v } # drop any empty datasets
data_mean

In [None]:
fields_min = [ 'SSH_min', 'mlotst_min','uh', 'vh' ]
data_min = { k: datastore.search(variable=k, 
                                 frequency='1mon', 
                                 variable_cell_methods='.*time: min.*'
                                ).to_dataset_dict(
    xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs) for k in fields_min }
data_min = { k: list(v.values())[0] for k, v in data_min.items() if v } # drop any empty datasets
data_min

In [None]:
fields_max = [ 'SSH_max', 'mlotst_max', 'uh', 'vh' ]
data_max = { k: datastore.search(
    variable=k, 
    frequency='1mon', 
    variable_cell_methods='.*time: max.*'
).to_dataset_dict(
    xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs) for k in fields_max }
data_max = { k: list(v.values())[0] for k, v in data_max.items() if v } # drop any empty datasets
data_max

In [None]:
for k, d in data_mean.items():
    dat = d[k]
    if 'depth' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
        dat = grid.average(dat, ['X', 'Y', 'Z'])
    elif 'longitude' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
        dat = grid.average(dat, ['X', 'Y'])
    dat.load()
    if int((dat.time[1]-dat.time[0]).values/1e9/60/60/24) == 1:
        label = 'daily'
    else:
        label = 'monthly mean' # BUG: wild guess
    plt.figure(figsize=(10,5))
    dat.plot(label=label)
    dat.cf.resample(time='1YE').mean('time').plot(label='annual mean')
    # dat.cf.rolling(time=12, center=True).mean('time').plot()
    # dat.cf.rolling(time='1YE', center=True).mean('time').plot()
    plt.title(k)
    plt.legend()
    # break

In [None]:
for k, d in data_min.items():
    dat = d[k]
    if 'depth' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
        dat = dat.cf.min(dim=['latitude', 'longitude', 'depth'])
    elif 'longitude' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
        dat = dat.cf.min(dim=['latitude', 'longitude'])
    dat.load()
    plt.figure(figsize=(10,5))
    dat.plot()
    dat.cf.rolling(time=12, center=True).mean('time').plot()
    plt.title(k)

In [None]:
for k, d in data_max.items():
    dat = d[k]
    if 'depth' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
        dat = dat.cf.max(dim=['latitude', 'longitude', 'depth'])
    elif 'longitude' in dat.cf.coords: # https://xgcm.readthedocs.io/en/latest/grid_metrics.html?highlight=average#Grid-aware-(weighted)-average
        dat = dat.cf.max(dim=['latitude', 'longitude'])
    dat.load()
    plt.figure(figsize=(10,5))
    dat.plot()
    dat.cf.rolling(time=12, center=True).mean('time').plot()
    plt.title(k)