In [1]:
import math
import time
import pandas as pd
import numpy as np
import numpy.linalg as la

from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor

In [2]:
def load_data(path):
    df = pd.read_csv(path)
    data = df.drop(columns=['Unnamed: 0'])
    
    return data

In [3]:
def preprocess_data(data, time_len, rate, seq_len, pre_len):
    train_size = int(time_len * rate)
    train_data = data[0:train_size]
    test_data = data[train_size:time_len]
    
    trainX, trainY, testX, testY = [], [], [], []
    for i in range(len(train_data) - seq_len - pre_len):
        a = train_data[i: i + seq_len + pre_len]
        trainX.append(a[0 : seq_len])
        trainY.append(a[seq_len : seq_len + pre_len])
    for i in range(len(test_data) - seq_len -pre_len):
        b = test_data[i: i + seq_len + pre_len]
        testX.append(b[0 : seq_len])
        testY.append(b[seq_len : seq_len + pre_len])
        
    return trainX, trainY, testX, testY

In [4]:
def getTestY(data):
    tmp_scaler = StandardScaler()
    tmp_data = tmp_scaler.fit_transform(data)
    tmp_data = tmp_scaler.inverse_transform(tmp_data)
    _, _, _, testY = preprocess_data(tmp_data, data.shape[0], train_rate, seq_len, pre_len)
    testY = np.array(testY)
    testY = np.reshape(testY, [-1, pre_len])
    
    return testY

In [5]:
 def evaluation(a,b):
    rmse = math.sqrt(mean_squared_error(a,b))
    mae = mean_absolute_error(a, b)
    mape = mean_absolute_percentage_error(a, b)
    F_norm = la.norm(a-b)/la.norm(a)
    
    return rmse, mae, mape, 1-F_norm

In [6]:
def predictSVR(data, testY):
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    rmses, maes, mapes, accs = [], [], [], []
    for i in range(data.shape[1]):
        print('Node', i)
        start = time.time()
        a = data[:, i]
        aX, aY, tX, tY = preprocess_data(a, data.shape[0], train_rate, seq_len, pre_len)
        
        aX = np.array(aX)
        aX = np.reshape(aX, [-1, seq_len])
        aY = np.array(aY)
        aY = np.reshape(aY, [-1, pre_len])

        tX = np.array(tX)
        tX = np.reshape(tX, [-1, seq_len])
        tY = np.array(tY)
        tY = np.reshape(tY, [-1, pre_len])
        
        print('(', aX.shape, aY.shape, tX.shape, tY.shape, round(aX.mean(),3), ')')
        reg = MultiOutputRegressor(SVR(kernel='linear'))
        reg.fit(aX, aY)
        pred = reg.predict(tX)
        
        mean = scaler.mean_[i]
        std = np.sqrt(scaler.var_[i])
        pred = pred*std + mean
        tY = tY*std + mean
        if i==data.shape[1]-1: tY = testY[tY.shape[0]*i:]
        else: tY = testY[tY.shape[0] * i:tY.shape[0] * (i+1)]
        
        rmse, mae, mape, acc = evaluation(tY, pred)
        rmses.append(rmse)
        maes.append(mae)
        mapes.append(mape)
        accs.append(acc)
        
        print('(', rmse, mae, mape, acc, time.time() - start)
    
    print('RMSE: ' + str(np.mean(rmses)) + ', MAE: ' + str(np.mean(maes)) + ', MAPE: ' + str(np.mean(mapes)) + ', ACC: ' + str(np.mean(accs)))

In [7]:
train_rate = 0.8
seq_len = 12
pre_len = 6

In [8]:
path_0 = 'data/METR-LA/speed_la_0.csv'
path_5 = 'data/METR-LA/speed_la_5.csv'
path_10 = 'data/METR-LA/speed_la_10.csv'
path_20 = 'data/METR-LA/speed_la_20.csv'

In [9]:
masterData = load_data(path_0)
print(masterData.shape, (masterData == 0).sum().sum())
testY = getTestY(masterData)
testY.shape, (testY == 0).sum().sum()

(34272, 207) 0


((1415259, 6), 0)

### Unobserved Node = 10%

