## Data Generation for WoFSCast

1. Limit varialbes to wind components, temp, pressure, comp. refl, and 2-5 UH 
2. Using WRFOUT files from May 2020 
3. Only using the center 150 x 150 of a WoFS domain. 

In [1]:
# Data Generation for WoFSCast 
#
# Using the raw WoFS WRFOUTS stored locally on the NSSL machines, 
# build a 3D dataset formatted for the Google's GraphCast codebase
#
# Using the training inputs, compute the normalization statistics. 
# 

import sys, os 
sys.path.append('/home/monte.flora/python_packages/frdd-wofs-cast/')

import xarray as xr
import pandas as pd
import numpy as np 

from glob import glob
import os
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
import traceback
from tqdm import tqdm

import dataclasses


VARS_3D_TO_KEEP = ['U', 'V', 'W', 'T', 'PH', 'PHB']
VARS_2D_TO_KEEP = ['T2', 'RAINNC', 'COMPOSITE_REFL_10CM', 'UP_HELI_MAX']
CONSTANTS = ['HGT', 'XLAND']

VARS_TO_KEEP = VARS_3D_TO_KEEP + VARS_2D_TO_KEEP + CONSTANTS
BASE_WRFOUT_PATH = '/work2/wof/realtime/FCST/2020/'


# Main parallel processing script
init_times = ['2000', '2100', '2200', '2300', '0000', '0100']
mems = range(1, 18)
dates = os.listdir(BASE_WRFOUT_PATH)
dates.sort()
dates = dates[5:25]

# For Debugging! 
init_times = ['2000', '2100']
dates = dates[:6]
 
total_files = len(dates)*len(mems)*len(init_times)

max_workers = min(50, total_files)                
                  
print(f"Num of Files: {total_files}")

# Assuming a single latitude longitude grid for all WoFS cases!!
this_path = glob(os.path.join(BASE_WRFOUT_PATH, dates[0], '2300', 'ENS_MEM_01', 'wrfwof_d01_*'))[0]
with xr.open_dataset(this_path) as this_ds:
    data_vars = this_ds.data_vars
    drop_vars = [v for v in data_vars if v not in VARS_TO_KEEP]
    this_ds = this_ds.compute()
    
    # Renaming coordinate variables to align with the ERA5 naming convention.
    this_ds = this_ds.rename({ 
                    'XLAT': 'latitude', 'XLONG' : 'longitude', 
                    'south_north' : 'lat', 'west_east' : 'lon'
               })
    
    # Latitude and longitude are expected to be 1d vectors. 
    lat_1d = this_ds['latitude'].isel(lon=0, Time=0)
    lon_1d = this_ds['longitude'].isel(lat=0, Time=0)


