# Goal
The goal of this notebook is the development of the Data Loader class with (hopefully) best practices

In [56]:
import pandas as pd
from netCDF4 import Dataset as DS
import numpy as np
import torch
import glob
import logging
from torch.utils.data import DataLoader, Dataset


## Initialize class

In [9]:
file_path = '/hkfs/work/workspace/scratch/ke4365-pangu/PANGU_ERA5_data_v0/'

In [202]:
class GetDataset(Dataset):
    def __init__(self, params, location, train):
        self.params = params
        self.file_path= file_path
        self.train = train
        print("params", params)
        self.dt = params['dt']
        self.n_history = params['n_history']
        self.in_channels = np.array(params['in_channels'])
        self.out_channels = np.array(params['out_channels'])
        self.n_in_channels  = 0#len(self.in_channels)
        self.n_out_channels = 0#len(self.out_channels)
        self.roll = params['roll']
        self._get_files_stats(file_path)
        self.add_noise = params['add_noise'] if train else False
        
        try:
            self.normalize = params.normalize
        except:
            self.normalize = True #by default turn on normalization if not specified in config

    def _get_files_stats(self, file_path, dt=6):
        self.files_paths_pressure = glob.glob(file_path + "/????.h5") # indicates file paths for pressure levels
        self.files_paths_surface = glob.glob(file_path + "/single_????.h5") # indicates file paths for pressure levels

        
        self.files_paths_pressure.sort()
        self.files_paths_surface.sort()
        assert len(self.files_paths_pressure) == len(self.files_paths_surface), "Number of years not identical in pressure vs. surface level data."
    
        self.n_years = len(self.files_paths_pressure)
        with h5py.File(self.files_paths_pressure[0], 'r') as _f:
            logging.info("Getting file stats from {}".format(self.files_paths_pressure[0]))
            self.n_samples_per_year = _f['fields'].shape[0]
            #original image shape (before padding)
            self.img_shape_x = _f['fields'].shape[2]
            self.img_shape_y = _f['fields'].shape[3]
            self.n_in_channels = 13 #TODO
        self.n_samples_total = self.n_years * self.n_samples_per_year
        self.files_pressure = [None for _ in range(self.n_years)]
        self.files_surface = [None for _ in range(self.n_years)]
        
        logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
        logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(file_path, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels))
        logging.info("Delta t: {} hours".format(6*self.dt))
        
    def _open_pressure_file(self, year_idx):
        _file = h5py.File(self.files_paths_pressure[year_idx], 'r')
        self.files_pressure[year_idx] = _file
        
    def _open_surface_file(self, year_idx):
        print(self.files_paths_surface[year_idx])
        _file = h5py.File(self.files_paths_surface[year_idx], 'r')
        self.files_surface[year_idx] = _file
      
    def __len__(self):
        return self.n_samples_total
    
    def __getitem__(self, global_idx):
        year_idx  = int(global_idx/self.n_samples_per_year) #which year we are on
        local_idx = int(global_idx%self.n_samples_per_year) #which sample in that year we are on - determines indices for centering
        print("year idx", year_idx)    
        print("local idx", local_idx)    
        y_roll = np.random.randint(0, 1440) if self.train else 0#roll image in y direction
    
        #open image file
        if self.files_pressure[year_idx] is None:
            self._open_pressure_file(year_idx)
    
        if self.files_surface[year_idx] is None:
            self._open_surface_file(year_idx)
        
        step = self.dt
        #if we are not at least self.dt*n_history timesteps into the prediction
        if local_idx < self.dt*self.n_history:
            local_idx += self.dt*self.n_history
    
            #if we are on the last image in a year predict identity, else predict next timestep
            step = 0 if local_idx >= self.n_samples_per_year-self.dt else self.dt
        
        if self.train and self.roll:
          y_roll = random.randint(0, self.img_shape_y)
        else:
          y_roll = 0

        return self.files_pressure[year_idx]['fields'][local_idx], \
               self.files_pressure[year_idx]['fields'][local_idx+step], \
               self.files_surface[year_idx]['fields'][local_idx], \
               self.files_surface[year_idx]['fields'][local_idx+step]
                
        #return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, self.params, y_roll, self.train, self.normalize, self.add_noise), \
        #       reshape_fields(self.files[year_idx][local_idx + step, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, self.params, y_roll, self.train, self.normalize)