In [10]:
data = load_data(path_10)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 685440
Node 0
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 21.942752079139016 14.228088848116748 0.5529735619072663 0.6143398061907871 243.1716160774231
Node 1
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.008 )
( 17.6201409705326 10.569615615232179 0.42364645989671373 0.6932842081613142 525.4387438297272
Node 2
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.016 )
( 13.908776549476851 8.944727539260889 0.19759490845819616 0.7710771339356364 217.4322111606598
Node 3
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.012 )
( 17.484128503036107 13.642478136558514 0.3867677010950514 0.6911352861122729 469.5645875930786
Node 4
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 24.52888942846362 20.000228926110875 0.6506375959991342 0.5400362880085023 295.0154604911804
Node 5
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 14.955161178972348 12.064159689520409 0.20161804163774585 0.7638189230371057 537.4303917884827
Node 6
( (27399, 12) (27399, 6) (6837

( 14.230012278842757 9.492910465495981 0.17054865057374224 0.7787490085956176 343.13099455833435
Node 55
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 11.480201966045449 7.773062838205749 0.18968671973217674 0.8161434065914832 572.534773349762
Node 56
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.156 )
( 23.70605799574874 18.36289401630773 0.31524496356026505 0.6263011607107036 398.977331161499
Node 57
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.005 )
( 14.541061058374476 10.688284003761254 0.18260914362424152 0.7727168954015173 260.4702479839325
Node 58
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.011 )
( 9.777359752777693 6.772563090488759 0.14287566924686765 0.8473797480639704 421.8633563518524
Node 59
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 62.866645227959694 62.551305851757384 1.0 0.0 0.10729289054870605
Node 60
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.089 )
( 20.60405850876582 15.755408119751756 0.25779619433710416 0.6648816904611248 272.1

( 24.304647018168346 16.285530639681756 0.6709779692790304 0.5593019113492357 479.42891454696655
Node 108
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.006 )
( 23.506582279199826 15.809080711274014 0.6093410655790531 0.5773182869886242 370.7323434352875
Node 109
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.017 )
( 15.416219837735529 11.263078685288711 0.32345420738548625 0.7422967108847747 538.4210243225098
Node 110
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.013 )
( 13.25639658408047 7.458022850213069 0.21932101527879588 0.7873257558666211 413.66854524612427
Node 111
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.004 )
( 17.280178885312992 9.984604348559552 0.2536083955078186 0.7232499580097149 231.13935375213623
Node 112
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.004 )
( 16.91808837528116 12.595819089901228 0.27347939637467666 0.7290587076627846 489.8331000804901
Node 113
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.006 )
( 17.552931804087695 11.598573896683332 0

( 25.992739832384288 19.34739581885315 0.47023344798800926 0.5388140177880644 183.1616668701172
Node 160
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 29.980267582539355 23.671549761550597 0.7677790860618559 0.4047065968887802 170.56272339820862
Node 161
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 24.920129140058698 17.456598159225784 0.7584644419075307 0.5352456787826938 264.15870237350464
Node 162
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.016 )
( 18.117496625436505 10.872603550378772 0.28178238395448757 0.703939644874463 195.95208549499512
Node 163
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 29.850652559507473 22.070994533604964 0.42567279752097226 0.5162386913259689 124.12187933921814
Node 164
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 62.08026864411319 61.517459785419156 1.0 0.0 0.10781574249267578
Node 165
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.018 )
( 19.895184912591276 14.220998095091439 0.25614707159794037 0.6784448587

### Unobserved Node = 20%

In [11]:
data = load_data(path_20)
print(data.shape, (data == 0).sum().sum())
predictSVR(data, testY)

(34272, 207) 1405152
Node 0
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 56.89659558174184 54.39488904714701 1.0 0.0 0.08142209053039551
Node 1
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 57.44777882124741 55.526076754788164 1.0 0.0 0.10300517082214355
Node 2
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 60.757480406375294 60.05083228666306 1.0 0.0 0.1018369197845459
Node 3
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.012 )
( 17.484128503036107 13.642478136558514 0.3867677010950514 0.6911352861122729 451.0477488040924
Node 4
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 24.52888942846362 20.000228926110875 0.6506375959991342 0.5400362880085023 290.71330976486206
Node 5
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.037 )
( 14.955161178972348 12.064159689520409 0.20161804163774585 0.7638189230371057 526.9345541000366
Node 6
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.029 )
( 13.773720534124191 9.190637622922315 0.15168333053977331 0.780442721

( 12.662941103839879 11.237992850626 0.18915322932754874 0.806678545399882 572.1602592468262
Node 54
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.017 )
( 14.230012278842757 9.492910465495981 0.17054865057374224 0.7787490085956176 332.20093727111816
Node 55
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.001 )
( 11.480201966045449 7.773062838205749 0.18968671973217674 0.8161434065914832 554.8874337673187
Node 56
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.156 )
( 23.70605799574874 18.36289401630773 0.31524496356026505 0.6263011607107036 384.55732893943787
Node 57
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.005 )
( 14.541061058374476 10.688284003761254 0.18260914362424152 0.7727168954015173 252.41074514389038
Node 58
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 64.06331812946969 63.62375116371942 1.0 0.0 0.10312747955322266
Node 59
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 62.866645227959694 62.551305851757384 1.0 0.0 0.10473275184631348
Node 60
( (27399,

( 24.304647018168346 16.285530639681756 0.6709779692790304 0.5593019113492357 468.81162452697754
Node 108
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 55.61296255692791 52.74841148922803 1.0 0.0 0.10987305641174316
Node 109
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.017 )
( 15.416219837735529 11.263078685288711 0.32345420738548625 0.7422967108847747 522.8262434005737
Node 110
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 62.331932284976816 61.233238620314324 1.0 0.0 0.10603141784667969
Node 111
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 62.43966129522564 61.336523259296264 1.0 0.0 0.10794687271118164
Node 112
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.004 )
( 16.91808837528116 12.595819089901228 0.27347939637467666 0.7290587076627846 476.8409535884857
Node 113
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.006 )
( 17.552931804087695 11.598573896683332 0.2636027565618508 0.7188658661023635 342.90063977241516
Node 114
( (27399, 12) (27399, 6) (6837

( 24.920129140058698 17.456598159225784 0.7584644419075307 0.5352456787826938 255.47685980796814
Node 162
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.016 )
( 18.117496625436505 10.872603550378772 0.28178238395448757 0.703939644874463 189.35168051719666
Node 163
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.036 )
( 29.850652559507473 22.070994533604964 0.42567279752097226 0.5162386913259689 119.96042490005493
Node 164
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.008 )
( 16.90411531016045 11.282932517463992 0.2118610449997698 0.727705506445752 274.0714147090912
Node 165
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.0 )
( 61.87176742405953 61.32000141426246 1.0 0.0 0.08139157295227051
Node 166
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) 0.004 )
( 12.371617825860854 7.862494183936035 0.19068308441404822 0.8046065782345236 224.44416332244873
Node 167
( (27399, 12) (27399, 6) (6837, 12) (6837, 6) -0.009 )
( 16.34863575130283 9.464389600032629 0.3128441852972918 0.728622839384285