In [1]:
import Embedding.GNN_Based_Layers as GNN
import ST_backbone.RNN_Based_Models as RNN
import sys
sys.path.append('/home/zwt/Bigscity-LibCity')
from logging import getLogger
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libcity.model import loss


In [2]:
import torch.nn as nn
class AGCRN(AbstractTrafficStateModel):
    def __init__(self, config, data_feature):
        self.num_nodes = data_feature.get('num_nodes', 1)
        self.feature_dim = data_feature.get('feature_dim', 1)
        config['num_nodes'] = self.num_nodes
        config['feature_dim'] = self.feature_dim

        super().__init__(config, data_feature)
        self.input_window = config.get('input_window', 1)
        self.output_window = config.get('output_window', 1)
        self.output_dim = self.data_feature.get('output_dim', 1)
        self.hidden_dim = config.get('rnn_units', 64)
        self.embed_dim = config.get('embed_dim', 10)

#         self.gcn=GNN.AVWGCN(self.feature_dim,self.hidden_dim,config.get('cheb_order', 2),self.num_nodes,self.embed_dim)
#         self.gcn_encode=GNN.TimedistributedGCN(self.gcn)
        self.trnn=RNN.TemporalGRU(config,0,0)
        
        self.end_conv = nn.Conv2d(1, self.output_window * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)

        self.device = config.get('device', torch.device('cpu'))
        self._logger = getLogger()
        self._scaler = self.data_feature.get('scaler')
        self._init_parameters()
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)
    
    def forward(self, batch):
        # source: B, T_1, N, D
        # target: B, T_2, N, D
        source = batch['X']
        
        batch_size=source.shape[0]
                
        output = self.trnn(source)  # B, T, N, hidden
        output = output[:, -1:, :, :]                                       # B, 1, N, hidden

        # CNN based predictor
        output = self.end_conv(output)                           # B, T*C, N, 1
        output = output.squeeze(-1).reshape(-1, self.output_window, self.output_dim, self.num_nodes)
        output = output.permute(0, 1, 3, 2)                      # B, T, N, C
        return output
    
    def calculate_loss(self, batch):
        y_true = batch['y']
        y_predicted = self.predict(batch)
        y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
        y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
        return loss.masked_mae_torch(y_predicted, y_true, 0)

    def predict(self, batch):
        return self.forward(batch)
    

In [3]:
#test AGCRNCELL
from libcity.data import get_dataset
from libcity.utils import get_executor
from libcity.utils import get_model
from libcity.utils import get_logger


config = {
    'log_level': 'INFO',

    'dataset': 'METR_LA',
    'model': 'AGCRN',
    'evaluator': 'TrafficStateEvaluator',
    'executor': 'TrafficStateExecutor',
    'dataset_class': 'TrafficStatePointDataset',
    'metrics': ['MAE', 'MAPE', 'MSE', 'RMSE', 'masked_MAE', 'masked_MSE', 'masked_RMSE', 'masked_MAPE', 'R2', 'EVAR'],
    'weight_col': 'cost',
    'data_col': ['traffic_speed'],
    'calculate_weight': True,
    'adj_epsilon': 0.1,
    'add_time_in_day': False,
    'add_day_in_week': False,
    'pad_with_last_sample': False,
    'scaler': "standard",

    'num_workers': 1,
    'cache_dataset': True,
    'gpu': True,
    'gpu_id': '1',
    'batch_size': 64,

    'input_window': 12,
    'output_window': 12,
    'tod': False,
    'column_wise': False,
    'default_graph': True,
    'embed_dim': 10,
    'rnn_units': 64,
    'num_layers': 2,
    'cheb_order': 2,

    'train_rate': 0.7,
    'eval_rate': 0.1,
    'learning_rate': 0.003,
    'learner': 'adam',
    'lr_decay': False,
    'lr_decay_rate': 0.3,
    'steps': [5, 20, 40, 70],
    'lr_scheduler': 'multisteplr',
    'epoch': 0,
    'max_epoch': 100,
    'clip_grad_norm': False,
    'use_early_stop': False,
    'max_grad_norm': 5,
    'patience': 15,
}

