In [None]:
from torch.utils.data import Dataset, DataLoader


In [None]:
class BellingeDataset(Dataset):
    def __init__(self, H5_file_path, data_dir, series):

        self.H5_file_path = H5_file_path
        self.data_dir = data_dir
        self.hdf_file = h5py.File(H5_file_path, 'r')  # Keep the file open
        self.timestamps = pd.to_datetime(f[self.data_dir]['timestamps'][:])
        self.return_type = 'np'

    def __len__(self):
        return len(self.timestamps) - 1

    def __getitem__(self, idx):
        starttime = self.timestamps[idx]
        endtime = self.timestamps[idx + 1]
        data, timestamps, columns, start_idx, end_idx, column_indices = load_dataframe(
            self.H5_file_path, self.data_dir, return_type=self.return_type, 
            starttime=starttime, endtime=endtime
            )
        return data, timestamps, columns, start_idx, end_idx, column_indices



In [None]:
import numpy as np
import pandas as pd
import h5py
from torch.utils.data import Dataset

class BellingeDataset(Dataset):
    def __init__(self, H5_file_path: str, data_dir: str, return_type: str = 'np'):
        self.H5_file_path = H5_file_path
        self.data_dir = data_dir
        self.return_type = return_type
        self.hdf_file = h5py.File(H5_file_path, 'r')  # Keep the file open
        self.timestamps = pd.to_datetime(self.hdf_file[data_dir]['timestamps'][:])

    def __len__(self) -> int:
        return len(self.timestamps) - 1  # Ensure no out-of-bounds error

    def __getitem__(self, idx: int):
        # Get the start and end time for the current index
        starttime = self.timestamps[idx]
        endtime = self.timestamps[idx + 1]
        
        # Use load_dataframe to retrieve data
        data, timestamps, columns, start_idx, end_idx, column_indices = load_dataframe(
            self.hdf_file, self.data_dir, return_type=self.return_type, 
            starttime=starttime, endtime=endtime
        )
        return data, timestamps, columns, start_idx, end_idx, column_indices

    def __del__(self):
        self.hdf_file.close()  # Ensure the file is closed when the instance is deleted


In [None]:

def load_dataframe(hdf_file, group_path: str, return_type: str = 'df', 
                   starttime=None, endtime=None, columns=None):
    group = hdf_file[group_path]  # Directly access the group

    # Get the timestamps and columns
    timestamps = pd.to_datetime(group['timestamps'][:])
    all_columns = group['columns'][:].astype(str)

    # Identify the row indices based on starttime and endtime
    start_idx = np.searchsorted(timestamps, starttime) if starttime else 0
    end_idx = np.searchsorted(timestamps, endtime, side='right') if endtime else len(timestamps)

    # Filter column indices if specified
    if columns is not None:
        column_indices = [np.where(all_columns == col)[0][0] for col in columns]
    else:
        column_indices = slice(None)  # Select all columns if none are specified

    # Load the data
    data = group['data'][start_idx:end_idx, column_indices]
    timestamps = timestamps[start_idx:end_idx]
    columns = all_columns[column_indices]

    print(f"        Data loaded from group '{group.name}'")
    
    if return_type == 'np':
        return data, timestamps, columns, start_idx, end_idx, column_indices
    elif return_type == 'df':
        df = pd.DataFrame(data, columns=columns, index=timestamps)
        df.index.name = 'time'
        return df, timestamps, columns, start_idx, end_idx, column_indices
    else:
        raise ValueError(f"Unknown return type: {return_type}")