In [203]:
params = {}
params['dt'] = 1
params['n_history'] = 1
params['in_channels'] = 8
params['out_channels'] = 8
params['roll'] = False
params['add_noise'] = False

dataset = GetDataset(params=params, location='/hkfs/work/workspace/scratch/ke4365-pangu/PANGU_ERA5_data_v0/', train=True)
dataset.__getitem__(0)

params {'dt': 1, 'n_history': 1, 'in_channels': 8, 'out_channels': 8, 'roll': False, 'add_noise': False}
year idx 0
local idx 0
/hkfs/work/workspace/scratch/ke4365-pangu/PANGU_ERA5_data_v0/single_1980.h5


(array([[[[ 1.72310974e+03,  1.72310974e+03,  1.72310974e+03, ...,
            1.72310974e+03,  1.72310974e+03,  1.72310974e+03],
          [ 1.70017114e+03,  1.70017114e+03,  1.70017114e+03, ...,
            1.70017114e+03,  1.70017114e+03,  1.70017114e+03],
          [ 1.68050940e+03,  1.68050940e+03,  1.68050940e+03, ...,
            1.68050940e+03,  1.68050940e+03,  1.68050940e+03],
          ...,
          [ 2.74698151e+02,  2.74698151e+02,  2.71421204e+02, ...,
            2.74698151e+02,  2.74698151e+02,  2.74698151e+02],
          [ 2.71421204e+02,  2.71421204e+02,  2.71421204e+02, ...,
            2.71421204e+02,  2.71421204e+02,  2.71421204e+02],
          [ 2.74698151e+02,  2.74698151e+02,  2.74698151e+02, ...,
            2.74698151e+02,  2.74698151e+02,  2.74698151e+02]],
 
         [[ 7.32341602e+03,  7.32341602e+03,  7.32341602e+03, ...,
            7.32341602e+03,  7.32341602e+03,  7.32341602e+03],
          [ 7.29720068e+03,  7.29720068e+03,  7.29720068e+03, ...,
     

In [129]:
def reshape_fields(img, inp_or_tar, crop_size_x, crop_size_y, params, y_roll, train, normalize=True, add_noise=False):
    #Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
    if len(np.shape(img)) == 3:
      img = np.expand_dims(img, 0)

    n_history = np.shape(img)[0] - 1
    img_shape_x = np.shape(img)[-2]
    img_shape_y = np.shape(img)[-1]
    n_channels = np.shape(img)[1] #this will either be N_in_channels or N_out_channels
    channels = params.in_channels if inp_or_tar =='inp' else params.out_channels
    means = np.load(params.global_means_path)[:, channels]
    stds = np.load(params.global_stds_path)[:, channels]
    
    if normalize:
        if params.normalization == 'minmax':
          raise Exception("minmax not supported. Use zscore")
        elif params.normalization == 'zscore':
          img -= means
          img /= stds

    if params.roll:
        img = np.roll(img, y_roll, axis = -1)


    if inp_or_tar == 'inp':
        img = np.reshape(img, (n_channels*(n_history+1), img_shape_x, img_shape_y))
    elif inp_or_tar == 'tar':
        img = np.reshape(img, (n_channels, img_shape_x, img_shape_y))

    if add_noise:
        img = img + np.random.normal(0, scale=params.noise_std, size=img.shape)

    return torch.as_tensor(img)

In [184]:
import h5py

d2 = h5py.File('/hkfs/work/workspace/scratch/ke4365-pangu/PANGU_ERA5_data_v0/single_1980.h5')

In [186]:
d2['fields'].shape

(96, 4, 721, 1440)