In [7]:
import os
import json
import torch
import random
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.preprocessing import StandardScaler
from utils import savgol

label_names = ['AA_SN_X', 'AA_SN_Y', 'AA_SN_Z', 'GH_AA_X', 'GH_AA_Y', 'GH_AA_Z']


class Sensor2AngleDataset(Dataset):
    def __init__(self, dataDir, segment_len=128):
        self.dataDir = dataDir
        self.segment_len = segment_len
        self.data = self.load_data(dataDir)
        
        
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, index):
        pass
    
    
    def load_data(self, dataDir, doStandardize=True, doFilter=True):
        sensorPath = Path(dataDir) / "sensor.npy"
        anglePath = Path(dataDir) / "angle.npy"
        with open(sensorPath, 'rb') as f:
            sensorAll = np.load(f)
        with open(anglePath, 'rb') as f:
            angleAll = np.load(f)
        assert sensorAll.shape[0] == angleAll.shape[0]
        if doStandardize:
            sensorAll = self.standardize_sensor(sensorAll)
        if doFilter:
            sensorAll, angleAll = self.filter_sensor_and_angle(sensorAll, angleAll)
        datasetList = []
        for sensorSample, angleSample in zip(sensorAll, angleAll):
            datasetList.append([sensorSample, angleSample])
        
        return datasetList
    
    
    def standardize_sensor(self, sensorAll):
        scalerStd = StandardScaler()
        sensorStd = scalerStd.fit_transform(sensorAll)
        
        return sensorStd

    
    def filter_sensor_and_angle(self, sensorAll, angleAll, do_plot=False):
        for i in range(sensorAll.shape[1]):
            title = f'sensor_{i}'
            sensorAll[:,i] = savgol(sensorAll[:,i], 51, 2, title=title, do_plot=do_plot)
        angleFlt = []
        for i in range(angleAll.shape[1]):
            title = label_names[i]
            angleAll[:,i] = savgol(angleAll[:,i], 51, 2, title=title, do_plot=do_plot)
            
        return sensorAll, angleAll

In [8]:
dataset = Sensor2AngleDataset(dataDir='./data')