In [1]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [2]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

In [3]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [4]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Loading the Data

In [5]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [6]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [7]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.explanation.navigator.model import Navigator
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'raw', 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix, the node indices and the
# matrix itself.
header, node_ids_dict, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the STGNN in evaluation mode.
spatial_temporal_gnn.eval();

# Get the Navigator and load the checkpoints.
navigator = Navigator(DEVICE)

navigator_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                          'navigator_metr_la.pth')

navigator_checkpoints = torch.load(navigator_checkpoints_path)
navigator.load_state_dict(navigator_checkpoints['model_state_dict'])

# Set the Navigator in evaluation mode.
navigator.eval();



In [8]:
from src.data.data_extraction import get_locations_dataframe

# Get the dataframe containing the latitude and longitude of each sensor.
locations_df = get_locations_dataframe(
    os.path.join(BASE_DATA_DIR, 'raw', 'graph_sensor_locations_metr_la.csv'),
    has_header=True)

In [9]:
# Get the node positions dictionary.
node_pos_dict = { i: id for id, i in node_ids_dict.items() }

In [10]:
import pickle

# Get the data scaler.
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [11]:
import os
import numpy as np

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_train.npy'))
y_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_train.npy'))
x_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_val.npy'))
y_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_val.npy'))
x_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_test.npy'))
y_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_test.npy'))

# Get the time intervals.
x_test_time = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_test_time.npy'))
y_test_time = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_test_time.npy'))

In [12]:
x_sample, y_sample = x_train[0], y_train[0]

In [13]:
explanation_size = min((y_sample.flatten() != 0).sum() * 2, 500)

In [14]:
print(explanation_size)

24


In [15]:
np.linspace(.1, .12, 12, dtype=np.float32, endpoint=False)[::-1]

array([0.11833333, 0.11666667, 0.115     , 0.11333334, 0.11166666,
       0.11      , 0.10833333, 0.10666667, 0.105     , 0.10333333,
       0.10166667, 0.1       ], dtype=float32)

In [16]:
from src.explanation.monte_carlo.search import get_best_input_subset

subset = get_best_input_subset(
    x_sample,
    y_sample,
    adj_matrix,
    spatial_temporal_gnn,
    navigator,
    scaler,
    verbose=True,
    n_rollouts=10,
    maximum_leaf_size=explanation_size)


Execution 1/10
mae: 9.096492767333984
Execution 2/10
mae: 9.091655731201172
Execution 3/10
mae: 9.091655731201172
Execution 4/10
mae: 9.091655731201172
Execution 5/10
mae: 9.091655731201172
Execution 6/10
mae: 9.091655731201172
Execution 7/10
mae: 9.091655731201172
Execution 8/10
mae: 9.091655731201172
Execution 9/10
mae: 9.091655731201172
Execution 10/10
mae: 9.091655731201172


In [17]:
input_events_subset = [ ( 0, e[0], e[1] ) for e in subset.input_events ]

In [18]:
#print(input_events_subset)

In [19]:
from src.explanation.events import remove_features_by_events

x_subset = x_sample.copy()

x_subset = remove_features_by_events(x_subset, input_events_subset)

#for e in input_events_subset:
#    x_subset[e[1], e[2], 0] = 0.
    
x_subset = x_subset[..., :1]

In [20]:
print(x_subset.shape)
print(y_sample.shape)

(12, 207, 1)
(12, 207, 1)


In [21]:
# Concatenate the input events subset with the output events.
explained_instance = np.concatenate((x_subset, y_sample), axis=0)


MPH_TO_KMH_FACTOR = 1.609344
explained_instance *= MPH_TO_KMH_FACTOR

In [22]:
print(explained_instance.shape)

(24, 207, 1)


In [23]:
clusters = np.zeros_like(explained_instance)
clusters[12:] = 1.
clusters[explained_instance == 0.] = -1

In [24]:
from src.explanation.clustering.analyisis import (
    get_node_values_with_clusters_and_location_dataframe)

location_df_with_clusters = \
    get_node_values_with_clusters_and_location_dataframe(
        explained_instance, clusters, node_pos_dict, locations_df)

In [25]:
from keplergl.keplergl import KeplerGl

m = KeplerGl(height=800, show_docs=False, data={'data': location_df_with_clusters})

In [26]:
m

KeplerGl(data={'data':      sensor_id  latitude  longitude  cluster  speed  datetime
0       773869  34.15497 …

In [27]:
from src.explanation.monte_carlo.search import get_explanations_from_data

x_explained, y_explained = get_explanations_from_data(
    x_test[:500],
    y_test[:500],
    adj_matrix,
    spatial_temporal_gnn,
    navigator,
    scaler,
    n_rollouts=10,
)

In [28]:
import os
import numpy as np

EXPLAINED_DATA_DIR = os.path.join(BASE_DATA_DIR, 'explained')

# Save the explained data.
np.save(os.path.join(EXPLAINED_DATA_DIR, 'x_test.npy'), x_explained)
np.save(os.path.join(EXPLAINED_DATA_DIR, 'y_test.npy'), y_explained)

# Save the explained time information of the datasets.
np.save(os.path.join(EXPLAINED_DATA_DIR, 'x_test_time.npy'), x_test_time[:500])
np.save(os.path.join(EXPLAINED_DATA_DIR, 'y_test_time.npy'), y_test_time[:500])