In [2]:
import sys
sys.path.append('/home/zwt/Bigscity-LibCity')
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from logging import getLogger
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libcity.model import loss
from scipy.sparse.linalg import eigs

In [3]:
from Embedding import GNN_Based_Layers as GNN
from ST_backbone import CNN_Based_Models as TCNN
from ST_backbone import STAttention_Based_Models as Attn
from Output import LinearTransform as LT


class ASTGCNBlock(nn.Module):
    def __init__(self, device, in_channels, k, nb_chev_filter, nb_time_filter,
                 time_strides, cheb_polynomials, num_of_nodes, num_of_timesteps):
        super(ASTGCNBlock, self).__init__()
        self.TAt = Attn.TemporalAttention(device, in_channels, num_of_nodes, num_of_timesteps)
        self.SAt = Attn.SpatialAttention(device, in_channels, num_of_nodes, num_of_timesteps)
        self.cheb_conv_SAt = GNN.ChebConvWithSAt(k, cheb_polynomials, in_channels, nb_chev_filter)
        # 时间卷积: 输入时间长度 = num_of_timesteps = time_strides * output_window
        # 输入必须是输出output_window的固定倍数！
        # ker=3, pad=2, stride=time_strides
        # 输出时间长度 = (time_strides * output_window + 2 * pad - ker) / time_strides + 1 = output_window
        self.time_conv = nn.Conv2d(nb_chev_filter, nb_time_filter, kernel_size=(1, 3),
                                   stride=(1, time_strides), padding=(0, 1))
        # 时间维度上卷积: 输入时间长度 = num_of_timesteps = time_strides * output_window
        # ker=1, stride=time_strides
        # 输出时间长度 = (time_strides * output_window - ker) / time_strides + 1 = output_window
        self.residual_conv = nn.Conv2d(in_channels, nb_time_filter, kernel_size=(1, 1), stride=(1, time_strides))
        self.ln = nn.LayerNorm(nb_time_filter)  # 需要将channel放到最后一个维度上

    def forward(self, x):
        """
        Args:
            x: (batch_size, N, F_in, T)

        Returns:
            torch.tensor: (batch_size, N, nb_time_filter, output_window)
        """
        batch_size, num_of_vertices, num_of_features, num_of_timesteps = x.shape

        # TAt
        temporal_at = self.TAt(x)  # (B, T, T)

        x_tat = torch.matmul(x.reshape(batch_size, -1, num_of_timesteps), temporal_at)\
            .reshape(batch_size, num_of_vertices, num_of_features, num_of_timesteps)
        # 结合时间注意力：(B, N*F_in, T) * (B, T, T) -> (B, N*F_in, T) -> (B, N, F_in, T)

        # SAt
        spatial_at = self.SAt(x_tat)  # (B, N, N)

        # 结合空间注意力的图卷积 cheb gcn
        spatial_gcn = self.cheb_conv_SAt(x, spatial_at)  # (B, N, F_out, T), F_out = nb_chev_filter

        # convolution along the time axis
        time_conv_output = self.time_conv(spatial_gcn.permute(0, 2, 1, 3))
        # (B, N, F_out, T) -> (B, F_out, N, T) 用(1,3)的卷积核去做->(B, F_out', N, T') F_out'=nb_time_filter

        # residual shortcut
        x_residual = self.residual_conv(x.permute(0, 2, 1, 3))
        # (B, N, F_in, T) -> (B, F_in, N, T) 用(1,1)的卷积核去做->(B, F_out', N, T') F_out'=nb_time_filter

        x_residual = self.ln(F.relu(x_residual + time_conv_output).permute(0, 3, 2, 1)).permute(0, 2, 3, 1)
        # (B, F_out', N, T') -> (B, T', N, F_out') -ln -> (B, T', N, F_out') -> (B, N, F_out', T')

        return x_residual

class ASTGCNSubmodule(nn.Module):
    def __init__(self, device, nb_block, in_channels, k, nb_chev_filter, nb_time_filter,
                 time_strides, cheb_polynomials, output_window, output_dim, num_of_vertices):
        super(ASTGCNSubmodule, self).__init__()

        self.BlockList = nn.ModuleList([ASTGCNBlock(device, in_channels, k, nb_chev_filter,
                                                    nb_time_filter, time_strides, cheb_polynomials,
                                                    num_of_vertices, time_strides * output_window)])

        self.BlockList.extend([ASTGCNBlock(device, nb_time_filter, k, nb_chev_filter,
                                           nb_time_filter, 1, cheb_polynomials,
                                           num_of_vertices, output_window)
                               for _ in range(nb_block-1)])

        self.final_conv = nn.Conv2d(output_window, output_window,
                                    kernel_size=(1, nb_time_filter - output_dim + 1))

        self.fusionlayer = LT.FusionLayer(output_window, num_of_vertices, output_dim, device)

    def forward(self, x):
        """
        Args:
            x: (B, T_in, N_nodes, F_in)

        Returns:
            torch.tensor: (B, T_out, N_nodes, out_dim)
        """
        x = x.permute(0, 2, 3, 1)  # (B, N, F_in(feature_dim), T_in)
        for block in self.BlockList:
            x = block(x)  # 每个时空块的输出维度是nb_time_filter
        # (B, N, F_out(nb_time_filter), T_out(output_window))
        # 将nb_time_filter变成output_dim
        output = self.final_conv(x.permute(0, 3, 1, 2))
        # (B, N, F_out, T_out) --> (B, T_out, N, F_out) --> conv<1,F_out-out_dim+1> --> (B, T_out, N, out_dim)
        output = self.fusionlayer(output)
        return output

