In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 6

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3381950, 6), 0)

### Unobserved Node = 10%

In [10]:
data = load_data(path_10)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 1667712
Node 0
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.63207193208224 66.52535556409742 1.0 0.0 0.1543281078338623
Node 1
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 15.386915420790967 8.257516004923767 0.12339640776819694 0.7695197566260219 189.10903191566467
Node 2
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.005 )
( 12.694174127897817 8.017809480522788 0.12038009121347648 0.8092914451450123 585.9748809337616
Node 3
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.01 )
( 10.038969832505513 6.295291906116039 0.09908497132413023 0.8474078646855691 506.7389063835144
Node 4
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.012 )
( 13.693137615730508 9.967945096244929 0.19040078586234432 0.7741615834955231 528.4215903282166
Node 5
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 62.60353764409611 62.25651547184316 1.0 0.0 0.16348981857299805
Node 6
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.002 )
( 15.892748381612224 10.

( 22.595928389886947 16.161531541897094 0.5461707917887713 0.5900690672428718 395.55007576942444
Node 53
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 59.68026610482089 57.798026779422095 1.0 0.0 0.1599407196044922
Node 54
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.0 )
( 11.288936880994623 5.760531031872838 0.08677595571419296 0.8312154856296929 462.0718455314636
Node 55
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.64387826489718 66.5839467614836 1.0 0.0 0.16057038307189941
Node 56
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 16.74489856225525 9.606151086794762 0.1472332625981359 0.7477324515205368 176.17015743255615
Node 57
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 64.57462363363686 64.06846050355578 1.0 0.0 0.16107940673828125
Node 58
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 55.5908495928949 53.60052213466591 1.0 0.0 0.16082215309143066
Node 59
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 61.5189363

( 15.1973330948796 10.513572791570622 0.20155719300236594 0.7519536855439471 244.94236969947815
Node 106
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 22.042288865184073 15.444393885919672 0.5075391154291871 0.5961588347348405 268.50003266334534
Node 107
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 19.904455081045544 12.832200086218384 0.3608899719689012 0.6703187758874696 407.19856882095337
Node 108
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.001 )
( 12.254341500468033 5.460658376249256 0.08270741699121388 0.8160126445678064 283.15310621261597
Node 109
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 15.740319269233456 8.492182354826472 0.12904272720485013 0.7634488793631263 284.5075464248657
Node 110
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.004 )
( 5.771301071448847 3.437749292508844 0.053543414802190804 0.9131135037964414 649.0102505683899
Node 111
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 10.002629383584884 5.6758216

( 8.283148839645149 6.2750309994621665 0.10828313456236728 0.8656822017266775 723.3806173801422
Node 159
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.021 )
( 9.37652835078859 5.441159039126315 0.1161256763765725 0.8456404941783817 884.4421994686127
Node 160
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 54.21395199109422 51.75263629957059 1.0 0.0 0.16356444358825684
Node 161
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.012 )
( 13.16773789052588 7.852817064729462 0.18725926610758867 0.7897071751897303 352.7804012298584
Node 162
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 13.269817560640686 8.610192730154578 0.1284554277823218 0.8010058049815829 365.07022047042847
Node 163
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.016 )
( 11.025301135740056 7.081982546536799 0.10664385141082981 0.8343230354055452 448.59383893013
Node 164
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 3.6634156212684834 2.669695719888496 0.04041904421751607 0.9449698

( 8.061349586487125 5.660504024080734 0.08431698215699236 0.8793123069335083 666.0669643878937
Node 211
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.128 )
( 9.788026750982207 6.678988493400355 0.11396619924358813 0.8523333397258336 387.11121439933777
Node 212
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.043 )
( 6.974157287759394 3.853025729665834 0.10860275359926784 0.8912458349384884 927.9731333255768
Node 213
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.124 )
( 9.81764208165501 5.461724458883218 0.11106758581119597 0.8457809819978626 442.8749852180481
Node 214
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.007 )
( 11.010058888825967 7.078123402744662 0.13992000814085 0.8286570473301572 525.8350002765656
Node 215
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.015 )
( 6.259008644976315 4.553997337704355 0.07400876546060296 0.9037904196596369 983.0824239253998
Node 216
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.0 )
( 7.152501028507699 3.902618289772519 0.059

( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 4.678623731183447 3.1499808894163217 0.04795764327740809 0.9296261748114614 1016.475307226181
Node 264
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.5687571875869 66.49911108975625 1.0 0.0 0.16283917427062988
Node 265
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.045 )
( 6.130025323849087 3.833816095621772 0.070700281590748 0.9057852595824104 786.7149002552032
Node 266
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.056 )
( 7.95640596488521 4.891793082868614 0.11476769477962974 0.8739094738962465 814.1798374652863
Node 267
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.004 )
( 8.598804024483778 5.530852931999078 0.11181754991231214 0.8621956339102919 693.1719965934753
Node 268
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.028 )
( 8.615481240871528 5.960181235521529 0.09838687645472971 0.8648835732882176 852.8707191944122
Node 269
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.018 )
( 6.89073276269

( 6.972855834075799 3.902139919246878 0.05951688222378587 0.8950749293639721 595.0768702030182
Node 317
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.02 )
( 9.331676635060239 4.919781005358642 0.07411904444909058 0.8595630429685066 543.0560097694397
Node 318
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 7.682818911475401 5.834998833042398 0.09620182459457227 0.8807837780113177 864.7387826442719
Node 319
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.015 )
( 12.04979670231141 8.314755547133055 0.18856957726470688 0.7984247374912765 977.8362462520599
Node 320
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.008 )
( 12.793924053873925 7.951906481348826 0.14297447670869026 0.7917338750687725 572.3315849304199
Node 321
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.052 )
( 16.23175551441262 11.187848966691876 0.29835538068562356 0.7235153491282524 677.1919672489166
Node 322
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 18.5071992495519 12.20103781579429

### Unobserved Node = 20%

In [11]:
data = load_data(path_20)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 3387540
Node 0
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 8.086414310949818 5.28257205022424 0.08526686387938014 0.8786408095009826 536.34521317482
Node 1
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 15.386915420790967 8.257516004923767 0.12339640776819694 0.7695197566260219 187.53626012802124
Node 2
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.005 )
( 12.694174127897817 8.017809480522788 0.12038009121347648 0.8092914451450123 582.3085534572601
Node 3
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.01 )
( 10.038969832505513 6.295291906116039 0.09908497132413023 0.8474078646855691 502.75917172431946
Node 4
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.012 )
( 13.693137615730508 9.967945096244929 0.19040078586234432 0.7741615834955231 524.3542704582214
Node 5
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.014 )
( 12.143914196981628 6.976166856786729 0.13184655044104865 0.8060187226795344 387.6297969818115
Node 6
( (41674, 12) (41

( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.001 )
( 20.63356385195718 10.89826775573849 0.1638193016263615 0.6903907097071611 258.6487925052643
Node 56
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 16.74489856225525 9.606151086794762 0.1472332625981359 0.7477324515205368 175.33852577209473
Node 57
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.093 )
( 12.855131004353275 7.1819923406193915 0.14016773654615966 0.800925963157189 780.2973258495331
Node 58
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 55.5908495928949 53.60052213466591 1.0 0.0 0.1603703498840332
Node 59
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 14.904386317428946 9.390788430038095 0.18240288624556697 0.7577268530562253 342.48035883903503
Node 60
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.003 )
( 14.964971236397792 9.122115316275215 0.18774788667986372 0.755359060693184 288.56227445602417
Node 61
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 22.75900338571

( 15.740319269233456 8.492182354826472 0.12904272720485013 0.7634488793631263 283.1603002548218
Node 110
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.42345270694068 66.35886667947973 1.0 0.0 0.15805792808532715
Node 111
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 10.002629383584884 5.67582162233148 0.13577659548530632 0.8441396569494025 561.8855566978455
Node 112
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.026 )
( 24.132255797832475 17.409101064991546 0.6528753619207092 0.551309985465444 318.9740755558014
Node 113
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.005 )
( 16.53019690153926 9.708042766318123 0.20792619103743504 0.7292490559041538 268.7774221897125
Node 114
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 16.419172510256384 10.015918545517938 0.19984465571528873 0.731629973802307 204.9697539806366
Node 115
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.171 )
( 20.24116224020527 14.929446951249924 0.47601715088739777 0.6257935

( 11.025301135740056 7.081982546536799 0.10664385141082981 0.8343230354055452 445.9825689792633
Node 164
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 3.6634156212684834 2.669695719888496 0.04041904421751607 0.9449698287557698 928.0844831466675
Node 165
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.107 )
( 14.520093629742192 8.99132825457294 0.182494987828428 0.7726974984845967 199.9276623725891
Node 166
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.057 )
( 22.66214697424756 15.785176946464683 0.6373594968709496 0.5796885726147076 620.3756096363068
Node 167
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.002 )
( 9.72982910827756 6.168620556122424 0.1406396455710802 0.8406173564222944 596.6294164657593
Node 168
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.007 )
( 10.535560056387304 7.279036766917135 0.15067815166089576 0.8280385969185284 742.1118357181549
Node 169
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.005 )
( 21.08158430311832 14.83255761749165

( 7.152501028507699 3.902618289772519 0.059322850332536224 0.8917683353591164 492.4467873573303
Node 217
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.026 )
( 14.718100134922626 7.265738623550148 0.10963963673322698 0.778712792921215 269.1440050601959
Node 218
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.035 )
( 9.19958371674162 5.581716142108277 0.08386844087608537 0.8612573970617212 594.7463085651398
Node 219
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.022 )
( 13.76941251422343 8.207111945308354 0.16714668546343603 0.7809024883579885 410.0350513458252
Node 220
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.013 )
( 21.237274594305575 15.247496895544908 0.5036063827437838 0.6086124300726923 381.3032751083374
Node 221
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 9.606380201354808 6.065980203374429 0.13066244437472965 0.8406982194001034 585.7994821071625
Node 222
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.078 )
( 10.90587081480784 6.771342723943559 0

( 6.864942864098548 4.182965454019274 0.06425450605443085 0.8964787226653934 643.2045946121216
Node 271
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.002 )
( 6.407476527579649 3.714602871642207 0.056325101939725425 0.9039358350142525 626.154224395752
Node 272
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.103 )
( 5.633115777430339 4.097140167874602 0.061343110369116514 0.9152320046821907 839.8799369335175
Node 273
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 5.4727205188694645 4.312083647404257 0.06443040745839924 0.9178516506533037 921.7999713420868
Node 274
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.012 )
( 14.377248061268729 8.147008970091333 0.12659431105638255 0.7823372638530466 448.2757124900818
Node 275
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.02 )
( 15.193160992810705 9.33916832492688 0.15867667870403515 0.7624777458869894 516.0566444396973
Node 276
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.004 )
( 14.45155929943536 8.32659450803197

( 18.5071992495519 12.201037815794294 0.4010912202129811 0.6741097569391095 628.9304957389832
Node 323
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.009 )
( 10.388755124213388 6.369456103595202 0.12410144014788103 0.8386938142299661 561.6927745342255
Node 324
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.56057410589176 66.46232942533165 1.0 0.0 0.1608114242553711
RMSE: 22.941876198257024, MAE: 19.094211959596255, MAPE: 0.3495532391521493, ACC: 0.6340926426033052
