# Examples of Using Prefect + Funnel with CESM data

In [91]:
import os
os.environ["PREFECT__FLOWS__CHECKPOINTING"] = "True"

In [132]:
from esm_collections import calc
import ast
import os
from prefect import Flow, Parameter, task
from funnel import CacheStore, SQLMetadataStore
from funnel.prefect.result import FunnelResult
import intake
import prefect
from distributed import Client
from ncar_jobqueue import NCARCluster
import xcollection as xc
import holoviews as hv
import pop_tools
import hvplot
import xarray as xr
import hvplot.xarray
prefect.context.to_dict()['config']['flows']

<Box: {'eager_edge_validation': False, 'run_on_schedule': True, 'checkpointing': True, 'defaults': {'storage': {'add_default_labels': True, 'default_class': 'prefect.storage.Local'}}}>

In [93]:
cache_dir = '/glade/scratch/mgrover/funnel_cache/funnel_demo_xdev'

In [94]:
r = FunnelResult(SQLMetadataStore(CacheStore(cache_dir), serializer='xcollection',
                                  database_url=f'sqlite:///{cache_dir}/funnel.db'))

In [141]:
SQLMetadataStore(CacheStore(cache_dir), serializer='xcollection',
                                  database_url=f'sqlite:///{cache_dir}/funnel.db').df

Unnamed: 0_level_0,serializer,load_kwargs,dump_kwargs,custom_fields,checksum,created_at
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
long_term_mean/,xcollection,{},{},{},,2021-12-01 21:02:33.131818
annual_mean/,xcollection,{},{},{},,2021-12-01 21:02:33.131818
global_average/,xcollection,{},{},{},,2021-12-01 21:02:33.131818


## Setup your Tasks

In [95]:
# Deal with Intake-ESM catalogs
@task
def read_catalog(path, csv_kwargs):
    return intake.open_esm_datastore(path, csv_kwargs=csv_kwargs)

@task
def subset_catalog(catalog, search_dict):
    return catalog.search(**search_dict)

# Calculations
@task
def subset_dates(catalog, search_dict, date_subset):
    search_dict.update({'date':catalog.df.date[date_subset].values})
    return catalog.search(**search_dict)

@task
def convert_to_collection(dsets):
    return xc.Collection(dsets)

@task
def load_data(catalog, cdf_kwargs):
    return catalog.to_dataset_dict(cdf_kwargs=cdf_kwargs)

@task
def center_time(collection):
    return collection.map(calc.center_time)

@task(target="long_term_mean/", result=r)
def long_term_mean(collection):
    return collection.map(calc.temporal_average)

@task(target="annual_mean/", result=r)
def annual_mean(collection):
    return collection.map(calc.yearly_average)

@task
def global_average(ds, horizontal_dims, area_field, land_sea_mask, time_dim, include_ms=False):
    return calc.global_mean(ds, horizontal_dims=horizontal_dims, area_field=area_field, land_sea_mask=land_sea_mask, time_dim=time_dim, normalize=True)

@task
def global_integral(ds, horizontal_dims, area_field, land_sea_mask, time_dim, include_ms=False):
    return calc.global_mean(ds, horizontal_dims=horizontal_dims, area_field=area_field, land_sea_mask=land_sea_mask, time_dim=time_dim, normalize=False)

@task
def zonal_average(da, grid, lat_field, ydim, xdim, lat_axis, region_mask=None):
    return calc.zonal_mean(da=da, grid=grid, lat_field=lat_field, ydim=ydim, xdim=xdim, lat_axis=lat_axis, region_mask=region_mask)


## Setup the Flow Using A Collection Json

In [96]:
with Flow('timeseries_collection_from_json') as timeseries_collection_from_json:
    path = Parameter('path', )
    csv_kwargs = Parameter('csv_kwargs', default={})
    search_dict = Parameter('search_dict', default={})
    cdf_kwargs = Parameter('cdf_kwargs', default={'chunks':{}})

    # Read in the data catalog
    data_catalog = read_catalog(path, csv_kwargs)

    # Subset the catalog
    catalog_subset = subset_catalog(data_catalog, search_dict)

    # Load the data
    dsets_timeseries = load_data(catalog_subset, cdf_kwargs=cdf_kwargs)

    # Convert to xcollection
    collection_timeseries = convert_to_collection(dsets_timeseries)

    # Center time
    collection_timeseries_center_time = center_time(collection_timeseries)
    
    # Calculate a long-term mean
    long_term_average = long_term_mean(collection_timeseries_center_time)
    
    # Calculate the annual average
    annual_average = annual_mean(collection_timeseries_center_time)

