In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 6

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3381950, 6), 0)

### Unobserved Node = 0%

In [10]:
data = load_data(path_0)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 0
Node 0
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 8.086414310949818 5.28257205022424 0.08526686387938014 0.8786408095009826 537.8583295345306
Node 1
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 15.386915420790967 8.257516004923767 0.12339640776819694 0.7695197566260219 187.22418761253357
Node 2
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.005 )
( 12.694174127897817 8.017809480522788 0.12038009121347648 0.8092914451450123 581.6705358028412
Node 3
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.01 )
( 10.038969832505513 6.295291906116039 0.09908497132413023 0.8474078646855691 503.40339183807373
Node 4
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.012 )
( 13.693137615730508 9.967945096244929 0.19040078586234432 0.7741615834955231 525.2432930469513
Node 5
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.014 )
( 12.143914196981628 6.976166856786729 0.13184655044104865 0.8060187226795344 387.0471053123474
Node 6
( (41674, 12) (41674,

( 11.90181452780362 7.4112316125841575 0.16001487518035856 0.8064095462989815 470.4709634780884
Node 52
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 22.595928389886947 16.161531541897094 0.5461707917887713 0.5900690672428718 392.0153331756592
Node 53
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.003 )
( 17.29537079309526 11.86801107548789 0.35148337687628506 0.7101995027515782 494.00748586654663
Node 54
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.0 )
( 11.288936880994623 5.760531031872838 0.08677595571419296 0.8312154856296929 458.40769386291504
Node 55
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.001 )
( 20.63356385195718 10.89826775573849 0.1638193016263615 0.6903907097071611 257.6645083427429
Node 56
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.001 )
( 16.74489856225525 9.606151086794762 0.1472332625981359 0.7477324515205368 174.37146425247192
Node 57
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.093 )
( 12.855131004353275 7.1819923406193915 

( 22.544562281046264 15.905415046871624 0.5906411440332583 0.5818127030990823 487.4690647125244
Node 104
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 12.02619588347197 7.375787301145305 0.13907765668452968 0.8061603278070211 529.6016767024994
Node 105
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 15.1973330948796 10.513572791570622 0.20155719300236594 0.7519536855439471 242.41566133499146
Node 106
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 22.042288865184073 15.444393885919672 0.5075391154291871 0.5961588347348405 266.1298382282257
Node 107
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 19.904455081045544 12.832200086218384 0.3608899719689012 0.6703187758874696 402.0558228492737
Node 108
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.001 )
( 12.254341500468033 5.460658376249256 0.08270741699121388 0.8160126445678064 279.8398380279541
Node 109
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 15.740319269233456 8.49218235482

( 5.744758794617271 3.627645296463038 0.05454795758727879 0.9139320023171588 942.5625398159027
Node 155
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.004 )
( 5.682760532547693 3.613008073258053 0.05451836717220827 0.9143782290934064 714.6346864700317
Node 156
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.003 )
( 9.3657154641899 6.063690432949797 0.1171585274785303 0.8535634530602805 629.0737237930298
Node 157
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.005 )
( 15.89868271612715 10.956705378758713 0.28337177320440216 0.7246079240502779 528.1783602237701
Node 158
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 8.283148839645149 6.2750309994621665 0.10828313456236728 0.8656822017266775 716.82883644104
Node 159
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.021 )
( 9.37652835078859 5.441159039126315 0.1161256763765725 0.8456404941783817 876.1877098083496
Node 160
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.032 )
( 20.559319102873683 14.228729869997508 0.5

( 6.300784110826001 4.5699160176567934 0.07207753896515366 0.9038860769011643 895.9401383399963
Node 207
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.069 )
( 7.4289614503805605 5.205356813990121 0.08144980428070576 0.8878757021818167 714.9600372314453
Node 208
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.025 )
( 11.17630735426837 7.870798575593905 0.11720912878465067 0.8327254578458919 534.8653926849365
Node 209
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.088 )
( 3.9647272762143766 2.9272981715891953 0.04444907001034389 0.9404732054586035 875.5700221061707
Node 210
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.022 )
( 8.061349586487125 5.660504024080734 0.08431698215699236 0.8793123069335083 660.3460829257965
Node 211
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.128 )
( 9.788026750982207 6.678988493400355 0.11396619924358813 0.8523333397258336 384.37401151657104
Node 212
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.043 )
( 6.974157287759394 3.8530257296

( 14.906639271731828 9.41228626526602 0.2660054764358734 0.7466632165309546 870.046238899231
Node 259
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.012 )
( 21.3291809996714 14.42500078444308 0.5188080546691826 0.6122516223965004 914.9421858787537
Node 260
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.071 )
( 11.248348411074467 6.557026542458907 0.15678274321173036 0.8221179167302991 857.4880740642548
Node 261
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 5.656999361103358 3.229252776698278 0.05013236952886054 0.9145824913117746 877.6312687397003
Node 262
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.01 )
( 4.600869523153057 2.8182161889948625 0.04293229346842372 0.9309737848548207 925.8744790554047
Node 263
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 4.678623731183447 3.1499808894163217 0.04795764327740809 0.9296261748114614 1010.3485431671143
Node 264
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.028 )
( 6.814367604649277 4.068769360557162 0.

( 12.580195316232924 7.354737640964747 0.15450236719282243 0.8004620927245998 461.86721181869507
Node 310
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.006 )
( 19.761219389967913 13.151962382958734 0.3783615256012729 0.6459611600544588 421.0409595966339
Node 311
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 15.87629627118373 11.327517281546656 0.2023868196381392 0.7403148518895679 221.7296063899994
Node 312
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.048 )
( 12.84093701068008 8.5272634350178 0.20132073265743902 0.7855284760901693 540.9047985076904
Node 313
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.016 )
( 20.669177083263094 14.661408514461334 0.5465116601102639 0.616102357620937 777.3014914989471
Node 314
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.038 )
( 14.1593384628157 8.025150024578295 0.2348639249757595 0.769514021922314 574.8753354549408
Node 315
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.025 )
( 5.335390073136595 3.4580598663424524 0.0

### Unobserved Node = 5%

In [11]:
data = load_data(path_5)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 833856
Node 0
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 8.086414310949818 5.28257205022424 0.08526686387938014 0.8786408095009826 532.3421235084534
Node 1
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 15.386915420790967 8.257516004923767 0.12339640776819694 0.7695197566260219 185.87498664855957
Node 2
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.005 )
( 12.694174127897817 8.017809480522788 0.12038009121347648 0.8092914451450123 577.3955385684967
Node 3
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.01 )
( 10.038969832505513 6.295291906116039 0.09908497132413023 0.8474078646855691 500.2025272846222
Node 4
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.012 )
( 13.693137615730508 9.967945096244929 0.19040078586234432 0.7741615834955231 521.2571790218353
Node 5
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.014 )
( 12.143914196981628 6.976166856786729 0.13184655044104865 0.8060187226795344 384.4397237300873
Node 6
( (41674, 12) (41

( 22.595928389886947 16.161531541897094 0.5461707917887713 0.5900690672428718 390.1860945224762
Node 53
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.003 )
( 17.29537079309526 11.86801107548789 0.35148337687628506 0.7101995027515782 492.0255494117737
Node 54
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.0 )
( 11.288936880994623 5.760531031872838 0.08677595571419296 0.8312154856296929 456.448303937912
Node 55
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.64387826489718 66.5839467614836 1.0 0.0 0.16130876541137695
Node 56
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 66.3775371155932 66.29929207508478 1.0 0.0 0.1622638702392578
Node 57
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 64.57462363363686 64.06846050355578 1.0 0.0 0.15991759300231934
Node 58
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.006 )
( 20.298237519994746 14.036505767468213 0.42882535053039444 0.6348636930602138 445.87389874458313
Node 59
( (41674, 12) (41674, 6) (10406, 12

( 15.1973330948796 10.513572791570622 0.20155719300236594 0.7519536855439471 241.48423624038696
Node 106
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 22.042288865184073 15.444393885919672 0.5075391154291871 0.5961588347348405 264.9306104183197
Node 107
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.004 )
( 19.904455081045544 12.832200086218384 0.3608899719689012 0.6703187758874696 401.0472409725189
Node 108
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.001 )
( 12.254341500468033 5.460658376249256 0.08270741699121388 0.8160126445678064 278.75977778434753
Node 109
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 15.740319269233456 8.492182354826472 0.12904272720485013 0.7634488793631263 280.8009331226349
Node 110
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.004 )
( 5.771301071448847 3.437749292508844 0.053543414802190804 0.9131135037964414 641.3798871040344
Node 111
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 10.002629383584884 5.675821622

( 15.89868271612715 10.956705378758713 0.28337177320440216 0.7246079240502779 526.6677937507629
Node 158
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.002 )
( 8.283148839645149 6.2750309994621665 0.10828313456236728 0.8656822017266775 714.8998618125916
Node 159
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.021 )
( 9.37652835078859 5.441159039126315 0.1161256763765725 0.8456404941783817 872.0159080028534
Node 160
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.032 )
( 20.559319102873683 14.228729869997508 0.5100743322158758 0.6207743883668358 554.7946672439575
Node 161
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.012 )
( 13.16773789052588 7.852817064729462 0.18725926610758867 0.7897071751897303 347.1434600353241
Node 162
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 13.269817560640686 8.610192730154578 0.1284554277823218 0.8010058049815829 359.55000615119934
Node 163
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.016 )
( 11.025301135740056 7.081982546536

( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.022 )
( 8.061349586487125 5.660504024080734 0.08431698215699236 0.8793123069335083 657.7024402618408
Node 211
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.128 )
( 9.788026750982207 6.678988493400355 0.11396619924358813 0.8523333397258336 383.14161229133606
Node 212
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.043 )
( 6.974157287759394 3.853025729665834 0.10860275359926784 0.8912458349384884 917.5150406360626
Node 213
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.0 )
( 63.66038513822558 63.35696873598574 1.0 0.0 0.16074514389038086
Node 214
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.007 )
( 11.010058888825967 7.078123402744662 0.13992000814085 0.8286570473301572 517.254807472229
Node 215
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.015 )
( 6.259008644976315 4.553997337704355 0.07400876546060296 0.9037904196596369 970.7362198829651
Node 216
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.0 )
( 7.1525010285076

( 4.600869523153057 2.8182161889948625 0.04293229346842372 0.9309737848548207 921.9664454460144
Node 263
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.008 )
( 4.678623731183447 3.1499808894163217 0.04795764327740809 0.9296261748114614 1006.414023399353
Node 264
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.028 )
( 6.814367604649277 4.068769360557162 0.06163620955567539 0.8976341471202958 850.6605904102325
Node 265
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.045 )
( 6.130025323849087 3.833816095621772 0.070700281590748 0.9057852595824104 777.8379406929016
Node 266
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.056 )
( 7.95640596488521 4.891793082868614 0.11476769477962974 0.8739094738962465 804.9382696151733
Node 267
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.004 )
( 8.598804024483778 5.530852931999078 0.11181754991231214 0.8621956339102919 684.8217284679413
Node 268
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.028 )
( 8.615481240871528 5.960181235521529 0

( 14.1593384628157 8.025150024578295 0.2348639249757595 0.769514021922314 574.4036490917206
Node 315
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.025 )
( 5.335390073136595 3.4580598663424524 0.05208263083646698 0.9200801783957018 701.9879221916199
Node 316
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.019 )
( 6.972855834075799 3.902139919246878 0.05951688222378587 0.8950749293639721 590.0990166664124
Node 317
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.02 )
( 9.331676635060239 4.919781005358642 0.07411904444909058 0.8595630429685066 537.8296847343445
Node 318
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) 0.003 )
( 7.682818911475401 5.834998833042398 0.09620182459457227 0.8807837780113177 855.8936274051666
Node 319
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.015 )
( 12.04979670231141 8.314755547133055 0.18856957726470688 0.7984247374912765 971.2434585094452
Node 320
( (41674, 12) (41674, 6) (10406, 12) (10406, 6) -0.008 )
( 12.793924053873925 7.951906481348826 0