In [5]:
import os
import shutil
from typing import List, Tuple

import torch
import pandas as pd
from torch_geometric import data, nn
import torch_geometric.nn.functional as F

In [6]:
# number of possible 5 minutes in a days. Formula: 24 (hours) * 60 (minutes/hour) / 5 (minutes) = 288
N_DAY_SLOT = (24 * 60) // 5
config = {
    "F": 12,
    "H": 9,
    "N_DAYS": 44,
    "N_DAY_SLOT": N_DAY_SLOT,
}

config["N_SLOT"] = config["N_DAY_SLOT"] - (config["H"] + config["F"]) + 1

In [7]:
from typing import List, Tuple, Union


class TrafficDataset(data.InMemoryDataset):
    def __init__(
        self,
        config: dict,
        root: str,
        gat_version: bool = True,
        transform=None,
        pre_transform=None,
    ):
        self.config = config
        self.gat_version = gat_version
        super().__init__(root, transform, pre_transform)
        (
            self._data,
            self._slice,
            self._num_nodes,
            self._mean,
            self._std,
        ) = torch.load(self.processed_paths[0])

    # return the path of the file contains data which is processed
    @property
    def processed_file_names(self) -> str | List[str] | Tuple:
        return ["./data.pt"]

    # The path to the file contains data
    @property
    def raw_file_names(self) -> str | List[str] | Tuple:
        return [
            os.path.join(self.raw_dir, "PeMSD7_V_228.csv"),
            os.path.join(self.raw_dir, "PeMSD7_W_228.csv"),
        ]

    # download the raw dataset file
    def download(self):
        V_dest = os.path.join(self.raw_dir, "PeMSD7_V_228.csv")
        W_dest = os.path.join(self.raw_dir, "PeMSD7_W_228.csv")
        shutil.copyfile("../data/raw/PeMSD7_V_228.csv", V_dest)
        shutil.copyfile("../data/raw/PeMSD7_W_228.csv", W_dest)

    def process(self):
        df = pd.read_csv(self.raw_file_names[0], header=None)
        weight_df = pd.read_csv(self.raw_file_names[1], header=None)
        W = self._distance_to_weight(torch.from_numpy(weight_df.values))
        data_ = torch.from_numpy(df.values)
        mean = torch.mean(data_)
        std = torch.std(data_)
        _, num_nodes = data_.shape
        edge_index = torch.zeros((2, num_nodes**2), dtype=torch.long)
        edge_label = torch.zeros((num_nodes**2, 2))
        num_edges = 0
        # extract edge list from adjacency matrix
        for i in range(num_nodes):
            for j in range(num_nodes):
                if W[i, j] != 0:
                    edge_index[0, num_edges] = i
                    edge_index[1, num_edges] = j
                    edge_label[num_edges] = self.W[i, j]
                    num_edges += 1

        # resize edge list from number_nodes^2
        edge_index = edge_index.resize_((2, num_edges))
        edge_label = edge_label.resize_(num_edges, 1)
        sequences = self._speed2vec(
            edge_index,
            edge_label,
            num_nodes,
            self.config["N_DAYS"],
            self.config["N_SLOT"],
            data_,
            self.config["F"],
            self.config["H"],
        )
        data_, slices = self.collate(sequences)

        torch.save(
            (data_, slices, num_nodes, mean, std),
            self.processed_paths[0],
        )

    def _distance_to_weight(
        self,
        W: torch.tensor,
        sigma2: float = 0.1,
        epsilon: float = 0.5,
        gat_version: bool = False,
    ):
        num_nodes = W.shape[0]
        BASE_KM = 10_000.0
        W = W / BASE_KM
        W2 = W * W
        W_mask = torch.ones([num_nodes, num_nodes]) - torch.eye(num_nodes)
        W = (
            torch.exp(-W2 / sigma2)
            * (torch.exp(-W2 / sigma2) >= epsilon)
            * W_mask
        )

        if gat_version:
            W[W > 0] = 1
            W += torch.eye(num_nodes)

        return W

    def _speed2vec(
        self,
        edge_index: torch.tensor,
        edge_label: torch.tensor,
        num_nodes: int,
        n_days: int,
        n_slot: int,
        data_: torch.tensor,
        F: int,
        H: int,
    ):
        window_length = F + H
        sequences = []
        for i in range(n_days):
            for j in range(n_slot):
                G = data.Data()
                G.__num_nodes__ = num_nodes
                G.edge_index = edge_index
                G.edge_label = edge_label

                start = i * F + j
                end = start + window_length
                # transpose
                full_windows = data_[start:end:].T
                G.x = full_windows[:, 0:F]
                G.y = full_windows[:, F::]
                sequences.append(G)

        return sequences

In [8]:
dataset = TrafficDataset(config=config, root="../data/processed/")

In [9]:
dataset._data.x.shape

torch.Size([2688576, 12])