In [1]:
import import_ipynb


In [None]:
import os
from collections.abc import Iterable
from pathlib import Path
import scipy.sparse as sp
import pandas as pd
import numpy as np
import json
import import_ipynb

import torch
from torch_geometric.utils import dense_to_sparse
from torch.utils.data import random_split
from data.traffic_dataset import TrafficDataset

from src.data.wrapper import wrap_traffic_dataset
from src.data.utils import (
    generate_regression_task,
    generate_split, StandardScaler
)


def get_connectivity(adj_matrix_path):
    adj = sp.load_npz(adj_matrix_path)
    edge_indices, edge_values = dense_to_sparse(torch.tensor(adj.toarray()))
    edge_values = 1 / edge_values  # edge weights are [0, 1], convert to float
    return edge_indices, edge_values


def get_raw_data(dataset_path, split_ratio, n_hist, n_pred, norm):
    assert norm, 'Traffic data should be normalized for better performance'

    X_s, y_s = list(), list()
    scaler = None
    for split in ['train', 'val', 'test']:
        data_path = list(Path(dataset_path).glob(f'{split}*_hist{n_hist}_pred{n_pred}.npz'))

        if data_path:
            data = np.load(data_path[0])
            X_s.append(data['x'])
            y_s.append(data['y'])
            if split == 'train':
                scaler = StandardScaler(mean=data['mean'], std=data['std'])
        else:
            print(f"preprocessed data not found at {dataset_path}, generating new data")
            h5_path = list(Path(dataset_path).glob('*.h5'))
            add_time_in_day, add_time_in_week = None, None
            if h5_path:
                print(f'Loading data from {h5_path[0]}')
                df = pd.read_hdf(h5_path[0])
                add_time_in_day, add_time_in_week = True, False
                features, targets = generate_regression_task(
                    df, n_hist, n_pred,
                    add_time_in_day=add_time_in_day,
                    add_day_in_week=add_time_in_week,
                )
                features_fill, targets_fill = generate_regression_task(
                    df, n_hist, n_pred,
                    add_time_in_day=add_time_in_day,
                    add_day_in_week=add_time_in_week,
                    replace_drops=True,
                )

                (
                    (train_x, val_x, test_x,
                     train_y, val_y, test_y, scaler),
                    train_idx, val_idx, test_idx,
                ) = generate_split(
                    (features, features_fill),
                    (targets, targets_fill),
                    split_ratio,
                    norm
                )
            else:
                data_csv_path = os.path.join(dataset_path, 'vel.csv')
                print(f'Loading data from {data_csv_path}')
                # process X and get node features
                X = pd.read_csv(data_csv_path).to_numpy()  # shape [time slices, nodes]
                if len(X.shape) == 2:
                    X = np.expand_dims(X, 1)
                X = X.transpose((0, 2, 1)).astype(np.float32)
                features, targets = generate_regression_task(
                    X, n_hist, n_pred
                )
                features_fill, targets_fill = generate_regression_task(
                    X, n_hist, n_pred, replace_drops=True
                )

                (
                    (train_x, val_x, test_x,
                     train_y, val_y, test_y, scaler),
                    train_idx, val_idx, test_idx,
                ) = generate_split(
                    (features, features_fill),
                    (targets, targets_fill),
                    split_ratio,
                    norm
                )

            suffix = ''
            if add_time_in_day is not None:
                if add_time_in_day:
                    suffix += '_day'
                if add_time_in_week:
                    suffix += '_week'
            suffix += f'_hist{n_hist}_pred{n_pred}'
            np.savez_compressed(
                os.path.join(dataset_path, f'train{suffix}.npz'),
                x=train_x,
                y=train_y,
                idx=train_idx,
                mean=scaler.mean,
                std=scaler.std
            )
            np.savez_compressed(
                os.path.join(dataset_path, f'val{suffix}.npz'),
                x=val_x,
                y=val_y,
                idx=val_idx
            )
            np.savez_compressed(
                os.path.join(dataset_path, f'test{suffix}.npz'),
                x=test_x,
                y=test_y,
                idx=test_idx
            )
            return train_x, val_x, test_x, train_y, val_y, test_y, scaler

    return X_s + y_s + [scaler]


def get_dataset(
        mode: str = 'pretrain',
        data_dir: str = None,
        dataset_name: str = 'traffic',  # 바꿔주세요
        n_hist: int = 12,
        n_pred: int = 3,
        split_ratio=(0.7, 0.2, 0.1),
        graph_token=True,
        seed=0,
        task='pred',
        norm=True,
):
    assert mode in ["pretrain", "finetune", "valid", "test", "debug"]
    assert task == 'pred'

    # 1) 데이터 불러오기
    if not data_dir:
        raise ValueError("data_dir must point to your .npy file")
    traffic_data = np.load(data_dir)  # e.g. 'dataset/traffic_dataset_13.npy'

    # 2) TrafficDataset 생성
    full_ds = TrafficDataset(traffic_data, window=n_hist, week_steps=480*7)

    # 3) train/val/test 분할
    total = len(full_ds)
    n_train = int(total * split_ratio[0])
    n_val   = int(total * split_ratio[1])
    n_test  = total - n_train - n_val
    train_ds, val_ds, test_ds = random_split(full_ds, [n_train, n_val, n_test])

    # 4) 그래프 인코딩 정보: edge_adj_mat, edge_degree, edge_spd
    #    이 변수들은 모듈 전역 또는 TrafficDataset 내부에서 접근 가능하도록 저장해두세요
    from traffic_dataset import edge_adj_mat, edge_degree_list, edge_spd  # 미리 계산해 두었다고 가정

    # 각 데이터셋에 이 그래프 메타정보를 붙이려면
    for ds in (train_ds, val_ds, test_ds):
        ds.edge_adj_mat     = edge_adj_mat
        ds.edge_degree      = edge_degree_list
        ds.edge_spd         = edge_spd

    # 5) 반환
    return {
        'train_dataset': train_ds,
        'valid_dataset': val_ds,
        'test_dataset': test_ds,
    }