import os
os.environ["CUDA_VISIBLE_DEVICES"] = config['gpu_id']
import torch
config['device'] = torch.device("cuda" if torch.cuda.is_available() and config['gpu'] else "cpu")
config['task']='traffic_state_pred'

logger = get_logger(config)
# 加载数据集
dataset = get_dataset(config)
# 转换数据，并划分数据集
train_data, valid_data, test_data = dataset.get_data()
print(len(train_data), len(train_data.dataset), train_data.dataset[0][0].shape, train_data.dataset[0][1].shape, train_data.batch_size)
print(len(valid_data), len(valid_data.dataset), valid_data.dataset[0][0].shape, valid_data.dataset[0][1].shape, valid_data.batch_size)
print(len(test_data), len(test_data.dataset), test_data.dataset[0][0].shape, test_data.dataset[0][1].shape, test_data.batch_size)

data_feature = dataset.get_data_feature()
print(data_feature['adj_mx'].shape)
print(data_feature['adj_mx'].sum())

model = AGCRN(config,data_feature)

# 加载执行器
model_cache_file = './libcity/cache/model_cache/' + config['model'] + '_' + config['dataset'] + '.m'
executor = get_executor(config, model)
# 训练
executor.train(train_data, valid_data)
executor.save_model(model_cache_file)
executor.load_model(model_cache_file)
# 评估，评估结果将会放在 cache/evaluate_cache 下
executor.evaluate(test_data)


2021-12-03 15:17:06,062 - INFO - Log directory: ./libcity/log
2021-12-03 15:17:06,806 - INFO - Loaded file METR_LA.geo, num_nodes=207
2021-12-03 15:17:06,815 - INFO - set_weight_link_or_dist: dist
2021-12-03 15:17:06,816 - INFO - init_weight_inf_or_zero: inf
2021-12-03 15:17:06,848 - INFO - Loaded file METR_LA.rel, shape=(207, 207)
2021-12-03 15:17:06,849 - INFO - Loading ./libcity/cache/dataset_cache/point_based_METR_LA_12_12_0.7_0.1_standard_64_False_False_False_False.npz
2021-12-03 15:17:09,960 - INFO - train	x: (23974, 12, 207, 1), y: (23974, 12, 207, 1)
2021-12-03 15:17:09,963 - INFO - eval	x: (3425, 12, 207, 1), y: (3425, 12, 207, 1)
2021-12-03 15:17:09,963 - INFO - test	x: (6850, 12, 207, 1), y: (6850, 12, 207, 1)
2021-12-03 15:17:10,312 - INFO - StandardScaler mean: 54.40592829587626, std: 19.493739270573098
2021-12-03 15:17:10,313 - INFO - NoneScaler
375 23974 (12, 207, 1) (12, 207, 1) 64
54 3425 (12, 207, 1) (12, 207, 1) 64
108 6850 (12, 207, 1) (12, 207, 1) 64
(207, 207)
inf

2021-12-03 15:24:28,760 - INFO - epoch complete!
2021-12-03 15:24:28,763 - INFO - evaluating now!
2021-12-03 15:24:30,135 - INFO - Epoch [18/100] train_loss: 2.7341, val_loss: 2.8841, lr: 0.003000, 22.10s
2021-12-03 15:24:30,154 - INFO - Saved model at 18
2021-12-03 15:24:30,155 - INFO - Val loss decrease from 2.8940 to 2.8841, saving to ./libcity/cache/model_cache/AGCRN_METR_LA_epoch18.tar
2021-12-03 15:24:53,551 - INFO - epoch complete!
2021-12-03 15:24:53,555 - INFO - evaluating now!
2021-12-03 15:24:54,959 - INFO - Epoch [19/100] train_loss: 2.7212, val_loss: 2.9116, lr: 0.003000, 24.80s
2021-12-03 15:25:16,166 - INFO - epoch complete!
2021-12-03 15:25:16,169 - INFO - evaluating now!
2021-12-03 15:25:17,493 - INFO - Epoch [20/100] train_loss: 2.7100, val_loss: 2.9225, lr: 0.003000, 22.53s
2021-12-03 15:25:41,683 - INFO - epoch complete!
2021-12-03 15:25:41,686 - INFO - evaluating now!
2021-12-03 15:25:43,253 - INFO - Epoch [21/100] train_loss: 2.6982, val_loss: 2.9247, lr: 0.003000