In [4]:

class ASTGCN(AbstractTrafficStateModel):
    def __init__(self, config, data_feature):
        super().__init__(config, data_feature)

        self.num_nodes = self.data_feature.get('num_nodes', 1)
        self.feature_dim = self.data_feature.get('feature_dim', 1)
        self.len_period = self.data_feature.get('len_period', 0)
        self.len_trend = self.data_feature.get('len_trend', 0)
        self.len_closeness = self.data_feature.get('len_closeness', 0)
        if self.len_period == 0 and self.len_trend == 0 and self.len_closeness == 0:
            raise ValueError('Num of days/weeks/hours are all zero! Set at least one of them not zero!')
        self.output_dim = self.data_feature.get('output_dim', 1)

        self.output_window = config.get('output_window', 1)
        self.device = config.get('device', torch.device('cpu'))
        self.nb_block = config.get('nb_block', 2)
        self.K = config.get('K', 3)
        self.nb_chev_filter = config.get('nb_chev_filter', 64)
        self.nb_time_filter = config.get('nb_time_filter', 64)

        adj_mx = self.data_feature.get('adj_mx')
        l_tilde = GNN.scaled_laplacian(adj_mx)
        self.cheb_polynomials = [torch.from_numpy(i).type(torch.FloatTensor).to(self.device)
                                 for i in GNN.cheb_polynomial(l_tilde, self.K)]
        self._logger = getLogger()
        self._scaler = self.data_feature.get('scaler')

        if self.len_closeness > 0:
            self.hours_ASTGCN_submodule = \
                ASTGCNSubmodule(self.device, self.nb_block, self.feature_dim,
                                self.K, self.nb_chev_filter, self.nb_time_filter,
                                self.len_closeness // self.output_window, self.cheb_polynomials,
                                self.output_window, self.output_dim, self.num_nodes)
        if self.len_period > 0:
            self.days_ASTGCN_submodule = \
                ASTGCNSubmodule(self.device, self.nb_block, self.feature_dim,
                                self.K, self.nb_chev_filter, self.nb_time_filter,
                                self.len_period // self.output_window, self.cheb_polynomials,
                                self.output_window, self.output_dim, self.num_nodes)
        if self.len_trend > 0:
            self.weeks_ASTGCN_submodule = \
                ASTGCNSubmodule(self.device, self.nb_block, self.feature_dim,
                                self.K, self.nb_chev_filter, self.nb_time_filter,
                                self.len_trend // self.output_window, self.cheb_polynomials,
                                self.output_window, self.output_dim, self.num_nodes)
        self._init_parameters()

    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)

    def forward(self, batch):
        x = batch['X']  # (B, Tw+Td+Th, N_nodes, F_in)
        # 时间维度(第1维)上的顺序是CPT，即
        # [0, len_closeness) -- input1
        # [len_closeness, len_closeness+len_period) -- input2
        # [len_closeness+len_period, len_closeness+len_period+len_trend) -- input3
        output = 0
        if self.len_closeness > 0:
            begin_index = 0
            end_index = begin_index + self.len_closeness
            output_hours = self.hours_ASTGCN_submodule(x[:, begin_index:end_index, :, :])
            output += output_hours
        if self.len_period > 0:
            begin_index = self.len_closeness
            end_index = begin_index + self.len_period
            output_days = self.days_ASTGCN_submodule(x[:, begin_index:end_index, :, :])
            output += output_days
        if self.len_trend > 0:
            begin_index = self.len_closeness + self.len_period
            end_index = begin_index + self.len_trend
            output_weeks = self.weeks_ASTGCN_submodule(x[:, begin_index:end_index, :, :])
            output += output_weeks
        return output  # (B, Tp, N_nodes, F_out)

    def calculate_loss(self, batch):
        y_true = batch['y']
        y_predicted = self.predict(batch)
        y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
        y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
        return loss.masked_mse_torch(y_predicted, y_true)

    def predict(self, batch):
        return self.forward(batch)

In [2]:
from libcity.data import get_dataset
from libcity.utils import get_executor
from libcity.utils import get_model
from libcity.utils import get_logger

