This notebook serves to generate the datasets for training, validation and testing, which will be used for training baseline models.

It follows the exact procedure and seed as the way those deep learning models are trained with.

In [1]:
import numpy as np
import torch
import os
import argparse

from torch.utils.data import DataLoader, Dataset, Subset

from utils import seed_torch, seed_worker
from train import create_parser

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = create_parser()

In [3]:
# Be careful to use parser.parse_args([]) instead of parser.parse_args(). Otherwise it will prompt error. 
# The issue lies in JupyterNotebook
# See answer in https://stackoverflow.com/questions/50360012/python-argparse-error-error-argument-count-invalid-int-value for more details
args = parser.parse_args([])

In [4]:
# For reproducibility
seed_torch(args.seed)

In [5]:
args.seed

912

In [6]:
class TrafficData(Dataset):
    """
    Load data under folders
    """
    def __init__(self, args):
        self.args = args  
        
        X_path = f"{self.args.data_dir}/np_in_5min.npy"  # sequence features of TMC segments in frequency of 5 min
        Y_path = f"{self.args.data_dir}/np_out_5min.npy"  # ground truth of TMC speed & incident data in frequency of 5 min

        self.X = torch.from_numpy(np.load(X_path)).float()  # (21060, feat_dim)

        # (21060, num_seg, 4) the last dimension refers to 1. speed of all vehicles, 2. speed of truck, 3. speed of personal vehicles, 4. incident status
        self.Y = torch.from_numpy(np.load(Y_path)).float()  
        
    def __len__(self):
        return self.X.size(0) 

    def __getitem__(self, idx):
        x_idx_base = idx // 180
        x_idx_remain = min(max(idx % 180, 6, self.args.seq_len_in-1), np.floor(186 - self.args.seq_len_out)) # ensure we have valid idx based on input and output sequence length
        idx = int(x_idx_remain + x_idx_base * 180)
        Y_idx = [idx-6 + i for i in range(self.args.seq_len_out+1)]  # be careful, the starting point (first idx) of Y is the same as the last idx of X, and won't count into output sequence length
        
        X = self.X[(idx-self.args.seq_len_in+1):idx+1, :]
        Y = self.Y[Y_idx, :, :]

        return X, Y


def get_dataset(args):
    """
    Creates training, validation and testing data loaders for model training
    """
    whole_dataset = TrafficData(args=args)

    # train_size = int(np.ceil(args.data_train_ratio * len(whole_dataset)))
    # test_size = len(whole_dataset) - train_size

    # make sure the splitting preserves the integrity of 180 slots in a day
    train_size = int(np.ceil(args.data_train_ratio * (len(whole_dataset)/180))) * 180   # 14760
    val_size = int(np.ceil(args.data_val_ratio * (len(whole_dataset)/180))) * 180   # 4320
    test_size = len(whole_dataset) - train_size - val_size   # 1980

    # split train and test dataset
    g = torch.Generator()
    g.manual_seed(args.seed)
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset = whole_dataset, lengths = [train_size, val_size, test_size], generator=g)

    return train_dataset, val_dataset, test_dataset


In [7]:
train_dataset, val_dataset, test_dataset = get_dataset(args)

In [8]:
np_train_in_5min = np.array([i[0].detach().numpy() for i in train_dataset])
np_train_out_5min = np.array([i[1].detach().numpy() for i in train_dataset])
np_val_in_5min = np.array([i[0].detach().numpy() for i in val_dataset])
np_val_out_5min = np.array([i[1].detach().numpy() for i in val_dataset])
np_test_in_5min = np.array([i[0].detach().numpy() for i in test_dataset])
np_test_out_5min = np.array([i[1].detach().numpy() for i in test_dataset])

In [9]:
np.save(f"{args.data_dir}/np_train_in_5min.npy", np_train_in_5min)
np.save(f"{args.data_dir}/np_train_out_5min.npy", np_train_out_5min)
np.save(f"{args.data_dir}/np_val_in_5min.npy", np_val_in_5min)
np.save(f"{args.data_dir}/np_val_out_5min.npy", np_val_out_5min)
np.save(f"{args.data_dir}/np_test_in_5min.npy", np_test_in_5min)
np.save(f"{args.data_dir}/np_test_out_5min.npy", np_test_out_5min)