## Execute our Flow

### Spin up a Dask Cluster

In [8]:
cluster = NCARCluster()
cluster.scale(10)
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mgrover/proxy/8787/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mgrover/proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.39:46130,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mgrover/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [97]:
timeseries_collection_from_json.run(path="/glade/campaign/cesm/development/omwg/projects/MOMMARBL_vs_POPECO/catalog/MOMvsPOP.json",
                                    search_dict={'stream':'pop.h', 'variable':['TEMP', 'SALT', 'FG_CO2']})

[2021-12-01 14:25:42-0700] INFO - prefect.FlowRunner | Beginning Flow run for 'timeseries_collection_from_json'
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'path': Starting task run...
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'path': Finished task run for task with final state: 'Success'
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'csv_kwargs': Starting task run...
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'csv_kwargs': Finished task run for task with final state: 'Success'
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'search_dict': Starting task run...
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'search_dict': Finished task run for task with final state: 'Success'
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'cdf_kwargs': Starting task run...
[2021-12-01 14:25:42-0700] INFO - prefect.TaskRunner | Task 'cdf_kwargs': Finished task run for task with final state: 'Success'
[2

[2021-12-01 14:25:50-0700] INFO - prefect.TaskRunner | Task 'load_data': Finished task run for task with final state: 'Success'
[2021-12-01 14:25:50-0700] INFO - prefect.TaskRunner | Task 'convert_to_collection': Starting task run...
[2021-12-01 14:25:50-0700] INFO - prefect.TaskRunner | Task 'convert_to_collection': Finished task run for task with final state: 'Success'
[2021-12-01 14:25:50-0700] INFO - prefect.TaskRunner | Task 'center_time': Starting task run...
[2021-12-01 14:25:50-0700] INFO - prefect.TaskRunner | Task 'center_time': Finished task run for task with final state: 'Success'
[2021-12-01 14:25:50-0700] INFO - prefect.TaskRunner | Task 'long_term_mean': Starting task run...
[2021-12-01 14:27:18-0700] INFO - prefect.TaskRunner | Task 'long_term_mean': Finished task run for task with final state: 'Success'
[2021-12-01 14:27:18-0700] INFO - prefect.TaskRunner | Task 'annual_mean': Starting task run...
[2021-12-01 14:29:13-0700] INFO - prefect.TaskRunner | Task 'annual_mean

<Success: "All reference tasks succeeded.">

In [98]:
annual_mean_collection = xc.open_collection(f'{cache_dir}/annual_mean')

In [99]:
long_term_mean_collection = xc.open_collection(f'{cache_dir}/long_term_mean')

In [117]:
long_term_mean_collection['ocn.pop.h.pop_control']

## Add a Plotting Function

We start first with a function that operates on an **`xarray.Dataset`**

In [101]:
def plot_2D(ds, xdim, ydim):
    ds[xdim] = ds[xdim]
    ds[ydim] = ds[ydim]
    variables = list(set(ds.variables) - set(ds.coords))
    return ds.hvplot.quadmesh(x=xdim, y=ydim, z=variables, cmap='magma', levels=20, rasterize=True)

Then, we modify it to be applied to collection of datasets (`xcollection.Collection`)

In [102]:
@task
def apply_plot2D(collection, xdim, ydim, ncols=1):
    plots = []
    for key in collection.keys():
        plots.append(plot_2D(collection[key], xdim, ydim))
    return hv.Layout(plots).cols(ncols)

In [103]:
with Flow('timeseries_collection_from_json') as timeseries_collection_from_json:
    path = Parameter('path', )
    csv_kwargs = Parameter('csv_kwargs', default={})
    search_dict = Parameter('search_dict', default={})
    cdf_kwargs = Parameter('cdf_kwargs', default={'chunks':{}})

    # Read in the data catalog
    data_catalog = read_catalog(path, csv_kwargs)

    # Subset the catalog
    catalog_subset = subset_catalog(data_catalog, search_dict)

    # Load the data
    dsets_timeseries = load_data(catalog_subset, cdf_kwargs=cdf_kwargs)

    # Convert to xcollection
    collection_timeseries = convert_to_collection(dsets_timeseries)

    # Center time
    collection_timeseries_center_time = center_time(collection_timeseries)
    
    # Calculate a long-term mean
    long_term_average = long_term_mean(collection_timeseries_center_time)
    long_term_average_plot = apply_plot2D(long_term_average, xdim='nlon', ydim='nlat')
    
    # Calculate the annual average
    annual_average = annual_mean(collection_timeseries_center_time)
    annual_average_plot = apply_plot2D(annual_average, xdim='nlon', ydim='nlat')
    

### Run this again!

In [104]:
timeseries_with_plots = timeseries_collection_from_json.run(path="/glade/campaign/cesm/development/omwg/projects/MOMMARBL_vs_POPECO/catalog/MOMvsPOP.json",
                                                            search_dict={'stream':'pop.h', 'variable':['TEMP', 'SALT', 'FG_CO2']})

[2021-12-01 14:29:42-0700] INFO - prefect.FlowRunner | Beginning Flow run for 'timeseries_collection_from_json'
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'search_dict': Starting task run...
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'search_dict': Finished task run for task with final state: 'Success'
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'csv_kwargs': Starting task run...
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'csv_kwargs': Finished task run for task with final state: 'Success'
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'path': Starting task run...
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'path': Finished task run for task with final state: 'Success'
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'cdf_kwargs': Starting task run...
[2021-12-01 14:29:42-0700] INFO - prefect.TaskRunner | Task 'cdf_kwargs': Finished task run for task with final state: 'Success'
[2

[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'load_data': Finished task run for task with final state: 'Success'
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'convert_to_collection': Starting task run...
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'convert_to_collection': Finished task run for task with final state: 'Success'
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'center_time': Starting task run...
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'center_time': Finished task run for task with final state: 'Success'
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'annual_mean': Starting task run...
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'annual_mean': Finished task run for task with final state: 'Cached'
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'long_term_mean': Starting task run...
[2021-12-01 14:29:51-0700] INFO - prefect.TaskRunner | Task 'long_term_mean'



[2021-12-01 14:30:01-0700] INFO - prefect.TaskRunner | Task 'apply_plot2D': Finished task run for task with final state: 'Success'
[2021-12-01 14:30:01-0700] INFO - prefect.TaskRunner | Task 'apply_plot2D': Starting task run...




[2021-12-01 14:30:02-0700] INFO - prefect.TaskRunner | Task 'apply_plot2D': Finished task run for task with final state: 'Success'
[2021-12-01 14:30:02-0700] INFO - prefect.FlowRunner | Flow run SUCCESS: all reference tasks succeeded


In [105]:
timeseries_with_plots.result[annual_average_plot]._result.value

In [107]:
timeseries_with_plots.result[long_term_average_plot]._result.value

## Add a Global Mean Operation

As with before, we add a function that works on the `xarray.Dataset` level

In [136]:
def global_average(ds, horizontal_dims, area_field, land_sea_mask, time_dim, include_ms=False):
    grid = pop_tools.get_grid('POP_gx1v7')[['TAREA', 'KMT', 'REGION_MASK']]
    ds = xr.merge([ds, grid])
    return calc.global_mean(ds, horizontal_dims=horizontal_dims, area_field=area_field, land_sea_mask=land_sea_mask, time_dim=time_dim, normalize=True)

Then, we apply this to the the collection of datasets

In [137]:
@task(target="global_average/", result=r)
def global_average_collection(collection, horizontal_dims, area_field, land_sea_mask, time_dim, include_ms=False):
    return collection.map(global_average,
                          horizontal_dims=horizontal_dims,
                          area_field=area_field,
                          land_sea_mask=land_sea_mask,
                          time_dim=time_dim,
                          include_ms=include_ms)

One thing to note here is that we will need to modify the name of the target (`xcollection`)

In [142]:
with Flow('timeseries_collection_from_json') as timeseries_collection_from_json:
    path = Parameter('path', )
    csv_kwargs = Parameter('csv_kwargs', default={})
    search_dict = Parameter('search_dict', default={})
    cdf_kwargs = Parameter('cdf_kwargs', default={'chunks':{}})

    # Read in the data catalog
    data_catalog = read_catalog(path, csv_kwargs)

    # Subset the catalog
    catalog_subset = subset_catalog(data_catalog, search_dict)

    # Load the data
    dsets_timeseries = load_data(catalog_subset, cdf_kwargs=cdf_kwargs)

    # Convert to xcollection
    collection_timeseries = convert_to_collection(dsets_timeseries)

    # Center time
    collection_timeseries_center_time = center_time(collection_timeseries)
    
    # Calculate a long-term mean
    long_term_average = long_term_mean(collection_timeseries_center_time)
    long_term_average_plot = apply_plot2D(long_term_average, xdim='nlon', ydim='nlat')
    
    # Calculate the annual average
    annual_average = annual_mean(collection_timeseries_center_time)
    annual_average_plot = apply_plot2D(annual_average, xdim='nlon', ydim='nlat')
    
    # Calculate the global averages
    global_average_collection.target = 'long_term_global_average'
    long_term_global_average = global_average_collection(long_term_average, 
                                                         horizontal_dims=('nlat', 'nlon'),
                                                         area_field='TAREA',
                                                         land_sea_mask='KMT',
                                                         time_dim=None)
    
    global_average_collection.target = 'annual_global_average'
    annual_global_average = global_average_collection(annual_average, 
                                                      horizontal_dims=('nlat', 'nlon'),
                                                      area_field='TAREA',
                                                      land_sea_mask='KMT',
                                                      time_dim='year')

In [143]:
timeseries_global_average = timeseries_collection_from_json.run(path="/glade/campaign/cesm/development/omwg/projects/MOMMARBL_vs_POPECO/catalog/MOMvsPOP.json",
                                                                search_dict={'stream':'pop.h', 'variable':['TEMP', 'SALT', 'FG_CO2', 'TAREA']})

[2021-12-01 14:54:34-0700] INFO - prefect.FlowRunner | Beginning Flow run for 'timeseries_collection_from_json'
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'search_dict': Starting task run...
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'search_dict': Finished task run for task with final state: 'Success'
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'path': Starting task run...
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'path': Finished task run for task with final state: 'Success'
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'csv_kwargs': Starting task run...
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'csv_kwargs': Finished task run for task with final state: 'Success'
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'cdf_kwargs': Starting task run...
[2021-12-01 14:54:34-0700] INFO - prefect.TaskRunner | Task 'cdf_kwargs': Finished task run for task with final state: 'Success'
[2

[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'load_data': Finished task run for task with final state: 'Success'
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'convert_to_collection': Starting task run...
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'convert_to_collection': Finished task run for task with final state: 'Success'
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'center_time': Starting task run...
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'center_time': Finished task run for task with final state: 'Success'
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'long_term_mean': Starting task run...
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'long_term_mean': Finished task run for task with final state: 'Cached'
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'annual_mean': Starting task run...
[2021-12-01 14:54:45-0700] INFO - prefect.TaskRunner | Task 'annual_mean'



[2021-12-01 14:55:07-0700] INFO - prefect.TaskRunner | Task 'apply_plot2D': Finished task run for task with final state: 'Success'
[2021-12-01 14:55:07-0700] INFO - prefect.TaskRunner | Task 'apply_plot2D': Starting task run...




[2021-12-01 14:55:10-0700] INFO - prefect.TaskRunner | Task 'apply_plot2D': Finished task run for task with final state: 'Success'
[2021-12-01 14:55:10-0700] INFO - prefect.TaskRunner | Task 'global_average_collection': Starting task run...
[2021-12-01 14:55:11-0700] INFO - prefect.TaskRunner | Task 'global_average_collection': Finished task run for task with final state: 'Success'
[2021-12-01 14:55:11-0700] INFO - prefect.FlowRunner | Flow run SUCCESS: all reference tasks succeeded


In [145]:
long_term_global_average_collection = xc.open_collection(f'{cache_dir}/long_term_global_average')
annual_global_average_collection = xc.open_collection(f'{cache_dir}/annual_global_average')

In [156]:
annual_global_average_collection['ocn.pop.h.pop_control'].TEMP.hvplot.contourf(x='year', ylim=(100000, 0), cmap='magma', levels=20)