config = {'task': 'traffic_state_pred', 'model': 'ASTGCN', 'dataset': 'PEMSD4', 'saved_model': True, 'train': True, 'dataset_class': 'ASTGCNDataset', 'executor': 'TrafficStateExecutor', 'evaluator': 'TrafficStateEvaluator', 'nb_block': 2, 'K': 3, 'nb_chev_filter': 64, 'nb_time_filter': 64, 'scaler': 'standard', 'load_external': False, 'normal_external': False, 'ext_scaler': 'none', 'add_time_in_day': False, 'add_day_in_week': False, 'train_rate': 0.6, 'eval_rate': 0.2, 'max_epoch': 100, 'learner': 'adam', 'learning_rate': 0.0001, 'lr_decay': False, 'clip_grad_norm': False, 'use_early_stop': False, 'batch_size': 64, 'cache_dataset': True, 'num_workers': 0, 'pad_with_last_sample': True, 'input_window': 12, 'output_window': 12, 'len_closeness': 2, 'len_period': 1, 'len_trend': 2, 'interval_period': 1, 'interval_trend': 7, 'gpu': True, 'gpu_id': 0, 'train_loss': 'none', 'epoch': 0, 'weight_decay': 0, 'lr_epsilon': 1e-08, 'lr_beta1': 0.9, 'lr_beta2': 0.999, 'lr_alpha': 0.99, 'lr_momentum': 0, 'lr_scheduler': 'multisteplr', 'lr_decay_ratio': 0.1, 'steps': [5, 20, 40, 70], 'step_size': 10, 'lr_T_max': 30, 'lr_eta_min': 0, 'lr_patience': 10, 'lr_threshold': 0.0001, 'max_grad_norm': 1.0, 'patience': 50, 'log_level': 'INFO', 'log_every': 1, 'load_best_epoch': True, 'hyper_tune': False, 'metrics': ['MAE', 'MAPE', 'MSE', 'RMSE', 'masked_MAE', 'masked_MAPE', 'masked_MSE', 'masked_RMSE', 'R2', 'EVAR'], 'mode': 'single', 'save_modes': ['csv'], 'geo': {'including_types': ['Point'], 'Point': {}}, 'rel': {'including_types': ['geo'], 'geo': {'cost': 'num'}}, 'dyna': {'including_types': ['state'], 'state': {'entity_id': 'geo_id', 'traffic_flow': 'num', 'traffic_occupancy': 'num', 'traffic_speed': 'num'}}, 'data_col': ['traffic_flow', 'traffic_occupancy', 'traffic_speed'], 'weight_col': 'cost', 'data_files': ['PEMSD4'], 'geo_file': 'PEMSD4', 'rel_file': 'PEMSD4', 'output_dim': 3, 'time_intervals': 300, 'init_weight_inf_or_zero': 'zero', 'set_weight_link_or_dist': 'link', 'calculate_weight_adj': False, 'weight_adj_epsilon': 0.1}
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
config['device'] = torch.device("cuda" if torch.cuda.is_available() and config['gpu'] else "cpu")

logger = get_logger(config)
# 加载数据集
dataset = get_dataset(config)
# 转换数据，并划分数据集
train_data, valid_data, test_data = dataset.get_data()
print(len(train_data), len(train_data.dataset), train_data.dataset[0][0].shape, train_data.dataset[0][1].shape, train_data.batch_size)
print(len(valid_data), len(valid_data.dataset), valid_data.dataset[0][0].shape, valid_data.dataset[0][1].shape, valid_data.batch_size)
print(len(test_data), len(test_data.dataset), test_data.dataset[0][0].shape, test_data.dataset[0][1].shape, test_data.batch_size)

data_feature = dataset.get_data_feature()
print(data_feature['adj_mx'].shape)
print(data_feature['adj_mx'].sum())

model = get_model(config, data_feature)

# 加载执行器
model_cache_file = './libcity/cache/model_cache/' + config['model'] + '_' + config['dataset'] + '.m'
executor = get_executor(config, model)

# 训练
executor.train(train_data, valid_data)
executor.save_model(model_cache_file)
executor.load_model(model_cache_file)
# 评估，评估结果将会放在 cache/evaluate_cache 下
executor.evaluate(test_data)

2021-12-03 21:23:48,078 - INFO - Log directory: ./libcity/log
2021-12-03 21:23:48,740 - INFO - Loaded file PEMSD4.geo, num_nodes=307
2021-12-03 21:23:48,744 - INFO - set_weight_link_or_dist: link
2021-12-03 21:23:48,745 - INFO - init_weight_inf_or_zero: zero
2021-12-03 21:23:48,749 - INFO - Loaded file PEMSD4.rel, shape=(307, 307)
2021-12-03 21:23:48,750 - INFO - Loading file PEMSD4.dyna
2021-12-03 21:23:52,566 - INFO - Loaded file PEMSD4.dyna, shape=(16992, 307, 3)
2021-12-03 21:23:57,150 - INFO - closeness: (12949, 24, 307, 3)
2021-12-03 21:23:57,659 - INFO - period: (12949, 12, 307, 3)
2021-12-03 21:23:58,649 - INFO - trend: (12949, 24, 307, 3)
2021-12-03 21:24:04,998 - INFO - Dataset created
2021-12-03 21:24:04,999 - INFO - x shape: (12949, 60, 307, 3), y shape: (12949, 12, 307, 3)
2021-12-03 21:24:05,016 - INFO - train	x: (7769, 60, 307, 3), y: (7769, 12, 307, 3)
2021-12-03 21:24:05,017 - INFO - eval	x: (2590, 60, 307, 3), y: (2590, 12, 307, 3)
2021-12-03 21:24:05,018 - INFO - tes

