In [1]:
import time

import numpy as np
import pandas as pd
from pandas import DataFrame
import math

from matplotlib.pyplot import savefig

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.optim import AdamW
import random

import matplotlib.pyplot as plt
import scipy.sparse as sp

In [2]:
training_flow = pd.read_csv('data/training_flow.csv')
training_flow['Timestamp'] = pd.to_datetime(training_flow['Timestamp'])
training_flow.set_index('Timestamp', inplace=True)

training_speed = pd.read_csv('data/training_speed.csv')
training_speed['Timestamp'] = pd.to_datetime(training_speed['Timestamp'])
training_speed.set_index('Timestamp', inplace=True)

test_flow = pd.read_csv('data/test_flow.csv')
test_flow['Timestamp'] = pd.to_datetime(test_flow['Timestamp'])
test_flow.set_index('Timestamp', inplace=True)

test_speed = pd.read_csv('data/test_speed.csv')
test_speed['Timestamp'] = pd.to_datetime(test_speed['Timestamp'])
test_speed.set_index('Timestamp', inplace=True)

num_link = training_flow.shape[1]

In [3]:
time_leg = 24
prediction_horizon = 6

In [4]:
continuous_training_flow = []
continuous_training_speed = []

continuous_training_time = []

training_times = training_flow.index

for i in range(len(training_times) - time_leg - prediction_horizon + 1):
    
    # Check if the 30 time points starting from the current line are consecutive
    if all(training_times[i + j] == training_times[i] + pd.Timedelta(minutes=5 * j) for j in range(1,time_leg + prediction_horizon)):
        
        # If consecutive, extract the data from these 30 lines
        data_slice_flow = training_flow.iloc[i:i + time_leg + prediction_horizon].values
        data_slice_speed= training_speed.iloc[i:i + time_leg + prediction_horizon].values
        
        # Add the data to the list
        continuous_training_flow.append(data_slice_flow)
        continuous_training_speed.append(data_slice_speed)
        
        continuous_training_time.append(training_times[i])
        
flow_data_training = np.array(continuous_training_flow)
speed_data_training = np.array(continuous_training_speed)

flow_data_training.shape

(67752, 30, 150)

In [5]:
continuous_test_flow = []
continuous_test_speed = []

continuous_test_time = []

test_times = test_flow.index

for i in range(len(test_times) - time_leg - prediction_horizon + 1):
    
    # Check if the 30 time points starting from the current line are consecutive
    if all(test_times[i + j] == test_times[i] + pd.Timedelta(minutes=5 * j) for j in range(1,time_leg + prediction_horizon)):
        
        # If consecutive, extract the data from these 30 lines
        data_slice_flow = test_flow.iloc[i:i + time_leg + prediction_horizon].values
        data_slice_speed= test_speed.iloc[i:i + time_leg + prediction_horizon].values
        
        # Add the data to the list
        continuous_test_flow.append(data_slice_flow)
        continuous_test_speed.append(data_slice_speed)
        
        continuous_test_time.append(test_times[i])
        
flow_data_test = np.array(continuous_test_flow)
speed_data_test = np.array(continuous_test_speed)

flow_data_test.shape

(4891, 30, 150)

In [6]:
training_set_q = torch.from_numpy(flow_data_training).to(torch.float32)
test_set_q = torch.from_numpy(flow_data_test).to(torch.float32)

training_set_v = torch.from_numpy(speed_data_training).to(torch.float32)
test_set_v = torch.from_numpy(speed_data_test).to(torch.float32)

print(training_set_q.shape)
print(test_set_q.shape)
print(training_set_v.shape)
print(test_set_v.shape)

torch.Size([67752, 30, 150])
torch.Size([4891, 30, 150])
torch.Size([67752, 30, 150])
torch.Size([4891, 30, 150])


In [7]:
v_f = pd.read_csv('parameter/v_f.csv').values
k_c = pd.read_csv('parameter/k_c.csv').values
mm = pd.read_csv('parameter/mm.csv').values

v_f = torch.from_numpy(v_f).to(torch.float32).reshape(-1)
k_c = torch.from_numpy(k_c).to(torch.float32).reshape(-1)
mm = torch.from_numpy(mm).to(torch.float32).reshape(-1)

v_f.shape

torch.Size([150])

==============================   Prediction model   =================================

In [9]:
adj = pd.read_csv('data/adj_matrix/proximity_matrix.csv').values
adj_correlation_q = pd.read_csv('data/adj_matrix/correlation_matrix_flow.csv').values
adj_knn_q = pd.read_csv('data/adj_matrix/knn_matrix_flow.csv').values
adj_correlation_v = pd.read_csv('data/adj_matrix/correlation_matrix_speed.csv').values
adj_knn_v = pd.read_csv('data/adj_matrix/knn_matrix_speed.csv').values

In [10]:
adj_correlation_q = torch.from_numpy(adj_correlation_q).to(torch.float32)
adj_correlation_v = torch.from_numpy(adj_correlation_v).to(torch.float32)

