# DataModule

## 0. imports

In [1]:
%load_ext jupyter_black

In [2]:
import sys

sys.path.append("..")

In [3]:
import os
import re
import glob

import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler

In [7]:
from src.dataset.dataset import ETTDataset
from src.dataset.preprocess import load_data, trn_val_tst_split

## 1. DataModule

In [16]:
class ETTDataModule:
    def __init__(
        self,
        data_path: str,
        task: str = "M",
        freq: str = "h",
        target: str = "OT",
        seq_len: int = 96,
        label_len: int = 48,
        pred_len: int = 96,
        use_scaler: bool = True,
        use_time_enc: bool = True,
        batch_size: int = 32,
    ):
        self.data_path = data_path
        self.task = task
        self.freq = freq
        self.target = target
        self.seq_len = seq_len
        self.label_len = label_len
        self.pred_len = pred_len
        self.use_scaler = use_scaler
        self.use_time_enc = use_time_enc

        self.batch_size = batch_size

        self.scaler = None
        if self.use_scaler:
            self.scaler = StandardScaler()

        self.split_idx_dict = {
            "train": [0, 12 * 30 * 24],
            "val": [12 * 30 * 24 - seq_len, 12 * 30 * 24 + 4 * 30 * 24],
            "test": [12 * 30 * 24 + 4 * 30 * 24 - seq_len, 12 * 30 * 24 + 8 * 30 * 24],
        }

        self.setup()

    def setup(self):
        df = load_data(self.data_path, task=self.task, target=self.target)
        data_dict = trn_val_tst_split(df, self.split_idx_dict, self.scaler)

        self.trainset = ETTDataset(
            data_dict["train"],
            self.seq_len,
            self.label_len,
            self.pred_len,
            self.freq,
            self.use_time_enc,
        )

        self.valset = ETTDataset(
            data_dict["val"],
            self.seq_len,
            self.label_len,
            self.pred_len,
            self.freq,
            self.use_time_enc,
        )

        self.testset = ETTDataset(
            data_dict["test"],
            self.seq_len,
            self.label_len,
            self.pred_len,
            self.freq,
            self.use_time_enc,
        )

    def train_dataloader(self):
        return DataLoader(
            self.trainset, batch_size=self.batch_size, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.valset, batch_size=self.batch_size, shuffle=True, drop_last=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.valset, batch_size=self.batch_size, shuffle=False, drop_last=False
        )

In [17]:
dm_params = {
    "data_path": "../data/ETT-small/ETTh1.csv",
    "task": "M",
    "freq": "h",
    "target": "OT",
    "seq_len": 96,
    "label_len": 48,
    "pred_len": 96,
    "use_scaler": True,
    "use_time_enc": True,
    "batch_size": 32,
}


dm = ETTDataModule(**dm_params)

In [19]:
train_dataloader = dm.train_dataloader()
val_dataloader = dm.val_dataloader()
test_dataloader = dm.test_dataloader()

In [20]:
train_batch = next(iter(train_dataloader))

In [21]:
train_batch.keys()

dict_keys(['past_values', 'past_time_features', 'future_values', 'future_time_features'])

In [25]:
train_batch["past_values"].shape

torch.Size([32, 96, 7])