2021-12-03 21:30:17,251 - INFO - hours_ASTGCN_submodule.BlockList.0.ln.bias	torch.Size([64])	cuda:0	True
2021-12-03 21:30:17,251 - INFO - hours_ASTGCN_submodule.BlockList.1.TAt.U1	torch.Size([307])	cuda:0	True
2021-12-03 21:30:17,252 - INFO - hours_ASTGCN_submodule.BlockList.1.TAt.U2	torch.Size([64, 307])	cuda:0	True
2021-12-03 21:30:17,252 - INFO - hours_ASTGCN_submodule.BlockList.1.TAt.U3	torch.Size([64])	cuda:0	True
2021-12-03 21:30:17,253 - INFO - hours_ASTGCN_submodule.BlockList.1.TAt.be	torch.Size([1, 12, 12])	cuda:0	True
2021-12-03 21:30:17,253 - INFO - hours_ASTGCN_submodule.BlockList.1.TAt.Ve	torch.Size([12, 12])	cuda:0	True
2021-12-03 21:30:17,254 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.W1	torch.Size([12])	cuda:0	True
2021-12-03 21:30:17,255 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.W2	torch.Size([64, 12])	cuda:0	True
2021-12-03 21:30:17,255 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.W3	torch.Size([64])	cuda:0	True
2021-12-03 21:30:17,256 - INFO - hours_AST

2021-12-03 21:30:17,301 - INFO - weeks_ASTGCN_submodule.BlockList.0.cheb_conv_SAt.Theta.0	torch.Size([3, 64])	cuda:0	True
2021-12-03 21:30:17,302 - INFO - weeks_ASTGCN_submodule.BlockList.0.cheb_conv_SAt.Theta.1	torch.Size([3, 64])	cuda:0	True
2021-12-03 21:30:17,303 - INFO - weeks_ASTGCN_submodule.BlockList.0.cheb_conv_SAt.Theta.2	torch.Size([3, 64])	cuda:0	True
2021-12-03 21:30:17,303 - INFO - weeks_ASTGCN_submodule.BlockList.0.time_conv.weight	torch.Size([64, 64, 1, 3])	cuda:0	True
2021-12-03 21:30:17,304 - INFO - weeks_ASTGCN_submodule.BlockList.0.time_conv.bias	torch.Size([64])	cuda:0	True
2021-12-03 21:30:17,304 - INFO - weeks_ASTGCN_submodule.BlockList.0.residual_conv.weight	torch.Size([64, 3, 1, 1])	cuda:0	True
2021-12-03 21:30:17,305 - INFO - weeks_ASTGCN_submodule.BlockList.0.residual_conv.bias	torch.Size([64])	cuda:0	True
2021-12-03 21:30:17,306 - INFO - weeks_ASTGCN_submodule.BlockList.0.ln.weight	torch.Size([64])	cuda:0	True
2021-12-03 21:30:17,306 - INFO - weeks_ASTGCN_su

2021-12-03 21:46:55,724 - INFO - epoch complete!
2021-12-03 21:46:55,727 - INFO - evaluating now!
2021-12-03 21:47:07,557 - INFO - Epoch [11/100] train_loss: 555.8788, val_loss: 528.2083, lr: 0.000100, 78.72s
2021-12-03 21:47:07,606 - INFO - Saved model at 11
2021-12-03 21:47:07,607 - INFO - Val loss decrease from 543.9979 to 528.2083, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch11.tar
2021-12-03 21:48:15,602 - INFO - epoch complete!
2021-12-03 21:48:15,606 - INFO - evaluating now!
2021-12-03 21:48:26,923 - INFO - Epoch [12/100] train_loss: 536.2228, val_loss: 519.0206, lr: 0.000100, 79.32s
2021-12-03 21:48:26,972 - INFO - Saved model at 12
2021-12-03 21:48:26,973 - INFO - Val loss decrease from 528.2083 to 519.0206, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch12.tar
2021-12-03 21:49:41,628 - INFO - epoch complete!
2021-12-03 21:49:41,632 - INFO - evaluating now!
2021-12-03 21:49:53,586 - INFO - Epoch [13/100] train_loss: 520.4568, val_loss: 510.3232, lr: 0.00

