In [1]:
import sys, os 
sys.path.append('/home/monte.flora/python_packages/frdd-wofs-cast/')

from wofscast import data_utils
from wofscast.wofscast_task_config import WOFS_TASK_CONFIG, train_lead_times
from wofscast.data_generator import to_static_vars, add_local_solar_time, load_wofscast_data

import xarray as xr
from glob import glob
import numpy as np
import dataclasses
import random 


import random
from collections import defaultdict

In [2]:
def check_for_nans(dataset):

    # Iterate through each variable in the Dataset
    for var_name, data_array in dataset.items():
        # Find boolean mask of NaNs
        nan_mask = data_array.isnull()
    
        # Use np.where to find the indices of NaNs
        nan_indices = np.where(nan_mask)
    
        # `nan_indices` is a tuple of arrays, each array corresponds to indices along one dimension
        # Print the locations of NaNs
        print(f"NaN locations in {var_name}:")
        for dim, inds in zip(nan_mask.dims, nan_indices):
            print(f"  {dim}: {inds}")   

### Randomly sample the different cases and ensemble members to improve training dataset diversity 

In [3]:
data_paths = []
# Only doing 2019 and 2020 and leaving 2021 as the evaluation dataset?
for year in ['2019', '2020']:
    data_paths.extend(glob(os.path.join(f'/work/mflora/wofs-cast-data/datasets/{year}/wrf*.nc')))

# Function to parse date and ensemble member from a file name
def parse_file_info(file_name):
    parts = file_name.split('_')
    date = parts[1]
    ens_mem = parts[-1].split('.')[0]
    return date, ens_mem

# Organize files by date and ensemble member
files_by_date_and_ens = defaultdict(lambda: defaultdict(list))
for file_path in data_paths:
    file_name = os.path.basename(file_path)
    date, ens_mem = parse_file_info(file_name)
    files_by_date_and_ens[date][ens_mem].append(file_path)

# Decide how many samples you want per date and ensemble member
samples_per_date_and_ens = 3  # Example: 1 sample per combination

# Sample files
sampled_files = []
for date, ens_members in files_by_date_and_ens.items():
    for ens_mem, files in ens_members.items():
        if len(files) >= samples_per_date_and_ens:
            sampled_files.extend(random.sample(files, samples_per_date_and_ens))
        else:
            sampled_files.extend(files)  # Add all if fewer files than desired samples

# sampled_files now contains your randomly sampled files
print(len(sampled_files))

4860


In [4]:
from tqdm.notebook import tqdm
def read_netcdfs(paths, dim, transform_func=None):
    def process_one_path(path):
        # use a context manager, to ensure the file gets closed after use
        with xr.open_dataset(path) as ds:
            # transform_func should do some sort of selection or
            # aggregation
            if transform_func is not None:
                ds = transform_func(ds)
            # load all data from the transformed dataset, to ensure we can
            # use it after closing each original file
            ds.load()
            return ds
        
    datasets = [process_one_path(p) for p in tqdm(paths, desc="Processing files")]
    combined = xr.concat(datasets, dim)
    return combined

In [5]:
%%time

#inputs, targets, forcings = load_wofscast_data(sampled_files, 
#                                               train_lead_times, 
#                                               WOFS_TASK_CONFIG, client)

# here we suppose we only care about the combined mean of each file;
# you might also use indexing operations like .sel to subset datasets
dataset = read_netcdfs(sampled_files[:3000], dim='batch',
                        transform_func=add_local_solar_time)

inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(dataset,
                                                        target_lead_times=train_lead_times,
                                                        **dataclasses.asdict(WOFS_TASK_CONFIG))

Processing files:   0%|          | 0/3000 [00:00<?, ?it/s]

CPU times: user 6min 52s, sys: 1min 12s, total: 8min 4s
Wall time: 20min 33s


In [6]:
print('Train Input Dims: ', inputs.dims.mapping)
print('Train Target Dims: ', targets.dims.mapping)
print('Train Forcing Dims: ', forcings.dims.mapping)

Train Input Dims:  {'batch': 3000, 'time': 2, 'level': 17, 'lat': 150, 'lon': 150}
Train Target Dims:  {'batch': 3000, 'time': 1, 'level': 17, 'lat': 150, 'lon': 150}
Train Forcing Dims:  {'batch': 3000, 'time': 1, 'lon': 150, 'lat': 150}


In [8]:
# Calculate the memory usage in bytes
memory_usage_bytes = targets.nbytes + inputs.nbytes + forcings.nbytes

# Alternatively, for gigabytes (GB)
memory_usage_gb = memory_usage_bytes / (1024**3)
print(f"Total Memory Usage for Inputs, Targets, and Forcings: {memory_usage_gb:.2f} GB")


# For the new reduce ~15-17 MB files, 500 files is 13.66 GBs

# 1000 -> 27 
# 2000 -> 54
# 4000 -> 110
# 5000 -> 135 GBs 


Total Memory Usage for Inputs, Targets, and Forcings: 81.98 GB


In [9]:
%%time
out_path = '/work/mflora/wofs-cast-data/train_datasets'
inputs.to_netcdf(os.path.join(out_path, 'train_inputs.nc'))
targets.to_netcdf(os.path.join(out_path, 'train_targets.nc'))
forcings.to_netcdf(os.path.join(out_path, 'train_forcings.nc'))

In [10]:
import gc
gc.collect()

81