In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 6

In [8]:
path_0 = 'data/METR-LA/speed_la_0.csv'
path_5 = 'data/METR-LA/speed_la_5.csv'
path_10 = 'data/METR-LA/speed_la_10.csv'
path_20 = 'data/METR-LA/speed_la_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(34272, 207) 0


((1415259, 6), 0)

### Unobserved Node = 0%

In [10]:
data = load_data(path_0)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 0
Node 0
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 21.942752079139016 14.228088848116748 0.5529735619072663 0.6143398061907871 240.37309789657593
Node 1
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.008 )
( 17.6201409705326 10.569615615232179 0.42364645989671373 0.6932842081613142 523.3573937416077
Node 2
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.016 )
( 13.908776549476851 8.944727539260889 0.19759490845819616 0.7710771339356364 216.4753339290619
Node 3
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.012 )
( 17.484128503036107 13.642478136558514 0.3867677010950514 0.6911352861122729 460.61068415641785
Node 4
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 24.52888942846362 20.000228926110875 0.6506375959991342 0.5400362880085023 296.1851348876953
Node 5
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 14.955161178972348 12.064159689520409 0.20161804163774585 0.7638189230371057 537.655561208725
Node 6
( (27399, 12) (27399, 6) (6837, 12

( 18.72442097204015 12.347813753921201 0.19694796231824374 0.7110050933737064 210.2603108882904
Node 53
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.018 )
( 12.662941103839879 11.237992850626 0.18915322932754874 0.806678545399882 588.9319169521332
Node 54
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.017 )
( 14.230012278842757 9.492910465495981 0.17054865057374224 0.7787490085956176 342.99280881881714
Node 55
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 11.480201966045449 7.773062838205749 0.18968671973217674 0.8161434065914832 569.395592212677
Node 56
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.156 )
( 23.70605799574874 18.36289401630773 0.31524496356026505 0.6263011607107036 397.4082827568054
Node 57
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.005 )
( 14.541061058374476 10.688284003761254 0.18260914362424152 0.7727168954015173 259.7489187717438
Node 58
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 9.777359752777693 6.772563090488759 0.1428756692468

( 12.259744899859097 7.961192087658016 0.15882013345102863 0.8059617816752735 293.8764657974243
Node 105
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.006 )
( 9.056167720649235 5.44435046771531 0.11298713681744639 0.8579690193916952 463.86428141593933
Node 106
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.002 )
( 14.695667939252589 8.388612861118277 0.25063014441172476 0.7575928351062621 430.9281997680664
Node 107
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 24.304647018168346 16.285530639681756 0.6709779692790304 0.5593019113492357 474.2232756614685
Node 108
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.006 )
( 23.506582279199826 15.809080711274014 0.6093410655790531 0.5773182869886242 365.6978530883789
Node 109
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.017 )
( 15.416219837735529 11.263078685288711 0.32345420738548625 0.7422967108847747 533.6334462165833
Node 110
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.013 )
( 13.25639658408047 7.458022850213069 0.219

( 9.431796075021156 7.241813945985818 0.14533608059952305 0.8436463665658702 476.3340675830841
Node 157
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.016 )
( 17.48202614975135 10.670831326490628 0.26090618920040076 0.7118655883481435 252.83797812461853
Node 158
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.004 )
( 18.512495680661047 13.113284029882033 0.28758022715804343 0.6883899277056098 361.8340766429901
Node 159
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 25.992739832384288 19.34739581885315 0.47023344798800926 0.5388140177880644 179.85464191436768
Node 160
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 29.980267582539355 23.671549761550597 0.7677790860618559 0.4047065968887802 167.92269206047058
Node 161
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 24.920129140058698 17.456598159225784 0.7584644419075307 0.5352456787826938 262.1902275085449
Node 162
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.016 )
( 18.117496625436505 10.872603550378772 0.

### Unobserved Node = 5%

In [11]:
data = load_data(path_5)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 342720
Node 0
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 21.942752079139016 14.228088848116748 0.5529735619072663 0.6143398061907871 233.13194966316223
Node 1
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.008 )
( 17.6201409705326 10.569615615232179 0.42364645989671373 0.6932842081613142 506.6892156600952
Node 2
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.016 )
( 13.908776549476851 8.944727539260889 0.19759490845819616 0.7710771339356364 210.61586952209473
Node 3
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.012 )
( 17.484128503036107 13.642478136558514 0.3867677010950514 0.6911352861122729 449.0275089740753
Node 4
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 24.52888942846362 20.000228926110875 0.6506375959991342 0.5400362880085023 284.8460192680359
Node 5
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 14.955161178972348 12.064159689520409 0.20161804163774585 0.7638189230371057 521.0201258659363
Node 6
( (27399, 12) (27399, 6) (68

( 18.72442097204015 12.347813753921201 0.19694796231824374 0.7110050933737064 203.10964918136597
Node 53
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.018 )
( 12.662941103839879 11.237992850626 0.18915322932754874 0.806678545399882 568.0560421943665
Node 54
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.017 )
( 14.230012278842757 9.492910465495981 0.17054865057374224 0.7787490085956176 332.6990211009979
Node 55
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 11.480201966045449 7.773062838205749 0.18968671973217674 0.8161434065914832 554.7440645694733
Node 56
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.156 )
( 23.70605799574874 18.36289401630773 0.31524496356026505 0.6263011607107036 381.3267776966095
Node 57
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.005 )
( 14.541061058374476 10.688284003761254 0.18260914362424152 0.7727168954015173 251.1396210193634
Node 58
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 9.777359752777693 6.772563090488759 0.142875669246

( 9.056167720649235 5.44435046771531 0.11298713681744639 0.8579690193916952 456.2209384441376
Node 106
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.002 )
( 14.695667939252589 8.388612861118277 0.25063014441172476 0.7575928351062621 421.9783537387848
Node 107
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 24.304647018168346 16.285530639681756 0.6709779692790304 0.5593019113492357 466.1492929458618
Node 108
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.006 )
( 23.506582279199826 15.809080711274014 0.6093410655790531 0.5773182869886242 356.58567786216736
Node 109
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.017 )
( 15.416219837735529 11.263078685288711 0.32345420738548625 0.7422967108847747 521.0873057842255
Node 110
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.013 )
( 13.25639658408047 7.458022850213069 0.21932101527879588 0.7873257558666211 398.68617963790894
Node 111
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.004 )
( 17.280178885312992 9.984604348559552 0.25

( 18.512495680661047 13.113284029882033 0.28758022715804343 0.6883899277056098 347.31033515930176
Node 159
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 25.992739832384288 19.34739581885315 0.47023344798800926 0.5388140177880644 170.9080159664154
Node 160
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 29.980267582539355 23.671549761550597 0.7677790860618559 0.4047065968887802 159.26699423789978
Node 161
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 24.920129140058698 17.456598159225784 0.7584644419075307 0.5352456787826938 250.22873497009277
Node 162
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.016 )
( 18.117496625436505 10.872603550378772 0.28178238395448757 0.703939644874463 186.8733265399933
Node 163
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 29.850652559507473 22.070994533604964 0.42567279752097226 0.5162386913259689 117.52667951583862
Node 164
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.008 )
( 16.90411531016045 11.282932517463992 0