In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 9

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3380975, 9), 0)

### Unobserved Node = 0%

In [10]:
data = load_data(path_0)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 0
Node 0
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.004 )
( 7.814306200976785 5.111519658182142 0.08250572824942264 0.882809462539383 794.8200767040253
Node 1
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.002 )
( 15.295689628449892 8.220059352996374 0.12285702741103648 0.770815042278983 320.01222467422485
Node 2
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.005 )
( 12.295152975348282 7.832662719585675 0.11750723495554788 0.8153624511169215 885.1048891544342
Node 3
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.01 )
( 10.016385561397366 6.326934081303493 0.10077496383134915 0.8472090783984898 792.5848526954651
Node 4
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.012 )
( 13.55576734474368 9.877784128326162 0.18886363637640805 0.7764260254257975 844.6009142398834
Node 5
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.014 )
( 11.902737856993193 6.862106196240442 0.1304862435216088 0.809968493697092 611.8510935306549
Node 6
( (41671, 12) (41671, 9) (

( 12.131852956986977 7.5671494227374865 0.16810095298084282 0.8021206898842776 731.6421685218811
Node 52
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 22.773949531623636 16.29898130461548 0.557160927859338 0.58558415019164 626.0483646392822
Node 53
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.003 )
( 16.851310577523883 11.480094823556513 0.33315295214648905 0.7196810655945496 799.1233322620392
Node 54
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.0 )
( 10.991522560132232 5.627920164355936 0.08472090412837054 0.8357238727036508 714.2051672935486
Node 55
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.001 )
( 20.512734005820842 10.855977571522606 0.16335116338916927 0.6920969184192812 446.35442662239075
Node 56
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.001 )
( 16.680516729367696 9.558430383925144 0.14643398243633543 0.7487737558100549 294.3913428783417
Node 57
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.093 )
( 12.982226926468119 7.321011395921469

( 22.37902216722177 15.724446339824892 0.5835288445616835 0.5854345904303639 755.6675279140472
Node 104
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.002 )
( 11.779204805030318 7.22386210423957 0.13620474471391722 0.81029776587198 786.2192583084106
Node 105
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 15.280522948199483 10.58753467591291 0.20530352901627472 0.7502665902553655 392.8503439426422
Node 106
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 22.088678436069216 15.473654186881769 0.5117991733496808 0.5949212365511873 416.71016025543213
Node 107
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.004 )
( 19.61291131790929 12.587099553563288 0.3506720204543196 0.676143363931498 657.7750070095062
Node 108
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.001 )
( 12.000907925402784 5.358341666175927 0.08111328821136404 0.8198696936116545 436.02029275894165
Node 109
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.008 )
( 15.629902909069258 8.42857359810956

( 5.615910052115901 3.5562352971400375 0.053523535977471975 0.9158628152309454 1329.9959335327148
Node 155
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.004 )
( 5.52303375744057 3.540506068743235 0.05342643225884444 0.9167923729183631 1083.3294956684113
Node 156
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.003 )
( 9.402700138810417 6.060421724970611 0.11804837326874495 0.8529201884197319 924.0244860649109
Node 157
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.005 )
( 15.789584143250183 10.860231380168932 0.2812342412219346 0.7265974417560634 794.6635730266571
Node 158
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.002 )
( 8.041681117661787 6.061667295360879 0.10495289210933287 0.869585453856143 1072.9860298633575
Node 159
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.021 )
( 9.2272719172121 5.390537501045563 0.11539553939823342 0.8480886403364258 1231.1471643447876
Node 160
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.031 )
( 20.52579294750548 14.187009270032

( 9.86041652280059 6.413566532320498 0.1107441247147737 0.8454540197798317 699.3524284362793
Node 206
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.043 )
( 6.170770134739376 4.484026219075279 0.07071332198403144 0.9058594219031052 1313.8835015296936
Node 207
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.069 )
( 7.322202717252811 5.115526466787195 0.08019893220108844 0.8894670438107446 1064.7329699993134
Node 208
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.025 )
( 11.024086578272545 7.791129413837457 0.11599197489360918 0.8349933060893293 803.3950107097626
Node 209
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.088 )
( 3.931528252546257 2.9071854681293314 0.04413537740632445 0.9409732223704353 1281.4695258140564
Node 210
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.022 )
( 7.976904392754421 5.621465980919402 0.08372364331431204 0.8805737736632941 992.0274569988251
Node 211
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.128 )
( 9.695258335794954 6.614744202507

( 15.272844852445917 10.282611675714492 0.291806756444603 0.7339631651002273 682.8948276042938
Node 257
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.015 )
( 15.430543443188343 10.316059453872366 0.20162722055380689 0.7462314649170338 530.0655016899109
Node 258
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 14.613394793250066 9.19816810730293 0.25695503669349384 0.7522545404911517 1279.5161662101746
Node 259
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.012 )
( 21.281273765734866 14.370127730197076 0.517155450356839 0.6132453453849763 1353.842830657959
Node 260
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.07 )
( 11.741112793842163 6.797968609758531 0.1687560483350696 0.8134825461774947 1275.5567781925201
Node 261
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.004 )
( 5.521364446147197 3.179939439225951 0.04946732751104503 0.916608992174059 1303.9647660255432
Node 262
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.01 )
( 4.487057440551586 2.784432150178535

( 11.594646355687669 8.407304725623455 0.12690965406691848 0.8253259399815476 745.8499829769135
Node 308
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.031 )
( 11.774986968395257 9.469880380663442 0.14270252302244632 0.8222829882260899 1146.3626902103424
Node 309
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.02 )
( 11.933836751424032 6.887800969848156 0.14132596535160066 0.8118778255123376 776.3558351993561
Node 310
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.006 )
( 19.624069930273244 13.041298941875068 0.37579821965702126 0.6486091383246992 656.9444427490234
Node 311
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.0 )
( 16.05639777278759 11.454754548088074 0.21057347120364958 0.7364810453256594 381.8490753173828
Node 312
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.048 )
( 12.290631065491176 8.178734042286216 0.18853454129762948 0.7954994278157946 826.0616612434387
Node 313
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.016 )
( 20.39669297653457 14.379119504

### Unobserved Node = 5%

In [11]:
data = load_data(path_5)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 833856
Node 0
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.004 )
( 7.814306200976785 5.111519658182142 0.08250572824942264 0.882809462539383 789.037734746933
Node 1
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.002 )
( 15.295689628449892 8.220059352996374 0.12285702741103648 0.770815042278983 317.74370861053467
Node 2
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.005 )
( 12.295152975348282 7.832662719585675 0.11750723495554788 0.8153624511169215 878.1013264656067
Node 3
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.01 )
( 10.016385561397366 6.326934081303493 0.10077496383134915 0.8472090783984898 787.3328659534454
Node 4
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.012 )
( 13.55576734474368 9.877784128326162 0.18886363637640805 0.7764260254257975 839.4999792575836
Node 5
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.014 )
( 11.902737856993193 6.862106196240442 0.1304862435216088 0.809968493697092 607.537840127945
Node 6
( (41671, 12) (41671, 9

( 22.773949531623636 16.29898130461548 0.557160927859338 0.58558415019164 623.849205493927
Node 53
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.003 )
( 16.851310577523883 11.480094823556513 0.33315295214648905 0.7196810655945496 796.4414966106415
Node 54
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.0 )
( 10.991522560132232 5.627920164355936 0.08472090412837054 0.8357238727036508 711.9483897686005
Node 55
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.0 )
( 66.62074929718848 66.55998269729872 1.0 0.0 0.1798851490020752
Node 56
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.0 )
( 66.39639414724543 66.31816997233695 1.0 0.0 0.18001747131347656
Node 57
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.0 )
( 64.27975448207522 63.70782573402981 1.0 0.0 0.1802501678466797
Node 58
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.006 )
( 20.112497951752975 13.882705138067024 0.4248432520890316 0.6385799124815106 693.6455521583557
Node 59
( (41671, 12) (41671, 9) (10403, 12) (

( 15.280522948199483 10.58753467591291 0.20530352901627472 0.7502665902553655 391.1498312950134
Node 106
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 22.088678436069216 15.473654186881769 0.5117991733496808 0.5949212365511873 415.4927098751068
Node 107
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.004 )
( 19.61291131790929 12.587099553563288 0.3506720204543196 0.676143363931498 655.1954226493835
Node 108
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.001 )
( 12.000907925402784 5.358341666175927 0.08111328821136404 0.8198696936116545 434.02263498306274
Node 109
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.008 )
( 15.629902909069258 8.42857359810956 0.12814903595146435 0.7650852039410805 441.31473875045776
Node 110
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.004 )
( 5.485103438705676 3.3156699422296465 0.051656638376307844 0.9174476309752431 954.3974974155426
Node 111
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.002 )
( 10.336783464179774 5.8455709012

( 15.789584143250183 10.860231380168932 0.2812342412219346 0.7265974417560634 792.037983417511
Node 158
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.002 )
( 8.041681117661787 6.061667295360879 0.10495289210933287 0.869585453856143 1069.2329049110413
Node 159
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.021 )
( 9.2272719172121 5.390537501045563 0.11539553939823342 0.8480886403364258 1227.1479377746582
Node 160
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.031 )
( 20.52579294750548 14.187009270032341 0.5085330771009199 0.6215076859667511 858.5969729423523
Node 161
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.012 )
( 13.121936282938256 7.851406348640085 0.18734276965237928 0.7903925253313748 576.6791157722473
Node 162
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 13.184919499226769 8.568744871221817 0.12777608011103092 0.8022796589831926 594.877459526062
Node 163
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.016 )
( 10.990264918261548 7.05895596932619

( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.022 )
( 7.976904392754421 5.621465980919402 0.08372364331431204 0.8805737736632941 991.8608043193817
Node 211
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.128 )
( 9.695258335794954 6.614744202507399 0.11243098716993132 0.8537725186878729 593.6816217899323
Node 212
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.043 )
( 6.961094110174707 3.8374114552705234 0.10818005394606472 0.8914978502785337 1325.159820318222
Node 213
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.0 )
( 63.65367378791492 63.34815064030686 1.0 0.0 0.17378854751586914
Node 214
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.007 )
( 10.901903883517337 7.007652254086221 0.13855351500326826 0.8303172613850679 802.3995184898376
Node 215
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.015 )
( 6.173028399303356 4.49823641120145 0.07328216837048183 0.9051011287781581 1371.8816668987274
Node 216
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.0 )
( 6.904721812

( 4.487057440551586 2.7844321501785356 0.042420051804002795 0.932686307049501 1346.7648873329163
Node 263
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.008 )
( 4.5052359403301105 3.050460070646596 0.04645214915437516 0.9322273300170942 1445.1352798938751
Node 264
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.028 )
( 6.579212046620145 3.9546687246266305 0.05991045704877402 0.9011882605667704 1249.4948041439056
Node 265
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.045 )
( 5.933138651764157 3.728963033095667 0.06796845419257805 0.9089101008244921 1149.2588045597076
Node 266
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.056 )
( 7.925891008745617 4.8602734985051 0.11427558650945718 0.8744665529931512 1165.930805683136
Node 267
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.004 )
( 8.629562054851261 5.539378531545254 0.11285765105983228 0.8616680789799664 1043.8018014431
Node 268
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.029 )
( 8.460228940419295 5.8766381673361

( 14.922940396213125 8.558239851325927 0.26085946427772594 0.7549368292753166 872.8759198188782
Node 315
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.025 )
( 5.172516162222698 3.366392848984281 0.05082444995194999 0.9224745641931381 1072.6851844787598
Node 316
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.019 )
( 6.770397202117159 3.8338705151928867 0.05845280937035202 0.8981569586062063 891.7041869163513
Node 317
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.02 )
( 9.092694010568218 4.8083019937033775 0.07253050435343361 0.8630651204516189 827.2708342075348
Node 318
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) 0.003 )
( 7.554529490893114 5.823727301002854 0.09478084308185264 0.8832820224074031 1244.6964721679688
Node 319
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.015 )
( 11.92267377339927 8.238193821066988 0.18510837731172128 0.8007266880305427 1412.802345752716
Node 320
( (41671, 12) (41671, 9) (10403, 12) (10403, 9) -0.008 )
( 12.745964678749093 7.9392353508