2021-12-03 15:39:25,247 - INFO - evaluating now!
2021-12-03 15:39:26,580 - INFO - Epoch [57/100] train_loss: 2.5026, val_loss: 2.9717, lr: 0.003000, 22.67s
2021-12-03 15:39:44,855 - INFO - epoch complete!
2021-12-03 15:39:44,858 - INFO - evaluating now!
2021-12-03 15:39:46,268 - INFO - Epoch [58/100] train_loss: 2.4973, val_loss: 2.9696, lr: 0.003000, 19.69s
2021-12-03 15:40:06,619 - INFO - epoch complete!
2021-12-03 15:40:06,621 - INFO - evaluating now!
2021-12-03 15:40:07,988 - INFO - Epoch [59/100] train_loss: 2.4973, val_loss: 2.9687, lr: 0.003000, 21.72s
2021-12-03 15:40:28,187 - INFO - epoch complete!
2021-12-03 15:40:28,190 - INFO - evaluating now!
2021-12-03 15:40:29,558 - INFO - Epoch [60/100] train_loss: 2.4936, val_loss: 2.9748, lr: 0.003000, 21.57s
2021-12-03 15:40:54,040 - INFO - epoch complete!
2021-12-03 15:40:54,043 - INFO - evaluating now!
2021-12-03 15:40:55,388 - INFO - Epoch [61/100] train_loss: 2.4914, val_loss: 3.0012, lr: 0.003000, 25.83s
2021-12-03 15:41:16,073 

2021-12-03 15:54:20,974 - INFO - evaluating now!
2021-12-03 15:54:22,361 - INFO - Epoch [97/100] train_loss: 2.4270, val_loss: 3.0077, lr: 0.003000, 22.11s
2021-12-03 15:54:42,798 - INFO - epoch complete!
2021-12-03 15:54:42,801 - INFO - evaluating now!
2021-12-03 15:54:44,187 - INFO - Epoch [98/100] train_loss: 2.4265, val_loss: 2.9997, lr: 0.003000, 21.82s
2021-12-03 15:55:04,196 - INFO - epoch complete!
2021-12-03 15:55:04,198 - INFO - evaluating now!
2021-12-03 15:55:05,531 - INFO - Epoch [99/100] train_loss: 2.4255, val_loss: 3.0201, lr: 0.003000, 21.34s
2021-12-03 15:55:05,534 - INFO - Trained totally 100 epochs, average train time is 21.300s, average eval time is 1.387s
2021-12-03 15:55:05,553 - INFO - Loaded model at 18
2021-12-03 15:55:05,555 - INFO - Saved model at ./libcity/cache/model_cache/AGCRN_METR_LA.m
2021-12-03 15:55:05,570 - INFO - Loaded model at ./libcity/cache/model_cache/AGCRN_METR_LA.m
2021-12-03 15:55:05,577 - INFO - Start evaluating ...
2021-12-03 15:55:15,363

