In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 12

In [8]:
path_0 = '../Data/PEMS-BAY/speed_bay_0.csv'
path_5 = '../Data/PEMS-BAY/speed_bay_5.csv'
path_10 = '../Data/PEMS-BAY/speed_bay_10.csv'
path_20 = '../Data/PEMS-BAY/speed_bay_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(52116, 325) 0


((3380000, 12), 0)

### Unobserved Node = 10%

In [10]:
data = load_data(path_10)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 1667712
Node 0
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.72999811417121 66.62970592948712 1.0 0.0 0.16617369651794434
Node 1
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 15.175368021945536 8.169190360498108 0.12214403934403976 0.7725417462955099 479.71165657043457
Node 2
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.005 )
( 11.873133164568015 7.6333650236816615 0.11440546717839482 0.8217802183263296 1205.6857645511627
Node 3
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.01 )
( 9.972937648425004 6.321861782765244 0.10216123663977943 0.8472606499036938 1095.733115196228
Node 4
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.012 )
( 13.396977204264713 9.756582009798533 0.18633352756839738 0.7792280990164643 1159.6079466342926
Node 5
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 62.643408027588066 62.287137019230755 1.0 0.0 0.23123693466186523
Node 6
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.002 )
( 15.

( 12.233275187521699 7.614571607368049 0.17369558918696051 0.7998961052501115 988.0590734481812
Node 52
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 22.89679245224199 16.414631945798007 0.5636329444831562 0.5824677036223607 874.0934870243073
Node 53
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 60.507745395151524 58.83215625 1.0 0.0 0.23064136505126953
Node 54
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 10.700146326791618 5.489664656646951 0.08264263222114725 0.840134012133253 980.2717816829681
Node 55
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.60109378312774 66.53918589743576 1.0 0.0 0.2277069091796875
Node 56
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 16.63450893603398 9.514300283081951 0.14552051762130827 0.7496023338848479 422.6558117866516
Node 57
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 63.9182214393608 63.263151442307766 1.0 0.0 0.19728469848632812
Node 58
( (41668, 12) (41668, 12) (1040

( 11.533206534154145 7.0543094236067985 0.133254088692306 0.814381274285661 1034.7106130123138
Node 105
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 15.256645992730478 10.59269857825067 0.20737032429782856 0.7502852299441264 544.9578337669373
Node 106
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 22.119166602730594 15.49970433839095 0.5151187546329281 0.5940935463249652 580.6150107383728
Node 107
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 19.312365570150178 12.330736974087307 0.34098194695054757 0.6820503811815299 919.8174347877502
Node 108
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 11.6487357253528 5.227703851285329 0.0791416896407393 0.8252023652699514 611.6295027732849
Node 109
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.008 )
( 15.491975019975044 8.35408306809204 0.12704246418441062 0.7671274112692027 605.8837575912476
Node 110
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 5.233283941724305 3.2134

( 9.4448579114075 6.068170859105218 0.1197542867324049 0.8521103046572355 1218.327353477478
Node 157
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.005 )
( 15.599377548886721 10.686785923234261 0.2764700917444156 0.7302614180043505 1045.1329617500305
Node 158
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 8.02775955892293 6.020966389577564 0.10533891627988339 0.8697540838997242 1425.0346376895905
Node 159
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.021 )
( 9.135467339699039 5.33078890025303 0.11514046103219279 0.8495626363345758 1593.5587327480316
Node 160
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 54.31609350037085 51.88418910256397 1.0 0.0 0.19799494743347168
Node 161
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.012 )
( 13.115155122977285 7.834151288827944 0.19152218404176435 0.7902713437357584 811.7195916175842
Node 162
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 13.039798840436168 8.527900069541731 0.1272238544566191

( 7.233885626934147 5.043745157347203 0.07922884179284385 0.8907791546383154 1399.1915140151978
Node 208
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.025 )
( 10.843142885369415 7.709414487702229 0.11479440505653331 0.8376929100221729 1067.283400297165
Node 209
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.088 )
( 3.9123250294970853 2.8963371003887186 0.0439574584681695 0.941263114108703 1668.286949634552
Node 210
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.022 )
( 7.894089746105442 5.588921335854985 0.08324878111699581 0.8818100809003936 1332.3688373565674
Node 211
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.128 )
( 9.585108131740334 6.537289102580183 0.11067043309642645 0.8554688937704783 819.2371590137482
Node 212
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.043 )
( 6.926221950196757 3.8039360385581795 0.10720106888258606 0.8920908605776559 1744.0350363254547
Node 213
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.124 )
( 9.852771101078496 5

( 14.320590540706121 9.007065364414329 0.24842487444427652 0.7578372517267578 1705.8430631160736
Node 259
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.012 )
( 21.20003777415256 14.298551625236273 0.5132485573340286 0.6151208347968509 1778.892656326294
Node 260
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.07 )
( 12.304582346297948 7.09239311249614 0.18309990765225717 0.803478605876455 1685.6359086036682
Node 261
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 5.390612040649916 3.1333599095356885 0.048818934340383285 0.9185766704948991 1734.929405450821
Node 262
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.6616480542906 66.60389743589734 1.0 0.0 0.19704055786132812
Node 263
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.008 )
( 4.372622811229427 2.9754683763857614 0.04531713178484503 0.9342195699892438 1871.4959254264832
Node 264
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.59429329784759 66.52443108974387 1.0 0.0 0.19306302

( 19.40695158955879 12.899278968611549 0.36940170712916487 0.6530384629532926 881.8834817409515
Node 311
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.001 )
( 16.285982206610537 11.631682411251093 0.2222777379588883 0.7316831274247371 552.9250960350037
Node 312
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.049 )
( 11.752920130382412 7.846965300639492 0.17702963505426286 0.8051184729080659 1174.1169102191925
Node 313
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.016 )
( 20.071867327188126 14.054309884490232 0.5181334880122508 0.6294926325717182 1468.682029247284
Node 314
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.038 )
( 15.709350787225475 9.114626109854564 0.28889247478162006 0.739591499679102 1219.4666802883148
Node 315
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.025 )
( 5.057404315429675 3.3039167820312145 0.05010282288460396 0.9241463880068884 1481.3573966026306
Node 316
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.019 )
( 6.5445208418915

### Unobserved Node = 20%

In [10]:
data = load_data(path_20)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(52116, 325) 3387540
Node 0
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 7.535285487996442 4.94151428795418 0.07972857724880893 0.8870779903949048 1034.6472074985504
Node 1
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 15.175368021945536 8.169190360498108 0.12214403934403976 0.7725417462955099 472.8891694545746
Node 2
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.005 )
( 11.873133164568015 7.6333650236816615 0.11440546717839482 0.8217802183263296 1190.2848944664001
Node 3
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.01 )
( 9.972937648425004 6.321861782765244 0.10216123663977943 0.8472606499036938 1081.372454404831
Node 4
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.012 )
( 13.396977204264713 9.756582009798533 0.18633352756839738 0.7792280990164643 1143.8376364707947
Node 5
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.014 )
( 11.659197392334741 6.759114931395188 0.12960810203446668 0.813879899586561 844.5400457382202
Node 6
( (

( 16.398129516578038 11.11838812578096 0.31767377332034236 0.7289912322878916 1104.3820600509644
Node 54
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.93197514727444 66.87259615384637 1.0 0.0 0.19748973846435547
Node 55
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 20.425299487482306 10.817095826645454 0.162696066573154 0.6933188581858287 628.2065110206604
Node 56
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 16.63450893603398 9.514300283081951 0.14552051762130827 0.7496023338848479 417.62158846855164
Node 57
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.093 )
( 13.21471140436982 7.50829709236404 0.1586152805693903 0.793255958836298 1515.1595380306244
Node 58
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 55.80089385599664 53.8732804487179 1.0 0.0 0.19931960105895996
Node 59
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.001 )
( 14.443792226567494 9.097263122626542 0.17373201183502449 0.7659753610962101 768.960657119751


( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 11.6487357253528 5.227703851285329 0.0791416896407393 0.8252023652699514 604.797623872757
Node 109
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.008 )
( 15.491975019975044 8.35408306809204 0.12704246418441062 0.7671274112692027 598.5411691665649
Node 110
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.467986957213 66.4029158653845 1.0 0.0 0.19937610626220703
Node 111
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.002 )
( 10.673669735141845 6.026554959218898 0.1519385824743746 0.8326655815005588 1122.048416376114
Node 112
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.027 )
( 23.88339142104978 17.16031911196051 0.6421545470802074 0.5570846100639499 735.4035198688507
Node 113
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.006 )
( 16.088362209486064 9.427546150229643 0.20084311369520758 0.7369758398610492 609.5244009494781
Node 114
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 16.3

( 20.47112679668617 14.10911715060287 0.5051278939234826 0.6231112092671685 1127.9809730052948
Node 161
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 62.53392052659322 61.570851762820745 1.0 0.0 0.19925165176391602
Node 162
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.003 )
( 13.039798840436168 8.527900069541731 0.12722385445661916 0.8044625994579552 824.9181523323059
Node 163
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.016 )
( 10.894153962356908 7.011534935518282 0.10550683352043054 0.8363175228942754 940.2454221248627
Node 164
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.001 )
( 3.5372842514017524 2.585231057172669 0.039162916674503036 0.9468723652467383 1792.4841063022614
Node 165
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.107 )
( 14.370554493960542 8.879844320617195 0.18452217062417411 0.7748187426057556 488.649924993515
Node 166
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.057 )
( 22.295505182258747 15.435247455081429 0.623964780

( 64.18568424578305 63.85480208333319 1.0 0.0 0.19965171813964844
Node 213
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.124 )
( 9.852771101078496 5.460838860237463 0.11272603619628267 0.8451944173899442 963.6824328899384
Node 214
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.007 )
( 10.790597966858096 6.940927294941459 0.13737521159758895 0.8320402859844116 1096.9908142089844
Node 215
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.015 )
( 6.1124024267183525 4.452513019021656 0.072826041059071 0.9060175473115427 1731.0114588737488
Node 216
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.0 )
( 6.658563051351844 3.7392494144567574 0.057084577212929234 0.899176824514621 944.3928382396698
Node 217
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.027 )
( 14.609961116190007 7.227627200530761 0.10903619078138739 0.7803396417876381 627.2305421829224
Node 218
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.035 )
( 8.868004684845173 5.416827961799664 0.08136520331003

( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.045 )
( 5.711902878203843 3.5787064574407093 0.06475911203056649 0.9123782862064316 1531.860105752945
Node 266
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 63.18539426542799 62.85070913461525 1.0 0.0 0.19840383529663086
Node 267
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 8.658558686797493 5.558485809854049 0.11385042828371268 0.8611921420902332 1428.5368492603302
Node 268
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 63.66983118373053 63.48176522435909 1.0 0.0 0.19914817810058594
Node 269
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 65.45688285230504 65.34744791666652 1.0 0.0 0.19947028160095215
Node 270
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.019 )
( 6.403865116591984 3.9812542068730483 0.061329630816722165 0.9033369164583476 1262.59467959404
Node 271
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.003 )
( 5.876304267941636 3.4610988036800556 0.0524792567780704

( 6.544520841891592 3.750853233748726 0.05713403574435852 0.9015969742005325 1199.026612997055
Node 317
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 66.36728873611376 66.30350240384625 1.0 0.0 0.1984858512878418
Node 318
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.004 )
( 7.4315148285397035 5.815320704393486 0.09357343204669154 0.8856259858595636 1621.7532784938812
Node 319
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.0 )
( 59.91357693880895 59.282011217948764 1.0 0.0 0.19862031936645508
Node 320
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.008 )
( 12.671170867632162 7.91191781086905 0.14588305346080158 0.7932493391881795 1188.4743683338165
Node 321
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) 0.052 )
( 15.426903325904464 10.616594879644508 0.2738566985278224 0.7391148835301423 1295.8869717121124
Node 322
( (41668, 12) (41668, 12) (10400, 12) (10400, 12) -0.004 )
( 18.610373586632402 12.240593955270514 0.40537549449028615 0.6714624094850297 1256