In [1]:
import time

import numpy as np
import pandas as pd
from pandas import DataFrame
import math

from matplotlib.pyplot import savefig

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.optim import AdamW
import random

import matplotlib.pyplot as plt
import scipy.sparse as sp

In [2]:
training_flow = pd.read_csv('data/training_flow.csv')
training_flow['Timestamp'] = pd.to_datetime(training_flow['Timestamp'])
training_flow.set_index('Timestamp', inplace=True)

training_speed = pd.read_csv('data/training_speed.csv')
training_speed['Timestamp'] = pd.to_datetime(training_speed['Timestamp'])
training_speed.set_index('Timestamp', inplace=True)

test_flow = pd.read_csv('data/test_flow.csv')
test_flow['Timestamp'] = pd.to_datetime(test_flow['Timestamp'])
test_flow.set_index('Timestamp', inplace=True)

test_speed = pd.read_csv('data/test_speed.csv')
test_speed['Timestamp'] = pd.to_datetime(test_speed['Timestamp'])
test_speed.set_index('Timestamp', inplace=True)

num_link = training_flow.shape[1]

In [3]:
time_leg = 24
prediction_horizon = 6

In [4]:
continuous_training_speed = []
continuous_training_time = []

training_times = training_speed.index

for i in range(len(training_times) - time_leg - prediction_horizon + 1):
    
    # Check if the 30 time points starting from the current line are consecutive
    if all(training_times[i + j] == training_times[i] + pd.Timedelta(minutes=5 * j) for j in range(1,time_leg + prediction_horizon)):
        
        # If consecutive, extract the data from these 30 lines
        data_slice_speed= training_speed.iloc[i:i + time_leg + prediction_horizon].values
        
        # Add the data to the list
        continuous_training_speed.append(data_slice_speed)
        
        continuous_training_time.append(training_times[i])
        
speed_data_training = np.array(continuous_training_speed)

speed_data_training.shape

(67752, 30, 150)

In [7]:
continuous_test_speed = []
continuous_test_time = []

test_times = test_speed.index

for i in range(len(test_times) - time_leg - prediction_horizon + 1):
    
    # Check if the 30 time points starting from the current line are consecutive
    if all(test_times[i + j] == test_times[i] + pd.Timedelta(minutes=5 * j) for j in range(1,time_leg + prediction_horizon)):
        
        # If consecutive, extract the data from these 30 lines
        data_slice_speed= test_speed.iloc[i:i + time_leg + prediction_horizon].values
        
        # Add the data to the list
        continuous_test_speed.append(data_slice_speed)
        
        continuous_test_time.append(test_times[i])
        
speed_data_test = np.array(continuous_test_speed)

speed_data_test.shape

(4891, 30, 150)

In [8]:
training_set_v = torch.from_numpy(speed_data_training).to(torch.float32)
test_set_v = torch.from_numpy(speed_data_test).to(torch.float32)

print(training_set_v.shape)
print(test_set_v.shape)

torch.Size([67752, 30, 150])
torch.Size([4891, 30, 150])


In [9]:
v_f = pd.read_csv('parameter/v_f.csv').values
k_c = pd.read_csv('parameter/k_c.csv').values
mm = pd.read_csv('parameter/mm.csv').values

v_f = torch.from_numpy(v_f).to(torch.float32).reshape(-1)
k_c = torch.from_numpy(k_c).to(torch.float32).reshape(-1)
mm = torch.from_numpy(mm).to(torch.float32).reshape(-1)

v_f.shape

torch.Size([150])

==============================   Prediction model   =================================

In [10]:
adj = pd.read_csv('data/adj_matrix/proximity_matrix.csv').values
adj_correlation_v = pd.read_csv('data/adj_matrix/correlation_matrix_speed.csv').values
adj_knn_v = pd.read_csv('data/adj_matrix/knn_matrix_speed.csv').values

adj_correlation_v = torch.from_numpy(adj_correlation_v).to(torch.float32)