Unnamed: 0,MAE,MAPE,MSE,RMSE,masked_MAE,masked_MSE,masked_RMSE,masked_MAPE,R2,EVAR
1,9.186357,inf,452.979889,21.283323,2.371144,17.55316,4.189649,0.059056,0.12644,0.232163
2,9.507854,inf,468.304962,21.640354,2.62991,24.420654,4.941726,0.067678,0.096909,0.207817
3,9.707434,inf,475.000458,21.794506,2.814184,29.374212,5.419798,0.074508,0.08402,0.197607
4,9.862527,inf,480.417267,21.918423,2.961941,33.820026,5.815499,0.08005,0.073605,0.189818
5,9.980296,inf,485.030396,22.023405,3.072748,37.672756,6.137814,0.084825,0.064737,0.182019
6,10.071779,inf,488.467834,22.101309,3.165912,40.857956,6.392023,0.088015,0.058129,0.17569
7,10.1478,inf,491.265381,22.164507,3.244794,43.598019,6.60288,0.090828,0.052762,0.172713
8,10.23679,inf,495.110748,22.251083,3.324201,45.979401,6.780811,0.093793,0.045374,0.167092
9,10.315548,inf,498.925995,22.336651,3.39562,48.645607,6.97464,0.096167,0.038044,0.160617
10,10.357146,inf,498.763031,22.333004,3.460344,51.144669,7.15155,0.09843,0.03839,0.161081


In [None]:
class TGCN(AbstractTrafficStateModel):
    def __init__(self, config, data_feature):
        self.num_nodes = data_feature.get('num_nodes', 1)
        self.feature_dim = data_feature.get('feature_dim', 1)
        config['num_nodes'] = self.num_nodes
        config['feature_dim'] = self.feature_dim

        super().__init__(config, data_feature)
        self.input_window = config.get('input_window', 1)
        self.output_window = config.get('output_window', 1)
        self.output_dim = self.data_feature.get('output_dim', 1)
        self.hidden_dim = config.get('rnn_units', 64)
        self.embed_dim = config.get('embed_dim', 10)

#         self.gcn=GNN.AVWGCN(self.feature_dim,self.hidden_dim,config.get('cheb_order', 2),self.num_nodes,self.embed_dim)
#         self.gcn_encode=GNN.TimedistributedGCN(self.gcn)
        self.trnn=RNN.TemporalGRU(config,0,0)
        
        self.end_conv = nn.Conv2d(1, self.output_window * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)

        self.device = config.get('device', torch.device('cpu'))
        self._logger = getLogger()
        self._scaler = self.data_feature.get('scaler')
        self._init_parameters()
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)
    
    def forward(self, batch):
        # source: B, T_1, N, D
        # target: B, T_2, N, D
        source = batch['X']
        
        batch_size=source.shape[0]
                
        output = self.trnn(source)  # B, T, N, hidden
        output = output[:, -1:, :, :]                                       # B, 1, N, hidden

        # CNN based predictor
        output = self.end_conv(output)                           # B, T*C, N, 1
        output = output.squeeze(-1).reshape(-1, self.output_window, self.output_dim, self.num_nodes)
        output = output.permute(0, 1, 3, 2)                      # B, T, N, C
        return output
    
    def calculate_loss(self, batch):
        y_true = batch['y']
        y_predicted = self.predict(batch)
        y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
        y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
        return loss.masked_mae_torch(y_predicted, y_true, 0)

    def predict(self, batch):
        return self.forward(batch)
    

In [7]:
#test TGCN cell
model = TGCN(config,data_feature)

# 加载执行器
model_cache_file = './libcity/cache/model_cache/' + config['model'] + '_' + config['dataset'] + '.m'
executor = get_executor(config, model)
# 训练
executor.train(train_data, valid_data)
executor.save_model(model_cache_file)
executor.load_model(model_cache_file)
# 评估，评估结果将会放在 cache/evaluate_cache 下
executor.evaluate(test_data)


