In [1]:
from datetime import datetime, timedelta
import xarray as xr
import numpy as np
import torch
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
import yaml
from credit.parser import credit_main_parser, predict_data_check
import multiprocessing as mp
from credit.datasets import setup_data_loading
from credit.forecast import load_forecasts
from credit.transforms import load_transforms
from torch.utils.data import Dataset, DataLoader, Sampler, DistributedSampler
import pandas as pd
from pydantic import BaseModel, Extra
from glob import glob
from os.path import join
import itertools

In [2]:
class ERA5Dataset(Dataset):
    
    """ Pytorch Dataset for processed ERA5 data. Relies on a configuration dictionary to define:
            1) 2D / 3D variables
            2) Start, End and Frequency of Datetimes
            3) the base path to the directory where the data is stored. 
            4) Example YAML Format:
            
                data:
                  source:
                    ERA5:
                      vars_3D: ['T', 'U', 'V', 'Q']
                      vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
                      vars_persist: None
                      path: "/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_cesm_stage1/all_in_one/"
                
                  start_datetime: "2017-01-01" 
                  end_datetime: "2019-12-31"
                  time_step: "6h"
        
        Assumptions:
            1) The data must be stored in yearly zarr files with a unique 4-digit year (YYYY) in the file name
            2) "time" dimension / coordinate is present with the datetime64[ns] datatype
            3) "level" dimension name representing the vertical level
            4) Dimention order of ('time', level', 'latitude', 'longitude') for 3D vars (remove level for 2D)
            5) Stored Zarr data should be chunked efficiently for a fast read (recommend small chunks across time dimension).
            
            """ 
    def __init__(self, config):
        
        self.source_name = "ERA5"
        self.base_path = config['data']['source'][self.source_name]['path']
        self.file_list = sorted(glob(join(self.base_path, "*")))
        self.start_datetime = config['data']['start_datetime']
        self.end_datetime = config['data']['end_datetime']
        self.time_step = config['data']['time_step']
        self.datetimes = pd.date_range(self.start_datetime, self.end_datetime, freq=self.time_step)
        self.vars_2D = config['data']['source'][self.source_name]['vars_2D']
        self.vars_3D = config['data']['source'][self.source_name]['vars_3D']
        self.forecast_step = 5
        self.return_target = False
        self.files = self._map_files()
        
    def __len__(self):

        return len(self.datetimes) - self.forecast_step
        
    def __getitem__(self, args):
        idx = int(args[0])
        dataset_x = self._open_file(self.files[idx], self.datetimes[idx])
        data_array_x = self._reshape_and_concat(dataset_x)
        
        if self.return_target:
            
            idx_y = idx + 1
            dataset_y = self._open_file(self.files[idx_y], self.datetimes[idx_y])
            data_array_y = self._reshape_and_concat(dataset_y)
            
            return torch.from_numpy(data_array_x).float(), torch.from_numpy(data_array_y).float()
            
        else:
    
            return torch.from_numpy(data_array_x).float(), self.datetimes[idx].strftime("%Y%m%d_%H00"), idx
        
    def _map_files(self):
        
        """ Create a list of files that contain the data for a given time step. """

        years = [str(y) for y in self.datetimes.year]
        self.file_map = {int(y): f for f in self.file_list for y in years if y in f}

        return [self.file_map[d.year] for d in self.datetimes]

    def _open_file(self, filename, datetime):

        """ Open a specific file and subset a specific time step. """

        data = xr.open_zarr(filename).sel(time=datetime)
        
        return data
        
    def _reshape_and_concat(self, data):

        """ Stack 3D variables along level and variable, concatenate with 2D variables, and reorder dimesions. """ 
        
        data_3D = data[self.vars_3D].to_array().stack({'level_var':['variable', 'level']}).values
        data_3D = np.expand_dims(data_3D.transpose(2, 0, 1), axis=1)

        data_2D = np.expand_dims(data[self.vars_2D].to_array().values, axis=1)
        
        combined_data = np.concatenate([data_3D, data_2D])

        return combined_data

