In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 9

In [8]:
path_0 = 'data/METR-LA/speed_la_0.csv'
path_5 = 'data/METR-LA/speed_la_5.csv'
path_10 = 'data/METR-LA/speed_la_10.csv'
path_20 = 'data/METR-LA/speed_la_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(34272, 207) 0


((1414638, 9), 0)

### Unobserved Node = 10%

In [10]:
data = load_data(path_10)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 685440
Node 0
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.011 )
( 22.28368250403314 14.597110158816143 0.5820142233825448 0.6054693616726359 379.5890142917633
Node 1
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.008 )
( 16.976743086936295 10.079912378170588 0.391276893227374 0.7062362346719975 746.1290993690491
Node 2
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.016 )
( 13.930583439245247 8.960767607290757 0.20096199576260568 0.7704828902354471 332.2337181568146
Node 3
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.012 )
( 17.547791931647417 13.748898794197594 0.39897042227071194 0.6882157355581615 668.842931509018
Node 4
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.037 )
( 24.444096381386707 19.914675891741954 0.6422564165843314 0.545333825228411 438.62084126472473
Node 5
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.037 )
( 14.634971978299781 11.816424626462656 0.19567865258076617 0.7690577768958944 755.7001128196716
Node 6
( (27396, 12) (27396, 9) (6834

( 14.157394020055285 9.480528661543888 0.17175080771497436 0.7796842830154891 510.40879106521606
Node 55
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 11.470560662888692 7.763380463281269 0.18940205247191894 0.8163038639423142 838.4772100448608
Node 56
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.156 )
( 23.663383215805354 18.346713404086863 0.3142957119493021 0.6272470831779207 605.5696015357971
Node 57
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.005 )
( 14.48683822549872 10.655196918316868 0.18303423115527803 0.7733983618042352 398.19836926460266
Node 58
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.011 )
( 9.48487355625632 6.587551051833227 0.13874788617937703 0.8520334154337466 611.0304660797119
Node 59
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 62.80562091524117 62.489247925999166 1.0 0.0 0.1285264492034912
Node 60
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.089 )
( 20.30742935032663 15.5323406991654 0.25411508052893583 0.6698089985962948 409.56951

( 24.298076523466158 16.28445324016339 0.6704331808800041 0.5594507785146222 714.4980819225311
Node 108
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.006 )
( 23.51268457150478 15.804213373584288 0.6095413674743098 0.5772865957309489 544.3337705135345
Node 109
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.017 )
( 15.43363618234204 11.269660416125822 0.32490737612712045 0.7418868480678147 747.622635602951
Node 110
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.012 )
( 13.169901452244986 7.380210371490102 0.21783592307917307 0.7887289790779128 593.3752844333649
Node 111
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.004 )
( 17.139263381986375 9.914540544534326 0.25260631540401185 0.725508886998403 353.018607378006
Node 112
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.004 )
( 16.918155001385582 12.610562197005628 0.2738303903430061 0.7290430922339062 771.8624300956726
Node 113
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.006 )
( 17.185665440668895 11.390587267154238 0.260161

( 29.74067535251644 23.451907514229344 0.7566809644559718 0.410817526223474 256.6333336830139
Node 161
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.001 )
( 25.106600888127335 17.657841069326473 0.777875019569251 0.5295986871056735 384.01590943336487
Node 162
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.016 )
( 18.007521232300824 10.84384621739975 0.2832385965550303 0.7054273697610222 294.86001324653625
Node 163
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.036 )
( 29.60874887529984 21.84200039310556 0.42142636896838986 0.5198671658917492 194.8952488899231
Node 164
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 62.13713863938833 61.57026786391734 1.0 0.0 0.09757328033447266
Node 165
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.018 )
( 19.652162432503086 14.093595070502593 0.25382844175337627 0.6819773786916967 641.102874994278
Node 166
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.004 )
( 12.086101391271349 7.625096304659426 0.1875874657853941 0.8091620009073699 323.

### Unobserved Node = 20%

In [11]:
data = load_data(path_20)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 1405152
Node 0
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 56.48150064721487 53.85806789916998 1.0 0.0 0.09910035133361816
Node 1
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 57.79045985464155 56.000012064896 1.0 0.0 0.09660816192626953
Node 2
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 60.69518500618873 59.96399825362247 1.0 0.0 0.0955965518951416
Node 3
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.012 )
( 17.547791931647417 13.748898794197594 0.39897042227071194 0.6882157355581615 651.4965577125549
Node 4
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.037 )
( 24.444096381386707 19.914675891741954 0.6422564165843314 0.545333825228411 427.29720878601074
Node 5
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.037 )
( 14.634971978299781 11.816424626462656 0.19567865258076617 0.7690577768958944 728.5655670166016
Node 6
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.029 )
( 13.618476337341681 9.110149702023127 0.15072258674538783 0.782581980437

( 12.576385159140887 11.19469507738926 0.1877628001026483 0.8079767177179606 837.6375846862793
Node 54
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.017 )
( 14.157394020055285 9.480528661543888 0.17175080771497436 0.7796842830154891 487.1014382839203
Node 55
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 11.470560662888692 7.763380463281269 0.18940205247191894 0.8163038639423142 806.128529548645
Node 56
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.156 )
( 23.663383215805354 18.346713404086863 0.3142957119493021 0.6272470831779207 579.7649433612823
Node 57
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.005 )
( 14.48683822549872 10.655196918316868 0.18303423115527803 0.7733983618042352 381.97424268722534
Node 58
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 64.10145631231589 63.67385220902444 1.0 0.0 0.1242516040802002
Node 59
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 62.80562091524117 62.489247925999166 1.0 0.0 0.12384343147277832
Node 60
( (27396, 12) 

( 15.43363618234204 11.269660416125822 0.32490737612712045 0.7418868480678147 709.5838220119476
Node 110
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 62.33652582718257 61.2390578167879 1.0 0.0 0.09697175025939941
Node 111
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 62.44013948053227 61.33683783076512 1.0 0.0 0.09599494934082031
Node 112
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.004 )
( 16.918155001385582 12.610562197005628 0.2738303903430061 0.7290430922339062 732.757128238678
Node 113
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.006 )
( 17.185665440668895 11.390587267154238 0.26016170011057205 0.7247545170922135 508.9685254096985
Node 114
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.002 )
( 13.580143807316556 7.714208470597096 0.2232417890666241 0.7825018335796669 566.3717167377472
Node 115
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.012 )
( 18.787758538861464 11.243228602060793 0.33601827853436417 0.6908355791704839 364.37352895736694
Node 116
( (2

( 16.73200952245106 11.184234177225767 0.2107396691928787 0.7307244928100898 412.9210512638092
Node 165
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 61.79485708172788 61.24728395964985 1.0 0.0 0.09609651565551758
Node 166
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.004 )
( 12.086101391271349 7.625096304659426 0.1875874657853941 0.8091620009073699 315.9324221611023
Node 167
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) -0.009 )
( 16.138388098773095 9.286849527933919 0.3062313162341496 0.7326112626138118 550.1393494606018
Node 168
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.033 )
( 19.82283000389491 12.794908514407062 0.4491732703924209 0.6564871205731952 477.53509068489075
Node 169
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.0 )
( 56.952492392208086 54.3420664728026 1.0 0.0 0.12464475631713867
Node 170
( (27396, 12) (27396, 9) (6834, 12) (6834, 9) 0.016 )
( 15.972641656773881 11.841094419481482 0.283552236246205 0.7411213590677384 607.3851444721222
Node 171
( (27396