In [3]:
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from math import pi
import pandas as pd
import torch
from torch.utils.data import Dataset
import random

In [6]:
class CSVDataset(Dataset):
    def __init__(self, fpath, window_size):
        csv = pd.read_csv(fpath)
        
        # 랜덤한 시간씩 window를 움직여서 랜덤한 overlap을 생성(DeepRNNFramework)
        self.yaw, self.pitch, self.roll = [], [], []
        self.iyaw, self.ipitch, self.iroll = [], [], []
        self.accx, self.accy, self.accz = [], [], []
        
        i = 0
        w = window_size
        while i <= len(csv) - w:
            self.yaw.append(csv.input_orientation_yaw[i:i+w].to_numpy())
            self.pitch.append(csv.input_orientation_pitch[i:i+w].to_numpy())
            self.roll.append(csv.input_orientation_roll[i:i+w].to_numpy())
            self.accx.append(csv.acceleration_x[i:i+w].to_numpy())
            self.accy.append(csv.acceleration_y[i:i+w].to_numpy())
            self.accz.append(csv.acceleration_z[i:i+w].to_numpy())
            self.iyaw.append(csv.input_orientation_yaw[i+18:i+18+w].to_numpy())
            self.ipitch.append(csv.input_orientation_pitch[i+18:i+18+w].to_numpy())
            self.iroll.append(csv.input_orientation_roll[i+18:i+18+w].to_numpy())
            i += random.randint(1, w)

    def __len__(self):
        return len(self.yaw)
    
    def __getitem__(self, idx):
        yaw = self.yaw[idx]
        pitch = self.pitch[idx]
        roll = self.roll[idx]
        accx = self.accx[idx]
        accy = self.accy[idx]
        accz = self.accz[idx]
        x = np.stack((yaw, pitch, roll, accx, accy, accz))
        x = torch.tensor(x, dtype=torch.float32)
        
        iyaw = self.iyaw[idx]
        ipitch = self.ipitch[idx]
        iroll = self.iroll[idx]
        y = np.stack((iyaw, ipitch, iroll))
        y = torch.tensor(y, dtype=torch.float32)
        
        return x, y

In [8]:
ds = CSVDataset('data/1108/user1_scene1_0.csv', 48)

In [12]:
ds[0][0].shape, ds[0][1].shape

(torch.Size([6, 48]), torch.Size([3, 48]))