In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 3

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3382925, 3), 0)

### Unobserved Node = 10%

In [10]:
data = load_data(path_10)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 1667712
Node 0
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.58787792292135 66.47829122233931 1.0 0.0 0.11281871795654297
Node 1
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 15.446384759071634 8.290043235390383 0.12390448963361052 0.768713163875233 76.327463388443
Node 2
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.005 )
( 12.967767888338432 8.15823505548368 0.12257644814501197 0.8051230490893208 289.83680748939514
Node 3
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.009 )
( 10.144471452797013 6.2981759283069705 0.09828100866950229 0.8462816461074736 230.64129495620728
Node 4
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.012 )
( 13.786456414932635 10.063314361403734 0.1915501973957745 0.7727353266499768 244.12071132659912
Node 5
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 62.568438056806755 62.22415537835852 1.0 0.0 0.17021703720092773
Node 6
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.003 )
( 15.828485538407923 9

( 22.37358562433259 15.966868654383456 0.5324359119155827 0.5953673585690135 168.62578916549683
Node 53
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 59.233879775089186 57.24086527684364 1.0 0.0 0.1437089443206787
Node 54
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.0 )
( 11.464792401385349 5.838970926692169 0.08805460942962347 0.8285358081852988 217.08887600898743
Node 55
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.6623206736627 66.60304864380161 1.0 0.0 0.1434948444366455
Node 56
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.001 )
( 16.81417374930043 9.659099435052328 0.14792075245040734 0.7466409333135464 71.99956893920898
Node 57
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 64.851587345327 64.40700035225949 1.0 0.0 0.1451551914215088
Node 58
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 55.561275395712144 53.55888493931533 1.0 0.0 0.14655637741088867
Node 59
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 61.393219636

( 15.110743398014568 10.448526475368437 0.19839313430883432 0.7537048744767704 105.15357637405396
Node 106
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 22.031751126596216 15.455261201903388 0.504263371278351 0.5968199772916151 120.64325547218323
Node 107
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 20.168719115716215 13.064522484176939 0.37099550222182237 0.664835962295883 168.42109179496765
Node 108
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.001 )
( 12.452369481064418 5.544168544773819 0.0840351587493131 0.8129968737493427 131.09016299247742
Node 109
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.007 )
( 15.851992947672931 8.554453567791755 0.12992495607566398 0.761794756946814 127.2112889289856
Node 110
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.004 )
( 6.073320049316174 3.5541212594471077 0.05531684162059324 0.9085373088505172 331.67411041259766
Node 111
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 9.71744209963618 5.516411038

( 16.033718958579893 11.077704072150135 0.2862374196214407 0.7221005955832245 262.18830966949463
Node 158
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 8.404371266266248 6.37845861160148 0.10946445465712906 0.8637437418715843 356.31496500968933
Node 159
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.021 )
( 9.498385408577706 5.504694052968176 0.11693488297943612 0.8436336911961483 494.50448536872864
Node 160
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 54.1808670365812 51.70783616741904 1.0 0.0 0.1427755355834961
Node 161
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.012 )
( 13.114737372618444 7.841921470041235 0.1848224194959557 0.7906762003725853 138.48969054222107
Node 162
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 13.38082032128147 8.641363907414762 0.12887273083421794 0.7993251845553853 155.1460349559784
Node 163
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.017 )
( 11.104693759356088 7.114500165052359 0.10706631779999365 0.83312

( 4.017349950149848 2.9542605959919643 0.044890086507385574 0.9396835518689832 475.25614166259766
Node 210
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.023 )
( 8.159767375570762 5.695299138289731 0.08480304238865725 0.8778419798890241 329.39169359207153
Node 211
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.128 )
( 9.91425788422133 6.775557882927781 0.11624979568831816 0.8503803937385735 184.39181804656982
Node 212
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.043 )
( 6.97663028707149 3.848524873016792 0.10839641334919747 0.8911627108111557 496.74881958961487
Node 213
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.124 )
( 9.78521378482098 5.451843574338729 0.11001921940008501 0.8463137710530234 178.65810370445251
Node 214
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.007 )
( 11.13105826434843 7.164571445934203 0.14136653319229855 0.8267933421403404 241.54156064987183
Node 215
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.014 )
( 6.361300852944765 4.620941344814

( 5.7801107598131924 3.286061260156775 0.050955215732597786 0.9127582808501029 458.6824195384979
Node 262
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.64741160991242 66.58929452076723 1.0 0.0 0.14463067054748535
Node 263
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.008 )
( 4.876244208307443 3.2610351121770136 0.049638302452278606 0.9266632192785139 537.2656915187836
Node 264
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.55146824220287 66.48181061261123 1.0 0.0 0.14336872100830078
Node 265
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.045 )
( 6.349722355519494 3.9730573267184837 0.07418753146392822 0.9022947737808698 388.302707195282
Node 266
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.056 )
( 7.996574062808876 4.927389773489488 0.11547158113028155 0.8732063787288544 367.140353679657
Node 267
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.004 )
( 8.519688268775464 5.494989879443064 0.1099446700855617 0.8635047163817842 336.42229533195496
No

( 13.40448876140436 7.528618330811916 0.21127799982528198 0.783567741266659 288.27323508262634
Node 315
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.025 )
( 5.523903336756153 3.5642468106773193 0.05361021720951145 0.9173004022471387 344.5964422225952
Node 316
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.019 )
( 7.1770536431527985 3.9660723286943096 0.060515779484339526 0.8919644853275697 294.1873860359192
Node 317
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.02 )
( 9.486528666975294 4.994472785039958 0.0751463504297732 0.857353747581779 270.3048982620239
Node 318
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 7.800068722132524 5.832593166769751 0.09764252250372722 0.8783872859797774 468.8399953842163
Node 319
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.015 )
( 12.148238928722014 8.36651688689266 0.19079815437699424 0.7967038563515676 523.2068955898285
Node 320
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.008 )
( 12.8258224088714 7.958419704241527

### Unobserved Node = 20%

In [11]:
data = load_data(path_20)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 3387540
Node 0
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 8.328685226158738 5.429875924958618 0.08769796202071951 0.8749219004125711 273.76130056381226
Node 1
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 15.446384759071634 8.290043235390383 0.12390448963361052 0.768713163875233 75.52606534957886
Node 2
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.005 )
( 12.967767888338432 8.15823505548368 0.12257644814501197 0.8051230490893208 287.2674684524536
Node 3
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.009 )
( 10.144471452797013 6.2981759283069705 0.09828100866950229 0.8462816461074736 229.14789724349976
Node 4
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.012 )
( 13.786456414932635 10.063314361403734 0.1915501973957745 0.7727353266499768 242.58625197410583
Node 5
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.014 )
( 12.276460863316059 7.065478304742019 0.13273349172809903 0.8037914762684333 171.47734832763672
Node 6
( (41677, 12

( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.001 )
( 20.707779641052 10.932889453620911 0.16433387518483555 0.6893630549943733 102.93429446220398
Node 56
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.001 )
( 16.81417374930043 9.659099435052328 0.14792075245040734 0.7466409333135464 71.88474941253662
Node 57
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.092 )
( 12.748419585549017 7.064922949224777 0.13350214612810535 0.803421626094282 390.1756272315979
Node 58
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 55.561275395712144 53.55888493931533 1.0 0.0 0.1453385353088379
Node 59
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.001 )
( 15.14709802249316 9.550391814130078 0.18834275320877122 0.7532773470384845 152.59548139572144
Node 60
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.003 )
( 14.96941104289762 9.079842160767193 0.1845491890753809 0.755883513079023 129.29031133651733
Node 61
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.008 )
( 22.597145371359

( 15.851992947672931 8.554453567791755 0.12992495607566398 0.761794756946814 126.95336890220642
Node 110
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.40215778683132 66.3375828609857 1.0 0.0 0.14471125602722168
Node 111
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.002 )
( 9.71744209963618 5.516411038819778 0.12917755096195996 0.8489731512364468 265.7799587249756
Node 112
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.026 )
( 24.196049154352334 17.471283062580863 0.6558867533496774 0.5500583813014162 137.02335143089294
Node 113
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.005 )
( 16.765048138498138 9.871337730509497 0.21304961680514758 0.7250224644003915 119.57234477996826
Node 114
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 16.434294509403 10.006426472751093 0.19788152379877888 0.7317319324528234 89.55238366127014
Node 115
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.171 )
( 20.24555429330639 14.923834669422824 0.4717024839238826 0.62608735

( 13.38082032128147 8.641363907414762 0.12887273083421794 0.7993251845553853 154.8305377960205
Node 163
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.017 )
( 11.104693759356088 7.114500165052359 0.10706631779999365 0.8331202279658864 211.71930646896362
Node 164
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.001 )
( 3.7385975042752184 2.7082138479662654 0.040967077743068465 0.9438386150829476 484.2887260913849
Node 165
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.107 )
( 14.584606650178861 9.05083523877205 0.18253038145353487 0.771718698000935 83.50419330596924
Node 166
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.057 )
( 22.78603098582416 15.907745562348476 0.6402553481975347 0.5770855676672191 264.08379793167114
Node 167
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.002 )
( 9.71993999078136 6.211955023969885 0.1398557109001487 0.8408494523388319 279.56393575668335
Node 168
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.007 )
( 10.737239389840628 7.415194580

( 6.361300852944765 4.620941344814713 0.07484729708129075 0.9022243793731324 525.5287413597107
Node 216
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.0 )
( 7.322412274139606 3.965559657377447 0.06014976120867876 0.8892398406109306 239.75883650779724
Node 217
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.026 )
( 14.789263257567866 7.2926575383973535 0.1100488931372992 0.7776338862917136 112.66315865516663
Node 218
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.035 )
( 9.338996965635266 5.650340271912843 0.08485954331057567 0.8591663613914968 299.8289384841919
Node 219
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.022 )
( 13.926409404866284 8.322700759398815 0.17149866290554694 0.777833267735517 180.50719571113586
Node 220
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.013 )
( 21.227980682461194 15.256611963267915 0.5032257428577341 0.608691684064618 171.84483551979065
Node 221
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.003 )
( 9.487381325117315 6.01519040173919

( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 65.54556990671841 65.44036891151875 1.0 0.0 0.14503931999206543
Node 270
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.019 )
( 7.043026514428761 4.2742547372693025 0.06555659629556546 0.893869542960125 333.3000328540802
Node 271
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.002 )
( 6.647647055843654 3.825968435873127 0.05796267814861905 0.9003202396437286 311.7248032093048
Node 272
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.103 )
( 5.866309508697926 4.258003422032321 0.06374766767526126 0.9117250925046825 440.4962749481201
Node 273
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.001 )
( 5.5623311689128325 4.362548560222195 0.06525227729254576 0.9165039575384295 511.0553870201111
Node 274
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.012 )
( 14.489526200510708 8.21279224512221 0.12785702138280164 0.7804497578752462 200.53999042510986
Node 275
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.02 )
( 15.263050

( 16.645014915214695 11.482687442807793 0.3110099551455403 0.7154313476731705 361.9089765548706
Node 322
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.004 )
( 18.382068501265643 12.118237712459093 0.39558073450694625 0.6770776802589511 319.45578932762146
Node 323
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) -0.009 )
( 10.170687451348508 6.223413551916369 0.11524515001927749 0.8427960281359103 261.23873019218445
Node 324
( (41677, 12) (41677, 3) (10409, 12) (10409, 3) 0.0 )
( 66.61107950988561 66.51355237454781 1.0 0.0 0.14419794082641602
RMSE: 23.02613452072738, MAE: 19.146448000067014, MAPE: 0.35030181282809175, ACC: 0.6327720479857909
