In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import h5py
from tqdm import tqdm

In [2]:
path = "/glade/scratch/priyamm/transfer_learning_data/phoenix_data.h5"
hf = h5py.File(path, 'r')

In [10]:
hf["phoenix"][0, :][9:].reshape(8,33,33).shape

(8, 33, 33)

### Vanilla Loader

In [11]:
class hdfLoader(Dataset):
    def __init__(self):
        path = "/glade/scratch/priyamm/transfer_learning_data/phoenix_data.h5"
        self.hf = h5py.File(path, 'r')
        self.key = list(self.hf.keys())[0]
    def __len__(self):
        return self.hf[self.key].shape[0]

    def __getitem__(self, index):
        sample = self.hf[self.key][index, :]
        forcing = sample[:7]
        month = sample[7]
        lst = sample[8]
        image = sample[9:].reshape(8,33,33)
        return forcing, image, month, lst
        

### MultiDataset Access

In [12]:
dataset = hdfLoader()

In [None]:
datasets = [dataset, dataset, dataset, dataset, dataset]

In [None]:
for dataset in datasets:
    for forcing, image, month, lst in dataset:
        print(lst)
        break

### Preshuffled DataLoader

In [13]:
loader = DataLoader(dataset, shuffle=False, batch_size=1024, num_workers=48)

In [None]:
import time
t0 = time.time()

for forcing, image, month, lst in tqdm(loader, total=len(loader)):
    pass

t1 = time.time()
total = t1-t0
print("Time Elapsed", total / 60, "minutes")


### MultiDataloader Access

In [15]:
class hdfLoaderTop(Dataset):
    def __init__(self, length=500000):
        path = "/glade/scratch/priyamm/transfer_learning_data/phoenix_data.h5"
        self.hf = h5py.File(path, 'r')
        self.key = list(self.hf.keys())[0]
        self.length = length
        
    def __len__(self):
        return self.length

    def __getitem__(self, index):
        sample = self.hf[self.key][index, :]
        forcing = sample[:7]
        month = sample[7]
        lst = sample[8]
        image = sample[9:].reshape(8,33,33)
        return forcing, image, month, lst
    
    
class hdfLoaderRandom(Dataset):
    def __init__(self, length=500000):
        path = "/glade/scratch/priyamm/transfer_learning_data/phoenix_data.h5"
        self.hf = h5py.File(path, 'r')
        self.total_samples = hf["phoenix"].shape[0]
        self.key = list(self.hf.keys())[0]
        self.length = length
        self.rand_index = random.sample((range(0, self.total_samples)), self.length)

    def __len__(self):
        return self.length

    def __getitem__(self, index): 
        index = self.rand_index[index]
        sample = self.hf[self.key][index, :]
        forcing = sample[:7]
        month = sample[7]
        lst = sample[8]
        image = sample[9:].reshape(8,33,33)
        return forcing, image, month, lst

In [None]:
dataset = hdfLoaderRandom()
loader = DataLoader(dataset, shuffle=False, batch_size=16, num_workers=24)
for forcing, image, month, lst in tqdm(loader, total=len(loader)):
    pass

In [18]:
dataset = hdfLoaderTop()
loader = DataLoader(dataset, shuffle=False, batch_size=8)
dataloaders = [loader for i in range(100)]

In [None]:
for i, dataloads in tqdm(enumerate(zip(*dataloaders)), total=len(loader)):
    for forcing, image, month, lst in dataloads:
        pass
    pass

  0%|          | 303/62500 [02:31<8:37:16,  2.00it/s]