In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 3

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3382925, 3), 0)

### Unobserved Node = 0%

In [10]:
data = load_data(path_0)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 0
Node 0
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 8.328685226158738 5.429875924958618 0.08769796202071951 0.8749219004125711 274.6234927177429
Node 1
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 15.446384759071634 8.290043235390383 0.12390448963361052 0.768713163875233 75.76295113563538
Node 2
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.005 )
( 12.967767888338432 8.15823505548368 0.12257644814501197 0.8051230490893208 288.5596241950989
Node 3
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.009 )
( 10.144471452797013 6.2981759283069705 0.09828100866950229 0.8462816461074736 229.86034655570984
Node 4
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.012 )
( 13.786456414932635 10.063314361403734 0.1915501973957745 0.7727353266499768 243.29205346107483
Node 5
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.014 )
( 12.276460863316059 7.065478304742019 0.13273349172809903 0.8037914762684333 171.91705560684204
Node 6
( (41677, 12) (4167

( 11.676524566142172 7.292549534551501 0.15380481463659432 0.8105912308431764 222.83732199668884
Node 52
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 22.37358562433259 15.966868654383456 0.5324359119155827 0.5953673585690135 168.77750754356384
Node 53
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.003 )
( 17.744858443711582 12.224098292397736 0.37073352787744235 0.700427213090064 227.96997356414795
Node 54
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.0 )
( 11.464792401385349 5.838970926692169 0.08805460942962347 0.8285358081852988 217.69500064849854
Node 55
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.001 )
( 20.707779641052 10.932889453620911 0.16433387518483555 0.6893630549943733 103.31873917579651
Node 56
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.001 )
( 16.81417374930043 9.659099435052328 0.14792075245040734 0.7466409333135464 72.19391393661499
Node 57
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.092 )
( 12.748419585549017 7.06492294922477

( 22.556053218969467 15.926439243637091 0.5911278166629795 0.5813996242158137 236.92649149894714
Node 104
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 12.236097042309918 7.501150063661463 0.14231383684377263 0.80254935407576 260.10078263282776
Node 105
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 15.110743398014568 10.448526475368437 0.19839313430883432 0.7537048744767704 105.4953818321228
Node 106
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 22.031751126596216 15.455261201903388 0.504263371278351 0.5968199772916151 120.96744132041931
Node 107
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 20.168719115716215 13.064522484176939 0.37099550222182237 0.664835962295883 168.88354682922363
Node 108
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.001 )
( 12.452369481064418 5.544168544773819 0.0840351587493131 0.8129968737493427 131.38764357566833
Node 109
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.007 )
( 15.851992947672931 8.55445356

( 5.8740526652545935 3.6780132304456927 0.05530647910560707 0.911997640572208 486.06861090660095
Node 155
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.004 )
( 5.821047522886329 3.676584193355945 0.055517720112822355 0.9122789220822879 361.20150542259216
Node 156
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.003 )
( 9.354845196666899 6.0665606994118475 0.11644117409518318 0.8538275969661769 322.3872766494751
Node 157
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.005 )
( 16.033718958579893 11.077704072150135 0.2862374196214407 0.7221005955832245 262.7022008895874
Node 158
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 8.404371266266248 6.37845861160148 0.10946445465712906 0.8637437418715843 356.3050742149353
Node 159
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.021 )
( 9.498385408577706 5.504694052968176 0.11693488297943612 0.8436336911961483 496.43307876586914
Node 160
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.032 )
( 20.596878376158784 14.26644192

( 9.943374973452615 6.47350144370708 0.11058384994670252 0.844275852640608 201.62714457511902
Node 206
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.042 )
( 6.440539186782378 4.669509498693121 0.07373980039540956 0.9017652432679125 462.2105052471161
Node 207
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.069 )
( 7.553218690578255 5.304240678313195 0.0827433335168576 0.886021693768489 361.58043098449707
Node 208
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.025 )
( 11.33148101434809 7.941168551908167 0.1182559365769526 0.8304091724294731 246.86135506629944
Node 209
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.088 )
( 4.017349950149848 2.9542605959919643 0.044890086507385574 0.9396835518689832 475.7988865375519
Node 210
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.023 )
( 8.159767375570762 5.695299138289731 0.08480304238865725 0.8778419798890241 329.01636362075806
Node 211
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.128 )
( 9.91425788422133 6.775557882927781

( 15.417948756454452 10.192980869317536 0.1928394643000627 0.7472868719833423 136.6586136817932
Node 258
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 15.215080733341331 9.62280897029732 0.2756420487795928 0.7407747894366752 458.24859499931335
Node 259
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.012 )
( 21.324878577696225 14.430166320805357 0.5183002050934918 0.6123691732586314 481.72198009490967
Node 260
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.071 )
( 10.743711314981482 6.32587064852682 0.14572236414522724 0.83078881094819 451.2524621486664
Node 261
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 5.7801107598131924 3.286061260156775 0.050955215732597786 0.9127582808501029 459.14146852493286
Node 262
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.01 )
( 4.735084563195071 2.8867146056928523 0.04397125987499013 0.9289532114028742 495.2508239746094
Node 263
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.008 )
( 4.876244208307443 3.2610351121770

( 11.99399643897975 9.567847576653412 0.14389120541885178 0.8192085331394924 393.6959500312805
Node 309
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.021 )
( 13.123130855022598 7.788895914104818 0.16776180796410026 0.7903935658021498 219.52004742622375
Node 310
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.006 )
( 19.739821561427814 13.100292875993382 0.3757316834878108 0.6466525954428073 180.7419137954712
Node 311
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.0 )
( 15.742334727855496 11.231750031161255 0.1966397358879143 0.7431700059268229 89.83375573158264
Node 312
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.048 )
( 13.389496626551875 8.873614051212032 0.21463823389708828 0.7755015596642518 275.65904808044434
Node 313
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.016 )
( 20.937544522481026 14.926788367936794 0.5573782243247616 0.6104848294475673 414.1707320213318
Node 314
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.038 )
( 13.40448876140436 7.528618330

### Unobserved Node = 5%

In [11]:
data = load_data(path_5)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 833856
Node 0
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 8.328685226158738 5.429875924958618 0.08769796202071951 0.8749219004125711 274.0811462402344
Node 1
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 15.446384759071634 8.290043235390383 0.12390448963361052 0.768713163875233 75.78335309028625
Node 2
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.005 )
( 12.967767888338432 8.15823505548368 0.12257644814501197 0.8051230490893208 288.2922854423523
Node 3
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.009 )
( 10.144471452797013 6.2981759283069705 0.09828100866950229 0.8462816461074736 229.95194816589355
Node 4
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.012 )
( 13.786456414932635 10.063314361403734 0.1915501973957745 0.7727353266499768 243.1783902645111
Node 5
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.014 )
( 12.276460863316059 7.065478304742019 0.13273349172809903 0.8037914762684333 171.8874650001526
Node 6
( (41677, 12) (4

( 22.37358562433259 15.966868654383456 0.5324359119155827 0.5953673585690135 168.82114386558533
Node 53
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.003 )
( 17.744858443711582 12.224098292397736 0.37073352787744235 0.700427213090064 228.24528670310974
Node 54
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.0 )
( 11.464792401385349 5.838970926692169 0.08805460942962347 0.8285358081852988 217.56735968589783
Node 55
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.6623206736627 66.60304864380161 1.0 0.0 0.14579272270202637
Node 56
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.36499719234031 66.28684792006905 1.0 0.0 0.14605164527893066
Node 57
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 64.851587345327 64.40700035225949 1.0 0.0 0.14658784866333008
Node 58
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.006 )
( 20.397784320722 14.110621157488508 0.43071603120151797 0.6328776800847853 210.36419010162354
Node 59
( (41677, 12) (41677, 3) (10409, 1

( 15.110743398014568 10.448526475368437 0.19839313430883432 0.7537048744767704 105.51092004776001
Node 106
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 22.031751126596216 15.455261201903388 0.504263371278351 0.5968199772916151 120.88689541816711
Node 107
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 20.168719115716215 13.064522484176939 0.37099550222182237 0.664835962295883 169.0252709388733
Node 108
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.001 )
( 12.452369481064418 5.544168544773819 0.0840351587493131 0.8129968737493427 131.1624095439911
Node 109
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.007 )
( 15.851992947672931 8.554453567791755 0.12992495607566398 0.761794756946814 127.56783962249756
Node 110
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.004 )
( 6.073320049316174 3.5541212594471077 0.05531684162059324 0.9085373088505172 332.10919404029846
Node 111
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 9.71744209963618 5.5164110388

( 9.354845196666899 6.0665606994118475 0.11644117409518318 0.8538275969661769 322.6491050720215
Node 157
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.005 )
( 16.033718958579893 11.077704072150135 0.2862374196214407 0.7221005955832245 262.18944478034973
Node 158
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 8.404371266266248 6.37845861160148 0.10946445465712906 0.8637437418715843 356.33926463127136
Node 159
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.021 )
( 9.498385408577706 5.504694052968176 0.11693488297943612 0.8436336911961483 495.84913420677185
Node 160
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.032 )
( 20.596878376158784 14.266441925259715 0.5118808404406242 0.6198495981569946 275.10255312919617
Node 161
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.012 )
( 13.114737372618444 7.841921470041235 0.1848224194959557 0.7906762003725853 138.62436985969543
Node 162
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 13.38082032128147 8.641363907

( 11.33148101434809 7.941168551908167 0.1182559365769526 0.8304091724294731 246.46639251708984
Node 209
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.60455107408721 66.543391936465 1.0 0.0 0.14488935470581055
Node 210
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.023 )
( 8.159767375570762 5.695299138289731 0.08480304238865725 0.8778419798890241 328.69460463523865
Node 211
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.128 )
( 9.91425788422133 6.775557882927781 0.11624979568831816 0.8503803937385735 184.30499243736267
Node 212
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.043 )
( 6.97663028707149 3.848524873016792 0.10839641334919747 0.8911627108111557 497.2797601222992
Node 213
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 63.670075398863446 63.36897556601671 1.0 0.0 0.1452622413635254
Node 214
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.007 )
( 11.13105826434843 7.164571445934203 0.14136653319229855 0.8267933421403404 242.09589886665344
Node 21

( 10.743711314981482 6.32587064852682 0.14572236414522724 0.83078881094819 449.40016508102417
Node 261
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 5.7801107598131924 3.286061260156775 0.050955215732597786 0.9127582808501029 458.03787326812744
Node 262
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.01 )
( 4.735084563195071 2.8867146056928523 0.04397125987499013 0.9289532114028742 494.2493848800659
Node 263
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.008 )
( 4.876244208307443 3.2610351121770136 0.049638302452278606 0.9266632192785139 535.7718362808228
Node 264
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.028 )
( 7.048052143327056 4.165989158524912 0.06311002914997267 0.8940962186186959 441.2544252872467
Node 265
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.045 )
( 6.349722355519494 3.9730573267184837 0.07418753146392822 0.9022947737808698 387.0401635169983
Node 266
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.056 )
( 7.996574062808876 4.9273897734

( 15.742334727855496 11.231750031161255 0.1966397358879143 0.7431700059268229 89.52443718910217
Node 312
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.048 )
( 13.389496626551875 8.873614051212032 0.21463823389708828 0.7755015596642518 274.33635663986206
Node 313
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.016 )
( 20.937544522481026 14.926788367936794 0.5573782243247616 0.6104848294475673 413.0934867858887
Node 314
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.038 )
( 13.40448876140436 7.528618330811916 0.21127799982528198 0.783567741266659 288.402405500412
Node 315
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.025 )
( 5.523903336756153 3.5642468106773193 0.05361021720951145 0.9173004022471387 343.8297629356384
Node 316
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.019 )
( 7.1770536431527985 3.9660723286943096 0.060515779484339526 0.8919644853275697 293.2922155857086
Node 317
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.02 )
( 9.486528666975294 4.9944727850