# Function to process each dataset
def process_dataset(date, init_time, mem, drop_vars):
    BASE_WRFOUT_PATH = '/work2/wof/realtime/FCST/2020/'
    
    try:
        fname = f'/work/mflora/wofs-cast-data/eval_datasets/dataset_{date}{init_time}_mem_{mem:02d}.nc'
        
        #if os.path.exists(fname):
        #    return None
        
        # Start with 10-minute resolution 
        time_resolution = '10min'
    
        data_paths = glob(os.path.join(BASE_WRFOUT_PATH, date, init_time,
                                       f'ENS_MEM_{mem:02d}', 'wrfwof_d01_*'))
        print(f"Found {len(data_paths)} files for processing.")
        if not data_paths:
            print("No files found, skipping.")
            return None  # or handle the situation as needed
        
        data_paths.sort()

        # Extract an hour out since WoFS is more accurate 
        # at that point. Keep an 1 hr worth of forecasts. 
        data_paths = data_paths[12:68+2:2]
        
        drop_vars += ['XLAT', 'XLONG']
        
        ds = xr.open_mfdataset(data_paths, combine='nested', concat_dim='Time', 
                       drop_variables=drop_vars, engine='netcdf4')   

        # Loads the dataset into memory! 
        ds = ds.compute()

        # Combine geopotential perturbation + base state
        ds['GEOPOT'] = ds['PH'] + ds['PHB']
        ds = ds.drop_vars(['PH', 'PHB'])
        
        # Renaming coordinate variables to align with the ERA5 naming convention.
        ds = ds.rename({'Time': 'time', 'bottom_top' :'level', 
                    #'XLAT': 'latitude', 'XLONG' : 'longitude', 
                    'south_north' : 'lat', 'west_east' : 'lon'
               })

        # Destagger the wind fields 
        u_destaggered = 0.5 * (ds['U'] + ds['U'].roll(west_east_stag=-1, roll_coords=False))
        v_destaggered = 0.5 * (ds['V'] + ds['V'].roll(south_north_stag=-1, roll_coords=False))
        w_destaggered = 0.5 * (ds['W'] + ds['W'].roll(bottom_top_stag=-1, roll_coords=False))
        z_destaggered = 0.5 * (ds['GEOPOT'] + ds['GEOPOT'].roll(bottom_top_stag=-1, roll_coords=False))
        
        # Trim the last row/column if needed to match other variables' dimensions
        u_destaggered = u_destaggered.isel(west_east_stag=slice(None, -1))
        v_destaggered = v_destaggered.isel(south_north_stag=slice(None, -1))
        w_destaggered = w_destaggered.isel(bottom_top_stag=slice(None, -1))
        z_destaggered = z_destaggered.isel(bottom_top_stag=slice(None, -1))

        u_destaggered = u_destaggered.rename({'west_east_stag' : 'lon'})
        v_destaggered = v_destaggered.rename({'south_north_stag' : 'lat'})
        w_destaggered = w_destaggered.rename({'bottom_top_stag' : 'level'})
        z_destaggered = z_destaggered.rename({'bottom_top_stag' : 'level'})

        ds['U'] = u_destaggered
        ds['V'] = v_destaggered
        ds['W'] = w_destaggered
        ds['GEOPOT'] = w_destaggered

        # Add 300. to make it properly Kelvins. 
        ds['T']+=300. 
        
        # Latitude and longitude are expected to be 1d vectors. 
        #lat_1d = ds['latitude'].isel(lon=0, time=0)
        #lon_1d = ds['longitude'].isel(lat=0, time=0)
        #ds = ds.drop_vars(['lon', 'lat'])
        
        # Assign the 2D versions of 'xlat' and 'xlon' back to the dataset as coordinates
        ds = ds.assign_coords(lat=lat_1d, lon=lon_1d)

        #ds = ds.drop_vars(['longitude', 'latitude'])

        # Convert negative longitude values to 0-360 range and update the Dataset
        ds['lon'] = xr.where(ds['lon'] < 0, ds['lon'] + 180, ds['lon'])

        # Formating the time dimension for the graphcast code. 
        # Define the start time for your dataset
        start_time = pd.Timestamp(f'{date}{init_time}')

        num_time_points = ds.sizes['time']

        # Generate the datetime range
        time_range = pd.date_range(start=start_time, periods=num_time_points, freq=time_resolution)

        ds['time'] = time_range

        # Only adding a fake datetime so that the GraphCast code can drop it :) 
        ds = ds.assign_coords(datetime=time_range)

        # Convert 'time' dimension to timedeltas from the first time point
        time_deltas = (ds['time'] - ds['time'][0]).astype('timedelta64[ns]')
        ds['time'] = time_deltas

        # Add level coordinate to the dataset 
        ds = ds.assign_coords(level=ds.level)

        # Assuming 'lat' and 'lon' are the coordinate names for the grid dimensions
        n_lat, n_lon = ds.dims['lat'], ds.dims['lon']
    
        size = 150
        start_lat, start_lon = (n_lat - size) // 2, (n_lon - size) // 2
        end_lat, end_lon = start_lat + size, start_lon + size
        
        # Subsetting the dataset to the central size x size grid
        ds_subset = ds.isel(lat=slice(start_lat, end_lat), lon=slice(start_lon, end_lon))
        
        ## Define encoding with compression
        encoding = {var: {'zlib': True, 'complevel': 5} for var in ds.data_vars}
        ds_subset.to_netcdf(fname, encoding=encoding)
        
    except Exception as e:
        traceback.print_exc()
        return None 

    
