## Compute the normalization statistics for the GraphCast code

In [1]:
import xarray as xr 
import numpy as np
from glob import glob

import random 
import os

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

from wofscast import graphcast_lam as graphcast
import dask 

from wofscast.data_generator import (add_local_solar_time, 
                                     to_static_vars, 
                                     load_chunk, 
                                     dataset_to_input,
                                     ZarrDataGenerator)
from wofscast import data_utils
from wofscast.wofscast_task_config import DBZ_TASK_CONFIG, WOFS_TASK_CONFIG
from os.path import join
import dask 
from dask.distributed import Client

# Save to NetCDF files
save_path = '/work/mflora/wofs-cast-data/full_normalization_stats/'

In [2]:
%%time
import os
from os.path import join
from concurrent.futures import ThreadPoolExecutor

base_path = '/work/mflora/wofs-cast-data/datasets_zarr'#_zarr'
years = ['2019', '2020']

def get_files_for_year(year):
    year_path = join(base_path, year)
    with os.scandir(year_path) as it:
        return [join(year_path, entry.name) for entry in it if entry.is_dir() and entry.name.endswith('.zarr')]
        #return [join(year_path, entry.name) for entry in it if entry.is_file()]
    
with ThreadPoolExecutor() as executor:
    paths = []
    for files in executor.map(get_files_for_year, years):
        paths.extend(files)

print(len(paths))

18053
CPU times: user 34.5 ms, sys: 29.9 ms, total: 64.3 ms
Wall time: 41.3 ms


In [3]:
def compute_normalization_stats(paths, gpu_batch_size, task_config, save_path): 

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        full_dataset = load_chunk(paths, gpu_batch_size=gpu_batch_size) 

        # Setup computations using scattered data
        mean_by_level = full_dataset.mean(dim=['time', 'lat', 'lon', 'batch'])
        stddev_by_level = full_dataset.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1)

        time_diffs = full_dataset.diff(dim='time')
        diffs_stddev_by_level = time_diffs.std(dim=['time', 'lat', 'lon', 'batch'], ddof=1)

        # Save results to NetCDF files (this triggers the computation)
        mean_by_level.to_netcdf(os.path.join(save_path, 'mean_by_level.nc'))
        stddev_by_level.to_netcdf(os.path.join(save_path, 'stddev_by_level.nc'))
        diffs_stddev_by_level.to_netcdf(os.path.join(save_path, 'diffs_stddev_by_level.nc'))

        # Close all datasets
        all_datasets = [full_dataset, mean_by_level, stddev_by_level, diffs_stddev_by_level]
        
        for ds in all_datasets:
            ds.close()

In [4]:
compute_normalization_stats(paths[:4096], gpu_batch_size=32, 
                            task_config=WOFS_TASK_CONFIG, save_path=save_path)

CPU times: user 20min 34s, sys: 9min 50s, total: 30min 24s
Wall time: 9min 51s