2021-12-02 17:20:12,017 - INFO - Loaded file METR_LA.geo, num_nodes=207
2021-12-02 17:20:12,017 - INFO - Loaded file METR_LA.geo, num_nodes=207
2021-12-02 17:20:12,017 - INFO - Loaded file METR_LA.geo, num_nodes=207
2021-12-02 17:20:12,027 - INFO - set_weight_link_or_dist: dist
2021-12-02 17:20:12,027 - INFO - set_weight_link_or_dist: dist
2021-12-02 17:20:12,027 - INFO - set_weight_link_or_dist: dist
2021-12-02 17:20:12,029 - INFO - init_weight_inf_or_zero: inf
2021-12-02 17:20:12,029 - INFO - init_weight_inf_or_zero: inf
2021-12-02 17:20:12,029 - INFO - init_weight_inf_or_zero: inf
2021-12-02 17:20:12,067 - INFO - Loaded file METR_LA.rel, shape=(207, 207)
2021-12-02 17:20:12,067 - INFO - Loaded file METR_LA.rel, shape=(207, 207)
2021-12-02 17:20:12,067 - INFO - Loaded file METR_LA.rel, shape=(207, 207)
2021-12-02 17:20:12,070 - INFO - Loading ./libcity/cache/dataset_cache/point_based_METR_LA_12_12_0.7_0.1_standard_64_False_False_False_False.npz
2021-12-02 17:20:12,070 - INFO - Loadin

2021-12-02 17:20:58,797 - INFO - evaluating now!
2021-12-02 17:21:01,324 - INFO - Epoch [0/100] train_loss: 4.3200, val_loss: 3.4323, lr: 0.003000, 44.69s
2021-12-02 17:21:01,324 - INFO - Epoch [0/100] train_loss: 4.3200, val_loss: 3.4323, lr: 0.003000, 44.69s
2021-12-02 17:21:01,324 - INFO - Epoch [0/100] train_loss: 4.3200, val_loss: 3.4323, lr: 0.003000, 44.69s
2021-12-02 17:21:01,372 - INFO - Saved model at 0
2021-12-02 17:21:01,372 - INFO - Saved model at 0
2021-12-02 17:21:01,372 - INFO - Saved model at 0
2021-12-02 17:21:01,376 - INFO - Val loss decrease from inf to 3.4323, saving to ./libcity/cache/model_cache/AGCRN_METR_LA_epoch0.tar
2021-12-02 17:21:01,376 - INFO - Val loss decrease from inf to 3.4323, saving to ./libcity/cache/model_cache/AGCRN_METR_LA_epoch0.tar
2021-12-02 17:21:01,376 - INFO - Val loss decrease from inf to 3.4323, saving to ./libcity/cache/model_cache/AGCRN_METR_LA_epoch0.tar
2021-12-02 17:21:40,438 - INFO - epoch complete!
2021-12-02 17:21:40,438 - INFO -

2021-12-02 17:27:29,978 - INFO - evaluating now!
2021-12-02 17:27:29,978 - INFO - evaluating now!
2021-12-02 17:27:29,978 - INFO - evaluating now!
2021-12-02 17:27:32,558 - INFO - Epoch [9/100] train_loss: 2.8764, val_loss: 2.9523, lr: 0.003000, 43.57s
2021-12-02 17:27:32,558 - INFO - Epoch [9/100] train_loss: 2.8764, val_loss: 2.9523, lr: 0.003000, 43.57s
2021-12-02 17:27:32,558 - INFO - Epoch [9/100] train_loss: 2.8764, val_loss: 2.9523, lr: 0.003000, 43.57s
2021-12-02 17:27:32,601 - INFO - Saved model at 9
2021-12-02 17:27:32,601 - INFO - Saved model at 9
2021-12-02 17:27:32,601 - INFO - Saved model at 9
2021-12-02 17:27:32,603 - INFO - Val loss decrease from 2.9536 to 2.9523, saving to ./libcity/cache/model_cache/AGCRN_METR_LA_epoch9.tar
2021-12-02 17:27:32,603 - INFO - Val loss decrease from 2.9536 to 2.9523, saving to ./libcity/cache/model_cache/AGCRN_METR_LA_epoch9.tar
2021-12-02 17:27:32,603 - INFO - Val loss decrease from 2.9536 to 2.9523, saving to ./libcity/cache/model_cache