2021-12-03 22:16:26,811 - INFO - epoch complete!
2021-12-03 22:16:26,814 - INFO - evaluating now!
2021-12-03 22:16:39,014 - INFO - Epoch [32/100] train_loss: 394.4215, val_loss: 438.7227, lr: 0.000100, 86.98s
2021-12-03 22:16:39,069 - INFO - Saved model at 32
2021-12-03 22:16:39,070 - INFO - Val loss decrease from 442.2785 to 438.7227, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch32.tar
2021-12-03 22:17:52,738 - INFO - epoch complete!
2021-12-03 22:17:52,742 - INFO - evaluating now!
2021-12-03 22:18:05,043 - INFO - Epoch [33/100] train_loss: 391.1686, val_loss: 438.0845, lr: 0.000100, 85.97s
2021-12-03 22:18:05,096 - INFO - Saved model at 33
2021-12-03 22:18:05,097 - INFO - Val loss decrease from 438.7227 to 438.0845, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch33.tar
2021-12-03 22:19:18,945 - INFO - epoch complete!
2021-12-03 22:19:18,949 - INFO - evaluating now!
2021-12-03 22:19:31,340 - INFO - Epoch [34/100] train_loss: 387.9641, val_loss: 434.9087, lr: 0.00

2021-12-03 22:45:14,086 - INFO - epoch complete!
2021-12-03 22:45:14,091 - INFO - evaluating now!
2021-12-03 22:45:25,883 - INFO - Epoch [53/100] train_loss: 350.9382, val_loss: 404.2362, lr: 0.000100, 83.04s
2021-12-03 22:45:25,933 - INFO - Saved model at 53
2021-12-03 22:45:25,934 - INFO - Val loss decrease from 407.4367 to 404.2362, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch53.tar
2021-12-03 22:46:39,408 - INFO - epoch complete!
2021-12-03 22:46:39,412 - INFO - evaluating now!
2021-12-03 22:46:50,312 - INFO - Epoch [54/100] train_loss: 349.8194, val_loss: 406.5793, lr: 0.000100, 84.38s
2021-12-03 22:48:01,876 - INFO - epoch complete!
2021-12-03 22:48:01,880 - INFO - evaluating now!
2021-12-03 22:48:13,523 - INFO - Epoch [55/100] train_loss: 348.6784, val_loss: 407.3378, lr: 0.000100, 83.21s
2021-12-03 22:49:25,971 - INFO - epoch complete!
2021-12-03 22:49:25,974 - INFO - evaluating now!
2021-12-03 22:49:37,300 - INFO - Epoch [56/100] train_loss: 347.7945, val_loss: 40

2021-12-03 23:27:21,786 - INFO - Saved model at 83
2021-12-03 23:27:21,787 - INFO - Val loss decrease from 398.5360 to 398.0460, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch83.tar
2021-12-03 23:28:33,594 - INFO - epoch complete!
2021-12-03 23:28:33,597 - INFO - evaluating now!
2021-12-03 23:28:45,288 - INFO - Epoch [84/100] train_loss: 327.0211, val_loss: 400.1599, lr: 0.000100, 83.50s
2021-12-03 23:29:59,398 - INFO - epoch complete!
2021-12-03 23:29:59,402 - INFO - evaluating now!
2021-12-03 23:30:11,015 - INFO - Epoch [85/100] train_loss: 326.4068, val_loss: 398.9667, lr: 0.000100, 85.73s
2021-12-03 23:31:23,095 - INFO - epoch complete!
2021-12-03 23:31:23,099 - INFO - evaluating now!
2021-12-03 23:31:34,137 - INFO - Epoch [86/100] train_loss: 325.8729, val_loss: 398.1031, lr: 0.000100, 83.12s
2021-12-03 23:32:46,816 - INFO - epoch complete!
2021-12-03 23:32:46,820 - INFO - evaluating now!
2021-12-03 23:32:58,583 - INFO - Epoch [87/100] train_loss: 325.4769, val_loss: 40

Unnamed: 0,MAE,MAPE,MSE,RMSE,masked_MAE,masked_MAPE,masked_MSE,masked_RMSE,R2,EVAR
1,7.100527,inf,309.573944,17.594713,7.07377,7.003664,300.052643,17.322027,0.981823,0.981845
2,7.470409,inf,334.062714,18.277382,7.437992,11.226524,322.472931,17.957531,0.980384,0.980431
3,7.498755,inf,348.67746,18.672907,7.462386,5.166659,335.443512,18.315117,0.979529,0.979553
4,7.704169,inf,364.374573,19.088598,7.66335,6.639912,349.779236,18.702385,0.978605,0.978639
5,7.807232,inf,374.019867,19.339594,7.763651,5.993083,358.362457,18.930464,0.978038,0.978064
6,7.9368,inf,387.088226,19.674559,7.889977,6.365098,370.097748,19.237925,0.977266,0.977284
7,8.01377,inf,396.152313,19.903576,7.964163,6.409945,378.168945,19.446566,0.976732,0.976746
8,8.104656,inf,404.831238,20.120419,8.053809,6.100633,385.923187,19.644928,0.976222,0.976231
9,8.175008,inf,413.438904,20.333197,8.119801,6.7431,393.290192,19.831545,0.975718,0.975739
10,8.350195,inf,423.480988,20.578653,8.292344,6.374527,402.514191,20.062756,0.975125,0.975167


In [5]:
from libcity.data import get_dataset
from libcity.utils import get_executor
from libcity.utils import get_model
from libcity.utils import get_logger

