In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 12

In [8]:
path_0 = 'data/METR-LA/speed_la_0.csv'
path_5 = 'data/METR-LA/speed_la_5.csv'
path_10 = 'data/METR-LA/speed_la_10.csv'
path_20 = 'data/METR-LA/speed_la_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(34272, 207) 0


((1414017, 12), 0)

### Unobserved Node = 0%

In [10]:
data = load_data(path_0)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 0
Node 0
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.011 )
( 22.46590989448668 14.809909925230324 0.6052393200827606 0.5996709408821872 528.1569821834564
Node 1
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.008 )
( 16.369998085694352 9.62728734655152 0.36223680766280003 0.7183419763689802 965.1146650314331
Node 2
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.016 )
( 13.954387515173028 8.99717301574813 0.2051055760072642 0.76978251954308 449.2529525756836
Node 3
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.012 )
( 17.663613972970705 13.894068976288702 0.41332293398215314 0.684144030335341 862.8283324241638
Node 4
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 24.30594300148395 19.804676471648843 0.6290707851491036 0.5522429287262018 579.7229969501495
Node 5
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 14.307503321427916 11.58365098571094 0.19046487285567457 0.7743458527280562 944.0308573246002
Node 6
( (27393, 12) (27393, 12) (68

( 17.74058366760148 15.276562049771078 0.2572909781440615 0.7232976009842818 907.9032511711121
Node 52
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.018 )
( 18.241764491853875 12.01976215524247 0.19546424765743584 0.7188989234205028 446.3124647140503
Node 53
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.018 )
( 12.503875359331378 11.155077631196233 0.18345452525395642 0.8090503966927881 1148.5436871051788
Node 54
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.017 )
( 13.98387365251959 9.408050751121264 0.17168749367354996 0.7822207498929153 670.005532503128
Node 55
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 11.451803740730176 7.750567051261693 0.18894603101342952 0.8166312230554171 1113.8374404907227
Node 56
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.156 )
( 23.54962115675845 18.289445702138558 0.31232458004535824 0.6292501722747075 815.0039775371552
Node 57
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.005 )
( 14.43198841842356 10.640801720239617 0

( 13.929062505428321 8.99096877421119 0.16674414027488915 0.7798604811853035 532.368246793747
Node 104
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.01 )
( 11.929485634132186 7.810883306675844 0.15653131188830324 0.8111912030341776 584.9362595081329
Node 105
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.005 )
( 9.056475450476954 5.447476555720023 0.11293511618344959 0.8579458943258511 886.1125609874725
Node 106
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.002 )
( 14.721027082526716 8.375704744903917 0.25223447835299023 0.7571989426223773 864.8058452606201
Node 107
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.011 )
( 24.26492947920706 16.255375282700083 0.668369082099083 0.5601548483362125 948.9503235816956
Node 108
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.006 )
( 23.5215803254931 15.783546805896682 0.609488250053604 0.5772678045999733 750.1178011894226
Node 109
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.017 )
( 15.365308358919878 11.208840858982944 

( 26.15956222621815 17.53386498733229 0.31096463983166484 0.5737399721518499 357.13479471206665
Node 156
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.033 )
( 9.416725968490534 7.213761306953863 0.1457228875626451 0.8434716293735813 887.8183491230011
Node 157
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.016 )
( 16.881957565570996 10.288951978823487 0.2499989112145491 0.722575076656768 515.501344203949
Node 158
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.004 )
( 18.20556197036279 12.872411993494469 0.2860502794136222 0.6934395795239836 705.9232666492462
Node 159
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 25.51181424308482 18.978619537972353 0.45610640868934 0.5494680734675471 367.141544342041
Node 160
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 29.58037981277479 23.321530234041745 0.7469628494486876 0.4156552435526585 356.1009645462036
Node 161
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.001 )
( 25.32832835710435 17.856707458956535 0.7

### Unobserved Node = 5%

In [11]:
data = load_data(path_5)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 342720
Node 0
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.011 )
( 22.46590989448668 14.809909925230324 0.6052393200827606 0.5996709408821872 492.6568636894226
Node 1
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.008 )
( 16.369998085694352 9.62728734655152 0.36223680766280003 0.7183419763689802 907.5411438941956
Node 2
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.016 )
( 13.954387515173028 8.99717301574813 0.2051055760072642 0.76978251954308 419.58594822883606
Node 3
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.012 )
( 17.663613972970705 13.894068976288702 0.41332293398215314 0.684144030335341 815.4414632320404
Node 4
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 24.30594300148395 19.804676471648843 0.6290707851491036 0.5522429287262018 545.8853962421417
Node 5
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.037 )
( 14.307503321427916 11.58365098571094 0.19046487285567457 0.7743458527280562 886.7333016395569
Node 6
( (27393, 12) (27393, 1

( 18.241764491853875 12.01976215524247 0.19546424765743584 0.7188989234205028 419.494726896286
Node 53
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.018 )
( 12.503875359331378 11.155077631196233 0.18345452525395642 0.8090503966927881 1085.3528008460999
Node 54
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.017 )
( 13.98387365251959 9.408050751121264 0.17168749367354996 0.7822207498929153 615.4114727973938
Node 55
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.0 )
( 11.451803740730176 7.750567051261693 0.18894603101342952 0.8166312230554171 1028.869057416916
Node 56
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.156 )
( 23.54962115675845 18.289445702138558 0.31232458004535824 0.6292501722747075 746.040450334549
Node 57
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.005 )
( 14.43198841842356 10.640801720239617 0.18338289418354461 0.7741341282867684 496.825786113739
Node 58
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.012 )
( 9.218312488708044 6.41723473800581 0.135

( 11.929485634132186 7.810883306675844 0.15653131188830324 0.8111912030341776 529.7046096324921
Node 105
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.005 )
( 9.056475450476954 5.447476555720023 0.11293511618344959 0.8579458943258511 807.729691028595
Node 106
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.002 )
( 14.721027082526716 8.375704744903917 0.25223447835299023 0.7571989426223773 791.4165427684784
Node 107
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.011 )
( 24.26492947920706 16.255375282700083 0.668369082099083 0.5601548483362125 868.4013366699219
Node 108
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.006 )
( 23.5215803254931 15.783546805896682 0.609488250053604 0.5772678045999733 682.9540324211121
Node 109
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.017 )
( 15.365308358919878 11.208840858982944 0.32490044470227164 0.7428724306930907 855.5187001228333
Node 110
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) -0.012 )
( 13.094748960503203 7.29415247257010

( 18.20556197036279 12.872411993494469 0.2860502794136222 0.6934395795239836 607.4441118240356
Node 159
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 25.51181424308482 18.978619537972353 0.45610640868934 0.5494680734675471 319.57831621170044
Node 160
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 29.58037981277479 23.321530234041745 0.7469628494486876 0.4156552435526585 310.626256942749
Node 161
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.001 )
( 25.32832835710435 17.856707458956535 0.799718788039263 0.5232552247891401 439.9719805717468
Node 162
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.016 )
( 17.834333072232187 10.760872163937348 0.2839682550740564 0.7079194465598779 349.4061794281006
Node 163
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.036 )
( 29.428188106821878 21.690957370054956 0.4181602419373811 0.5224681208465485 238.2559196949005
Node 164
( (27393, 12) (27393, 12) (6831, 12) (6831, 12) 0.008 )
( 16.416224737663352 11.00447391487122 0.