In [1]:
import pyarrow.parquet as pq
from config import settings , parse_datetime_strings
from pathlib import Path
psd_nperseg = 8192
path_psd = Path(settings.default.path['processed_data']) / f'PSD_{psd_nperseg}.parquet'
table = pq.read_table(path_psd)

In [2]:
#load only the time column
time = table.column('time').to_pandas()
datetime_range = parse_datetime_strings(settings.split['train'])
start, end = datetime_range['start'].timestamp(), datetime_range['end'].timestamp()

time = time.index[(time >= start) & (time <= end)]

In [59]:
get_sensor_axis('ACCX_0')

'ACCX'

In [69]:
from typing import Tuple, List, Union
from datetime import datetime
from pathlib import Path
import torch
import numpy as np
import pyarrow.parquet as pq
from torch.utils.data import Dataset, DataLoader
from functools import cached_property

def get_sensor_axis(sensor_name: str):
    sensor_axis = sensor_name.split('_')[0][-1]
    return sensor_axis
def get_sensor_position(sensor_name: str):
    sensor_position = sensor_name.split('_')[1]
    return sensor_position

class PSDDataset(Dataset):
    def __init__(self, filename: Union[Path, str], datetime_range: Tuple[datetime, datetime],
                 drop: List[str] = None, transform=None, label_transform=None) -> None:
        """
        Initialize the dataset.
        """
        self.filename = filename
        self.transform = transform
        self.label_transform = label_transform
        self.drop = set(drop) if drop else set()

        # Open the parquet file
        self.parquet_file = pq.ParquetFile(filename)
        
        # Get keys within the datetime range
        time = pq.read_table(filename).column('time').to_pandas()
        self.keys = time.index[(time >= datetime_range[0].timestamp()) & (time <= datetime_range[1].timestamp())]

        self.initialize_dicts()
        
    def __len__(self) -> int:
        return len(self.keys)
    
    def initialize_dicts(self):
        psd_group = self.parquet_file.read_row_group(self.keys[0])
        sensor_names = set(np.unique(psd_group['sensor_name'].to_pandas().values[0]))
        
        # Filter out the sensors not in drop
        sensor_names = {sensor for sensor in sensor_names if sensor not in self.drop}

        axis_names = set(map(get_sensor_axis, sensor_names))
        position_names = set(map(get_sensor_position, sensor_names))

        # Store them as attributes
        self._sensor_names = list(sensor_names)
        self._axis_names = list(axis_names)
        self._position_names = list(position_names)
        
        self._mapping_dict = {sensor: i for i, sensor in enumerate(self._sensor_names)}
        self._mapping_position_dict = {position: i for i, position in enumerate(self._position_names)}
        self._mapping_axis_dict = {axis: i for i, axis in enumerate(self._axis_names)}
    
    @cached_property
    def sensor_names(self) -> List[str]:
        return self._sensor_names

    @cached_property
    def axis_names(self) -> List[str]:
        return self._axis_names

    @cached_property
    def position_names(self) -> List[str]:
        return self._position_names

    @cached_property
    def mapping_position_dict(self) -> dict:
        return self._mapping_position_dict

    @cached_property
    def mapping_axis_dict(self) -> dict:
        return self._mapping_axis_dict

    @cached_property
    def mapping_dict(self) -> dict:
        return self._mapping_dict
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        key = self.keys[index]
        
        # Read the required row group
        sample = self.parquet_file.read_row_group(key).to_pandas()
        
        psd = sample['ACC_psd'].values[0]
        sensor_name = sample['sensor_name'].values[0]
        psd = psd.reshape(len(sensor_name), -1)

        # Filter out the rows based on the `drop` attribute
        keep_rows = ~np.isin(sensor_name, list(self.drop))
        psd = psd[keep_rows]
        sensor_name = sensor_name[keep_rows]
        sensor_name = np.array([self._mapping_dict[sensor] for sensor in sensor_name])
        
        # Convert to tensors
        psd = torch.from_numpy(psd)
        sensor_name = torch.from_numpy(sensor_name)
        
        # Apply transformations if any
        if self.transform is not None:
            psd = self.transform(psd)
        if self.label_transform is not None:
            sensor_name = self.label_transform(sensor_name)
        
        return psd, sensor_name

# Example usage
from config import settings, parse_datetime_strings

range_dt = parse_datetime_strings(settings.split['train'])
training_range = (range_dt['start'], range_dt['end'])
dataset = PSDDataset(path_psd, training_range)
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)

from time import time
start = time()
s = start 
for i in dataloader:
    print('time',time()-start)
    start = time()
print('total time',time()-s)
