## Compute the normalization statistics for the GraphCast code

In [1]:
import xarray as xr 
import numpy as np
from glob import glob

import random 
import os

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

from wofscast import graphcast_lam as graphcast
import dask 

from wofscast.data_generator import (add_local_solar_time, 
                                     to_static_vars, 
                                     load_chunk, 
                                     dataset_to_input,
                                     ZarrDataGenerator, 
                                     WRFZarrFileProcessor,
                                     WoFSDataProcessor
                                    )
from wofscast import data_utils
from wofscast.wofscast_task_config import (DBZ_TASK_CONFIG, 
                                           WOFS_TASK_CONFIG, 
                                           DBZ_TASK_CONFIG_1HR,
                                           DBZ_TASK_CONFIG_FULL,
                                           WOFS_TASK_CONFIG_1HR,
                                           WOFS_TASK_CONFIG_5MIN, 
                                          )
from os.path import join

In [2]:
import random

def get_random_subset(input_list, subset_size, seed=123):
    """
    Get a random subset of a specified size from the input list.

    Parameters:
    -----------
    input_list : list
        The original list from which to draw the subset.
    subset_size : int
        The size of the subset to be drawn.
    seed : int, optional
        The seed for the random number generator. Default is None.

    Returns:
    --------
    list
        A random subset of the input list.
    """
    if subset_size > len(input_list):
        raise ValueError("subset_size must be less than or equal to the length of the input list")
    
    if seed is not None:
        random.seed(seed)

    return random.sample(input_list, subset_size)

In [3]:
import os
import dask
import dask.array as da
from dask.distributed import Client, progress
from dask.diagnostics import ProgressBar

def compute_normalization_stats(paths, gpu_batch_size, task_config, save_path, 
                                preprocess_fn=None, n_workers=4, threads_per_worker=1): 
    #client = Client(n_workers=n_workers, threads_per_worker=threads_per_worker)
    #print(client)

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        
        # Load dataset in chunks
        full_dataset = load_chunk(paths, gpu_batch_size, preprocess_fn) 

        # Re-chunk the dataset for efficient computation
        full_dataset = full_dataset.chunk({'lat': 50, 'lon': 50, 'batch': 128})
        
        # Persist the dataset in memory to avoid reloading and recomputing
        full_dataset = full_dataset.persist()
        #progress(full_dataset)

        # Compute mean and standard deviation by level
        mean_by_level = full_dataset.mean(dim=['time', 'lat', 'lon', 'batch']).persist()
        stddev_by_level = full_dataset.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1).persist()

        # Compute standard deviation of time differences
        time_diffs = full_dataset.diff(dim='time')
        diffs_stddev_by_level = time_diffs.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1).persist()

        # Save results to NetCDF files (this triggers the computation)
        with ProgressBar():
            mean_by_level.to_netcdf(os.path.join(save_path, 'mean_by_level.nc'))
            stddev_by_level.to_netcdf(os.path.join(save_path, 'stddev_by_level.nc'))
            diffs_stddev_by_level.to_netcdf(os.path.join(save_path, 'diffs_stddev_by_level.nc'))

        # Close all datasets
        all_datasets = [full_dataset, mean_by_level, stddev_by_level, diffs_stddev_by_level]
        
        for ds in all_datasets:
            ds.close()

### Compute the normalization statistics for the WOFS_TASK_CONFIG

In [4]:
%%time
# Save to NetCDF files
save_path = '/work/mflora/wofs-cast-data/norm_stats_5min/'

#%%time

import os
from os.path import join
from concurrent.futures import ThreadPoolExecutor

base_path = '/work/mflora/wofs-cast-data/datasets_5min'#_zarr'
years = ['2019', '2020']

def get_files_for_year(year):
    year_path = join(base_path, year)
    with os.scandir(year_path) as it:
        return [join(year_path, entry.name) for entry in it 
                if entry.is_dir() and entry.name.endswith('ens_mem_09.zarr')]
        #return [join(year_path, entry.name) for entry in it if entry.is_file()]
    
with ThreadPoolExecutor() as executor:
    paths = []
    for files in executor.map(get_files_for_year, years):
        paths.extend(files)

print(len(paths))

random_paths = get_random_subset(paths, 512)



def preprocess_fn(dataset):
    latlon_path = '/work/mflora/wofs-cast-data/datasets_zarr/2021/wrfwof_2021-05-15_040000_to_2021-05-15_043000__10min__ens_mem_09.zarr'
    preprocess = WoFSDataProcessor(latlon_path=latlon_path)
    dataset = preprocess(dataset)
    dataset = add_local_solar_time(dataset)
    
    dataset = dataset.drop_dims('datetime')
    
    return dataset

compute_normalization_stats(random_paths, 
                            gpu_batch_size=len(random_paths), 
                            task_config=WOFS_TASK_CONFIG_5MIN, 
                            save_path=save_path, 
                           preprocess_fn=preprocess_fn)

# 128 paths, gpu_batch=4 : 32s
# 128 paths, gpu_batch=32 : 11.1s

1003
[########################################] | 100% Completed | 102.05 ms
[########################################] | 100% Completed | 101.61 ms
[########################################] | 100% Completed | 101.57 ms
CPU times: user 4min 32s, sys: 2min 21s, total: 6min 53s
Wall time: 4min 14s


### Compute normalization statistics from DBZ_TASK_CONFIG_1HR