2021-12-02 17:35:35,926 - INFO - Epoch [20/100] train_loss: 2.6137, val_loss: 2.9889, lr: 0.003000, 43.63s
2021-12-02 17:36:17,751 - INFO - epoch complete!
2021-12-02 17:36:17,751 - INFO - epoch complete!
2021-12-02 17:36:17,751 - INFO - epoch complete!
2021-12-02 17:36:17,756 - INFO - evaluating now!
2021-12-02 17:36:17,756 - INFO - evaluating now!
2021-12-02 17:36:17,756 - INFO - evaluating now!
2021-12-02 17:36:20,371 - INFO - Epoch [21/100] train_loss: 2.5963, val_loss: 3.0145, lr: 0.003000, 44.44s
2021-12-02 17:36:20,371 - INFO - Epoch [21/100] train_loss: 2.5963, val_loss: 3.0145, lr: 0.003000, 44.44s
2021-12-02 17:36:20,371 - INFO - Epoch [21/100] train_loss: 2.5963, val_loss: 3.0145, lr: 0.003000, 44.44s
2021-12-02 17:37:01,761 - INFO - epoch complete!
2021-12-02 17:37:01,761 - INFO - epoch complete!
2021-12-02 17:37:01,761 - INFO - epoch complete!
2021-12-02 17:37:01,766 - INFO - evaluating now!
2021-12-02 17:37:01,766 - INFO - evaluating now!
2021-12-02 17:37:01,766 - INFO - 

