In [20]:
import pandas as pd
import numpy as np
import json
import os
from os.path import join
from tqdm import tqdm

In [21]:
import torch
from torch.utils.data import DataLoader, Dataset

In [22]:
train_df = pd.read_csv(
    "train_df.csv",
    index_col=[0, 1],
    parse_dates=[1],
)

train_target = pd.read_csv(
    "train_target.csv",
    index_col=[0, 1],
    parse_dates=[1]
)

In [23]:
class SlidingWindowDataset(Dataset):
    def __init__(self, df: pd.DataFrame, target: pd.Series, window_size: int, stride: int = 1, get_step_next = False):
        # Convert data and targets to numpy
        self.features = list(df.columns)
        self.data = torch.as_tensor(df.values, dtype=torch.float32)
        self.target = torch.as_tensor(target.values, dtype=torch.float32) if target is not None else None

        # Save multindex for safety checks
        self.index = df.index
        self.window_size = window_size
        self.get_step_next = get_step_next
        if self.get_step_next:
            self.window_size += 1
        # Calculate valid windows for each run_id
        self.valid_windows = self._precompute_valid_windows(stride)

    def _precompute_valid_windows(self, stride):
        valid_windows = []
        run_ids = self.index.get_level_values(0).unique()

        for run_id in tqdm(run_ids, desc="Building safe windows"):
            # Get all indices for running run_id
            run_mask = self.index.get_level_values(0) == run_id
            run_indices = np.where(run_mask)[0]

            # Check run_id has enough points
            if len(run_indices) < self.window_size:
                continue

            # Generate end window indices ONLY in the borders of this run_id
            for end_pos in range(self.window_size, len(run_indices), stride):
                start_pos = end_pos - self.window_size
                start_idx = run_indices[start_pos]
                end_idx = run_indices[end_pos]

                # Extra check
                assert self.index[start_idx][0] == self.index[end_idx-1][0], "Window crosses run_id boundary!"

                valid_windows.append((start_idx, end_idx))

        return valid_windows

    def __len__(self):
        return len(self.valid_windows)

    def __getitem__(self, idx):
        start_idx, end_idx = self.valid_windows[idx]
        sample = self.data[start_idx:end_idx]
        target = self.target[start_idx:end_idx].max() if self.target is not None else sample[-1]
        return sample, target


In [24]:
window_size = 2
stride = 5

train_dataset = SlidingWindowDataset(
            df = train_df,
            target=train_target,
            window_size=window_size,
            stride=stride
        )

Building safe windows: 100%|██████████| 21/21 [00:09<00:00,  2.13it/s]


In [25]:
batch_size = 32
train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True
        )

In [26]:
for data, target in train_loader:
            print(data.shape, target.shape)

torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 2, 30]) torch.Size([32])
torch.Size([32, 

In [29]:
data, target = next(iter(train_loader))

In [30]:
data.shape

torch.Size([32, 2, 30])

In [31]:
target.shape

torch.Size([32])

In [32]:
target

tensor([ 5.,  3.,  7.,  7., 17., 20., 11., 12., 15.,  3., 12.,  1.,  0.,  5.,
        19., 15., 11., 12., 18., 12., 19.,  0.,  1., 11., 11.,  1.,  5.,  5.,
        20., 18., 15., 11.])

In [33]:
data[0]

tensor([[ 6.0358e-01,  6.0339e-01,  6.9003e+01,  3.5773e+05,  1.0000e-01,
          1.0000e-01,  1.6220e+01,  5.9625e+05,  9.0000e-01,  9.0000e-01,
          6.9838e+01,  6.6416e-01,  6.6416e-01,  1.0000e+00,  1.3236e+02,
          5.9625e+05,  1.6123e+00, -4.0025e+02,  5.5177e+01,  5.5184e+01,
          7.3796e-01,  9.0000e-01,  1.0000e+00,  5.7352e+02,  1.0000e+00,
          6.9374e+01,  6.8851e+01,  6.9486e+01,  6.9705e+01,  6.6174e+01],
        [ 6.0320e-01,  6.0301e-01,  6.8985e+01,  3.5773e+05,  1.0000e-01,
          1.0000e-01,  1.6136e+01,  5.9624e+05,  9.0000e-01,  9.0000e-01,
          6.9822e+01,  6.6416e-01,  6.6416e-01,  1.0000e+00,  1.3236e+02,
          5.9624e+05,  1.6123e+00, -4.0025e+02,  5.5177e+01,  5.5184e+01,
          7.3796e-01,  9.0000e-01,  1.0000e+00,  5.7352e+02,  1.0000e+00,
          6.9359e+01,  6.8846e+01,  6.9483e+01,  6.9697e+01,  6.6149e+01]])