In [None]:
import numpy as np
import pandas as pd
import torch

from pathlib import Path
from typing import List, Dict, Optional, Union
from torch.utils.data import Dataset

In [None]:

class WindowDataset(Dataset):
    """Slice a feature parquet/CSV table into sliding windows for multi-task crypto prediction.

    Each sample i consists of:
        • X  :  history_steps × n_features (float32)
        • y  : dict with keys 'cls', 'p90', 'p10', 'sigma' (tensor)

    Parameters
    ----------
    path              : Path to parquet or csv produced by make_targets.py (must contain targets).
    history_steps     : Number of rows in the historical window (e.g. 90 d @ 15-min = 8640).
    stride            : Step size (rows) to slide the window. 96≈1 day; 1 for full overlap.
    split             : One of 'train', 'val', 'test' — dataset will filter according to `split_ranges`.
    split_ranges      : Dict {"train": (start_dt, end_dt), ...}. Dates in pandas-parseable string or pd.Timestamp.
    feature_cols      : Optional[List[str]] — if None, use all columns minus targets.
    cache_in_memory   : Whether to keep numpy arrays in memory (speed vs RAM).
    """

    TARGETS = {
        "cls": np.int64,
        "p90": np.float32,
        "p10": np.float32,
        "sigma": np.float32,
    }

    def __init__(
        self,
        path: Union[str, Path],
        history_steps: int = 8640,
        stride: int = 96,
        split: str = "train",
        split_ranges: Optional[Dict[str, tuple]] = None,
        feature_cols: Optional[List[str]] = None,
        cache_in_memory: bool = True,
    ) -> None:
        super().__init__()
        self.history = history_steps
        self.stride = stride

        # 1. load
        path = Path(path)
        if path.suffix == ".parquet":
            df = pd.read_parquet(path)
        elif path.suffix == ".csv":
            df = pd.read_csv(path)
        else:
            raise ValueError("Unsupported file type: %s" % path.suffix)

        # 2. ensure datetime index exists
        if "datetime" in df.columns:
            df["datetime"] = pd.to_datetime(df["datetime"], utc=True)
            df.set_index("datetime", inplace=True)
        elif "time" in df.columns:  # seconds
            df.index = pd.to_datetime(df["time"], unit="s", utc=True)
        else:
            raise KeyError("Input table must contain 'datetime' or 'time' column.")

        # 3. select split by date range
        if split_ranges is not None:
            start, end = split_ranges[split]
            df = df.loc[start:end]

        df.sort_index(inplace=True)

        # 4. drop rows with any NaN in targets or features
        self.df = df.dropna(how="any").reset_index(drop=True)

        # 5. choose feature columns
        all_target_cols = list(self.TARGETS.keys()) + ["r90", "sigma30_real"]  # r90 may be needed in loss
        if feature_cols is None:
            feature_cols = [c for c in self.df.columns if c not in all_target_cols]
        self.feat_cols = feature_cols

        # 6. numpy cache
        X = self.df[self.feat_cols].to_numpy(dtype=np.float32)
        y = {t: self.df[t].to_numpy(dtype=dtype) for t, dtype in self.TARGETS.items()}

        if cache_in_memory:
            self.X_mem = X
            self.y_mem = y
        else:
            self.X_mem = None
            self.y_mem = None

        # 7. precompute window end indices
        first_idx = self.history - 1
        last_idx = len(self.df) - 1
        self.ends = np.arange(first_idx, last_idx + 1, self.stride, dtype=np.int64)

    def __len__(self) -> int:
        return len(self.ends)

    def _get_slice(self, end_idx: int):
        start = end_idx - self.history + 1
        if self.X_mem is not None:
            x = self.X_mem[start : end_idx + 1]
        else:
            x = self.df.iloc[start : end_idx + 1][self.feat_cols].to_numpy(np.float32)
        return x

    def __getitem__(self, idx):
        end = self.ends[idx]
        x = self._get_slice(end)
        target = {
            "cls": torch.tensor(self.df.at[end, "cls"], dtype=torch.long),
            "p90": torch.tensor(self.df.at[end, "p90"], dtype=torch.float32),
            "p10": torch.tensor(self.df.at[end, "p10"], dtype=torch.float32),
            "sigma": torch.tensor(self.df.at[end, "sigma30_real"], dtype=torch.float32),
        }
        return {
            "inputs": torch.from_numpy(x),
            **target,
        }
