In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 12

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3380000, 12), 0)

### Unobserved Node = 0%

In [10]:
data = load_data(path_0)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 0
Node 0
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 7.535285487996442 4.94151428795418 0.07972857724880893 0.8870779903949048 1049.8796529769897
Node 1
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 15.175368021945536 8.169190360498108 0.12214403934403976 0.7725417462955099 480.23684072494507
Node 2
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.005 )
( 11.873133164568015 7.6333650236816615 0.11440546717839482 0.8217802183263296 1207.6277973651886
Node 3
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.01 )
( 9.972937648425004 6.321861782765244 0.10216123663977943 0.8472606499036938 1096.1857397556305
Node 4
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.012 )
( 13.396977204264713 9.756582009798533 0.18633352756839738 0.7792280990164643 1159.6018216609955
Node 5
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.014 )
( 11.659197392334741 6.759114931395188 0.12960810203446668 0.813879899586561 855.9328200817108
Node 6
( (4166

( 12.233275187521699 7.614571607368049 0.17369558918696051 0.7998961052501115 988.0054266452789
Node 52
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 22.89679245224199 16.414631945798007 0.5636329444831562 0.5824677036223607 874.4495935440063
Node 53
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.003 )
( 16.398129516578038 11.11838812578096 0.31767377332034236 0.7289912322878916 1116.2621150016785
Node 54
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 10.700146326791618 5.489664656646951 0.08264263222114725 0.840134012133253 980.1258282661438
Node 55
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 20.425299487482306 10.817095826645454 0.162696066573154 0.6933188581858287 636.0662271976471
Node 56
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 16.63450893603398 9.514300283081951 0.14552051762130827 0.7496023338848479 422.30413150787354
Node 57
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.093 )
( 13.21471140436982 7.50829709

( 11.410612176008945 7.017117667567015 0.16651931480117518 0.8209393995612377 1550.6123003959656
Node 103
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.001 )
( 22.24892769345743 15.615681070544236 0.5777007388181983 0.58864125403271 1047.687560558319
Node 104
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 11.533206534154145 7.0543094236067985 0.133254088692306 0.814381274285661 1035.1422991752625
Node 105
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 15.256645992730478 10.59269857825067 0.20737032429782856 0.7502852299441264 544.96994972229
Node 106
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 22.119166602730594 15.49970433839095 0.5151187546329281 0.5940935463249652 580.3560092449188
Node 107
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 19.312365570150178 12.330736974087307 0.34098194695054757 0.6820503811815299 920.6553754806519
Node 108
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 11.6487357253528 5.22770

( 5.5900288479111815 3.8815478242468155 0.05962720937845815 0.9164610803189939 1753.1441395282745
Node 154
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 5.553312912255799 3.5278626903324315 0.05305060057572159 0.9168031128655896 1664.2000024318695
Node 155
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.005 )
( 5.401135853467993 3.484129693758977 0.05252884368109954 0.9186436275461525 1431.1689729690552
Node 156
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.003 )
( 9.4448579114075 6.068170859105218 0.1197542867324049 0.8521103046572355 1214.518325805664
Node 157
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.005 )
( 15.599377548886721 10.686785923234261 0.2764700917444156 0.7302614180043505 1045.980843782425
Node 158
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 8.02775955892293 6.020966389577564 0.10533891627988339 0.8697540838997242 1425.475775718689
Node 159
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.021 )
( 9.135467339699039 5.3

( 7.919033424092646 4.889693091342884 0.13483672760562318 0.8738558992025336 1823.544225692749
Node 205
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.01 )
( 9.778832658610979 6.371479144644479 0.11056235316617428 0.8466782969177473 944.3165647983551
Node 206
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.043 )
( 6.072597552446899 4.418162255929025 0.06962304329159494 0.9073457187640551 1740.8830437660217
Node 207
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.068 )
( 7.233885626934147 5.043745157347203 0.07922884179284385 0.8907791546383154 1398.900417804718
Node 208
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.025 )
( 10.843142885369415 7.709414487702229 0.11479440505653331 0.8376929100221729 1067.5207080841064
Node 209
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.088 )
( 3.9123250294970853 2.8963371003887186 0.0439574584681695 0.941263114108703 1669.7855505943298
Node 210
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.022 )
( 7.894089746105442 5

( 7.398034562219583 4.417727072208543 0.08701622454371961 0.8853069696357926 1719.0929114818573
Node 256
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.014 )
( 15.126788658880239 10.137050841319402 0.2860585489222007 0.7370809559403296 941.8328058719635
Node 257
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.015 )
( 15.37226472062522 10.368784946517236 0.2080205917535868 0.7466205036722493 718.6878736019135
Node 258
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 14.320590540706121 9.007065364414329 0.24842487444427652 0.7578372517267578 1706.137621641159
Node 259
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.012 )
( 21.20003777415256 14.298551625236273 0.5132485573340286 0.6151208347968509 1779.0230059623718
Node 260
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.07 )
( 12.304582346297948 7.09239311249614 0.18309990765225717 0.803478605876455 1682.9187519550323
Node 261
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 5.390612040649916 3.1

( 4.642463904635918 2.720998791360905 0.04190907958216162 0.9302516054767913 944.3092465400696
Node 307
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.024 )
( 11.584656797237423 8.414594482598707 0.12689888666370933 0.8255635901678074 1029.6882736682892
Node 308
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.031 )
( 11.681486044069755 9.425925149729668 0.14217032570463642 0.8236198486616559 1506.4421346187592
Node 309
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.02 )
( 11.353555359214363 6.4844312598818705 0.13052757465891004 0.822005640308866 1178.324322938919
Node 310
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.006 )
( 19.40695158955879 12.899278968611549 0.36940170712916487 0.6530384629532926 882.0891284942627
Node 311
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.001 )
( 16.285982206610537 11.631682411251093 0.2222777379588883 0.7316831274247371 552.4530897140503
Node 312
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.049 )
( 11.7529201303824

### Unobserved Node = 5%

In [10]:
data = load_data(path_5)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 833856
Node 0
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 7.535285487996442 4.94151428795418 0.07972857724880893 0.8870779903949048 1040.1556732654572
Node 1
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 15.175368021945536 8.169190360498108 0.12214403934403976 0.7725417462955099 476.04922819137573
Node 2
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.005 )
( 11.873133164568015 7.6333650236816615 0.11440546717839482 0.8217802183263296 1191.4446485042572
Node 3
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.01 )
( 9.972937648425004 6.321861782765244 0.10216123663977943 0.8472606499036938 1081.995956659317
Node 4
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.012 )
( 13.396977204264713 9.756582009798533 0.18633352756839738 0.7792280990164643 1145.7217471599579
Node 5
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.014 )
( 11.659197392334741 6.759114931395188 0.12960810203446668 0.813879899586561 847.1058287620544
Node 6
( (

( 12.233275187521699 7.614571607368049 0.17369558918696051 0.7998961052501115 977.6524147987366
Node 52
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 22.89679245224199 16.414631945798007 0.5636329444831562 0.5824677036223607 867.9900524616241
Node 53
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.003 )
( 16.398129516578038 11.11838812578096 0.31767377332034236 0.7289912322878916 1106.741576910019
Node 54
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 10.700146326791618 5.489664656646951 0.08264263222114725 0.840134012133253 973.2931423187256
Node 55
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.60109378312774 66.53918589743576 1.0 0.0 0.19742727279663086
Node 56
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.43236414345883 66.35385817307683 1.0 0.0 0.1983199119567871
Node 57
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 63.9182214393608 63.263151442307766 1.0 0.0 0.19765663146972656
Node 58
( (41668, 12) (41668, 1

( 22.24892769345743 15.615681070544236 0.5777007388181983 0.58864125403271 1037.180728673935
Node 104
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 11.533206534154145 7.0543094236067985 0.133254088692306 0.814381274285661 1024.0406358242035
Node 105
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 15.256645992730478 10.59269857825067 0.20737032429782856 0.7502852299441264 540.0098373889923
Node 106
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 22.119166602730594 15.49970433839095 0.5151187546329281 0.5940935463249652 575.1897082328796
Node 107
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 19.312365570150178 12.330736974087307 0.34098194695054757 0.6820503811815299 911.2209949493408
Node 108
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 11.6487357253528 5.227703851285329 0.0791416896407393 0.8252023652699514 605.7237801551819
Node 109
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.008 )
( 15.491975019975044 8.35408

( 5.553312912255799 3.5278626903324315 0.05305060057572159 0.9168031128655896 1650.8259642124176
Node 155
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.005 )
( 5.401135853467993 3.484129693758977 0.05252884368109954 0.9186436275461525 1423.2157833576202
Node 156
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.003 )
( 9.4448579114075 6.068170859105218 0.1197542867324049 0.8521103046572355 1207.68590259552
Node 157
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.005 )
( 15.599377548886721 10.686785923234261 0.2764700917444156 0.7302614180043505 1038.3217957019806
Node 158
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 8.02775955892293 6.020966389577564 0.10533891627988339 0.8697540838997242 1419.8761224746704
Node 159
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.021 )
( 9.135467339699039 5.33078890025303 0.11514046103219279 0.8495626363345758 1587.0786554813385
Node 160
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.031 )
( 20.47112679668617 14.10

( 9.778832658610979 6.371479144644479 0.11056235316617428 0.8466782969177473 938.7374424934387
Node 206
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 65.54038811205037 65.42214903846177 1.0 0.0 0.19709062576293945
Node 207
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.068 )
( 7.233885626934147 5.043745157347203 0.07922884179284385 0.8907791546383154 1391.8958494663239
Node 208
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.025 )
( 10.843142885369415 7.709414487702229 0.11479440505653331 0.8376929100221729 1062.3691003322601
Node 209
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.60763454054293 66.54714663461515 1.0 0.0 0.19861340522766113
Node 210
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.022 )
( 7.894089746105442 5.588921335854985 0.08324878111699581 0.8818100809003936 1325.278811454773
Node 211
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.128 )
( 9.585108131740334 6.537289102580183 0.11067043309642645 0.8554688937704783 815.6801

( 15.37226472062522 10.368784946517236 0.2080205917535868 0.7466205036722493 715.2309041023254
Node 258
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 14.320590540706121 9.007065364414329 0.24842487444427652 0.7578372517267578 1699.8686504364014
Node 259
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.012 )
( 21.20003777415256 14.298551625236273 0.5132485573340286 0.6151208347968509 1765.8543193340302
Node 260
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.07 )
( 12.304582346297948 7.09239311249614 0.18309990765225717 0.803478605876455 1673.6387751102448
Node 261
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 5.390612040649916 3.1333599095356885 0.048818934340383285 0.9185766704948991 1730.1504600048065
Node 262
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.01 )
( 4.389131236062994 2.723176933526501 0.041490287547515686 0.9341580749325549 1762.3377301692963
Node 263
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.008 )
( 4.372622811229427 2

( 11.681486044069755 9.425925149729668 0.14217032570463642 0.8236198486616559 1498.5886371135712
Node 309
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.02 )
( 11.353555359214363 6.4844312598818705 0.13052757465891004 0.822005640308866 1172.7907490730286
Node 310
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.006 )
( 19.40695158955879 12.899278968611549 0.36940170712916487 0.6530384629532926 877.4679009914398
Node 311
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.001 )
( 16.285982206610537 11.631682411251093 0.2222777379588883 0.7316831274247371 549.6761474609375
Node 312
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.049 )
( 11.752920130382412 7.846965300639492 0.17702963505426286 0.8051184729080659 1167.382034778595
Node 313
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.016 )
( 20.071867327188126 14.054309884490232 0.5181334880122508 0.6294926325717182 1464.1920711994171
Node 314
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.038 )
( 15.709350787225