In [11]:
# Adjacency matrix normalization
def adj_def(adj):
   
    adj = adj + sp.eye(adj.shape[0])
    rowsum = np.array(adj.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0
    r_mat_inv = sp.diags(r_inv)
    adj = r_mat_inv.dot(adj).A
    
    adj = torch.from_numpy(adj).to(torch.float32)

    return adj

In [12]:
adj_n = adj_def(adj)
print(adj_n.shape)
adj_knn_q = adj_def(adj_knn_q)
adj_knn_v = adj_def(adj_knn_v)

torch.Size([150, 150])


In [13]:
max_q = training_set_q.max(axis=0)[0].max(axis=0)[0]
min_q = training_set_q.min(axis=0)[0].min(axis=0)[0]

max_v = training_set_v.max(axis=0)[0].max(axis=0)[0]
min_v = training_set_v.min(axis=0)[0].min(axis=0)[0]

max_q.shape

torch.Size([150])

In [14]:
# Define the normalization function
def norm(x,max_x,min_x):
    x0 = 2*(x-min_x)/(max_x-min_x)-1
    return x0

In [15]:
# Define the inverse normalization function
def r_norm(x0,max_x,min_x):
    x = (x0+1) * (max_x-min_x)/2 + min_x
    return x

In [16]:
# Normalize flow and speed data
training_set_q0 = norm(training_set_q,max_q,min_q)
test_set_q0 = norm(test_set_q,max_q,min_q)

training_set_v0 = norm(training_set_v,max_v,min_v)
test_set_v0 = norm(test_set_v,max_v,min_v)

In [17]:
# Training set input and output
training_input_q = training_set_q0[:,:time_leg,:]
training_output_q = training_set_q0[:,time_leg:time_leg+prediction_horizon,:]

training_input_v = training_set_v0[:,:time_leg,:]
training_output_v = training_set_v0[:,time_leg:time_leg+prediction_horizon,:]

# Test set input and output
test_input_q = test_set_q0[:,:time_leg,:]
test_output_q = test_set_q0[:,time_leg:time_leg+prediction_horizon,:]

test_input_v = test_set_v0[:,:time_leg,:]
test_output_v = test_set_v0[:,time_leg:time_leg+prediction_horizon,:]

# Ground truth
training_true_q = training_set_q[:,time_leg:time_leg+prediction_horizon,:]
training_true_v = training_set_v[:,time_leg:time_leg+prediction_horizon,:]

test_true_q = test_set_q[:,time_leg:time_leg+prediction_horizon,:]
test_true_v = test_set_v[:,time_leg:time_leg+prediction_horizon,:]

In [19]:
num_link = training_set_q0.shape[2]
num_link

150

In [24]:
# Define the GraphConvolution layer
class GraphConvolution(Module):

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.matmul(input, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

# Define the LSTM unit
class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, return_sequence):
        super(LSTM, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.return_sequence = return_sequence
        #self.seq_length = seq_length
        
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)

    def forward(self, x):
        
        h_0 = Variable(torch.zeros(1, x.size(0), self.hidden_size))
        c_0 = Variable(torch.zeros(1, x.size(0), self.hidden_size))
        
        # Propagate input through LSTM
        output, (h_n, c_n) = self.lstm(x, (h_0, c_0))
        if self.return_sequence==False:
            output = output[:,-1,:]
            output = output.contiguous().view(output.shape[0],-1)
        return output    

In [25]:
class GCN(Module):

    def __init__(self, nfea_gcn, nhid_gcn1, nhid_gcn2, nhid_gcn3, n_latent_gcn, num_link):
        super(GCN, self).__init__()
        
        self.nfea_gcn = nfea_gcn
        self.num_link = num_link
        
        self.nhid_gcn1 = nhid_gcn1
        self.nhid_gcn2 = nhid_gcn2
        self.nhid_gcn3 = nhid_gcn3
        
        self.gc11 = GraphConvolution(nfea_gcn, nhid_gcn1[0])
        self.gc12 = GraphConvolution(nhid_gcn1[0], nhid_gcn1[1])
        self.fc_n = nn.Linear(nhid_gcn1[2] * num_link, nhid_gcn1[2])
        
        
        self.gc21 = GraphConvolution(nfea_gcn, nhid_gcn2[0])
        self.gc22 = GraphConvolution(nhid_gcn2[0], nhid_gcn2[1])
        self.fc_m = nn.Linear(nhid_gcn2[2] * num_link, nhid_gcn2[2])
    
        self.gc31 = GraphConvolution(nfea_gcn, nhid_gcn3[0])
        self.gc32 = GraphConvolution(nhid_gcn3[0], nhid_gcn3[1])
        self.fc_s = nn.Linear(nhid_gcn3[2] * num_link, nhid_gcn3[2])
        
        self.fc = nn.Linear(nhid_gcn1[2], n_latent_gcn)
        
    def forward(self, x, adj_n, adj_c, adj_k):
        
        x1 = torch.tanh(self.gc11(x, adj_n))
        x1 = torch.tanh(self.gc12(x1, adj_n))
        x1 = x1.view(x1.shape[0],-1)
        x1 = torch.tanh(self.fc_n(x1))
        
        x2 = torch.tanh(self.gc21(x, adj_c))
        x2 = torch.tanh(self.gc22(x2, adj_c))
        x2 = x2.view(x2.shape[0],-1)
        x2 = torch.tanh(self.fc_m(x2))
        
        x3 = torch.tanh(self.gc31(x, adj_k))
        x3 = torch.tanh(self.gc32(x3, adj_k))
        x3 = x3.view(x3.shape[0],-1)
        x3 = torch.tanh(self.fc_s(x3))
        
        x0 = (x1 + x2 + x3)/3
        x0 = torch.tanh(self.fc(x0))
        
        return x0

In [32]:
class PIML(Module):
    
    def __init__(self, nfea_lstm, nfea_gcn, prediction_horizon, num_link,
                 nhid_lstm, nhid_gcn, n_latent_gcn, nhid_linear):
        super(PIML, self).__init__()
   
        self.nfea_lstm = nfea_lstm
        self.nfea_gcn = nfea_gcn
        self.prediction_horizon = prediction_horizon
        self.num_link = num_link
        self.nhid_linear = nhid_linear  

        # three-layer LSTM
        self.lstm1 = LSTM(nfea_lstm, nhid_lstm[0], return_sequence=True)
        self.lstm2 = LSTM(nhid_lstm[0], nhid_lstm[1], return_sequence=True)
        self.lstm3 = LSTM(nhid_lstm[1], nhid_lstm[2], return_sequence=False)
        
        # MGCN
        self.MGCN = GCN(nfea_gcn, nhid_gcn[0], nhid_gcn[1], nhid_gcn[2], n_latent_gcn, num_link)
        
        self.fc1 = nn.Linear(nhid_lstm[-1]+n_latent_gcn, nhid_linear[0])
        self.fc2 = nn.Linear(nhid_linear[0], prediction_horizon*nfea_lstm)
        
    def forward(self, fea, adj_n, adj_c, adj_k):
        
        x_t = self.lstm1(fea)
        x_t = self.lstm2(x_t)
        x_t = self.lstm3(x_t)

        x_s = fea.transpose(1,2)
        x_s = self.MGCN(x_s, adj_n, adj_c, adj_k)        

        x = torch.cat((x_t, x_s),dim=1)

        x = torch.tanh(self.fc1(x)) 
        x = torch.tanh(self.fc2(x))
        
        x = x.view(x.shape[0], self.prediction_horizon, self.nfea_lstm)

        return x

In [38]:
piml_q=PIML(nfea_lstm=num_link, nfea_gcn=time_leg, prediction_horizon=prediction_horizon,num_link = num_link,
            nhid_lstm=[256,256,256], nhid_gcn=[[32,32,32],[32,32,32],[32,32,32]], n_latent_gcn=512, nhid_linear=[512])

In [39]:
piml_v=PIML(nfea_lstm=num_link, nfea_gcn=time_leg, prediction_horizon=prediction_horizon,num_link = num_link,
            nhid_lstm=[256,256,256], nhid_gcn=[[32,32,32],[32,32,32],[32,32,32]], n_latent_gcn=512, nhid_linear=[512])

In [40]:
# Loss function
nn_loss = nn.MSELoss()
# Optimizers
optimizer = AdamW([
    {'params': piml_q.parameters(), 'lr': 0.001, 'weight_decay': 1e-2},
    {'params': piml_v.parameters(), 'lr': 0.001, 'weight_decay': 1e-2}
])
n_epochs = 2000
batch_size = 256
sample_size = training_set_q.shape[0]

In [41]:
def train(n_epochs, batch_size, sample_size):
    t0 = time.time()
    for epoch in range(n_epochs):
        
        idx = random.sample(range(0, sample_size),batch_size)
        
        input_q = training_input_q[idx]
        output_q = training_output_q[idx]
        
        input_v = training_input_v[idx]
        output_v = training_output_v[idx]
        
        optimizer.zero_grad()
        
        pred_q0 = piml_q(input_q, adj_n, adj_correlation_q, adj_knn_q)
        pred_v0 = piml_v(input_v, adj_n, adj_correlation_v, adj_knn_v)
        
        loss_q = nn_loss(output_q, pred_q0)
        loss_v = nn_loss(output_v, pred_v0)
        
        # S3
        pred_q = r_norm(pred_q0, max_q, min_q)
        pred_v = r_norm(pred_v0, max_v, min_v)
        v_cal = v_f / (1 + (pred_q/(pred_v*k_c))**mm) ** (2/mm)
        loss_fd_target = Variable(torch.Tensor(v_cal.shape).fill_(1.0), requires_grad=False)
        loss_fd = nn_loss(v_cal/pred_v, loss_fd_target)
        
        total_loss = loss_q + loss_v + 0.2 * loss_fd
            
        total_loss.backward()
        optimizer.step()
        
        print(
            "[Epoch %d/%d] [loss_q: %f] [loss_v: %f] [loss_fd: %f] [total_loss: %f] [time: %f]"
            % (epoch, n_epochs, loss_q.item(), loss_v.item(), loss_fd.item(), total_loss.item(), time.time()-t0)
        )

In [42]:
train(n_epochs = n_epochs, batch_size = batch_size, sample_size = sample_size)

[Epoch 0/2000] [loss_q: 0.218846] [loss_v: 0.332435] [loss_fd: 0.098412] [total_loss: 0.570964] [time: 2.304641]
[Epoch 1/2000] [loss_q: 0.197262] [loss_v: 0.312892] [loss_fd: 0.093555] [total_loss: 0.528865] [time: 4.372113]
[Epoch 2/2000] [loss_q: 0.169944] [loss_v: 0.235443] [loss_fd: 0.048844] [total_loss: 0.415156] [time: 6.453763]
[Epoch 3/2000] [loss_q: 0.126482] [loss_v: 0.169402] [loss_fd: 0.020569] [total_loss: 0.299997] [time: 8.334646]
[Epoch 4/2000] [loss_q: 0.097790] [loss_v: 0.175606] [loss_fd: 0.010359] [total_loss: 0.275467] [time: 10.301326]
[Epoch 5/2000] [loss_q: 0.085147] [loss_v: 0.147039] [loss_fd: 0.008378] [total_loss: 0.233861] [time: 11.882453]
[Epoch 6/2000] [loss_q: 0.074019] [loss_v: 0.143698] [loss_fd: 0.010973] [total_loss: 0.219912] [time: 13.104787]
[Epoch 7/2000] [loss_q: 0.069823] [loss_v: 0.137580] [loss_fd: 0.023337] [total_loss: 0.212070] [time: 14.131069]
[Epoch 8/2000] [loss_q: 0.060726] [loss_v: 0.118680] [loss_fd: 0.017595] [total_loss: 0.1829

[Epoch 72/2000] [loss_q: 0.023015] [loss_v: 0.047451] [loss_fd: 0.004723] [total_loss: 0.071412] [time: 65.485081]
[Epoch 73/2000] [loss_q: 0.022336] [loss_v: 0.047347] [loss_fd: 0.004867] [total_loss: 0.070657] [time: 66.257632]
[Epoch 74/2000] [loss_q: 0.021841] [loss_v: 0.047461] [loss_fd: 0.004717] [total_loss: 0.070245] [time: 67.032183]
[Epoch 75/2000] [loss_q: 0.022299] [loss_v: 0.046698] [loss_fd: 0.004728] [total_loss: 0.069943] [time: 67.789722]
[Epoch 76/2000] [loss_q: 0.022402] [loss_v: 0.047127] [loss_fd: 0.005130] [total_loss: 0.070555] [time: 68.564274]
[Epoch 77/2000] [loss_q: 0.022044] [loss_v: 0.047815] [loss_fd: 0.004720] [total_loss: 0.070803] [time: 69.363281]
[Epoch 78/2000] [loss_q: 0.020503] [loss_v: 0.042557] [loss_fd: 0.004415] [total_loss: 0.063943] [time: 70.158848]
[Epoch 79/2000] [loss_q: 0.022140] [loss_v: 0.047537] [loss_fd: 0.004608] [total_loss: 0.070598] [time: 70.901085]
[Epoch 80/2000] [loss_q: 0.023589] [loss_v: 0.047879] [loss_fd: 0.004918] [total

[Epoch 143/2000] [loss_q: 0.021235] [loss_v: 0.041246] [loss_fd: 0.004927] [total_loss: 0.063466] [time: 120.595840]
[Epoch 144/2000] [loss_q: 0.022662] [loss_v: 0.045488] [loss_fd: 0.005230] [total_loss: 0.069196] [time: 121.373393]
[Epoch 145/2000] [loss_q: 0.020880] [loss_v: 0.040208] [loss_fd: 0.004958] [total_loss: 0.062080] [time: 122.193977]
[Epoch 146/2000] [loss_q: 0.021704] [loss_v: 0.042516] [loss_fd: 0.004867] [total_loss: 0.065193] [time: 122.932503]
[Epoch 147/2000] [loss_q: 0.019850] [loss_v: 0.039467] [loss_fd: 0.004682] [total_loss: 0.060253] [time: 123.708055]
[Epoch 148/2000] [loss_q: 0.019705] [loss_v: 0.042278] [loss_fd: 0.004817] [total_loss: 0.062946] [time: 124.516630]
[Epoch 149/2000] [loss_q: 0.019947] [loss_v: 0.041769] [loss_fd: 0.004598] [total_loss: 0.062636] [time: 125.271167]
[Epoch 150/2000] [loss_q: 0.019859] [loss_v: 0.040051] [loss_fd: 0.004890] [total_loss: 0.060887] [time: 126.026706]
[Epoch 151/2000] [loss_q: 0.020514] [loss_v: 0.039117] [loss_fd:

[Epoch 214/2000] [loss_q: 0.018337] [loss_v: 0.032837] [loss_fd: 0.005095] [total_loss: 0.052193] [time: 175.846889]
[Epoch 215/2000] [loss_q: 0.018345] [loss_v: 0.034561] [loss_fd: 0.005213] [total_loss: 0.053948] [time: 176.696494]
[Epoch 216/2000] [loss_q: 0.017731] [loss_v: 0.034156] [loss_fd: 0.005582] [total_loss: 0.053003] [time: 177.466042]
[Epoch 217/2000] [loss_q: 0.017433] [loss_v: 0.034609] [loss_fd: 0.005182] [total_loss: 0.053078] [time: 178.289629]
[Epoch 218/2000] [loss_q: 0.018446] [loss_v: 0.035047] [loss_fd: 0.005483] [total_loss: 0.054589] [time: 179.047167]
[Epoch 219/2000] [loss_q: 0.018015] [loss_v: 0.034142] [loss_fd: 0.005172] [total_loss: 0.053191] [time: 179.781916]
[Epoch 220/2000] [loss_q: 0.017724] [loss_v: 0.035549] [loss_fd: 0.005101] [total_loss: 0.054293] [time: 180.611506]
[Epoch 221/2000] [loss_q: 0.017890] [loss_v: 0.035693] [loss_fd: 0.005638] [total_loss: 0.054710] [time: 181.389060]
[Epoch 222/2000] [loss_q: 0.017737] [loss_v: 0.034538] [loss_fd:

[Epoch 285/2000] [loss_q: 0.016901] [loss_v: 0.031454] [loss_fd: 0.005721] [total_loss: 0.049499] [time: 246.656916]
[Epoch 286/2000] [loss_q: 0.016561] [loss_v: 0.030578] [loss_fd: 0.005670] [total_loss: 0.048273] [time: 248.514240]
[Epoch 287/2000] [loss_q: 0.016088] [loss_v: 0.030805] [loss_fd: 0.005481] [total_loss: 0.047989] [time: 250.280498]
[Epoch 288/2000] [loss_q: 0.016906] [loss_v: 0.031626] [loss_fd: 0.005632] [total_loss: 0.049659] [time: 251.904654]
[Epoch 289/2000] [loss_q: 0.016494] [loss_v: 0.031011] [loss_fd: 0.005772] [total_loss: 0.048660] [time: 253.568836]
[Epoch 290/2000] [loss_q: 0.016073] [loss_v: 0.030593] [loss_fd: 0.005412] [total_loss: 0.047749] [time: 255.263551]
[Epoch 291/2000] [loss_q: 0.016927] [loss_v: 0.031270] [loss_fd: 0.005391] [total_loss: 0.049275] [time: 256.769622]
[Epoch 292/2000] [loss_q: 0.016348] [loss_v: 0.031392] [loss_fd: 0.005373] [total_loss: 0.048814] [time: 258.367764]
[Epoch 293/2000] [loss_q: 0.016314] [loss_v: 0.031284] [loss_fd:

[Epoch 356/2000] [loss_q: 0.015485] [loss_v: 0.028434] [loss_fd: 0.005815] [total_loss: 0.045082] [time: 364.397766]
[Epoch 357/2000] [loss_q: 0.014579] [loss_v: 0.028206] [loss_fd: 0.006051] [total_loss: 0.043996] [time: 366.116986]
[Epoch 358/2000] [loss_q: 0.015994] [loss_v: 0.027780] [loss_fd: 0.006222] [total_loss: 0.045019] [time: 367.925274]
[Epoch 359/2000] [loss_q: 0.015549] [loss_v: 0.027936] [loss_fd: 0.005772] [total_loss: 0.044640] [time: 369.685529]
[Epoch 360/2000] [loss_q: 0.015824] [loss_v: 0.028583] [loss_fd: 0.006266] [total_loss: 0.045660] [time: 371.492813]
[Epoch 361/2000] [loss_q: 0.014763] [loss_v: 0.027071] [loss_fd: 0.005980] [total_loss: 0.043031] [time: 373.230233]
[Epoch 362/2000] [loss_q: 0.015047] [loss_v: 0.027479] [loss_fd: 0.005680] [total_loss: 0.043662] [time: 374.865397]
[Epoch 363/2000] [loss_q: 0.015206] [loss_v: 0.028234] [loss_fd: 0.006558] [total_loss: 0.044752] [time: 376.447528]
[Epoch 364/2000] [loss_q: 0.014330] [loss_v: 0.026419] [loss_fd:

[Epoch 427/2000] [loss_q: 0.014106] [loss_v: 0.023873] [loss_fd: 0.006213] [total_loss: 0.039222] [time: 480.589375]
[Epoch 428/2000] [loss_q: 0.014729] [loss_v: 0.024682] [loss_fd: 0.005870] [total_loss: 0.040585] [time: 482.282104]
[Epoch 429/2000] [loss_q: 0.014825] [loss_v: 0.025181] [loss_fd: 0.006818] [total_loss: 0.041370] [time: 484.050363]
[Epoch 430/2000] [loss_q: 0.015013] [loss_v: 0.025943] [loss_fd: 0.006209] [total_loss: 0.042199] [time: 485.882674]
[Epoch 431/2000] [loss_q: 0.013930] [loss_v: 0.024675] [loss_fd: 0.006425] [total_loss: 0.039890] [time: 487.734986]
[Epoch 432/2000] [loss_q: 0.014744] [loss_v: 0.025905] [loss_fd: 0.006490] [total_loss: 0.041947] [time: 489.652554]
[Epoch 433/2000] [loss_q: 0.014454] [loss_v: 0.024634] [loss_fd: 0.006053] [total_loss: 0.040299] [time: 491.490858]
[Epoch 434/2000] [loss_q: 0.015142] [loss_v: 0.025227] [loss_fd: 0.006125] [total_loss: 0.041594] [time: 493.034958]
[Epoch 435/2000] [loss_q: 0.014471] [loss_v: 0.026541] [loss_fd:

[Epoch 498/2000] [loss_q: 0.013513] [loss_v: 0.024050] [loss_fd: 0.006329] [total_loss: 0.038829] [time: 602.632206]
[Epoch 499/2000] [loss_q: 0.013244] [loss_v: 0.021722] [loss_fd: 0.006385] [total_loss: 0.036243] [time: 604.394461]
[Epoch 500/2000] [loss_q: 0.013645] [loss_v: 0.022145] [loss_fd: 0.006388] [total_loss: 0.037068] [time: 606.154737]
[Epoch 501/2000] [loss_q: 0.013929] [loss_v: 0.023427] [loss_fd: 0.006914] [total_loss: 0.038739] [time: 607.818900]
[Epoch 502/2000] [loss_q: 0.014031] [loss_v: 0.023356] [loss_fd: 0.006259] [total_loss: 0.038639] [time: 609.656317]
[Epoch 503/2000] [loss_q: 0.013751] [loss_v: 0.023106] [loss_fd: 0.006387] [total_loss: 0.038134] [time: 611.473614]
[Epoch 504/2000] [loss_q: 0.014459] [loss_v: 0.024002] [loss_fd: 0.006504] [total_loss: 0.039762] [time: 613.320925]
[Epoch 505/2000] [loss_q: 0.013827] [loss_v: 0.024680] [loss_fd: 0.006398] [total_loss: 0.039787] [time: 614.947083]
[Epoch 506/2000] [loss_q: 0.013542] [loss_v: 0.023017] [loss_fd:

[Epoch 569/2000] [loss_q: 0.012726] [loss_v: 0.022009] [loss_fd: 0.006325] [total_loss: 0.035999] [time: 721.172899]
[Epoch 570/2000] [loss_q: 0.013056] [loss_v: 0.021791] [loss_fd: 0.006284] [total_loss: 0.036104] [time: 722.914535]
[Epoch 571/2000] [loss_q: 0.013647] [loss_v: 0.022520] [loss_fd: 0.006900] [total_loss: 0.037547] [time: 724.811327]
[Epoch 572/2000] [loss_q: 0.013295] [loss_v: 0.022170] [loss_fd: 0.006803] [total_loss: 0.036826] [time: 726.519544]
[Epoch 573/2000] [loss_q: 0.013198] [loss_v: 0.022136] [loss_fd: 0.006620] [total_loss: 0.036658] [time: 728.185707]
[Epoch 574/2000] [loss_q: 0.013130] [loss_v: 0.022062] [loss_fd: 0.006493] [total_loss: 0.036491] [time: 729.701784]
[Epoch 575/2000] [loss_q: 0.012905] [loss_v: 0.021610] [loss_fd: 0.006501] [total_loss: 0.035815] [time: 731.179523]
[Epoch 576/2000] [loss_q: 0.012910] [loss_v: 0.021392] [loss_fd: 0.006634] [total_loss: 0.035629] [time: 732.598532]
[Epoch 577/2000] [loss_q: 0.013545] [loss_v: 0.022456] [loss_fd:

[Epoch 640/2000] [loss_q: 0.012347] [loss_v: 0.020830] [loss_fd: 0.006701] [total_loss: 0.034517] [time: 792.312705]
[Epoch 641/2000] [loss_q: 0.012517] [loss_v: 0.020070] [loss_fd: 0.006372] [total_loss: 0.033862] [time: 793.091259]
[Epoch 642/2000] [loss_q: 0.012169] [loss_v: 0.020564] [loss_fd: 0.006553] [total_loss: 0.034043] [time: 793.850800]
[Epoch 643/2000] [loss_q: 0.012449] [loss_v: 0.021045] [loss_fd: 0.006611] [total_loss: 0.034816] [time: 794.625599]
[Epoch 644/2000] [loss_q: 0.012852] [loss_v: 0.020446] [loss_fd: 0.006457] [total_loss: 0.034589] [time: 795.416161]
[Epoch 645/2000] [loss_q: 0.012968] [loss_v: 0.021546] [loss_fd: 0.007332] [total_loss: 0.035980] [time: 796.191713]
[Epoch 646/2000] [loss_q: 0.012745] [loss_v: 0.021691] [loss_fd: 0.006566] [total_loss: 0.035749] [time: 797.002290]
[Epoch 647/2000] [loss_q: 0.013396] [loss_v: 0.021751] [loss_fd: 0.007127] [total_loss: 0.036573] [time: 797.741817]
[Epoch 648/2000] [loss_q: 0.012225] [loss_v: 0.020738] [loss_fd:

[Epoch 711/2000] [loss_q: 0.012561] [loss_v: 0.020475] [loss_fd: 0.006784] [total_loss: 0.034393] [time: 847.528361]
[Epoch 712/2000] [loss_q: 0.011882] [loss_v: 0.020343] [loss_fd: 0.007031] [total_loss: 0.033631] [time: 848.370962]
[Epoch 713/2000] [loss_q: 0.011905] [loss_v: 0.019728] [loss_fd: 0.006200] [total_loss: 0.032873] [time: 849.084469]
[Epoch 714/2000] [loss_q: 0.012300] [loss_v: 0.020575] [loss_fd: 0.006471] [total_loss: 0.034169] [time: 849.896331]
[Epoch 715/2000] [loss_q: 0.011993] [loss_v: 0.019204] [loss_fd: 0.006847] [total_loss: 0.032566] [time: 850.637859]
[Epoch 716/2000] [loss_q: 0.012504] [loss_v: 0.020037] [loss_fd: 0.007104] [total_loss: 0.033962] [time: 851.389394]
[Epoch 717/2000] [loss_q: 0.012569] [loss_v: 0.020016] [loss_fd: 0.006695] [total_loss: 0.033924] [time: 852.341070]
[Epoch 718/2000] [loss_q: 0.012603] [loss_v: 0.019238] [loss_fd: 0.006430] [total_loss: 0.033127] [time: 853.437851]
[Epoch 719/2000] [loss_q: 0.012142] [loss_v: 0.020478] [loss_fd:

[Epoch 782/2000] [loss_q: 0.011589] [loss_v: 0.018765] [loss_fd: 0.006233] [total_loss: 0.031601] [time: 966.085850]
[Epoch 783/2000] [loss_q: 0.012323] [loss_v: 0.019706] [loss_fd: 0.007569] [total_loss: 0.033543] [time: 967.993675]
[Epoch 784/2000] [loss_q: 0.011796] [loss_v: 0.019309] [loss_fd: 0.006716] [total_loss: 0.032448] [time: 969.708896]
[Epoch 785/2000] [loss_q: 0.011668] [loss_v: 0.019650] [loss_fd: 0.007141] [total_loss: 0.032747] [time: 971.214968]
[Epoch 786/2000] [loss_q: 0.011755] [loss_v: 0.019203] [loss_fd: 0.006601] [total_loss: 0.032278] [time: 972.504893]
[Epoch 787/2000] [loss_q: 0.012497] [loss_v: 0.020219] [loss_fd: 0.007055] [total_loss: 0.034127] [time: 973.818821]
[Epoch 788/2000] [loss_q: 0.011770] [loss_v: 0.018510] [loss_fd: 0.006643] [total_loss: 0.031609] [time: 974.982649]
[Epoch 789/2000] [loss_q: 0.012275] [loss_v: 0.019489] [loss_fd: 0.006505] [total_loss: 0.033065] [time: 976.032396]
[Epoch 790/2000] [loss_q: 0.011767] [loss_v: 0.018981] [loss_fd:

[Epoch 852/2000] [loss_q: 0.011463] [loss_v: 0.018800] [loss_fd: 0.006793] [total_loss: 0.031622] [time: 1024.829620]
[Epoch 853/2000] [loss_q: 0.010932] [loss_v: 0.018097] [loss_fd: 0.007028] [total_loss: 0.030434] [time: 1025.767287]
[Epoch 854/2000] [loss_q: 0.011690] [loss_v: 0.018790] [loss_fd: 0.007233] [total_loss: 0.031927] [time: 1027.014174]
[Epoch 855/2000] [loss_q: 0.011914] [loss_v: 0.018927] [loss_fd: 0.007127] [total_loss: 0.032267] [time: 1028.465954]
[Epoch 856/2000] [loss_q: 0.011388] [loss_v: 0.017951] [loss_fd: 0.007131] [total_loss: 0.030765] [time: 1030.080101]
[Epoch 857/2000] [loss_q: 0.010986] [loss_v: 0.017955] [loss_fd: 0.006525] [total_loss: 0.030246] [time: 1031.555153]
[Epoch 858/2000] [loss_q: 0.011591] [loss_v: 0.018777] [loss_fd: 0.006853] [total_loss: 0.031739] [time: 1032.875092]
[Epoch 859/2000] [loss_q: 0.011219] [loss_v: 0.018776] [loss_fd: 0.007241] [total_loss: 0.031443] [time: 1034.267082]
[Epoch 860/2000] [loss_q: 0.010930] [loss_v: 0.017363] [

[Epoch 922/2000] [loss_q: 0.011763] [loss_v: 0.019395] [loss_fd: 0.007051] [total_loss: 0.032568] [time: 1139.392675]
[Epoch 923/2000] [loss_q: 0.011012] [loss_v: 0.017462] [loss_fd: 0.006504] [total_loss: 0.029774] [time: 1141.154773]
[Epoch 924/2000] [loss_q: 0.011401] [loss_v: 0.018292] [loss_fd: 0.006759] [total_loss: 0.031045] [time: 1143.015097]
[Epoch 925/2000] [loss_q: 0.011433] [loss_v: 0.017423] [loss_fd: 0.006818] [total_loss: 0.030220] [time: 1144.708300]
[Epoch 926/2000] [loss_q: 0.011925] [loss_v: 0.019214] [loss_fd: 0.007515] [total_loss: 0.032642] [time: 1146.143342]
[Epoch 927/2000] [loss_q: 0.011077] [loss_v: 0.018312] [loss_fd: 0.006908] [total_loss: 0.030771] [time: 1147.854540]
[Epoch 928/2000] [loss_q: 0.010818] [loss_v: 0.017501] [loss_fd: 0.006737] [total_loss: 0.029666] [time: 1149.502710]
[Epoch 929/2000] [loss_q: 0.012253] [loss_v: 0.019395] [loss_fd: 0.007324] [total_loss: 0.033113] [time: 1151.158894]
[Epoch 930/2000] [loss_q: 0.011732] [loss_v: 0.019790] [

[Epoch 992/2000] [loss_q: 0.011357] [loss_v: 0.017767] [loss_fd: 0.007066] [total_loss: 0.030538] [time: 1259.126458]
[Epoch 993/2000] [loss_q: 0.011068] [loss_v: 0.017021] [loss_fd: 0.006428] [total_loss: 0.029376] [time: 1260.965768]
[Epoch 994/2000] [loss_q: 0.010920] [loss_v: 0.017272] [loss_fd: 0.006949] [total_loss: 0.029582] [time: 1262.861117]
[Epoch 995/2000] [loss_q: 0.011108] [loss_v: 0.017705] [loss_fd: 0.007300] [total_loss: 0.030273] [time: 1264.648912]
[Epoch 996/2000] [loss_q: 0.011548] [loss_v: 0.018370] [loss_fd: 0.007172] [total_loss: 0.031353] [time: 1266.518247]
[Epoch 997/2000] [loss_q: 0.010880] [loss_v: 0.017427] [loss_fd: 0.007146] [total_loss: 0.029737] [time: 1268.375566]
[Epoch 998/2000] [loss_q: 0.010699] [loss_v: 0.016141] [loss_fd: 0.006656] [total_loss: 0.028172] [time: 1270.224885]
[Epoch 999/2000] [loss_q: 0.010964] [loss_v: 0.017953] [loss_fd: 0.007230] [total_loss: 0.030362] [time: 1272.027168]
[Epoch 1000/2000] [loss_q: 0.011609] [loss_v: 0.017731] 

[Epoch 1061/2000] [loss_q: 0.010484] [loss_v: 0.017222] [loss_fd: 0.007063] [total_loss: 0.029118] [time: 1371.016470]
[Epoch 1062/2000] [loss_q: 0.010686] [loss_v: 0.016760] [loss_fd: 0.007227] [total_loss: 0.028891] [time: 1371.856065]
[Epoch 1063/2000] [loss_q: 0.010803] [loss_v: 0.017025] [loss_fd: 0.006867] [total_loss: 0.029202] [time: 1372.636620]
[Epoch 1064/2000] [loss_q: 0.010693] [loss_v: 0.016839] [loss_fd: 0.006812] [total_loss: 0.028895] [time: 1373.473215]
[Epoch 1065/2000] [loss_q: 0.011055] [loss_v: 0.017126] [loss_fd: 0.007312] [total_loss: 0.029643] [time: 1374.224751]
[Epoch 1066/2000] [loss_q: 0.011219] [loss_v: 0.017895] [loss_fd: 0.006984] [total_loss: 0.030511] [time: 1375.004375]
[Epoch 1067/2000] [loss_q: 0.010724] [loss_v: 0.016469] [loss_fd: 0.006910] [total_loss: 0.028575] [time: 1375.798940]
[Epoch 1068/2000] [loss_q: 0.011260] [loss_v: 0.018505] [loss_fd: 0.007635] [total_loss: 0.031292] [time: 1376.544471]
[Epoch 1069/2000] [loss_q: 0.011104] [loss_v: 0.

[Epoch 1130/2000] [loss_q: 0.011009] [loss_v: 0.016825] [loss_fd: 0.006802] [total_loss: 0.029194] [time: 1433.654999]
[Epoch 1131/2000] [loss_q: 0.010709] [loss_v: 0.017432] [loss_fd: 0.007297] [total_loss: 0.029600] [time: 1435.096022]
[Epoch 1132/2000] [loss_q: 0.010591] [loss_v: 0.016117] [loss_fd: 0.006970] [total_loss: 0.028102] [time: 1436.506026]
[Epoch 1133/2000] [loss_q: 0.010636] [loss_v: 0.016756] [loss_fd: 0.006886] [total_loss: 0.028769] [time: 1437.913641]
[Epoch 1134/2000] [loss_q: 0.010953] [loss_v: 0.017989] [loss_fd: 0.007649] [total_loss: 0.030472] [time: 1439.503775]
[Epoch 1135/2000] [loss_q: 0.011202] [loss_v: 0.017293] [loss_fd: 0.007495] [total_loss: 0.029993] [time: 1441.056881]
[Epoch 1136/2000] [loss_q: 0.010375] [loss_v: 0.016085] [loss_fd: 0.006541] [total_loss: 0.027768] [time: 1442.586969]
[Epoch 1137/2000] [loss_q: 0.010905] [loss_v: 0.016721] [loss_fd: 0.007407] [total_loss: 0.029107] [time: 1444.272171]
[Epoch 1138/2000] [loss_q: 0.010124] [loss_v: 0.

[Epoch 1199/2000] [loss_q: 0.010919] [loss_v: 0.016974] [loss_fd: 0.007120] [total_loss: 0.029316] [time: 1544.697109]
[Epoch 1200/2000] [loss_q: 0.010791] [loss_v: 0.017060] [loss_fd: 0.007359] [total_loss: 0.029322] [time: 1546.278233]
[Epoch 1201/2000] [loss_q: 0.010553] [loss_v: 0.016456] [loss_fd: 0.006543] [total_loss: 0.028317] [time: 1547.799317]
[Epoch 1202/2000] [loss_q: 0.010858] [loss_v: 0.016783] [loss_fd: 0.007298] [total_loss: 0.029101] [time: 1549.560572]
[Epoch 1203/2000] [loss_q: 0.010924] [loss_v: 0.016581] [loss_fd: 0.007645] [total_loss: 0.029034] [time: 1551.345496]
[Epoch 1204/2000] [loss_q: 0.010415] [loss_v: 0.016461] [loss_fd: 0.007116] [total_loss: 0.028298] [time: 1553.272871]
[Epoch 1205/2000] [loss_q: 0.010821] [loss_v: 0.016162] [loss_fd: 0.006801] [total_loss: 0.028343] [time: 1555.131190]
[Epoch 1206/2000] [loss_q: 0.010530] [loss_v: 0.016452] [loss_fd: 0.007175] [total_loss: 0.028417] [time: 1557.001550]
[Epoch 1207/2000] [loss_q: 0.011130] [loss_v: 0.

[Epoch 1268/2000] [loss_q: 0.010506] [loss_v: 0.016544] [loss_fd: 0.006857] [total_loss: 0.028421] [time: 1661.007895]
[Epoch 1269/2000] [loss_q: 0.010097] [loss_v: 0.015484] [loss_fd: 0.007010] [total_loss: 0.026983] [time: 1662.867224]
[Epoch 1270/2000] [loss_q: 0.010304] [loss_v: 0.015604] [loss_fd: 0.007208] [total_loss: 0.027350] [time: 1664.451347]
[Epoch 1271/2000] [loss_q: 0.010496] [loss_v: 0.016427] [loss_fd: 0.007147] [total_loss: 0.028353] [time: 1665.958548]
[Epoch 1272/2000] [loss_q: 0.010783] [loss_v: 0.016385] [loss_fd: 0.007421] [total_loss: 0.028652] [time: 1667.592712]
[Epoch 1273/2000] [loss_q: 0.010283] [loss_v: 0.016613] [loss_fd: 0.007683] [total_loss: 0.028433] [time: 1669.310936]
[Epoch 1274/2000] [loss_q: 0.010443] [loss_v: 0.016199] [loss_fd: 0.006899] [total_loss: 0.028022] [time: 1671.173266]
[Epoch 1275/2000] [loss_q: 0.010140] [loss_v: 0.015364] [loss_fd: 0.007019] [total_loss: 0.026907] [time: 1673.020578]
[Epoch 1276/2000] [loss_q: 0.010116] [loss_v: 0.

[Epoch 1337/2000] [loss_q: 0.010548] [loss_v: 0.016202] [loss_fd: 0.007004] [total_loss: 0.028151] [time: 1751.594868]
[Epoch 1338/2000] [loss_q: 0.010401] [loss_v: 0.016273] [loss_fd: 0.007290] [total_loss: 0.028132] [time: 1752.330391]
[Epoch 1339/2000] [loss_q: 0.010479] [loss_v: 0.015848] [loss_fd: 0.007813] [total_loss: 0.027889] [time: 1753.099940]
[Epoch 1340/2000] [loss_q: 0.010483] [loss_v: 0.016255] [loss_fd: 0.006953] [total_loss: 0.028128] [time: 1753.867486]
[Epoch 1341/2000] [loss_q: 0.009661] [loss_v: 0.015435] [loss_fd: 0.007094] [total_loss: 0.026514] [time: 1754.633992]
[Epoch 1342/2000] [loss_q: 0.010146] [loss_v: 0.015400] [loss_fd: 0.006819] [total_loss: 0.026909] [time: 1755.395535]
[Epoch 1343/2000] [loss_q: 0.009986] [loss_v: 0.015468] [loss_fd: 0.007147] [total_loss: 0.026883] [time: 1756.196104]
[Epoch 1344/2000] [loss_q: 0.010363] [loss_v: 0.016446] [loss_fd: 0.006909] [total_loss: 0.028191] [time: 1756.937785]
[Epoch 1345/2000] [loss_q: 0.010275] [loss_v: 0.

[Epoch 1406/2000] [loss_q: 0.010289] [loss_v: 0.015590] [loss_fd: 0.007193] [total_loss: 0.027317] [time: 1805.059707]
[Epoch 1407/2000] [loss_q: 0.009988] [loss_v: 0.015436] [loss_fd: 0.006900] [total_loss: 0.026805] [time: 1805.815244]
[Epoch 1408/2000] [loss_q: 0.009999] [loss_v: 0.015317] [loss_fd: 0.006910] [total_loss: 0.026698] [time: 1806.589796]
[Epoch 1409/2000] [loss_q: 0.009723] [loss_v: 0.014575] [loss_fd: 0.007150] [total_loss: 0.025728] [time: 1807.394368]
[Epoch 1410/2000] [loss_q: 0.009524] [loss_v: 0.014834] [loss_fd: 0.006959] [total_loss: 0.025750] [time: 1808.129892]
[Epoch 1411/2000] [loss_q: 0.010486] [loss_v: 0.015739] [loss_fd: 0.007290] [total_loss: 0.027684] [time: 1808.901441]
[Epoch 1412/2000] [loss_q: 0.010375] [loss_v: 0.015681] [loss_fd: 0.007370] [total_loss: 0.027530] [time: 1809.669988]
[Epoch 1413/2000] [loss_q: 0.010264] [loss_v: 0.015436] [loss_fd: 0.007160] [total_loss: 0.027132] [time: 1810.463552]
[Epoch 1414/2000] [loss_q: 0.009597] [loss_v: 0.

[Epoch 1475/2000] [loss_q: 0.010086] [loss_v: 0.014522] [loss_fd: 0.006866] [total_loss: 0.025981] [time: 1915.580724]
[Epoch 1476/2000] [loss_q: 0.009848] [loss_v: 0.015564] [loss_fd: 0.006781] [total_loss: 0.026768] [time: 1917.468068]
[Epoch 1477/2000] [loss_q: 0.009999] [loss_v: 0.015098] [loss_fd: 0.007011] [total_loss: 0.026500] [time: 1919.196005]
[Epoch 1478/2000] [loss_q: 0.009682] [loss_v: 0.015216] [loss_fd: 0.007050] [total_loss: 0.026308] [time: 1920.866193]
[Epoch 1479/2000] [loss_q: 0.010471] [loss_v: 0.015633] [loss_fd: 0.007023] [total_loss: 0.027509] [time: 1922.253184]
[Epoch 1480/2000] [loss_q: 0.010120] [loss_v: 0.014918] [loss_fd: 0.006635] [total_loss: 0.026365] [time: 1923.981411]
[Epoch 1481/2000] [loss_q: 0.009752] [loss_v: 0.014755] [loss_fd: 0.007028] [total_loss: 0.025912] [time: 1925.610569]
[Epoch 1482/2000] [loss_q: 0.009898] [loss_v: 0.014739] [loss_fd: 0.006923] [total_loss: 0.026022] [time: 1927.200702]
[Epoch 1483/2000] [loss_q: 0.009903] [loss_v: 0.

[Epoch 1544/2000] [loss_q: 0.009825] [loss_v: 0.014407] [loss_fd: 0.006885] [total_loss: 0.025609] [time: 2008.439425]
[Epoch 1545/2000] [loss_q: 0.009863] [loss_v: 0.014922] [loss_fd: 0.007451] [total_loss: 0.026275] [time: 2010.032557]
[Epoch 1546/2000] [loss_q: 0.009375] [loss_v: 0.014313] [loss_fd: 0.007164] [total_loss: 0.025120] [time: 2011.331485]
[Epoch 1547/2000] [loss_q: 0.009340] [loss_v: 0.013992] [loss_fd: 0.007091] [total_loss: 0.024749] [time: 2012.655838]
[Epoch 1548/2000] [loss_q: 0.010299] [loss_v: 0.014540] [loss_fd: 0.007305] [total_loss: 0.026300] [time: 2014.127886]
[Epoch 1549/2000] [loss_q: 0.010187] [loss_v: 0.015715] [loss_fd: 0.007526] [total_loss: 0.027406] [time: 2015.726025]
[Epoch 1550/2000] [loss_q: 0.009444] [loss_v: 0.014682] [loss_fd: 0.006976] [total_loss: 0.025521] [time: 2017.258115]
[Epoch 1551/2000] [loss_q: 0.009554] [loss_v: 0.014939] [loss_fd: 0.006920] [total_loss: 0.025877] [time: 2018.818226]
[Epoch 1552/2000] [loss_q: 0.009821] [loss_v: 0.

[Epoch 1613/2000] [loss_q: 0.009791] [loss_v: 0.014536] [loss_fd: 0.007136] [total_loss: 0.025754] [time: 2124.598446]
[Epoch 1614/2000] [loss_q: 0.010297] [loss_v: 0.016008] [loss_fd: 0.007655] [total_loss: 0.027836] [time: 2126.352694]
[Epoch 1615/2000] [loss_q: 0.010019] [loss_v: 0.014699] [loss_fd: 0.007049] [total_loss: 0.026128] [time: 2127.999868]
[Epoch 1616/2000] [loss_q: 0.009609] [loss_v: 0.014815] [loss_fd: 0.007110] [total_loss: 0.025845] [time: 2129.710084]
[Epoch 1617/2000] [loss_q: 0.009899] [loss_v: 0.015067] [loss_fd: 0.007318] [total_loss: 0.026430] [time: 2131.444322]
[Epoch 1618/2000] [loss_q: 0.009648] [loss_v: 0.014002] [loss_fd: 0.007003] [total_loss: 0.025050] [time: 2133.219611]
[Epoch 1619/2000] [loss_q: 0.009813] [loss_v: 0.015121] [loss_fd: 0.007240] [total_loss: 0.026381] [time: 2135.025874]
[Epoch 1620/2000] [loss_q: 0.009768] [loss_v: 0.014679] [loss_fd: 0.007469] [total_loss: 0.025941] [time: 2136.792130]
[Epoch 1621/2000] [loss_q: 0.009763] [loss_v: 0.

[Epoch 1682/2000] [loss_q: 0.009399] [loss_v: 0.013619] [loss_fd: 0.006979] [total_loss: 0.024414] [time: 2241.943086]
[Epoch 1683/2000] [loss_q: 0.009836] [loss_v: 0.014317] [loss_fd: 0.007400] [total_loss: 0.025634] [time: 2243.525215]
[Epoch 1684/2000] [loss_q: 0.009630] [loss_v: 0.014919] [loss_fd: 0.007014] [total_loss: 0.025952] [time: 2245.345510]
[Epoch 1685/2000] [loss_q: 0.009545] [loss_v: 0.014176] [loss_fd: 0.007044] [total_loss: 0.025130] [time: 2247.086747]
[Epoch 1686/2000] [loss_q: 0.010283] [loss_v: 0.014699] [loss_fd: 0.007461] [total_loss: 0.026475] [time: 2248.835994]
[Epoch 1687/2000] [loss_q: 0.009429] [loss_v: 0.013831] [loss_fd: 0.006925] [total_loss: 0.024645] [time: 2250.696320]
[Epoch 1688/2000] [loss_q: 0.009651] [loss_v: 0.014262] [loss_fd: 0.007128] [total_loss: 0.025339] [time: 2252.587664]
[Epoch 1689/2000] [loss_q: 0.009386] [loss_v: 0.014102] [loss_fd: 0.006702] [total_loss: 0.024828] [time: 2254.487015]
[Epoch 1690/2000] [loss_q: 0.009840] [loss_v: 0.

[Epoch 1751/2000] [loss_q: 0.009547] [loss_v: 0.014353] [loss_fd: 0.007308] [total_loss: 0.025361] [time: 2359.702956]
[Epoch 1752/2000] [loss_q: 0.009632] [loss_v: 0.014573] [loss_fd: 0.007554] [total_loss: 0.025716] [time: 2361.355130]
[Epoch 1753/2000] [loss_q: 0.009343] [loss_v: 0.013867] [loss_fd: 0.007414] [total_loss: 0.024693] [time: 2363.104405]
[Epoch 1754/2000] [loss_q: 0.009150] [loss_v: 0.014419] [loss_fd: 0.006881] [total_loss: 0.024946] [time: 2364.797580]
[Epoch 1755/2000] [loss_q: 0.009465] [loss_v: 0.013748] [loss_fd: 0.006934] [total_loss: 0.024600] [time: 2366.523812]
[Epoch 1756/2000] [loss_q: 0.009383] [loss_v: 0.014238] [loss_fd: 0.006778] [total_loss: 0.024976] [time: 2368.479227]
[Epoch 1757/2000] [loss_q: 0.009519] [loss_v: 0.014257] [loss_fd: 0.007375] [total_loss: 0.025251] [time: 2370.335524]
[Epoch 1758/2000] [loss_q: 0.009560] [loss_v: 0.014228] [loss_fd: 0.007141] [total_loss: 0.025216] [time: 2372.003279]
[Epoch 1759/2000] [loss_q: 0.009333] [loss_v: 0.

[Epoch 1820/2000] [loss_q: 0.010081] [loss_v: 0.015026] [loss_fd: 0.007663] [total_loss: 0.026640] [time: 2475.456562]
[Epoch 1821/2000] [loss_q: 0.009248] [loss_v: 0.013960] [loss_fd: 0.007308] [total_loss: 0.024670] [time: 2477.019674]
[Epoch 1822/2000] [loss_q: 0.008940] [loss_v: 0.013028] [loss_fd: 0.006803] [total_loss: 0.023328] [time: 2478.607807]
[Epoch 1823/2000] [loss_q: 0.009363] [loss_v: 0.013808] [loss_fd: 0.006952] [total_loss: 0.024562] [time: 2480.407085]
[Epoch 1824/2000] [loss_q: 0.009507] [loss_v: 0.013991] [loss_fd: 0.007347] [total_loss: 0.024967] [time: 2482.129309]
[Epoch 1825/2000] [loss_q: 0.009479] [loss_v: 0.013640] [loss_fd: 0.006661] [total_loss: 0.024451] [time: 2483.842276]
[Epoch 1826/2000] [loss_q: 0.009791] [loss_v: 0.014301] [loss_fd: 0.007389] [total_loss: 0.025570] [time: 2485.501457]
[Epoch 1827/2000] [loss_q: 0.009267] [loss_v: 0.013107] [loss_fd: 0.006867] [total_loss: 0.023748] [time: 2487.249704]
[Epoch 1828/2000] [loss_q: 0.009555] [loss_v: 0.

[Epoch 1889/2000] [loss_q: 0.009355] [loss_v: 0.013583] [loss_fd: 0.007415] [total_loss: 0.024421] [time: 2582.655919]
[Epoch 1890/2000] [loss_q: 0.009305] [loss_v: 0.013670] [loss_fd: 0.006945] [total_loss: 0.024364] [time: 2583.408454]
[Epoch 1891/2000] [loss_q: 0.009613] [loss_v: 0.014286] [loss_fd: 0.007459] [total_loss: 0.025391] [time: 2584.185007]
[Epoch 1892/2000] [loss_q: 0.008887] [loss_v: 0.013454] [loss_fd: 0.007091] [total_loss: 0.023759] [time: 2584.980574]
[Epoch 1893/2000] [loss_q: 0.009390] [loss_v: 0.013521] [loss_fd: 0.006936] [total_loss: 0.024299] [time: 2585.745118]
[Epoch 1894/2000] [loss_q: 0.009119] [loss_v: 0.013652] [loss_fd: 0.007110] [total_loss: 0.024193] [time: 2586.502657]
[Epoch 1895/2000] [loss_q: 0.009832] [loss_v: 0.014052] [loss_fd: 0.006987] [total_loss: 0.025282] [time: 2587.290218]
[Epoch 1896/2000] [loss_q: 0.009032] [loss_v: 0.013009] [loss_fd: 0.007290] [total_loss: 0.023499] [time: 2588.044032]
[Epoch 1897/2000] [loss_q: 0.009341] [loss_v: 0.

[Epoch 1958/2000] [loss_q: 0.009371] [loss_v: 0.014424] [loss_fd: 0.007325] [total_loss: 0.025260] [time: 2635.932638]
[Epoch 1959/2000] [loss_q: 0.009629] [loss_v: 0.013366] [loss_fd: 0.007073] [total_loss: 0.024409] [time: 2636.706189]
[Epoch 1960/2000] [loss_q: 0.009291] [loss_v: 0.013586] [loss_fd: 0.007223] [total_loss: 0.024322] [time: 2637.471898]
[Epoch 1961/2000] [loss_q: 0.009731] [loss_v: 0.014080] [loss_fd: 0.007520] [total_loss: 0.025315] [time: 2638.243447]
[Epoch 1962/2000] [loss_q: 0.009395] [loss_v: 0.013691] [loss_fd: 0.006758] [total_loss: 0.024438] [time: 2639.012995]
[Epoch 1963/2000] [loss_q: 0.009296] [loss_v: 0.013508] [loss_fd: 0.007022] [total_loss: 0.024208] [time: 2639.768532]
[Epoch 1964/2000] [loss_q: 0.009506] [loss_v: 0.014215] [loss_fd: 0.007196] [total_loss: 0.025160] [time: 2640.555092]
[Epoch 1965/2000] [loss_q: 0.009510] [loss_v: 0.013900] [loss_fd: 0.007185] [total_loss: 0.024846] [time: 2641.289615]
[Epoch 1966/2000] [loss_q: 0.009705] [loss_v: 0.

In [43]:
test_pred_q0 = piml_q(test_input_q, adj_n, adj_correlation_q, adj_knn_q)
test_pred_q = r_norm(test_pred_q0, max_q, min_q)

test_pred_v0 = piml_v(test_input_v, adj_n, adj_correlation_v, adj_knn_v)
test_pred_v = r_norm(test_pred_v0, max_v, min_v)

In [44]:
# 0.2
print(nn_loss(test_pred_q, test_true_q)**0.5)
print(nn_loss(test_pred_v, test_true_v)**0.5)

l1_loss = nn.L1Loss()
print(l1_loss(test_pred_q, test_true_q))
print(l1_loss(test_pred_v, test_true_v))

mape_q = torch.abs(test_pred_q-test_true_q)/test_true_q
mape_q = mape_q[mape_q<1000]

mape_v = torch.abs(test_pred_v-test_true_v)/test_true_v
mape_v = mape_v[mape_v<1000]

print(torch.mean(mape_q))
print(torch.mean(mape_v))

tensor(112.8850, grad_fn=<PowBackward0>)
tensor(5.1594, grad_fn=<PowBackward0>)
tensor(80.5469, grad_fn=<MeanBackward0>)
tensor(3.3322, grad_fn=<MeanBackward0>)
tensor(0.0944, grad_fn=<MeanBackward0>)
tensor(0.0801, grad_fn=<MeanBackward0>)
