In [23]:
import h5py
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
from typing import Dict, Any, Tuple, Optional
import itertools
import timeit

class CustomDataset(Dataset):
    def __init__(self, h5_file_path: str,):
        self._load_data(h5_file_path)
    
    def __len__(self):
        return 40001
    
    def __getitem__(self, index):
        """
        Get a specific chunk of data
        Ideally we want to be given a slice of the data that we want to grab
        Returns:
            tuple: (input_data, targets)
            - input_data: tensor of shape (channels, x_dim, y_dim, chunk_size)
            - targets: dictionary of invariants for this chunk
        """
        data = self.inputs['data'][index]
        gamma_n = self.inputs['gamma_n'][index]
        return {'data': data, 'gamma_n': gamma_n}
    
    def _load_h5_file_with_data(self, file_path:str, derived_data_key:str = "invariants"):
        """Method for loading .h5 files
        
        :returns: dict that contains name of the .h5 file as stored in the .h5 file, as well as a generator of the data
        """
        file = h5py.File(file_path)
        key = list(file.keys())[0]
        data = file[key]
        derived_data = file[derived_data_key]
        gamma_c = file[derived_data_key]['$\\Gamma_c$']
        gamma_n = file[derived_data_key]['$\\Gamma_n$']
        E = file[derived_data_key]['$\\mathcal{D}^E$']
        U = file[derived_data_key]['$\\mathcal{D}^U$']
        energy = file[derived_data_key]['energy']
        enstrophy = file[derived_data_key]['enstrophy']
        time = file[derived_data_key]['time']
        return dict(file=file, data=data, gamma_c=gamma_c, gamma_n=gamma_n, E=E, U=U, energy=energy, enstrophy=enstrophy, time=time)
    
    def _load_data(self, input_file, target_files=None):
        """Loads input data and optional target data into the dataset
        Args:
            input_file (str): Name of the input .h5 file
            target_files (dict): Dictionary mapping task names to target .h5 file names
        """
        self.inputs = self._load_h5_file_with_data(input_file) #Dictionary with keys "file", "data"...

In [27]:
from torch.utils.data import SubsetRandomSampler, DataLoader, BatchSampler
dataset = CustomDataset("/Users/anthonypoole/Repositories/hw_snapshots.h5")
sampler = SubsetRandomSampler(torch.arange(0,40001,1).tolist())
loader = DataLoader(dataset=dataset, batch_size=None, sampler= BatchSampler(sampler, 32, True))
for item in loader:
    print(item)
    break

TypeError: Indexing elements must be in increasing order