In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 12

In [8]:
path_0 = 'data/METR-LA/speed_la_0.csv'
path_5 = 'data/METR-LA/speed_la_5.csv'
path_10 = 'data/METR-LA/speed_la_10.csv'
path_20 = 'data/METR-LA/speed_la_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(34272, 207) 0


((1414017, 12), 0)

### Unobserved Node = 10%

In [10]:
data = load_data(path_10)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 685440
Node 0
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.011 )
( 22.46590989448668 14.809909925230324 0.6052393200827606 0.5996709408821872 531.7773349285126
Node 1
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.008 )
( 16.369998085694352 9.62728734655152 0.36223680766280003 0.7183419763689802 971.9055335521698
Node 2
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.016 )
( 13.954387515173028 8.99717301574813 0.2051055760072642 0.76978251954308 448.74230790138245
Node 3
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.012 )
( 17.663613972970705 13.894068976288702 0.41332293398215314 0.684144030335341 864.0362594127655
Node 4
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 24.30594300148395 19.804676471648843 0.6290707851491036 0.5522429287262018 583.3231151103973
Node 5
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 14.307503321427916 11.58365098571094 0.19046487285567457 0.7743458527280562 944.9260218143463
Node 6
( (27393, 12) (27393, 1

( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.017 )
( 13.98387365251959 9.408050751121264 0.17168749367354996 0.7822207498929153 670.5534138679504
Node 55
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 11.451803740730176 7.750567051261693 0.18894603101342952 0.8166312230554171 1110.0499987602234
Node 56
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.156 )
( 23.54962115675845 18.289445702138558 0.31232458004535824 0.6292501722747075 811.5533835887909
Node 57
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.005 )
( 14.43198841842356 10.640801720239617 0.18338289418354461 0.7741341282867684 542.0399901866913
Node 58
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.012 )
( 9.218312488708044 6.41723473800581 0.13535845055693396 0.8562341914668634 811.8851146697998
Node 59
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 62.73339257743232 62.41439984055719 1.0 0.0 0.15608978271484375
Node 60
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.089 )
( 19.92968922378718

( 14.721027082526716 8.375704744903917 0.25223447835299023 0.7571989426223773 874.3375895023346
Node 107
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.011 )
( 24.26492947920706 16.255375282700083 0.668369082099083 0.5601548483362125 954.8273935317993
Node 108
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.006 )
( 23.5215803254931 15.783546805896682 0.609488250053604 0.5772678045999733 755.4489455223083
Node 109
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.017 )
( 15.365308358919878 11.208840858982944 0.32490044470227164 0.7428724306930907 945.7476930618286
Node 110
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.012 )
( 13.094748960503203 7.294152472570102 0.21608209659713784 0.7899550562102844 768.6581254005432
Node 111
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.004 )
( 16.95141091520483 9.814082189885255 0.2511762536799874 0.7285094529353582 478.62315130233765
Node 112
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.004 )
( 16.882253147478657 12.597454059691

( 18.20556197036279 12.872411993494469 0.2860502794136222 0.6934395795239836 706.9466328620911
Node 159
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 25.51181424308482 18.978619537972353 0.45610640868934 0.5494680734675471 369.3043887615204
Node 160
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 29.58037981277479 23.321530234041745 0.7469628494486876 0.4156552435526585 356.7027678489685
Node 161
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.001 )
( 25.32832835710435 17.856707458956535 0.799718788039263 0.5232552247891401 513.7611346244812
Node 162
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.016 )
( 17.834333072232187 10.760872163937348 0.2839682550740564 0.7079194465598779 410.49827885627747
Node 163
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 29.428188106821878 21.690957370054956 0.4181602419373811 0.5224681208465485 274.7805588245392
Node 164
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 62.186779806794654 61.61528242028451 1.0

### Unobserved Node = 20%

In [11]:
data = load_data(path_20)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 1405152
Node 0
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 56.118608886384024 53.399124381612275 1.0 0.0 0.1139061450958252
Node 1
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 58.12011983418416 56.44941079377471 1.0 0.0 0.1116938591003418
Node 2
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 60.61393551644 59.85311039232138 1.0 0.0 0.11066007614135742
Node 3
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.012 )
( 17.663613972970705 13.894068976288702 0.41332293398215314 0.684144030335341 828.264018535614
Node 4
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 24.30594300148395 19.804676471648843 0.6290707851491036 0.5522429287262018 556.5575308799744
Node 5
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 14.307503321427916 11.58365098571094 0.19046487285567457 0.7743458527280562 905.6521325111389
Node 6
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.029 )
( 13.444152799545728 9.031709994212202 0.1498686146430938 0.784

( 12.503875359331378 11.155077631196233 0.18345452525395642 0.8090503966927881 1083.4863531589508
Node 54
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.017 )
( 13.98387365251959 9.408050751121264 0.17168749367354996 0.7822207498929153 623.3970518112183
Node 55
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 11.451803740730176 7.750567051261693 0.18894603101342952 0.8166312230554171 1050.2097566127777
Node 56
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.156 )
( 23.54962115675845 18.289445702138558 0.31232458004535824 0.6292501722747075 765.2821226119995
Node 57
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.005 )
( 14.43198841842356 10.640801720239617 0.18338289418354461 0.7741341282867684 510.1839168071747
Node 58
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 64.12033975785896 63.7014747631596 1.0 0.0 0.11130213737487793
Node 59
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 62.73339257743232 62.41439984055719 1.0 0.0 0.11031126976013184
Node 60


( 24.26492947920706 16.255375282700083 0.668369082099083 0.5601548483362125 897.6498305797577
Node 108
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 55.64180013125059 52.772091167018715 1.0 0.0 0.1435234546661377
Node 109
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.017 )
( 15.365308358919878 11.208840858982944 0.32490044470227164 0.7428724306930907 885.6337349414825
Node 110
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 62.34260498844857 61.2478510906549 1.0 0.0 0.14773774147033691
Node 111
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 62.43830990980543 61.33459857372886 1.0 0.0 0.14802074432373047
Node 112
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.004 )
( 16.882253147478657 12.597454059691318 0.27330626412811104 0.7296170704834 983.9182538986206
Node 113
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.006 )
( 16.860798509662065 11.186712478027344 0.25733069930248964 0.729960683849098 663.764306306839
Node 114
( (27393, 12) (27393, 12) (68

( 25.32832835710435 17.856707458956535 0.799718788039263 0.5232552247891401 462.19549894332886
Node 162
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.016 )
( 17.834333072232187 10.760872163937348 0.2839682550740564 0.7079194465598779 368.3841269016266
Node 163
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 29.428188106821878 21.690957370054956 0.4181602419373811 0.5224681208465485 252.3613040447235
Node 164
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.008 )
( 16.416224737663352 11.00447391487122 0.2090893071395994 0.7360174495501104 529.894504070282
Node 165
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 61.73759774482919 61.193247069847324 1.0 0.0 0.10858654975891113
Node 166
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.004 )
( 11.83396567649907 7.414679380250033 0.18457977732830302 0.8131390167505262 396.222181558609
Node 167
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.009 )
( 15.924894989497503 9.104836664851542 0.2993046345050212 0.736675587