date = '20200521'
init_time = '0000'
mem = 2
process_dataset(date, init_time, mem, drop_vars)

from wofscast import data_utils
from wofscast import my_graphcast as graphcast

# Create the training and target datasets. 
input_variables = ['U', 'V', 'W', 'T']#, 'P', 'REFL_10CM', 'UP_HELI_MAX']
target_variables = ['U', 'V', 'W', 'T']#, 'P', 'REFL_10CM', 'UP_HELI_MAX']
forcing_variables = ["XLAND"]
# Not pressure levels, but just vertical array indices at the moment. 
pressure_levels = np.arange(0, 40) #list(np.arange(0,40,2))
# Loads data from the past 20 minutes (2 steps) and 
# creates a target over the next 10-60 min. 
input_duration = '20min'
train_lead_times = slice('10min', '60min') 


task_config = graphcast.TaskConfig(
      input_variables=input_variables,
      target_variables=target_variables,
      forcing_variables=forcing_variables,
      pressure_levels=pressure_levels,
      input_duration=input_duration,
  )

# Example usage:
chunk_size = 3  # The size of each chunk
overlap = 2     # The overlap between consecutive chunks

data_paths = glob(os.path.join('/work/mflora/wofs-cast-data/eval_datasets/dataset*.nc'))
data_paths.sort()

train_input_list = []
train_target_list = []
train_forcing_list = []

datasets = []

for path in data_paths:
    dataset = xr.load_dataset(path)
    
    # @title Extract training and eval data
    example_batch = dataset.expand_dims(dim='batch', axis=0)

    _train_inputs, _train_targets, _train_forcings = data_utils.extract_inputs_targets_forcings(
            example_batch, target_lead_times=train_lead_times,
            **dataclasses.asdict(task_config))

    train_input_list.append(_train_inputs)
    train_target_list.append(_train_targets)
    train_forcing_list.append(_train_forcings)
    
train_inputs = xr.concat(train_input_list, dim='batch')
train_targets = xr.concat(train_target_list, dim='batch')
train_forcings = xr.concat(train_forcing_list, dim='batch')

base_path = '/work/mflora/wofs-cast-data/train_datasets'
# Save to NetCDF files
print('Saving the training datasets...')
train_inputs.to_netcdf(os.path.join(base_path, 'eval_inputs.nc'))
train_targets.to_netcdf(os.path.join(base_path, 'eval_targets.nc'))
train_forcings.to_netcdf(os.path.join(base_path, 'eval_forcings.nc'))