config = {'task': 'traffic_state_pred', 'model': 'ASTGCN', 'dataset': 'PEMSD4', 'saved_model': True, 'train': True, 'dataset_class': 'ASTGCNDataset', 'executor': 'TrafficStateExecutor', 'evaluator': 'TrafficStateEvaluator', 'nb_block': 2, 'K': 3, 'nb_chev_filter': 64, 'nb_time_filter': 64, 'scaler': 'standard', 'load_external': False, 'normal_external': False, 'ext_scaler': 'none', 'add_time_in_day': False, 'add_day_in_week': False, 'train_rate': 0.6, 'eval_rate': 0.2, 'max_epoch': 100, 'learner': 'adam', 'learning_rate': 0.0001, 'lr_decay': False, 'clip_grad_norm': False, 'use_early_stop': False, 'batch_size': 64, 'cache_dataset': True, 'num_workers': 0, 'pad_with_last_sample': True, 'input_window': 12, 'output_window': 12, 'len_closeness': 2, 'len_period': 1, 'len_trend': 2, 'interval_period': 1, 'interval_trend': 7, 'gpu': True, 'gpu_id': 0, 'train_loss': 'none', 'epoch': 0, 'weight_decay': 0, 'lr_epsilon': 1e-08, 'lr_beta1': 0.9, 'lr_beta2': 0.999, 'lr_alpha': 0.99, 'lr_momentum': 0, 'lr_scheduler': 'multisteplr', 'lr_decay_ratio': 0.1, 'steps': [5, 20, 40, 70], 'step_size': 10, 'lr_T_max': 30, 'lr_eta_min': 0, 'lr_patience': 10, 'lr_threshold': 0.0001, 'max_grad_norm': 1.0, 'patience': 50, 'log_level': 'INFO', 'log_every': 1, 'load_best_epoch': True, 'hyper_tune': False, 'metrics': ['MAE', 'MAPE', 'MSE', 'RMSE', 'masked_MAE', 'masked_MAPE', 'masked_MSE', 'masked_RMSE', 'R2', 'EVAR'], 'mode': 'single', 'save_modes': ['csv'], 'geo': {'including_types': ['Point'], 'Point': {}}, 'rel': {'including_types': ['geo'], 'geo': {'cost': 'num'}}, 'dyna': {'including_types': ['state'], 'state': {'entity_id': 'geo_id', 'traffic_flow': 'num', 'traffic_occupancy': 'num', 'traffic_speed': 'num'}}, 'data_col': ['traffic_flow', 'traffic_occupancy', 'traffic_speed'], 'weight_col': 'cost', 'data_files': ['PEMSD4'], 'geo_file': 'PEMSD4', 'rel_file': 'PEMSD4', 'output_dim': 3, 'time_intervals': 300, 'init_weight_inf_or_zero': 'zero', 'set_weight_link_or_dist': 'link', 'calculate_weight_adj': False, 'weight_adj_epsilon': 0.1}
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
config['device'] = torch.device("cuda" if torch.cuda.is_available() and config['gpu'] else "cpu")

logger = get_logger(config)
# 加载数据集
dataset = get_dataset(config)
# 转换数据，并划分数据集
train_data, valid_data, test_data = dataset.get_data()
print(len(train_data), len(train_data.dataset), train_data.dataset[0][0].shape, train_data.dataset[0][1].shape, train_data.batch_size)
print(len(valid_data), len(valid_data.dataset), valid_data.dataset[0][0].shape, valid_data.dataset[0][1].shape, valid_data.batch_size)
print(len(test_data), len(test_data.dataset), test_data.dataset[0][0].shape, test_data.dataset[0][1].shape, test_data.batch_size)

data_feature = dataset.get_data_feature()
print(data_feature['adj_mx'].shape)
print(data_feature['adj_mx'].sum())

model =ASTGCN(config, data_feature)

# 加载执行器
model_cache_file = './libcity/cache/model_cache/' + config['model'] + '_' + config['dataset'] + '.m'
executor = get_executor(config, model)

# 训练
executor.train(train_data, valid_data)
executor.save_model(model_cache_file)
executor.load_model(model_cache_file)
# 评估，评估结果将会放在 cache/evaluate_cache 下
executor.evaluate(test_data)

2021-12-06 20:29:11,290 - INFO - Log directory: ./libcity/log
2021-12-06 20:29:11,969 - INFO - Loaded file PEMSD4.geo, num_nodes=307
2021-12-06 20:29:11,972 - INFO - set_weight_link_or_dist: link
2021-12-06 20:29:11,973 - INFO - init_weight_inf_or_zero: zero
2021-12-06 20:29:11,977 - INFO - Loaded file PEMSD4.rel, shape=(307, 307)
2021-12-06 20:29:11,978 - INFO - Loading ./libcity/cache/dataset_cache/point_based_PEMSD4_2_1_2_1_7_12_0.6_0.2_standard_64_False_False_True.npz
2021-12-06 20:29:45,945 - INFO - train	x: (7769, 60, 307, 3), y: (7769, 12, 307, 3)
2021-12-06 20:29:45,947 - INFO - eval	x: (2590, 60, 307, 3), y: (2590, 12, 307, 3)
2021-12-06 20:29:45,948 - INFO - test	x: (2590, 60, 307, 3), y: (2590, 12, 307, 3)
2021-12-06 20:29:48,469 - INFO - StandardScaler mean: 91.72662831592477, std: 127.5481471486703
2021-12-06 20:29:48,471 - INFO - NoneScaler
122 7808 (60, 307, 3) (12, 307, 3) 64
41 2624 (60, 307, 3) (12, 307, 3) 64
41 2624 (60, 307, 3) (12, 307, 3) 64
(307, 307)
340.0
2021

