## Compute the normalization statistics for the GraphCast code

In [1]:
import xarray as xr 
import numpy as np
from glob import glob

import random 
import os

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

from wofscast import graphcast_lam as graphcast
import dask 

from wofscast.data_generator import (add_local_solar_time, 
                                     to_static_vars, 
                                     load_chunk, 
                                     dataset_to_input,
                                     ZarrDataGenerator, 
                                     WRFZarrFileProcessor,
                                     WoFSDataProcessor
                                    )
from wofscast import data_utils
from wofscast.wofscast_task_config import (DBZ_TASK_CONFIG, 
                                           WOFS_TASK_CONFIG, 
                                           DBZ_TASK_CONFIG_1HR,
                                           DBZ_TASK_CONFIG_FULL
                                          )
from os.path import join

# Save to NetCDF files
save_path = '/work/mflora/wofs-cast-data/full_normalization_stats/'

In [2]:
import random

def get_random_subset(input_list, subset_size, seed=123):
    """
    Get a random subset of a specified size from the input list.

    Parameters:
    -----------
    input_list : list
        The original list from which to draw the subset.
    subset_size : int
        The size of the subset to be drawn.
    seed : int, optional
        The seed for the random number generator. Default is None.

    Returns:
    --------
    list
        A random subset of the input list.
    """
    if subset_size > len(input_list):
        raise ValueError("subset_size must be less than or equal to the length of the input list")
    
    if seed is not None:
        random.seed(seed)

    return random.sample(input_list, subset_size)

In [3]:
def compute_normalization_stats(paths, gpu_batch_size, task_config, save_path, 
                                batch_over_time=False, preprocess_fn=None): 

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        full_dataset = load_chunk(paths, batch_over_time, 
                                  gpu_batch_size, preprocess_fn) 

        # Setup computations using scattered data
        mean_by_level = full_dataset.mean(dim=['time', 'lat', 'lon', 'batch'])
        stddev_by_level = full_dataset.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1)

        time_diffs = full_dataset.diff(dim='time')
        diffs_stddev_by_level = time_diffs.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1)

        # Save results to NetCDF files (this triggers the computation)
        mean_by_level.to_netcdf(os.path.join(save_path, 'mean_by_level.nc'))
        stddev_by_level.to_netcdf(os.path.join(save_path, 'stddev_by_level.nc'))
        diffs_stddev_by_level.to_netcdf(os.path.join(save_path, 'diffs_stddev_by_level.nc'))

        # Close all datasets
        all_datasets = [full_dataset, mean_by_level, stddev_by_level, diffs_stddev_by_level]
        
        for ds in all_datasets:
            ds.close()

### Compute the normalization statistics for the WOFS_TASK_CONFIG

### Compute normalization statistics from DBZ_TASK_CONFIG_1HR

In [4]:
%%time 
# Usage
base_path = '/work2/wofs_zarr/'
years = ['2019', '2020']
resolution_minutes = 10

# Specify the restrictions for testing
restricted_dates = None
restricted_times = ['1900', '2000', '2100', '2200', '2300', '0000', '0100', '0200', '0300']
restricted_members = ['ENS_MEM_1', 'ENS_MEM_12', 'ENS_MEM_17', 'ENS_MEM_5']#, 'ENS_MEM_10', 'ENS_MEM_11']

processor = WRFZarrFileProcessor(base_path, years, 
                             resolution_minutes, 
                             restricted_dates, 
                             restricted_times, restricted_members)

paths = processor.run()

CPU times: user 1.08 s, sys: 165 ms, total: 1.25 s
Wall time: 1.34 s


In [5]:
len(paths)

2516

In [6]:
random_paths = get_random_subset(paths, 6)

In [7]:
%%time
save_path = '/work/mflora/wofs-cast-data/normalization_stats_full_domain'

preprocessor = WoFSDataProcessor()

compute_normalization_stats(random_paths, gpu_batch_size=32, 
                            task_config=DBZ_TASK_CONFIG_FULL, 
                            save_path=save_path, batch_over_time=True, 
                            preprocess_fn=preprocessor)

CPU times: user 3min 5s, sys: 1min 41s, total: 4min 47s
Wall time: 33.6 s