'''    
    
#process_dataset(dates[0], init_times[0], 1, drop_vars)    
with ProcessPoolExecutor(max_workers=max_workers) as executor:
    futures = [executor.submit(process_dataset, date, init_time, 
                               mem, drop_vars) for date in dates for mem in mems for init_time in init_times]
    for future in tqdm(futures, desc="Processing datasets", total=total_files):
        future.result()  # Wait for all futures to complete

print('Done!')

from wofscast import data_utils
from wofscast import my_graphcast as graphcast

# Create the training and target datasets. 
input_variables = ['U', 'V', 'W', 'T']#, 'P', 'REFL_10CM', 'UP_HELI_MAX']
target_variables = ['U', 'V', 'W', 'T']#, 'P', 'REFL_10CM', 'UP_HELI_MAX']
forcing_variables = ["XLAND"]
# Not pressure levels, but just vertical array indices at the moment. 
pressure_levels = np.arange(0, 40) #list(np.arange(0,40,2))
# Loads data from the past 20 minutes (2 steps) and 
# creates a target over the next 10-60 min. 
input_duration = '20min'
train_lead_times = slice('10min', '60min') 


task_config = graphcast.TaskConfig(
      input_variables=input_variables,
      target_variables=target_variables,
      forcing_variables=forcing_variables,
      pressure_levels=pressure_levels,
      input_duration=input_duration,
  )

# Example usage:
chunk_size = 3  # The size of each chunk
overlap = 2     # The overlap between consecutive chunks

data_paths = glob(os.path.join('/work/mflora/wofs-cast-data/datasets/dataset*.nc'))
data_paths.sort()

train_input_list = []
train_target_list = []
train_forcing_list = []

datasets = []

for path in data_paths:
    dataset = xr.load_dataset(path)
    
    # @title Extract training and eval data
    example_batch = dataset.expand_dims(dim='batch', axis=0)

    _train_inputs, _train_targets, _train_forcings = data_utils.extract_inputs_targets_forcings(
            example_batch, target_lead_times=train_lead_times,
            **dataclasses.asdict(task_config))

    train_input_list.append(_train_inputs)
    train_target_list.append(_train_targets)
    train_forcing_list.append(_train_forcings)
    
train_inputs = xr.concat(train_input_list, dim='batch')
train_targets = xr.concat(train_target_list, dim='batch')
train_forcings = xr.concat(train_forcing_list, dim='batch')
    
print("All Examples:  ", example_batch.dims.mapping)
print("*"*80)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)

base_path = '/work/mflora/wofs-cast-data/train_datasets'
# Save to NetCDF files
print('Saving the training datasets...')
train_inputs.to_netcdf(os.path.join(base_path, 'train_inputs.nc'))
train_targets.to_netcdf(os.path.join(base_path, 'train_targets.nc'))
train_forcings.to_netcdf(os.path.join(base_path, 'train_forcings.nc'))

print('Computing and saving the normalization datasets...')
full_dataset = xr.concat([train_inputs, train_forcings], dim='batch')

# Compute the global mean and standard deviation by level
mean_by_level = full_dataset.mean(dim=['time', 'lat', 'lon', 'batch'])
stddev_by_level = full_dataset.std(dim=['time', 'lat', 'lon', 'batch'])

# For differences, first compute differences by time within each dataset, then concatenate
time_diffs = full_dataset.diff(dim='time')

# Compute the global standard deviation of the differences by level
diffs_stddev_by_level = time_diffs.std(dim=[ 'time', 'lat', 'lon', 'batch'])

# Save to NetCDF files
base_path = '/work/mflora/wofs-cast-data/normalization_stats/'

mean_by_level.to_netcdf(os.path.join(base_path, 'mean_by_level.nc'))
stddev_by_level.to_netcdf(os.path.join(base_path, 'stddev_by_level.nc'))
diffs_stddev_by_level.to_netcdf(os.path.join(base_path, 'diffs_stddev_by_level.nc'))
'''

Num of Files: 204
Found 73 files for processing.
Saving the training datasets...


'    \n    \n#process_dataset(dates[0], init_times[0], 1, drop_vars)    \nwith ProcessPoolExecutor(max_workers=max_workers) as executor:\n    futures = [executor.submit(process_dataset, date, init_time, \n                               mem, drop_vars) for date in dates for mem in mems for init_time in init_times]\n    for future in tqdm(futures, desc="Processing datasets", total=total_files):\n        future.result()  # Wait for all futures to complete\n\nprint(\'Done!\')\n\nfrom wofscast import data_utils\nfrom wofscast import my_graphcast as graphcast\n\n# Create the training and target datasets. \ninput_variables = [\'U\', \'V\', \'W\', \'T\']#, \'P\', \'REFL_10CM\', \'UP_HELI_MAX\']\ntarget_variables = [\'U\', \'V\', \'W\', \'T\']#, \'P\', \'REFL_10CM\', \'UP_HELI_MAX\']\nforcing_variables = ["XLAND"]\n# Not pressure levels, but just vertical array indices at the moment. \npressure_levels = np.arange(0, 40) #list(np.arange(0,40,2))\n# Loads data from the past 20 minutes (2 steps) a