In [1]:
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))
from wofscast.data_generator import (
                                     add_local_solar_time, 
                                     to_static_vars, 
                                     dataset_to_input,
                                     load_chunk,
                                     shard_xarray_dataset,
                                     extract_datetime_from_path,
                                     ZarrDataGenerator,
                                     WoFSDataProcessor
                                    )
from wofscast import data_utils
from wofscast import wofscast_task_config 
from wofscast import xarray_jax 

from os.path import join
import dataclasses

import xarray as xr
import numpy as np
from dask.distributed import Client, LocalCluster
import pandas as pd

import dask
from dask.diagnostics import ProgressBar
from concurrent.futures import ThreadPoolExecutor

from datetime import datetime
import time

In [2]:
%%time
import os
from os.path import join
from concurrent.futures import ThreadPoolExecutor


def get_paths(base_paths, years=['2019', '2020']):
    paths = []
    
    def get_files_for_year(year, base_path):
        """ Helper function to get zarr directories for a given year and base path """
        year_path = join(base_path, year)
        with os.scandir(year_path) as it:
            return [join(year_path, entry.name) for entry in it if entry.is_dir() and entry.name.endswith('.zarr')]
    
    # Use a thread pool to handle file retrieval in parallel for each year for each base path
    with ThreadPoolExecutor() as executor:
        # Generate all (year, base_path) combinations to fetch data from
        future_to_year = {executor.submit(get_files_for_year, year, base_path): (year, base_path) 
                          for base_path in base_paths for year in years}
        for future in future_to_year:
            paths.extend(future.result())  # Collect results from futures

    return paths

# Example usage:
#base_paths = [
#                '/work/mflora/wofs-cast-data/datasets_zarr',
#              '/work2/mflora/wofscast_datasets/dataset_10min_15min_init_train'
#            ]

# 5-min data
#base_paths = ['/work2/mflora/wofscast_datasets/new_dataset_5min']

#base_paths = ['/work2/mflora/wofscast_datasets/test_new_data_generation/']

base_paths = ['/work2/mflora/wofscast_datasets/dataset_10min_full_domain/']

paths = get_paths(base_paths, years=['2019', '2020'])
print(len(paths))

630
CPU times: user 2.72 ms, sys: 1.24 ms, total: 3.96 ms
Wall time: 5.21 ms


### TEST TIME FOR LOADING A MULTIPLE FILES. 

### Check dataset_to_input

### TEST TIME FOR LOADING A SINGLE FILE. 

In [8]:
%%time

rs = np.random.RandomState(42)

rs.shuffle(paths)

def preprocess_fn(ds):
    #if ds.dims['time']==4:
        # Drop the last time step
   #     ds = ds.isel(time=[0,1,2], datetime=[0,1,2])
        
    # Apply additional preprocessing, e.g., add local solar time
    ds = add_local_solar_time(ds)
    
    return ds


generator = ZarrDataGenerator(paths, 
                              wofscast_task_config.WOFS_TASK_CONFIG,#_GC, 
                              target_lead_times=None,
                              batch_size=4, 
                              num_devices=2, 
                              preprocess_fn=preprocess_fn,
                              prefetch_size=2,
                              random_seed=240, 
                              decode_times=False,
                             )
n_steps = 4
duration = np.zeros(n_steps)
for i in range(n_steps):
    start_time = time.time()
    inputs, targets, forcings = generator.generate()
    duration[i] = time.time() - start_time

CPU times: user 2.93 s, sys: 1.31 s, total: 4.24 s
Wall time: 2.49 s


In [9]:
inputs 

In [10]:
forcings 

In [6]:
np.mean(duration)

4.734031975269318

In [7]:
inputs.nbytes / (1024 *1024*1024)

0.8609914779663086