In [30]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [31]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [33]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Loading the Data

In [34]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [35]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [36]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.explanation.navigator.model import Navigator
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'raw', 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix, the node indices and the
# matrix itself.
header, node_ids_dict, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the STGNN in evaluation mode.
spatial_temporal_gnn.eval();

# Get the Navigator and load the checkpoints.
navigator = Navigator(DEVICE)

navigator_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                          'navigator_metr_la.pth')

navigator_checkpoints = torch.load(navigator_checkpoints_path)
navigator.load_state_dict(navigator_checkpoints['model_state_dict'])

# Set the Navigator in evaluation mode.
navigator.eval();



In [37]:
import pickle

# Get the data scaler.
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [38]:
import os
import numpy as np

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_train.npy'))
y_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_train.npy'))
x_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_val.npy'))
y_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_val.npy'))
x_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_test.npy'))
y_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_test.npy'))

In [39]:
x_sample, y_sample = x_train[0], y_train[0]

In [45]:
explanation_size = min((y_sample.flatten() != 0).sum() * 2, 500)

In [47]:
print(explanation_size)

24


In [48]:
from src.explanation.monte_carlo.search import get_best_input_subset

subset = get_best_input_subset(
    x_sample,
    y_sample,
    adj_matrix,
    spatial_temporal_gnn,
    navigator,
    scaler,
    verbose=True,
    maximum_leaf_size=explanation_size)


Execution 1/50
mae: 9.096492767333984
Execution 2/50
mae: 9.091655731201172
Execution 3/50
mae: 9.091655731201172
Execution 4/50
mae: 9.091655731201172
Execution 5/50
mae: 9.091655731201172
Execution 6/50
mae: 9.091655731201172
Execution 7/50
mae: 9.091655731201172
Execution 8/50
mae: 9.091655731201172
Execution 9/50
mae: 9.091655731201172
Execution 10/50
mae: 9.091655731201172
Execution 11/50
mae: 9.091655731201172
Execution 12/50
mae: 9.091655731201172
Execution 13/50
mae: 9.091655731201172
Execution 14/50
mae: 9.091655731201172
Execution 15/50
mae: 9.091655731201172
Execution 16/50
mae: 9.091655731201172
Execution 17/50
mae: 9.091655731201172
Execution 18/50
mae: 9.091655731201172
Execution 19/50
mae: 9.091655731201172
Execution 20/50
mae: 9.091655731201172
Execution 21/50
mae: 9.091655731201172
Execution 22/50
mae: 9.091655731201172
Execution 23/50
mae: 9.091655731201172
Execution 24/50
mae: 9.091655731201172
Execution 25/50
mae: 9.091655731201172
Execution 26/50
mae: 9.09165573120

In [50]:
print(subset.input_events)

[(10, 3, 61.111111111111114, 0.03474635163307853, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 62, 34.5, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (9, 3, 58.125, 0.03127171646977067, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (10, 62, 31.555555555555557, 0.03474635163307853, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 18, 54.125, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (10, 14, 54.111111111111114, 0.03474635163307853, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 22, 52.125, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 9, 58.25, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 0, 62.25, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 14, 55.25, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (11, 3, 60.0, 0.03822098679638638, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (10, 21, 49.55555555555556, 0.03474635163307853, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), (5, 18, 41.333333333333336, 0.017373175816539264, 0.0, 0.0, 0.0, 1.0, 0.0, 0.