In [1]:
import xarray as xr 
import numpy as np

from dask import delayed
from tqdm import tqdm
from dask.diagnostics import ProgressBar
import dask
from dask import delayed, compute
import glob
from dask.distributed import Client

import zarr
from numcodecs import Blosc

import gc

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

import wofscast.my_graphcast as graphcast

from wofscast.data_generator import (load_wofscast_data, 
                                    wofscast_data_generator, 
                                    wofscast_batch_generator, 
                                    to_static_vars,
                                    add_local_solar_time
                                    
                                    )
#from wofscast.wofscast_task_config import WOFS_TASK_CONFIG, train_lead_times, TARGET_VARS
from wofscast import data_utils
import dataclasses

from dask.distributed import performance_report

In [2]:
from datetime import datetime
import pandas as pd

def add_time_dim(dataset, paths):
    """Add time dimensions/coords to make use of GraphCast data utils"""
    # wrfout or wrfwof  
    start_str = os.path.basename(paths[0]).split('_')[0] # wrfout or wrfwof 

    dts = [datetime.strptime(os.path.basename(f), 
                             f'{start_str}_d01_%Y-%m-%d_%H:%M:%S.zarr')
           for f in paths]
    
    time_range = [pd.Timestamp(dt) for dt in dts]
    dataset['time'] = time_range
    
    dataset = dataset.assign_coords(datetime=time_range)

    # Convert 'time' dimension to timedeltas from the first time point
    time_deltas = (dataset['time'] - dataset['time'][0]).astype('timedelta64[ns]')
    dataset['time'] = time_deltas
        
    return dataset  

In [3]:
paths = glob.glob('/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof*')
paths.sort() 

dataset = xr.open_mfdataset(paths, concat_dim='Time', combine='nested', 
                            parallel=True, chunks={}, engine='zarr')  

dataset = dataset.rename({'Time': 'time'})
dataset = add_time_dim(dataset, paths)
dataset = add_local_solar_time(dataset)

In [15]:
paths

['/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:00:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:05:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:10:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:15:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:20:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:25:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:30:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:35:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:40:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:45:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_MEM_07/wrfwof_d01_2020-03-02_20:50:00.zarr',
 '/work2/wofs_zarr/2020/20200302/2000/ENS_M

In [4]:
%%time 
dataset = dataset.compute() 

CPU times: user 1.32 s, sys: 12.7 s, total: 14 s
Wall time: 6.28 s


In [9]:
# the number of gridpoints in one direction; square domain.
DOMAIN_SIZE = 300

VARS_3D = ['U', 'V', 'W', 'T', 'GEOPOT', 'QVAPOR']
VARS_2D = ['T2', 'COMPOSITE_REFL_10CM', 'UP_HELI_MAX']
STATIC_VARS = ['XLAND', 'HGT']

INPUT_VARS = VARS_3D + VARS_2D + STATIC_VARS
TARGET_VARS = VARS_3D + VARS_2D

# I compute this myself rather than using the GraphCast code. 
FORCING_VARS = (
            'local_solar_time_sin',
            'local_solar_time_cos',
            'toa_radiation'
        )
# Not pressure levels, but just vertical array indices at the moment. 
# When I created the wrfwof files, I pre-sampled every 3 levels. 
PRESSURE_LEVELS = np.arange(50)

# Loads data from the past 20 minutes (2 steps) and 
# creates a target over the next 10-60 min. 
INPUT_DURATION = '10min'
train_lead_times = '5min'

task_config = graphcast.TaskConfig(
      input_variables=INPUT_VARS,
      target_variables=TARGET_VARS,
      forcing_variables=FORCING_VARS,
      pressure_levels=PRESSURE_LEVELS,
      input_duration=INPUT_DURATION,
      n_vars_2D = len(VARS_2D),
      domain_size = DOMAIN_SIZE
 )

def batch_extract_inputs_targets_forcings(dataset, 
                                          n_input_steps, 
                                          n_target_steps, 
                                          target_lead_times): 
    '''
    Based on an input dataset with multiple timesteps, this function 
    returns rollouts multiple, mutually exclusive input/output pairs
    concatenating them along a 'batch' dimension. 
    '''
    inputs = []
    targets = []
    forcings = [] 

    n_total_steps = 3 # 2 input steps + 1 target step

    for i in range(0, dataset.time.size-n_total_steps, n_total_steps+1):
        print(i, i+n_total_steps, dataset.time.size)
        _inputs, _targets, _forcings = data_utils.extract_inputs_targets_forcings(
                dataset.isel(time=slice(i, i+n_total_steps), datetime=slice(i,i+n_total_steps)), 
                target_lead_times=train_lead_times,
                **dataclasses.asdict(task_config)
            )
        inputs.append(_inputs)
        targets.append(_targets)
        forcings.append(_forcings)
    
    inputs = xr.concat(inputs, dim='batch')
    targets = xr.concat(targets, dim='batch')
    forcings = xr.concat(forcings, dim='batch')
    
    return inputs, targets, forcings 

0 3 14
4 7 14
8 11 14


In [10]:
inputs

In [14]:
targets#.isel(batch=-1)['COMPOSITE_REFL_10CM'].plot()