In [12]:
# Adjacency matrix normalization
def adj_def(adj):
   
    adj = adj + sp.eye(adj.shape[0])
    rowsum = np.array(adj.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0
    r_mat_inv = sp.diags(r_inv)
    adj = r_mat_inv.dot(adj).A
    
    adj = torch.from_numpy(adj).to(torch.float32)

    return adj

In [14]:
adj_n = adj_def(adj)
print(adj_n.shape)
adj_knn_v = adj_def(adj_knn_v)

torch.Size([150, 150])


In [15]:
max_v = training_set_v.max(axis=0)[0].max(axis=0)[0]
min_v = training_set_v.min(axis=0)[0].min(axis=0)[0]

max_v.shape

torch.Size([150])

In [16]:
# Define the normalization function
def norm(x,max_x,min_x):
    x0 = 2*(x-min_x)/(max_x-min_x)-1
    return x0

In [17]:
# Define the inverse normalization function
def r_norm(x0,max_x,min_x):
    x = (x0+1) * (max_x-min_x)/2 + min_x
    return x

In [19]:
# Normalize flow and speed data
training_set_v0 = norm(training_set_v,max_v,min_v)
test_set_v0 = norm(test_set_v,max_v,min_v)

In [35]:
prediction_horizon = 1
# Training set input and output
training_input_v = training_set_v0[:,:time_leg,:]
training_output_v = training_set_v0[:,time_leg:time_leg+prediction_horizon,:]

# Test set input and output
test_input_v = test_set_v0[:,:time_leg,:]
test_output_v = test_set_v0[:,time_leg:time_leg+prediction_horizon,:]

# Ground truth
training_true_v = training_set_v[:,time_leg:time_leg+prediction_horizon,:]
test_true_v = test_set_v[:,time_leg:time_leg+pred_length,:]

In [36]:
num_link = training_set_v0.shape[2]
num_link

150

In [37]:
# Define the GraphConvolution layer
class GraphConvolution(Module):

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.matmul(input, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

# Define the LSTM unit
class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, return_sequence):
        super(LSTM, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.return_sequence = return_sequence
        #self.seq_length = seq_length
        
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)

    def forward(self, x):
        
        h_0 = Variable(torch.zeros(1, x.size(0), self.hidden_size))
        c_0 = Variable(torch.zeros(1, x.size(0), self.hidden_size))
        
        # Propagate input through LSTM
        output, (h_n, c_n) = self.lstm(x, (h_0, c_0))
        if self.return_sequence==False:
            output = output[:,-1,:]
            output = output.contiguous().view(output.shape[0],-1)
        return output    

In [38]:
class GCN(Module):

    def __init__(self, nfea_gcn, nhid_gcn1, nhid_gcn2, nhid_gcn3, n_latent_gcn, num_link):
        super(GCN, self).__init__()
        
        self.nfea_gcn = nfea_gcn
        self.num_link = num_link
        
        self.nhid_gcn1 = nhid_gcn1
        self.nhid_gcn2 = nhid_gcn2
        self.nhid_gcn3 = nhid_gcn3
        
        self.gc11 = GraphConvolution(nfea_gcn, nhid_gcn1[0])
        self.gc12 = GraphConvolution(nhid_gcn1[0], nhid_gcn1[1])
        self.fc_n = nn.Linear(nhid_gcn1[2] * num_link, nhid_gcn1[2])
        
        
        self.gc21 = GraphConvolution(nfea_gcn, nhid_gcn2[0])
        self.gc22 = GraphConvolution(nhid_gcn2[0], nhid_gcn2[1])
        self.fc_m = nn.Linear(nhid_gcn2[2] * num_link, nhid_gcn2[2])
    
        self.gc31 = GraphConvolution(nfea_gcn, nhid_gcn3[0])
        self.gc32 = GraphConvolution(nhid_gcn3[0], nhid_gcn3[1])
        self.fc_s = nn.Linear(nhid_gcn3[2] * num_link, nhid_gcn3[2])
        
        self.fc = nn.Linear(nhid_gcn1[2], n_latent_gcn)
        
    def forward(self, x, adj_n, adj_c, adj_k):
        
        x1 = torch.tanh(self.gc11(x, adj_n))
        x1 = torch.tanh(self.gc12(x1, adj_n))
        x1 = x1.view(x1.shape[0],-1)
        x1 = torch.tanh(self.fc_n(x1))
        
        x2 = torch.tanh(self.gc21(x, adj_c))
        x2 = torch.tanh(self.gc22(x2, adj_c))
        x2 = x2.view(x2.shape[0],-1)
        x2 = torch.tanh(self.fc_m(x2))
        
        x3 = torch.tanh(self.gc31(x, adj_k))
        x3 = torch.tanh(self.gc32(x3, adj_k))
        x3 = x3.view(x3.shape[0],-1)
        x3 = torch.tanh(self.fc_s(x3))
        
        x0 = (x1 + x2 + x3)/3
        x0 = torch.tanh(self.fc(x0))
        
        return x0

In [39]:
class PIML(Module):
    
    def __init__(self, nfea_lstm, nfea_gcn, prediction_horizon, num_link,
                 nhid_lstm, nhid_gcn, n_latent_gcn, nhid_linear):
        super(PIML, self).__init__()
   
        self.nfea_lstm = nfea_lstm
        self.nfea_gcn = nfea_gcn
        self.prediction_horizon = prediction_horizon
        self.num_link = num_link
        self.nhid_linear = nhid_linear  

        # three-layer LSTM
        self.lstm1 = LSTM(nfea_lstm, nhid_lstm[0], return_sequence=True)
        self.lstm2 = LSTM(nhid_lstm[0], nhid_lstm[1], return_sequence=True)
        self.lstm3 = LSTM(nhid_lstm[1], nhid_lstm[2], return_sequence=False)
        
        # MGCN
        self.MGCN = GCN(nfea_gcn, nhid_gcn[0], nhid_gcn[1], nhid_gcn[2], n_latent_gcn, num_link)
        
        self.fc1 = nn.Linear(nhid_lstm[-1]+n_latent_gcn, nhid_linear[0])
        self.fc2 = nn.Linear(nhid_linear[0], prediction_horizon*nfea_lstm)
        
    def forward(self, fea, adj_n, adj_c, adj_k):
        
        x_t = self.lstm1(fea)
        x_t = self.lstm2(x_t)
        x_t = self.lstm3(x_t)

        x_s = fea.transpose(1,2)
        x_s = self.MGCN(x_s, adj_n, adj_c, adj_k)        

        x = torch.cat((x_t, x_s),dim=1)

        x = torch.tanh(self.fc1(x)) 
        x = torch.tanh(self.fc2(x))
        
        x = x.view(x.shape[0], self.prediction_horizon, self.nfea_lstm)

        return x

In [40]:
piml_v=PIML(nfea_lstm=num_link, nfea_gcn=time_leg, prediction_horizon=prediction_horizon,num_link = num_link,
            nhid_lstm=[256,256,256], nhid_gcn=[[32,32,32],[32,32,32],[32,32,32]], n_latent_gcn=512, nhid_linear=[512])

In [41]:
# Loss function
nn_loss = nn.MSELoss()
# Optimizers
# optimizer = torch.optim.Adam(mtmgnn.parameters(), lr=0.001, betas=(0.5, 0.999))
optimizer = AdamW(piml_v.parameters(), lr=0.001, weight_decay=1e-2)
n_epochs = 2000
batch_size = 256
sample_size = training_set_v.shape[0]

In [42]:
def train(n_epochs, batch_size, sample_size):
    t0 = time.time()
    for epoch in range(n_epochs):
        
        idx = random.sample(range(0, sample_size),batch_size)
        
        input_v = training_input_v[idx]
        output_v = training_output_v[idx]
        
        optimizer.zero_grad()
        
        pred_v0 = piml_v(input_v, adj_n, adj_correlation_v, adj_knn_v)
        
        loss = nn_loss(output_v, pred_v0)
            
        loss.backward()
        optimizer.step()
        
        print(
            "[Epoch %d/%d] [loss: %f] [time: %f]"
            % (epoch, n_epochs, loss.item(), time.time()-t0)
        )

In [43]:
train(n_epochs = n_epochs, batch_size = batch_size, sample_size = sample_size)

[Epoch 0/2000] [loss: 0.330778] [time: 1.126775]
[Epoch 1/2000] [loss: 0.291294] [time: 2.185531]
[Epoch 2/2000] [loss: 0.174156] [time: 3.098179]
[Epoch 3/2000] [loss: 0.145638] [time: 3.977805]
[Epoch 4/2000] [loss: 0.134213] [time: 4.765363]
[Epoch 5/2000] [loss: 0.126310] [time: 5.471868]
[Epoch 6/2000] [loss: 0.126062] [time: 6.239699]
[Epoch 7/2000] [loss: 0.102843] [time: 6.899172]
[Epoch 8/2000] [loss: 0.087778] [time: 7.573648]
[Epoch 9/2000] [loss: 0.079609] [time: 8.211102]
[Epoch 10/2000] [loss: 0.071373] [time: 8.860566]
[Epoch 11/2000] [loss: 0.068058] [time: 9.539071]
[Epoch 12/2000] [loss: 0.064164] [time: 10.169494]
[Epoch 13/2000] [loss: 0.066847] [time: 10.810952]
[Epoch 14/2000] [loss: 0.060598] [time: 11.446537]
[Epoch 15/2000] [loss: 0.061620] [time: 12.105978]
[Epoch 16/2000] [loss: 0.060997] [time: 12.746459]
[Epoch 17/2000] [loss: 0.060055] [time: 13.357867]
[Epoch 18/2000] [loss: 0.056793] [time: 13.990317]
[Epoch 19/2000] [loss: 0.056349] [time: 14.589746]
[E

[Epoch 160/2000] [loss: 0.029226] [time: 101.149243]
[Epoch 161/2000] [loss: 0.029946] [time: 101.718648]
[Epoch 162/2000] [loss: 0.030651] [time: 102.272044]
[Epoch 163/2000] [loss: 0.030294] [time: 102.829439]
[Epoch 164/2000] [loss: 0.030582] [time: 103.397843]
[Epoch 165/2000] [loss: 0.030160] [time: 104.010162]
[Epoch 166/2000] [loss: 0.028219] [time: 104.684644]
[Epoch 167/2000] [loss: 0.030166] [time: 105.432177]
[Epoch 168/2000] [loss: 0.031473] [time: 106.168727]
[Epoch 169/2000] [loss: 0.030441] [time: 106.890212]
[Epoch 170/2000] [loss: 0.026522] [time: 107.566696]
[Epoch 171/2000] [loss: 0.028814] [time: 108.233167]
[Epoch 172/2000] [loss: 0.028913] [time: 108.882379]
[Epoch 173/2000] [loss: 0.029751] [time: 109.539847]
[Epoch 174/2000] [loss: 0.030832] [time: 110.200319]
[Epoch 175/2000] [loss: 0.028389] [time: 110.895814]
[Epoch 176/2000] [loss: 0.030184] [time: 111.611322]
[Epoch 177/2000] [loss: 0.027165] [time: 112.298833]
[Epoch 178/2000] [loss: 0.027984] [time: 112.9

[Epoch 315/2000] [loss: 0.017995] [time: 165.470790]
[Epoch 316/2000] [loss: 0.017071] [time: 165.761205]
[Epoch 317/2000] [loss: 0.017340] [time: 166.058416]
[Epoch 318/2000] [loss: 0.017111] [time: 166.352626]
[Epoch 319/2000] [loss: 0.016993] [time: 166.635827]
[Epoch 320/2000] [loss: 0.016440] [time: 166.919029]
[Epoch 321/2000] [loss: 0.017372] [time: 167.212237]
[Epoch 322/2000] [loss: 0.018288] [time: 167.496440]
[Epoch 323/2000] [loss: 0.016785] [time: 167.786998]
[Epoch 324/2000] [loss: 0.016796] [time: 168.078205]
[Epoch 325/2000] [loss: 0.017884] [time: 168.375416]
[Epoch 326/2000] [loss: 0.017180] [time: 168.657617]
[Epoch 327/2000] [loss: 0.017010] [time: 168.946823]
[Epoch 328/2000] [loss: 0.017364] [time: 169.227023]
[Epoch 329/2000] [loss: 0.016423] [time: 169.514227]
[Epoch 330/2000] [loss: 0.018117] [time: 169.802432]
[Epoch 331/2000] [loss: 0.017188] [time: 170.086782]
[Epoch 332/2000] [loss: 0.016993] [time: 170.371984]
[Epoch 333/2000] [loss: 0.016139] [time: 170.6

[Epoch 470/2000] [loss: 0.012592] [time: 210.071784]
[Epoch 471/2000] [loss: 0.011790] [time: 210.371998]
[Epoch 472/2000] [loss: 0.012917] [time: 210.654198]
[Epoch 473/2000] [loss: 0.012760] [time: 210.937400]
[Epoch 474/2000] [loss: 0.012790] [time: 211.225605]
[Epoch 475/2000] [loss: 0.012533] [time: 211.504804]
[Epoch 476/2000] [loss: 0.013037] [time: 211.796011]
[Epoch 477/2000] [loss: 0.012489] [time: 212.082215]
[Epoch 478/2000] [loss: 0.012467] [time: 212.371119]
[Epoch 479/2000] [loss: 0.011486] [time: 212.651319]
[Epoch 480/2000] [loss: 0.011667] [time: 212.936522]
[Epoch 481/2000] [loss: 0.011857] [time: 213.220724]
[Epoch 482/2000] [loss: 0.012334] [time: 213.501924]
[Epoch 483/2000] [loss: 0.012707] [time: 213.782124]
[Epoch 484/2000] [loss: 0.012443] [time: 214.069328]
[Epoch 485/2000] [loss: 0.011831] [time: 214.354531]
[Epoch 486/2000] [loss: 0.011692] [time: 214.643587]
[Epoch 487/2000] [loss: 0.012623] [time: 214.923787]
[Epoch 488/2000] [loss: 0.012218] [time: 215.2

[Epoch 625/2000] [loss: 0.010000] [time: 255.016551]
[Epoch 626/2000] [loss: 0.009702] [time: 255.305757]
[Epoch 627/2000] [loss: 0.009794] [time: 255.598966]
[Epoch 628/2000] [loss: 0.009238] [time: 255.892175]
[Epoch 629/2000] [loss: 0.009659] [time: 256.180380]
[Epoch 630/2000] [loss: 0.009820] [time: 256.463581]
[Epoch 631/2000] [loss: 0.010546] [time: 256.744782]
[Epoch 632/2000] [loss: 0.009638] [time: 257.028787]
[Epoch 633/2000] [loss: 0.010300] [time: 257.313990]
[Epoch 634/2000] [loss: 0.009727] [time: 257.605197]
[Epoch 635/2000] [loss: 0.009646] [time: 257.889399]
[Epoch 636/2000] [loss: 0.010163] [time: 258.183609]
[Epoch 637/2000] [loss: 0.009985] [time: 258.463808]
[Epoch 638/2000] [loss: 0.009735] [time: 258.750012]
[Epoch 639/2000] [loss: 0.009361] [time: 259.036216]
[Epoch 640/2000] [loss: 0.009648] [time: 259.321156]
[Epoch 641/2000] [loss: 0.011076] [time: 259.603357]
[Epoch 642/2000] [loss: 0.009688] [time: 259.892563]
[Epoch 643/2000] [loss: 0.009215] [time: 260.1

[Epoch 780/2000] [loss: 0.008511] [time: 299.713767]
[Epoch 781/2000] [loss: 0.008716] [time: 299.997970]
[Epoch 782/2000] [loss: 0.008239] [time: 300.282172]
[Epoch 783/2000] [loss: 0.008629] [time: 300.565374]
[Epoch 784/2000] [loss: 0.008621] [time: 300.851578]
[Epoch 785/2000] [loss: 0.008931] [time: 301.146788]
[Epoch 786/2000] [loss: 0.008424] [time: 301.432991]
[Epoch 787/2000] [loss: 0.008331] [time: 301.714975]
[Epoch 788/2000] [loss: 0.008854] [time: 302.001178]
[Epoch 789/2000] [loss: 0.008389] [time: 302.277375]
[Epoch 790/2000] [loss: 0.008389] [time: 302.563579]
[Epoch 791/2000] [loss: 0.008454] [time: 302.848782]
[Epoch 792/2000] [loss: 0.008743] [time: 303.134986]
[Epoch 793/2000] [loss: 0.007993] [time: 303.418187]
[Epoch 794/2000] [loss: 0.008134] [time: 303.707981]
[Epoch 795/2000] [loss: 0.009157] [time: 304.020204]
[Epoch 796/2000] [loss: 0.008524] [time: 304.298905]
[Epoch 797/2000] [loss: 0.008426] [time: 304.582107]
[Epoch 798/2000] [loss: 0.008489] [time: 304.8

[Epoch 935/2000] [loss: 0.007314] [time: 344.386932]
[Epoch 936/2000] [loss: 0.007547] [time: 344.739184]
[Epoch 937/2000] [loss: 0.007401] [time: 345.017381]
[Epoch 938/2000] [loss: 0.007675] [time: 345.298581]
[Epoch 939/2000] [loss: 0.007712] [time: 345.584519]
[Epoch 940/2000] [loss: 0.007639] [time: 345.866720]
[Epoch 941/2000] [loss: 0.007329] [time: 346.171937]
[Epoch 942/2000] [loss: 0.007420] [time: 346.449135]
[Epoch 943/2000] [loss: 0.007420] [time: 346.733336]
[Epoch 944/2000] [loss: 0.007131] [time: 347.018540]
[Epoch 945/2000] [loss: 0.007247] [time: 347.309747]
[Epoch 946/2000] [loss: 0.007254] [time: 347.594309]
[Epoch 947/2000] [loss: 0.007412] [time: 347.878512]
[Epoch 948/2000] [loss: 0.007967] [time: 348.229762]
[Epoch 949/2000] [loss: 0.007346] [time: 348.508960]
[Epoch 950/2000] [loss: 0.007911] [time: 348.795164]
[Epoch 951/2000] [loss: 0.007223] [time: 349.079365]
[Epoch 952/2000] [loss: 0.006995] [time: 349.379579]
[Epoch 953/2000] [loss: 0.008004] [time: 349.6

[Epoch 1088/2000] [loss: 0.007107] [time: 388.568558]
[Epoch 1089/2000] [loss: 0.006929] [time: 388.846756]
[Epoch 1090/2000] [loss: 0.006889] [time: 389.134405]
[Epoch 1091/2000] [loss: 0.007056] [time: 389.427613]
[Epoch 1092/2000] [loss: 0.007209] [time: 389.707813]
[Epoch 1093/2000] [loss: 0.007094] [time: 389.989013]
[Epoch 1094/2000] [loss: 0.006826] [time: 390.282221]
[Epoch 1095/2000] [loss: 0.007233] [time: 390.586438]
[Epoch 1096/2000] [loss: 0.008166] [time: 390.862635]
[Epoch 1097/2000] [loss: 0.007282] [time: 391.152841]
[Epoch 1098/2000] [loss: 0.007419] [time: 391.446500]
[Epoch 1099/2000] [loss: 0.007096] [time: 391.726699]
[Epoch 1100/2000] [loss: 0.006871] [time: 392.020909]
[Epoch 1101/2000] [loss: 0.006944] [time: 392.315118]
[Epoch 1102/2000] [loss: 0.006836] [time: 392.596318]
[Epoch 1103/2000] [loss: 0.007602] [time: 392.881521]
[Epoch 1104/2000] [loss: 0.006781] [time: 393.178733]
[Epoch 1105/2000] [loss: 0.007422] [time: 393.479948]
[Epoch 1106/2000] [loss: 0.0

[Epoch 1240/2000] [loss: 0.006797] [time: 432.418861]
[Epoch 1241/2000] [loss: 0.006776] [time: 432.700946]
[Epoch 1242/2000] [loss: 0.006438] [time: 432.996156]
[Epoch 1243/2000] [loss: 0.006311] [time: 433.277356]
[Epoch 1244/2000] [loss: 0.006557] [time: 433.562559]
[Epoch 1245/2000] [loss: 0.006359] [time: 433.846761]
[Epoch 1246/2000] [loss: 0.006885] [time: 434.128962]
[Epoch 1247/2000] [loss: 0.006280] [time: 434.424172]
[Epoch 1248/2000] [loss: 0.006616] [time: 434.706373]
[Epoch 1249/2000] [loss: 0.006223] [time: 435.059409]
[Epoch 1250/2000] [loss: 0.006715] [time: 435.344612]
[Epoch 1251/2000] [loss: 0.006742] [time: 435.631816]
[Epoch 1252/2000] [loss: 0.006594] [time: 435.955046]
[Epoch 1253/2000] [loss: 0.006698] [time: 436.248255]
[Epoch 1254/2000] [loss: 0.006333] [time: 436.537461]
[Epoch 1255/2000] [loss: 0.006596] [time: 436.885710]
[Epoch 1256/2000] [loss: 0.006612] [time: 437.175915]
[Epoch 1257/2000] [loss: 0.006831] [time: 437.464062]
[Epoch 1258/2000] [loss: 0.0

[Epoch 1392/2000] [loss: 0.006284] [time: 476.275321]
[Epoch 1393/2000] [loss: 0.006581] [time: 476.569531]
[Epoch 1394/2000] [loss: 0.006421] [time: 476.863840]
[Epoch 1395/2000] [loss: 0.006116] [time: 477.147041]
[Epoch 1396/2000] [loss: 0.006765] [time: 477.443252]
[Epoch 1397/2000] [loss: 0.006330] [time: 477.724452]
[Epoch 1398/2000] [loss: 0.006583] [time: 478.005653]
[Epoch 1399/2000] [loss: 0.006889] [time: 478.293857]
[Epoch 1400/2000] [loss: 0.006391] [time: 478.584064]
[Epoch 1401/2000] [loss: 0.006615] [time: 478.859260]
[Epoch 1402/2000] [loss: 0.006727] [time: 479.218951]
[Epoch 1403/2000] [loss: 0.006351] [time: 479.498150]
[Epoch 1404/2000] [loss: 0.006243] [time: 479.790357]
[Epoch 1405/2000] [loss: 0.006346] [time: 480.073559]
[Epoch 1406/2000] [loss: 0.006435] [time: 480.355760]
[Epoch 1407/2000] [loss: 0.006100] [time: 480.642964]
[Epoch 1408/2000] [loss: 0.006666] [time: 480.940176]
[Epoch 1409/2000] [loss: 0.006275] [time: 481.227380]
[Epoch 1410/2000] [loss: 0.0

[Epoch 1544/2000] [loss: 0.005966] [time: 520.020114]
[Epoch 1545/2000] [loss: 0.005674] [time: 520.309321]
[Epoch 1546/2000] [loss: 0.005761] [time: 520.598065]
[Epoch 1547/2000] [loss: 0.006148] [time: 520.910289]
[Epoch 1548/2000] [loss: 0.005873] [time: 521.188485]
[Epoch 1549/2000] [loss: 0.006463] [time: 521.470686]
[Epoch 1550/2000] [loss: 0.006313] [time: 521.762895]
[Epoch 1551/2000] [loss: 0.005929] [time: 522.055103]
[Epoch 1552/2000] [loss: 0.006478] [time: 522.341306]
[Epoch 1553/2000] [loss: 0.005686] [time: 522.629512]
[Epoch 1554/2000] [loss: 0.006118] [time: 522.913150]
[Epoch 1555/2000] [loss: 0.006656] [time: 523.196353]
[Epoch 1556/2000] [loss: 0.005880] [time: 523.478554]
[Epoch 1557/2000] [loss: 0.006297] [time: 523.769760]
[Epoch 1558/2000] [loss: 0.006243] [time: 524.055964]
[Epoch 1559/2000] [loss: 0.005966] [time: 524.354177]
[Epoch 1560/2000] [loss: 0.005831] [time: 524.638378]
[Epoch 1561/2000] [loss: 0.006021] [time: 524.920579]
[Epoch 1562/2000] [loss: 0.0

[Epoch 1696/2000] [loss: 0.005946] [time: 563.897891]
[Epoch 1697/2000] [loss: 0.005732] [time: 564.218678]
[Epoch 1698/2000] [loss: 0.005886] [time: 564.494874]
[Epoch 1699/2000] [loss: 0.006105] [time: 564.789274]
[Epoch 1700/2000] [loss: 0.006216] [time: 565.074477]
[Epoch 1701/2000] [loss: 0.005875] [time: 565.360680]
[Epoch 1702/2000] [loss: 0.005969] [time: 565.647885]
[Epoch 1703/2000] [loss: 0.005667] [time: 565.939092]
[Epoch 1704/2000] [loss: 0.006097] [time: 566.304353]
[Epoch 1705/2000] [loss: 0.005902] [time: 566.589555]
[Epoch 1706/2000] [loss: 0.005459] [time: 566.872756]
[Epoch 1707/2000] [loss: 0.005570] [time: 567.163964]
[Epoch 1708/2000] [loss: 0.005812] [time: 567.444697]
[Epoch 1709/2000] [loss: 0.005921] [time: 567.737908]
[Epoch 1710/2000] [loss: 0.005795] [time: 568.058136]
[Epoch 1711/2000] [loss: 0.006283] [time: 568.339336]
[Epoch 1712/2000] [loss: 0.005988] [time: 568.620536]
[Epoch 1713/2000] [loss: 0.006762] [time: 568.897733]
[Epoch 1714/2000] [loss: 0.0

[Epoch 1848/2000] [loss: 0.005947] [time: 607.921649]
[Epoch 1849/2000] [loss: 0.006014] [time: 608.197846]
[Epoch 1850/2000] [loss: 0.005429] [time: 608.517073]
[Epoch 1851/2000] [loss: 0.005711] [time: 608.817286]
[Epoch 1852/2000] [loss: 0.006493] [time: 609.096485]
[Epoch 1853/2000] [loss: 0.005584] [time: 609.382689]
[Epoch 1854/2000] [loss: 0.005597] [time: 609.671894]
[Epoch 1855/2000] [loss: 0.005664] [time: 609.960100]
[Epoch 1856/2000] [loss: 0.005284] [time: 610.246303]
[Epoch 1857/2000] [loss: 0.005700] [time: 610.528504]
[Epoch 1858/2000] [loss: 0.005717] [time: 610.818712]
[Epoch 1859/2000] [loss: 0.005690] [time: 611.099912]
[Epoch 1860/2000] [loss: 0.005924] [time: 611.393120]
[Epoch 1861/2000] [loss: 0.005736] [time: 611.684327]
[Epoch 1862/2000] [loss: 0.005356] [time: 611.975534]
[Epoch 1863/2000] [loss: 0.005390] [time: 612.273746]
[Epoch 1864/2000] [loss: 0.005371] [time: 612.597977]
[Epoch 1865/2000] [loss: 0.005563] [time: 612.903195]
[Epoch 1866/2000] [loss: 0.0

In [44]:
test_pred_v0 = piml_v(test_input_v, adj_n, adj_correlation_v, adj_knn_v)
test_pred_v = r_norm(test_pred_v0, max_v, min_v)

In [47]:
print(nn_loss(test_pred_v, test_true_v)**0.5)

l1_loss = nn.L1Loss()
print(l1_loss(test_pred_v, test_true_v))

mape_v = torch.abs(test_pred_v-test_true_v)/test_true_v
mape_v = mape_v[mape_v<1000]

print(torch.mean(mape_v))

tensor(2.8566, grad_fn=<PowBackward0>)
tensor(1.9949, grad_fn=<MeanBackward0>)
tensor(0.0462, grad_fn=<MeanBackward0>)


In [49]:
test_pred_q = test_pred_v * k_c * (((v_f**mm)/(test_pred_v**mm))**0.5-1)**(1/mm)