In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler
import numpy as np

class TrafficDataset(Dataset):
    """
    트래픽 시계열 데이터를 엣지별 시퀀스 + 시간 특성(요일, 시각)과 함께 반환합니다.
    입력: traffic_data (T_total, E, C_all)
    출력: past_edges, future_edges 등
    """
    def __init__(self, traffic_data, window=12, week_steps=480*7):
        super().__init__()
        self.traffic = traffic_data          # (T_total, E, C_all)
        self.window = window                # 과거 시퀀스 길이(12)
        self.week_steps = week_steps
        self.day_steps = week_steps // 7     # 하루당 스텝 수 = 480 (8시간 * 60분)
        self.E = traffic_data.shape[1]       # 엣지 개수

        T_total = traffic_data.shape[0]
        # 과거 window를 사용하므로 시작 인덱스 범위 재계산
        self.min_start = (T_total // week_steps - 1) * week_steps + (self.window - 1)
        # +12까지 참조하므로 max_start = T_total - max(future offsets) - 1
        self.max_start = T_total - 12 - 1
        self.starts = list(range(self.min_start, self.max_start + 1))

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t0 = self.starts[idx]
        # 1) 과거 window 길이만큼
        past_idxs = np.arange(t0 - self.window + 1, t0 + 1)                  # shape=(12,)
        # 2) 미래 3개 시점: +3, +6, +12
        fut_offsets = np.array([3, 6, 12], dtype=np.int64)
        fut_idxs = t0 + fut_offsets                                         # shape=(3,)

        # raw slice
        past = self.traffic[past_idxs]                                       # (12, E, C_all)
        fut  = self.traffic[fut_idxs]                                        # (3,  E, C_all)

        # 채널 0~2: volume, density, flow
        Xp = torch.from_numpy(past[..., :3]).float()                        # (12, E, 3)
        Xf = torch.from_numpy(fut[...,  :3]).float()                        # (3,  E, 3)

        # 시간 특성
        tod_enc = ((past_idxs % self.day_steps) * 24.0 / self.day_steps).astype(np.float32)
        dow_enc = ((past_idxs // self.day_steps) % 7).astype(np.int64)
        tod_dec = (((fut_idxs) % self.day_steps) * 24.0 / self.day_steps).astype(np.float32)
        dow_dec = (((fut_idxs) // self.day_steps) % 7).astype(np.int64)

        # (seq, E, 1)
        tod_feat_enc = torch.from_numpy(tod_enc)[:, None].expand(-1, self.E).unsqueeze(-1)
        dow_feat_enc = torch.from_numpy(dow_enc).float()[:, None].expand(-1, self.E).unsqueeze(-1)
        tod_feat_dec = torch.from_numpy(tod_dec)[:, None].expand(-1, self.E).unsqueeze(-1)
        dow_feat_dec = torch.from_numpy(dow_dec).float()[:, None].expand(-1, self.E).unsqueeze(-1)

        # 최종 엣지 피처
        past_edges   = torch.cat([Xp, tod_feat_enc, dow_feat_enc], dim=-1)   # (12, E, 5)
        future_edges = torch.cat([Xf, tod_feat_dec, dow_feat_dec], dim=-1)   # (3,  E, 3)
        future_edges = future_edges[:,:,:3]

        # 장기 히스토리 (5주치)
        K = self.traffic.shape[0] // self.week_steps
        hist_long = torch.zeros((self.window, self.E, K, 3), dtype=torch.float32)
        for tau in range(1, self.window + 1):
            for k in range(1, K + 1):
                t_hist = t0 + tau - k * self.week_steps
                frame = self.traffic[t_hist]
                hist_long[tau-1, :, k-1] = torch.from_numpy(frame[:, :3])

        return {
            'past_edges':   past_edges,    # (12, E, 5) # 0 volume, 1 density, 2 flow, 3 tod, 4 dow
            'future_edges': future_edges,  # (3,  E, 3)
            'hist_long':    hist_long,     # (12, E, K, 3)
        }
