## Compute the normalization statistics for the GraphCast code

In [1]:
import xarray as xr 
import numpy as np
from glob import glob

import random 
import os

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

from wofscast import graphcast_lam as graphcast
import dask 

from wofscast.data_generator import (add_local_solar_time, 
                                     to_static_vars, 
                                     load_chunk, 
                                     dataset_to_input,
                                     ZarrDataGenerator, 
                                     WRFZarrFileProcessor,
                                     WoFSDataProcessor
                                    )
from wofscast import data_utils
from wofscast.wofscast_task_config import (DBZ_TASK_CONFIG, 
                                           WOFS_TASK_CONFIG, 
                                           DBZ_TASK_CONFIG_1HR,
                                           DBZ_TASK_CONFIG_FULL
                                          )
from os.path import join

In [2]:
import random

def get_random_subset(input_list, subset_size, seed=123):
    """
    Get a random subset of a specified size from the input list.

    Parameters:
    -----------
    input_list : list
        The original list from which to draw the subset.
    subset_size : int
        The size of the subset to be drawn.
    seed : int, optional
        The seed for the random number generator. Default is None.

    Returns:
    --------
    list
        A random subset of the input list.
    """
    if subset_size > len(input_list):
        raise ValueError("subset_size must be less than or equal to the length of the input list")
    
    if seed is not None:
        random.seed(seed)

    return random.sample(input_list, subset_size)

In [3]:
from dask.diagnostics import ProgressBar

def compute_normalization_stats(paths, gpu_batch_size, task_config, save_path, 
                                batch_over_time=False, preprocess_fn=None): 

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        
        full_dataset = load_chunk(paths, batch_over_time, 
                                  gpu_batch_size, preprocess_fn) 

        full_dataset = full_dataset.chunk({'lat' : 50, 'lon' : 50, 'batch' : 128})
        
        # Setup computations using scattered data
        mean_by_level = full_dataset.mean(dim=['time', 'lat', 'lon', 'batch'])
        stddev_by_level = full_dataset.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1)

        time_diffs = full_dataset.diff(dim='time')
        diffs_stddev_by_level = time_diffs.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1)

        # Save results to NetCDF files (this triggers the computation)
        # Save results to NetCDF files (this triggers the computation)
        with ProgressBar():
            mean_by_level.to_netcdf(os.path.join(save_path, 'mean_by_level.nc'))
            stddev_by_level.to_netcdf(os.path.join(save_path, 'stddev_by_level.nc'))
            diffs_stddev_by_level.to_netcdf(os.path.join(save_path, 'diffs_stddev_by_level.nc'))

        # Close all datasets
        all_datasets = [full_dataset, mean_by_level, stddev_by_level, diffs_stddev_by_level]
        
        for ds in all_datasets:
            ds.close()

### Compute the normalization statistics for the WOFS_TASK_CONFIG

In [4]:
#%%time

import os
from os.path import join
from concurrent.futures import ThreadPoolExecutor

base_path = '/work/mflora/wofs-cast-data/datasets_zarr'#_zarr'
years = ['2019', '2020']

def get_files_for_year(year):
    year_path = join(base_path, year)
    with os.scandir(year_path) as it:
        return [join(year_path, entry.name) for entry in it if entry.is_dir() and entry.name.endswith('.zarr')]
        #return [join(year_path, entry.name) for entry in it if entry.is_file()]
    
with ThreadPoolExecutor() as executor:
    paths = []
    for files in executor.map(get_files_for_year, years):
        paths.extend(files)

print(len(paths))

random_paths = get_random_subset(paths, 4096)

18053
CPU times: user 26.5 ms, sys: 225 ms, total: 251 ms
Wall time: 772 ms


In [5]:
#%%time
# Save to NetCDF files
save_path = '/work/mflora/wofs-cast-data/test_norm_stats/'

compute_normalization_stats(random_paths, 
                            gpu_batch_size=len(random_paths), 
                            task_config=WOFS_TASK_CONFIG, 
                            save_path=save_path)

# 128 paths, gpu_batch=4 : 32s
# 128 paths, gpu_batch=32 : 11.1s

[########################################] | 100% Completed | 126.55 s
[########################################] | 100% Completed | 30.83 s
[########################################] | 100% Completed | 30.14 s
CPU times: user 8min 25s, sys: 4min 57s, total: 13min 23s
Wall time: 3min 41s


### Compute normalization statistics from DBZ_TASK_CONFIG_1HR