2021-12-02 17:39:28,820 - INFO - Evaluate result is {"MAE@1": 8.890697479248047, "MAPE@1": Infinity, "MSE@1": 420.39337158203125, "RMSE@1": 20.503496170043945, "masked_MAE@1": 2.4045143127441406, "masked_MSE@1": 18.11513328552246, "masked_RMSE@1": 4.256187438964844, "masked_MAPE@1": 0.061000507324934006, "R2@1": 0.18928243592607863, "EVAR@1": 0.2842714190483093, "MAE@2": 9.189087867736816, "MAPE@2": Infinity, "MSE@2": 432.1625671386719, "RMSE@2": 20.78852081298828, "masked_MAE@2": 2.666740655899048, "masked_MSE@2": 25.265684127807617, "masked_RMSE@2": 5.026498317718506, "masked_MAPE@2": 0.06985758244991302, "R2@2": 0.166607016418726, "EVAR@2": 0.2657508850097656, "MAE@3": 9.442975997924805, "MAPE@3": Infinity, "MSE@3": 444.35308837890625, "RMSE@3": 21.079683303833008, "masked_MAE@3": 2.860224962234497, "masked_MSE@3": 30.716676712036133, "masked_RMSE@3": 5.542262554168701, "masked_MAPE@3": 0.07655974477529526, "R2@3": 0.14311973728292582, "EVAR@3": 0.24535918235778809, "MAE@4": 9.68348

2021-12-02 17:39:28,830 - INFO - Evaluate result is saved at ./libcity/cache/evaluate_cache/2021_12_02_17_39_28_AGCRN_METR_LA.csv
2021-12-02 17:39:28,837 - INFO - 
          MAE  MAPE         MSE       RMSE  masked_MAE  masked_MSE  \
1    8.890697   inf  420.393372  20.503496    2.404514   18.115133   
2    9.189088   inf  432.162567  20.788521    2.666741   25.265684   
3    9.442976   inf  444.353088  21.079683    2.860225   30.716677   
4    9.683482   inf  458.266907  21.407169    3.016788   35.383114   
5    9.854841   inf  469.081482  21.658289    3.130225   39.579037   
6    9.988622   inf  477.489777  21.851540    3.226783   42.979313   
7   10.088787   inf  481.891052  21.952017    3.314008   45.636841   
8   10.188615   inf  487.819641  22.086639    3.387670   48.291916   
9   10.257476   inf  490.938507  22.157133    3.451633   50.660896   
10  10.316076   inf  493.472778  22.214247    3.508596   52.876759   
11  10.376493   inf  496.654907  22.285755    3.564339   54.999664

Unnamed: 0,MAE,MAPE,MSE,RMSE,masked_MAE,masked_MSE,masked_RMSE,masked_MAPE,R2,EVAR
1,8.890697,inf,420.393372,20.503496,2.404514,18.115133,4.256187,0.061001,0.189282,0.284271
2,9.189088,inf,432.162567,20.788521,2.666741,25.265684,5.026498,0.069858,0.166607,0.265751
3,9.442976,inf,444.353088,21.079683,2.860225,30.716677,5.542263,0.07656,0.14312,0.245359
4,9.683482,inf,458.266907,21.407169,3.016788,35.383114,5.94837,0.082559,0.116318,0.221254
5,9.854841,inf,469.081482,21.658289,3.130225,39.579037,6.291187,0.08707,0.09549,0.205704
6,9.988622,inf,477.489777,21.85154,3.226783,42.979313,6.555861,0.089972,0.079296,0.190132
7,10.088787,inf,481.891052,21.952017,3.314008,45.636841,6.755505,0.093096,0.070837,0.181941
8,10.188615,inf,487.819641,22.086639,3.38767,48.291916,6.949238,0.095938,0.059432,0.174298
9,10.257476,inf,490.938507,22.157133,3.451633,50.660896,7.117647,0.098257,0.053445,0.169262
10,10.316076,inf,493.472778,22.214247,3.508596,52.876759,7.271641,0.100323,0.04859,0.164712


In [None]:
class TGCLSTM(AbstractTrafficStateModel):
    def __init__(self, config, data_feature):
        self.num_nodes = data_feature.get('num_nodes', 1)
        self.feature_dim = data_feature.get('feature_dim', 1)
        config['num_nodes'] = self.num_nodes
        config['feature_dim'] = self.feature_dim

        super().__init__(config, data_feature)
        self.input_window = config.get('input_window', 1)
        self.output_window = config.get('output_window', 1)
        self.output_dim = self.data_feature.get('output_dim', 1)
        self.hidden_dim = config.get('rnn_units', 64)
        self.embed_dim = config.get('embed_dim', 10)

#         self.gcn=GNN.AVWGCN(self.feature_dim,self.hidden_dim,config.get('cheb_order', 2),self.num_nodes,self.embed_dim)
#         self.gcn_encode=GNN.TimedistributedGCN(self.gcn)
        self.trnn=RNN.TemporalGRU(config,0,0)
        
        self.end_conv = nn.Conv2d(1, self.output_window * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)

        self.device = config.get('device', torch.device('cpu'))
        self._logger = getLogger()
        self._scaler = self.data_feature.get('scaler')
        self._init_parameters()
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)
    
    def forward(self, batch):
        # source: B, T_1, N, D
        # target: B, T_2, N, D
        source = batch['X']
        
        batch_size=source.shape[0]
                
        output = self.trnn(source)  # B, T, N, hidden
        output = output[:, -1:, :, :]                                       # B, 1, N, hidden

        # CNN based predictor
        output = self.end_conv(output)                           # B, T*C, N, 1
        output = output.squeeze(-1).reshape(-1, self.output_window, self.output_dim, self.num_nodes)
        output = output.permute(0, 1, 3, 2)                      # B, T, N, C
        return output
    
    def calculate_loss(self, batch):
        y_true = batch['y']
        y_predicted = self.predict(batch)
        y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
        y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
        return loss.masked_mae_torch(y_predicted, y_true, 0)

    def predict(self, batch):
        return self.forward(batch)

In [None]:
#test TGCLSTM 
model = TGCLSTM(config,data_feature)

# 加载执行器
model_cache_file = './libcity/cache/model_cache/' + config['model'] + '_' + config['dataset'] + '.m'
executor = get_executor(config, model)
# 训练
executor.train(train_data, valid_data)
executor.save_model(model_cache_file)
executor.load_model(model_cache_file)
# 评估，评估结果将会放在 cache/evaluate_cache 下
executor.evaluate(test_data)