2021-12-06 20:29:58,622 - INFO - hours_ASTGCN_submodule.BlockList.1.TAt.Ve	torch.Size([12, 12])	cuda:0	True
2021-12-06 20:29:58,622 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.W1	torch.Size([12])	cuda:0	True
2021-12-06 20:29:58,623 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.W2	torch.Size([64, 12])	cuda:0	True
2021-12-06 20:29:58,624 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.W3	torch.Size([64])	cuda:0	True
2021-12-06 20:29:58,624 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.bs	torch.Size([1, 307, 307])	cuda:0	True
2021-12-06 20:29:58,625 - INFO - hours_ASTGCN_submodule.BlockList.1.SAt.Vs	torch.Size([307, 307])	cuda:0	True
2021-12-06 20:29:58,625 - INFO - hours_ASTGCN_submodule.BlockList.1.cheb_conv_SAt.Theta.0	torch.Size([64, 64])	cuda:0	True
2021-12-06 20:29:58,626 - INFO - hours_ASTGCN_submodule.BlockList.1.cheb_conv_SAt.Theta.1	torch.Size([64, 64])	cuda:0	True
2021-12-06 20:29:58,627 - INFO - hours_ASTGCN_submodule.BlockList.1.cheb_conv_SAt.Theta.2	torch.Size([64, 6

2021-12-06 20:29:58,683 - INFO - weeks_ASTGCN_submodule.BlockList.0.residual_conv.weight	torch.Size([64, 3, 1, 1])	cuda:0	True
2021-12-06 20:29:58,684 - INFO - weeks_ASTGCN_submodule.BlockList.0.residual_conv.bias	torch.Size([64])	cuda:0	True
2021-12-06 20:29:58,684 - INFO - weeks_ASTGCN_submodule.BlockList.0.ln.weight	torch.Size([64])	cuda:0	True
2021-12-06 20:29:58,685 - INFO - weeks_ASTGCN_submodule.BlockList.0.ln.bias	torch.Size([64])	cuda:0	True
2021-12-06 20:29:58,686 - INFO - weeks_ASTGCN_submodule.BlockList.1.TAt.U1	torch.Size([307])	cuda:0	True
2021-12-06 20:29:58,688 - INFO - weeks_ASTGCN_submodule.BlockList.1.TAt.U2	torch.Size([64, 307])	cuda:0	True
2021-12-06 20:29:58,689 - INFO - weeks_ASTGCN_submodule.BlockList.1.TAt.U3	torch.Size([64])	cuda:0	True
2021-12-06 20:29:58,689 - INFO - weeks_ASTGCN_submodule.BlockList.1.TAt.be	torch.Size([1, 12, 12])	cuda:0	True
2021-12-06 20:29:58,690 - INFO - weeks_ASTGCN_submodule.BlockList.1.TAt.Ve	torch.Size([12, 12])	cuda:0	True
2021-12-

2021-12-06 20:47:31,995 - INFO - Saved model at 12
2021-12-06 20:47:31,996 - INFO - Val loss decrease from 520.2737 to 514.1774, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch12.tar
2021-12-06 20:48:42,581 - INFO - epoch complete!
2021-12-06 20:48:42,586 - INFO - evaluating now!
2021-12-06 20:48:53,726 - INFO - Epoch [13/100] train_loss: 458.0413, val_loss: 506.5170, lr: 0.000100, 81.73s
2021-12-06 20:48:53,814 - INFO - Saved model at 13
2021-12-06 20:48:53,815 - INFO - Val loss decrease from 514.1774 to 506.5170, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch13.tar
2021-12-06 20:50:04,417 - INFO - epoch complete!
2021-12-06 20:50:04,421 - INFO - evaluating now!
2021-12-06 20:50:14,871 - INFO - Epoch [14/100] train_loss: 446.5151, val_loss: 504.9375, lr: 0.000100, 81.05s
2021-12-06 20:50:14,952 - INFO - Saved model at 14
2021-12-06 20:50:14,953 - INFO - Val loss decrease from 506.5170 to 504.9375, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch14.tar
202

2021-12-06 21:15:43,641 - INFO - Saved model at 33
2021-12-06 21:15:43,643 - INFO - Val loss decrease from 441.9211 to 439.3347, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch33.tar
2021-12-06 21:16:53,766 - INFO - epoch complete!
2021-12-06 21:16:53,769 - INFO - evaluating now!
2021-12-06 21:17:04,657 - INFO - Epoch [34/100] train_loss: 350.4509, val_loss: 435.9985, lr: 0.000100, 81.01s
2021-12-06 21:17:04,742 - INFO - Saved model at 34
2021-12-06 21:17:04,744 - INFO - Val loss decrease from 439.3347 to 435.9985, saving to ./libcity/cache/model_cache/ASTGCN_PEMSD4_epoch34.tar
2021-12-06 21:18:14,927 - INFO - epoch complete!
2021-12-06 21:18:14,932 - INFO - evaluating now!
2021-12-06 21:18:25,809 - INFO - Epoch [35/100] train_loss: 348.2296, val_loss: 436.3118, lr: 0.000100, 81.06s
2021-12-06 21:19:33,296 - INFO - epoch complete!
2021-12-06 21:19:33,300 - INFO - evaluating now!
2021-12-06 21:19:44,226 - INFO - Epoch [36/100] train_loss: 345.9705, val_loss: 439.0736, lr: 0.00

2021-12-06 21:53:24,169 - INFO - evaluating now!
2021-12-06 21:53:35,566 - INFO - Epoch [61/100] train_loss: 312.7376, val_loss: 422.2078, lr: 0.000100, 82.93s
2021-12-06 21:54:46,091 - INFO - epoch complete!
2021-12-06 21:54:46,098 - INFO - evaluating now!
2021-12-06 21:54:56,926 - INFO - Epoch [62/100] train_loss: 311.8857, val_loss: 420.8501, lr: 0.000100, 81.36s
2021-12-06 21:56:08,822 - INFO - epoch complete!
2021-12-06 21:56:08,827 - INFO - evaluating now!
2021-12-06 21:56:18,870 - INFO - Epoch [63/100] train_loss: 311.2052, val_loss: 423.8804, lr: 0.000100, 81.94s
2021-12-06 21:57:31,892 - INFO - epoch complete!
2021-12-06 21:57:31,895 - INFO - evaluating now!
2021-12-06 21:57:42,944 - INFO - Epoch [64/100] train_loss: 310.5776, val_loss: 421.3298, lr: 0.000100, 84.07s
2021-12-06 21:58:53,745 - INFO - epoch complete!
2021-12-06 21:58:53,750 - INFO - evaluating now!
2021-12-06 21:59:04,222 - INFO - Epoch [65/100] train_loss: 309.3364, val_loss: 421.5234, lr: 0.000100, 81.28s
2021

2021-12-06 22:45:09,659 - INFO - Loaded model at 55
2021-12-06 22:45:09,661 - INFO - Saved model at ./libcity/cache/model_cache/ASTGCN_PEMSD4.m
2021-12-06 22:45:09,739 - INFO - Loaded model at ./libcity/cache/model_cache/ASTGCN_PEMSD4.m
2021-12-06 22:45:09,797 - INFO - Start evaluating ...
2021-12-06 22:45:44,213 - INFO - Note that you select the single mode to evaluate!
2021-12-06 22:45:44,229 - INFO - Evaluate result is saved at ./libcity/cache/evaluate_cache/2021_12_06_22_45_44_ASTGCN_PEMSD4.csv
2021-12-06 22:45:44,242 - INFO - 
         MAE  MAPE         MSE       RMSE  masked_MAE  masked_MAPE  \
1   7.166030   inf  316.750610  17.797489    7.135028     7.481318   
2   7.467807   inf  341.731140  18.485971    7.433187     7.500246   
3   7.653435   inf  359.094147  18.949780    7.613025     7.908596   
4   7.969870   inf  377.259766  19.423176    7.925467    12.665545   
5   7.994197   inf  388.898590  19.720512    7.947711     7.292997   
6   8.171041   inf  404.382202  20.109257 

Unnamed: 0,MAE,MAPE,MSE,RMSE,masked_MAE,masked_MAPE,masked_MSE,masked_RMSE,R2,EVAR
1,7.16603,inf,316.75061,17.797489,7.135028,7.481318,308.166687,17.554678,0.981402,0.981407
2,7.467807,inf,341.73114,18.485971,7.433187,7.500246,331.33905,18.202721,0.979934,0.979942
3,7.653435,inf,359.094147,18.94978,7.613025,7.908596,346.922028,18.625843,0.978918,0.978926
4,7.96987,inf,377.259766,19.423176,7.925467,12.665545,363.563293,19.067335,0.977849,0.977871
5,7.994197,inf,388.89859,19.720512,7.947711,7.292997,373.824371,19.334538,0.977165,0.977185
6,8.171041,inf,404.382202,20.109257,8.120287,8.449827,387.550079,19.686291,0.97625,0.976282
7,8.327589,inf,416.086823,20.398207,8.274202,8.739942,397.86499,19.946554,0.975561,0.975588
8,8.406165,inf,426.58017,20.653818,8.348886,9.109753,406.909607,20.172001,0.974944,0.974977
9,8.527379,inf,437.099945,20.906935,8.468082,9.247981,416.218903,20.401443,0.974328,0.974363
10,8.650572,inf,450.421387,21.223133,8.588638,8.319716,427.740753,20.681894,0.973543,0.973578
