In [1]:
import pandas as pd
from sklearn import preprocessing

from gnnad.graphanomaly import GNNAD
from gnnad.plot import plot_test_anomalies, plot_predictions, plot_sensor_error_scores

def normalise(X, scaler_fn):
    scaler = scaler_fn.fit(X)
    return pd.DataFrame(scaler.transform(X), index=X.index, columns = X.columns)

In [2]:
# read in training data
train = pd.read_csv('/nfs/home/canzen/gnnad/wadi_data/train.csv', index_col=0)
X_train = train.iloc[:, :-1]
#X_train.index = pd.to_datetime(X_train.index)

# read in test data
test = pd.read_csv('/nfs/home/canzen/gnnad/wadi_data/test.csv', index_col=0)
#X_tmp.index = pd.to_datetime(X_tmp.index)
X_test = test.iloc[:, :-1]
y_test = test['attack']

# normalise
#X_test = normalise(X_test, preprocessing.StandardScaler())
#X_train = normalise(X_train, preprocessing.StandardScaler())

# create ANOOMS dict for plotting
#ANOMS = {'type1': {}}
#X_test_anoms = X_tmp.iloc[:,8:]

#for i in range(len(X_test_anoms.columns)):
#    anom_col_name = X_test_anoms.columns[i]
#    sensor_col_name = X_test.columns[i]
#    anom_idxs = X_test_anoms[anom_col_name][X_test_anoms[anom_col_name]].index
#
#    if len(anom_idxs) > 0:
#        ANOMS['type1'][sensor_col_name] = anom_idxs

In [3]:
# plot input data
#plot_test_anomalies(X_test, ANOMS)

In [4]:
# run model
model = GNNAD(threshold_type="max_validation", topk=30, 
              slide_win=5, epoch=50, early_stop_win=10, 
              slide_stride=1, embed_dim=128, 
              out_layer_inter_dim=128, device='cpu')
fitted_model = model.fit(X_train, X_test, y_test)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch (0 / 50) (Loss:0.17127507, ACU_loss:3.59677653)
epoch (1 / 50) (Loss:0.07710462, ACU_loss:1.61919711)
epoch (2 / 50) (Loss:0.06010792, ACU_loss:1.26226642)
epoch (3 / 50) (Loss:0.04943165, ACU_loss:1.03806471)
epoch (4 / 50) (Loss:0.04216208, ACU_loss:0.88540373)
epoch (5 / 50) (Loss:0.03608968, ACU_loss:0.75788337)
epoch (6 / 50) (Loss:0.03083129, ACU_loss:0.64745709)
epoch (7 / 50) (Loss:0.02529539, ACU_loss:0.53120310)
epoch (8 / 50) (Loss:0.02146659, ACU_loss:0.45079838)
epoch (9 / 50) (Loss:0.01815354, ACU_loss:0.38122427)
epoch (10 / 50) (Loss:0.01574942, ACU_loss:0.33073779)
epoch (11 / 50) (Loss:0.01422495, ACU_loss:0.29872385)
epoch (12 / 50) (Loss:0.01277856, ACU_loss:0.26834975)
epoch (13 / 50) (Loss:0.01148485, ACU_loss:0.24118177)
epoch (14 / 50) (Loss:0.01068395, ACU_loss:0.22436302)
epoch (15 / 50) (Loss:0.01016480, ACU_loss:0.21346079)
epoch (16 / 50) (Loss:0.00961395, ACU_loss:0.20189295)
epoch (17 / 50) (Loss:0.00909035, ACU_loss:0.19089733)
epoch (18 / 50) (Los

In [5]:
# model summary
fitted_model.summary()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Embedding-1                  [-1, 128]          16,128
         Embedding-2                  [-1, 128]          16,128
            Linear-3                  [-1, 128]             640
    SumAggregation-4               [-1, 1, 128]               0
        GraphLayer-5                  [-1, 128]             128
       BatchNorm1d-6                  [-1, 128]             256
              ReLU-7                  [-1, 128]               0
          GNNLayer-8                  [-1, 128]               0
         Embedding-9                  [-1, 128]          16,128
      BatchNorm1d-10             [-1, 128, 126]             256
          Dropout-11             [-1, 126, 128]               0
           Linear-12               [-1, 126, 1]             129
         OutLayer-13               [-1, 126, 1]               0
Total params: 49,793
Trainable params: 

In [6]:
# GDN+, sensor thresholds
preds = fitted_model.sensor_threshold_preds(tau = 99)
fitted_model.print_eval_metrics(preds)

recall: 90.1
precision: 9.9
accuracy: 51.2
specificity: 48.8
f1: 17.8


In [None]:
# plot predictions
plot_predictions(fitted_model, X_test, ANOMS, preds = preds, figsize=(20, 20))

In [None]:
plot_sensor_error_scores(fitted_model, X_test)