In [16]:
import data

In [17]:
metadata: Dict[str, Dict[str,str]] = dict(
    lunar = dict(
        catalog = '../data/lunar/training/catalogs/apollo12_catalog_GradeA_final.csv',
        train_path = '../data/mars/training/data',
        test_path = '../data/mars/test/data',
    ),
    mars = dict(
        catalog = '../data/mars/training/catalogs/Mars_InSight_training_catalog_final.csv',
        train_path = '../data/lunar/training/data/S12_GradeA',
        test_path = '../data/lunar/test/data'
    )
)

def recursive_search(parent: str) -> Generator[str, None, None]:
    for child in os.listdir(parent):
        child_path = os.path.join(parent, child)
        if os.path.isdir(child_path):
            yield from recursive_search(child_path)
        elif child.endswith('.csv'):
            yield child_path

In [100]:
from msilib import sequence
from lightning.pytorch import LightningDataModule
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Generator, Tuple
from torch.utils.data import Dataset
from torch import Tensor
import pandas as pd
import torch
import os


class TrainDataset(Dataset):
    def __init__(self, sequence_length: int) -> None:
        self.sequence_length = sequence_length
        self.filepaths = [filename for filename in recursive_search('../data/mars/training/data')] + \
                        [filename for filename in recursive_search('../data/lunar/training/data')]
        self.meta_lunar: pd.DataFrame = pd.read_csv(metadata['lunar']['catalog'], index_col = ['filename'])
        self.meta_mars: pd.DataFrame = pd.read_csv(metadata['mars']['catalog'], index_col = ['filename'])
        self.metadata: pd.DataFrame = pd.concat(
            [
                self.meta_lunar,
                self.meta_mars,
            ], axis = 0
        )

    def zero_padding(self, x: torch.Tensor, length: int = 60) -> torch.Tensor:
        pad_size = length - x.size(-1)
        if pad_size > 0:
            return torch.nn.functional.pad(x, (0, pad_size))
        return x
    
    def preprocessing(self) -> None:
        self.data: List[Tuple[Tensor, Tensor]] = []
        for file in self.filepaths:
            try:
                arrive = self.metadata.loc[['time_rel(sec)'], file]
                out: Tensor = self.get_data(file)
                self.data.extend((input, target)) ### mirar
            except IndexError:
                continue

    def __len__(self) -> int:
        return len(self.metadata)

    def get_data(self, file: str) -> List[Tuple[Tensor, Tensor]]: ## atencion
        ## get target (tensor de 0 y 1s tal que el idx del arrival este coincidiendo con el arrival real)
        print(self.metadata["time_rel(sec)"].loc[os.path.basename(file)])
        ## creas el tensor
        
        df: pd.DataFrame = pd.read_csv(file, parse_dates =['time_abs(%Y-%m-%dT%H:%M:%S.%f)'] ,index_col = ['time_abs(%Y-%m-%dT%H:%M:%S.%f)'])
        velocity: Tensor = torch.from_numpy(df["velocity(c/s)"].resample('s').mean().values)
        target = torch.zeros((velocity.shape))
        # target = 
        max: Tensor = torch.from_numpy(df.resample('s').max().values)
        min: Tensor = torch.from_numpy(df.resample('s').min().values)
        ### add the wavelet / fourier transform is needed

        ### separar de a sequence_length
        ### input(sequence_length, input_size) -> target(sequence_length) (0, 0, 1, 0)
        return tuple(self.zero_padding(i) for i in velocity.split(self.sequence_length)) #target)

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
        filename: str = self.metadata[idx]
        return self.data[filename]

In [101]:
dataset = TrainDataset(60)
dataset.get_data(dataset.filepaths[0])[-1]

2130.0


tensor([2.3166e+02, 2.2006e+02, 1.7895e+02, 1.9109e+02, 1.7878e+02, 1.6507e+02,
        1.4346e+02, 1.1570e+02, 1.2584e+02, 1.2864e+02, 1.2163e+02, 1.3797e+02,
        9.3658e+01, 9.8850e+01, 9.0091e+01, 8.4632e+01, 9.9423e+01, 1.3291e+02,
        1.6241e+02, 1.6645e+02, 1.8494e+02, 2.0618e+02, 1.7522e+02, 2.0021e+02,
        1.8730e+02, 1.7772e+02, 1.7130e+02, 1.5150e+02, 1.6209e+02, 1.5609e+02,
        1.6114e+02, 1.8311e+02, 1.8972e+02, 1.7930e+02, 1.6509e+02, 1.6191e+02,
        1.3282e+02, 1.3052e+02, 1.2823e+02, 9.9613e+01, 9.2671e+01, 6.3208e+01,
        6.9348e+01, 6.3965e+01, 5.3418e+01, 5.2325e+01, 5.2556e+01, 5.0903e+01,
        4.5167e+01, 3.7608e+01, 3.5124e+01, 2.9001e+01, 2.1964e+01, 1.5242e+01,
        9.7051e+00, 6.1164e+00, 3.2394e+00, 1.7777e+00, 6.4901e-01, 8.0729e-02],
       dtype=torch.float64)