In [3]:
class MultiStepBatchSamplerSubset(Sampler):
    
    def __init__(
        self,
        dataset,
        index_subset = None,  # if None, use entire dataset
        batch_size=3,
        num_forecast_steps=4,
        backprop_forecast_steps=[], # list from 1 to forecast_steps
    ) -> None:
        """
        taking advantage of DistributedSampler class code with this dataset.
        can be used on its own with index_subset=None
        
        Args:
            data: list of data
        """

        self.dataset = dataset

        if index_subset is not None and len(index_subset) > 0:
            # don't need to shuffle because distributed sampler shuffles for us.
            self.index_subset = torch.tensor(
                index_subset
            )  # must all be valid starting times, this is given by the DistributedSampler Wrapper
        else:
            self.index_subset = torch.randperm(len(dataset))

        self.batch_size = batch_size
        self.num_forecast_steps = num_forecast_steps
        if backprop_forecast_steps:
            self.backprop_forecast_steps = backprop_forecast_steps
        else:
            self.backprop_forecast_steps = list(range(1, self.num_forecast_steps+1))

            
        self.num_start_batches = (
            len(self.index_subset) + self.batch_size - 1
        ) // self.batch_size

    def __len__(self):
        # actual number of iters of the sampler
        return self.num_start_batches * self.num_forecast_steps

    def __iter__(self):
        index_iter = iter(self.index_subset)

        batch = list(itertools.islice(index_iter, self.batch_size))
        
        while batch:
            # iterate through batches of valid starting times,
            # wrt self.num_forecast_steps
            for i in range(self.num_forecast_steps + 1):
                # for each batch of valid starting times,
                # iterate through subsequent valid forecast times
                
                if i == 0:
                    yield [(k + i, "init") for k in batch]
                    
                elif i in self.backprop_forecast_steps:
                    yield [(k + i, "backprop") for k in batch]
                    
                else:
                    yield [(k + i, "forcing") for k in batch]

            batch = list(itertools.islice(index_iter, self.batch_size))


class DistributedMultiStepBatchSampler(DistributedSampler):
    
    def __init__(self, dataset: Dataset,
                 batch_size: int,
                 num_forecast_steps: int,
                 backprop_forecast_steps=[],
                 num_replicas = None,
                 rank = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False,
                 ) -> None:
        
        super().__init__(dataset=dataset, num_replicas=num_replicas,
                         rank=rank, shuffle=shuffle, seed=seed,
                         drop_last=drop_last)
        
        self.batch_size = batch_size
        self.num_forecast_steps = num_forecast_steps
        self.backprop_forecast_steps = backprop_forecast_steps

    def __iter__(self):

        indices = list(super().__iter__())

        batch_sampler = MultiStepBatchSamplerSubset(self.dataset,
                                                    indices,
                                                    batch_size=self.batch_size,
                                                    num_forecast_steps=self.num_forecast_steps,
                                                    backprop_forecast_steps=self.backprop_forecast_steps,
                                                    )
        return iter(batch_sampler)

    def __len__(self) -> int:
        
        return self.num_samples * self.num_forecast_steps

In [2]:
path = "/glade/work/cbecker/notebooks/credit_test_yaml.yaml"
with open(path) as cnfg:
    config = yaml.safe_load(cnfg)

In [3]:
config

{'data': {'source': {'ERA5': {'vars_3D': ['T', 'U', 'V', 'Q'],
    'vars_2D': ['T500', 'U500', 'V500', 'Q500', 'Z500', 'tsi', 't2m', 'SP'],
    'vars_persist': 'None',
    'path': '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_cesm_stage1/all_in_one/'}},
  'start_datetime': '2017-01-01',
  'end_datetime': '2019-12-31',
  'time_step': '6h'}}

In [5]:
rank = 0
world_size = 2
data = list(range(25))
dataset = ERA5Dataset(config=config)
sampler = DistributedMultiStepBatchSampler(dataset, num_forecast_steps=3, batch_size=6, num_replicas=2, rank=0, shuffle=True)
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4, prefetch_factor=2)

In [7]:
l = []
for batch_idx, data in enumerate(loader):
    print(f"Batch {batch_idx}")
    l.append(data)

In [8]:
for i in range(10):
    print(l[i][1])

('20181214_0000', '20181014_0600', '20170810_1800', '20170731_0600', '20170603_0600', '20180228_1800')
('20181214_0600', '20181014_1200', '20170811_0000', '20170731_1200', '20170603_1200', '20180301_0000')
('20181214_1200', '20181014_1800', '20170811_0600', '20170731_1800', '20170603_1800', '20180301_0600')
('20181214_1800', '20181015_0000', '20170811_1200', '20170801_0000', '20170604_0000', '20180301_1200')
('20171020_0000', '20190121_0000', '20190218_0000', '20190226_0000', '20180729_1200', '20190421_1800')
('20171020_0600', '20190121_0600', '20190218_0600', '20190226_0600', '20180729_1800', '20190422_0000')
('20171020_1200', '20190121_1200', '20190218_1200', '20190226_1200', '20180730_0000', '20190422_0600')
('20171020_1800', '20190121_1800', '20190218_1800', '20190226_1800', '20180730_0600', '20190422_1200')
('20170501_0000', '20190728_0600', '20170928_0600', '20180715_0600', '20181031_1800', '20171204_0600')
('20170501_0600', '20190728_1200', '20170928_1200', '20180715_1200', '201