In [27]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [28]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [30]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Loading the Data

In [31]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [32]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [33]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.explanation.navigator.model import Navigator
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix, the node indices and the
# matrix itself.
header, node_ids_dict, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the STGNN in evaluation mode.
spatial_temporal_gnn.eval();

# Get the Navigator and load the checkpoints.
navigator = Navigator(DEVICE)

navigator_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                          'navigator_metr_la.pth')

navigator_checkpoints = torch.load(navigator_checkpoints_path)
navigator.load_state_dict(navigator_checkpoints['model_state_dict'])

# Set the Navigator in evaluation mode.
navigator.eval();



In [34]:
import pickle

# Get the data scaler.
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [35]:
import os
import numpy as np

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_train.npy'))
y_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_train.npy'))
x_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_val.npy'))
y_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_val.npy'))
x_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_test.npy'))
y_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_test.npy'))

In [36]:
x_train[0][x_train[0] > 0]

array([6.43750000e+01, 1.00000000e+00, 6.76250000e+01, ...,
       5.94285714e+01, 3.82209868e-02, 1.00000000e+00])

In [37]:
y_train[0][y_train[0] > 0]

array([68.54322 , 68.69324 , 69.817474, 70.409035, 69.05133 , 69.580765,
       68.64413 , 69.151245, 69.29832 , 69.345566, 69.57638 , 68.65461 ],
      dtype=float32)

In [38]:
### TRY COMPUTING THE INPUT EVENT SCORES BEFORE BEFORE GOING THROUGH THE MCTS
### ALGORITHM.


In [39]:
from copy import deepcopy
from typing import List, Tuple
from collections import defaultdict
import math

from src.spatial_temporal_gnn.metrics import MAE
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.explanation.navigator.model import Navigator
from src.explanation.events import (
    remove_features_by_events, get_largest_event_set)


class Node():
    """
    A representation of a single state.
    MCTS works by constructing a tree of these Nodes.
    Could be e.g. a chess or checkers board state.
    """
    def __init__(
        self, input_events: List[List[int]]) -> None:
        self.input_events = input_events

    def find_children(self) -> List[List[int]]:
        """
        Get all possible successors of the current node.

        Returns
        -------
        set of Node
            All possible successors of the current node.
        """
        children = []

        for i, _ in enumerate(self.input_events):
            input_events_subset = self.input_events[:i] + self.input_events[i+1:]
            children.append(Node([ e for e in input_events_subset ]))

        return children

    def is_terminal(self, leaf_size: int) -> bool:
        """
        Returns True if the node has less than or equal to `leaf_size`
        events.

        Parameters
        ----------
        leaf_size : int
            The maximum number of events allowed in a leaf node.

        Returns
        -------
        bool
            Whether or not the node is terminal.
        """
        return len(self.input_events) <= leaf_size

    def reward(
        self, spatial_temporal_gnn: SpatialTemporalGNN,
        x: torch.FloatTensor, y: torch.FloatTensor) -> float:
        """Get the reward of the current node in terms of the negative
        Mean Absolute Error (MAE) between the predicted output data,
        given the subset of input events expressed by the current node,
        and the actual output data.

        Parameters
        ----------
        spatial_temporal_gnn : SpatialTemporalGNN
            The Spatial Temporal GNN model used to predict the output
            events.
        x : FloatTensor
            The input data.
        y : FloatTensor
            The output data.

        Returns
        -------
        float
            The negative MAE between the predicted output data and the
            actual output data.
        """
        x = x.clone()
        # Set the MAE criterion.
        mae_criterion = MAE()
        # Get the device of the spatial temporal GNN.
        device = spatial_temporal_gnn.device

        # Set the input events as a list.
        input_events = [[0, e[0], e[1]] for e in self.input_events]
        # Remove the features corresponding to the input events in
        # the input data.
        x_subset = remove_features_by_events(x, input_events)
        x_subset = scaler.scale(x_subset)

        # Predict the output events.
        y_pred = spatial_temporal_gnn(x_subset.unsqueeze(0).to(device))
        
        y_pred[y.unsqueeze(0) == 0] = 0
        y_pred = scaler.un_scale(y_pred)

        # Compute the reward as the negative MAE between the predicted
        # output events and the actual output events.
        reward = - mae_criterion(y_pred, y.unsqueeze(0).to(device)).item()
        return reward

    def __hash__(self) -> int:
        """Hash the node by the input events.

        Returns
        -------
        int
            The hash of the node.
        """
        return hash(frozenset(self.input_events))

    def __eq__(node1: 'Node', node2: 'Node') -> bool:
        """Get whether or not two nodes are equal.
        A node is equal to another node if they have the same input
        events set.

        Parameters
        ----------
        node1 : Node
            The first node to compare.
        node2 : Node
            The second node to compare.

        Returns
        -------
        bool
            Whether or not the two nodes are equal.
        """
        return frozenset(node1.input_events) == frozenset(node2.input_events)
        #return node1.input_events == node2.input_events

class MonteCarloTreeSearch:
    "Monte Carlo Tree Search. First rollout the tree then choose a move."

    def __init__(
        self, spatial_temporal_gnn: SpatialTemporalGNN, navigator: Navigator,
        x: torch.FloatTensor, y: torch.FloatTensor,
        maximum_leaf_size: int = 20, exploration_weight: int = 1) -> None:
        """Initialize the MCTS.

        Parameters
        ----------
        spatial_temporal_gnn : SpatialTemporalGNN
            The Spatial Temporal Graph Neural Network used to get the
            reward of a leaf node.
        navigator : Navigator
            The Navigator used to select which node to expand during
            the tree search.
        exploration_weight : int, optional
            The exploration weight used in the Upper Confidence Bound
            for Trees (UCT) formula, by default 1.
        """
        # Set dictionary of total reward of each node.
        self.C = defaultdict(int)
        # Set dictionary of total visit count for each node.
        self.N = defaultdict(int)
        # Set dictionary of children of each node.
        self.children = dict()
        # Set dictionary of expanded children of each node.
        self.expanded_children = dict()
        # Set the best found leaf node along with its error.
        self.best_leaf = ( None, - math.inf )
        # Set the exploration weight.
        self.exploration_weight = exploration_weight
        # Set the Spatial Temporal Graph Neural Network.
        self.spatial_temporal_gnn = spatial_temporal_gnn
        # Set the Navigator.
        self.navigator = navigator
        # Set the maximum leaf size.
        self.maximum_leaf_size = maximum_leaf_size
        # Set the inputs
        self.x = x.clone()
        # Set the outputs
        self.y = y.clone()
        # Set the target events.
        self.target_events = [[e[1], e[2], y[e[1], e[2], 0]] 
                              for e in get_largest_event_set(y)]

    def rollout(self, node: Node) -> None:
        """Do a Monte Carlo Tree Search rollout starting from the given
        root node and reaching a leaf node. After the rollout, the leaf node
        is saved as the best leaf node if it has a lower error than the
        current best leaf node. Moreover, the reward is backpropogated
        from the leaf node to the root node in order to update the
        total reward and total visit count of each node.

        Parameters
        ----------
        node : Node
            The root node of the tree search.
        """
        # Get the path from the root node to the leaf node.
        # Apply node expansion through the navigator and node selection
        # through the Upper Confidence Bound applied to Trees (UCT).
        path = self._select(node)
        # Get the leaf node.
        leaf = path[-1]
        #### self._expand(leaf)
        # Get the reward of the leaf node.
        reward = self._simulate(leaf)
        # Backpropogate the reward from the leaf node to the root node.
        self._backpropagate(path, reward)
        #print(path)

    def _select(self, node: Node):
        "Find an unexplored descendent of `node`"
        # Set the rollout path.
        path = []
        while True:
            # Append the node to the path.
            path.append(node)
            # If the node is a terminal node, return the path.
            if node.is_terminal(self.maximum_leaf_size):
                return path
            # Expand the node children.
            self._expand(node)
            # Explore the child node that maximizes the Upper Confidence
            # Bound applied to Trees (UCT) formula.
            node = self._get_node_by_upper_confidence_bound(node)

    def _expand(self, node: Node) -> None:
        "Update the `children` dict with the children of `node`"
        if node not in self.children.keys():
            if node not in self.expanded_children.keys():
                # The node has never been expanded yet.
                self.children[node] = [ e for e in node.input_events ]
                #self.children[node] = node.find_children()
                self.expanded_children[node] = []
            else:
                # The node has been fully expanded.
                return
        # Get the best child of the node according to the correlation score.
        #best_child_idx = np.argmax(
        #    [n.correlation_score for n in self.children[node]])
        # Add the best child to the expanded children of the node.
        input_events = [ e for e in node.input_events ]
        input_events.remove(self.children[node][0])
        self.expanded_children[node].append(Node(input_events))
        # Delete the expanded child from the children of the node.
        del self.children[node][0]#[best_child_idx]

    def _simulate(self, node: Node) -> float:
        "Returns the reward for a random simulation (to completion) of `node`"
        # Assumes node is terminal
        # input_events = [[0, e[0], e[1]] for e in node.input_events]
        #x = remove_features_by_events(self.x, input_events)

        reward = node.reward(self.spatial_temporal_gnn, self.x, self.y)
        #print(reward)
        if reward > self.best_leaf[1]:
            self.best_leaf = (node, reward)

        return reward

    def _backpropagate(self, path: List[Node], reward: float) -> None:
        """Backpropagate the reward from the node to its ancestors.

        Parameters
        ----------
        path : list of Node
            The path from the root node to the leaf node.
        reward : float
            The reward score of the leaf node to backpropagate.
        """
        reward /= 100.
        for node in reversed(path):
            self.N[node] += 1
            self.C[node] += reward
            reward += 1  # Add 1 to the reward for the parent node.

    def _get_node_by_upper_confidence_bound(self, node: Node) -> Node:
        """
        Get a child node of the given node by the Upper Confidence Bound
        for Trees (UCT) algorithm, balancing exploration & exploitation.

        Parameters
        ----------
        node : Node
            The parent node to get a child of to explore.

        Returns
        -------
        Node
            The child node to explore.
        """
        #"Select a child of node, balancing exploration & exploitation"

        # All children of node should already be expanded:
        #assert all(n in self.expanded_children 
        #           for n in self.expanded_children[node])

        #log_N_vertex = math.log(self.N[node])
        # Get the sum of total visit count of each children.
        N_sum = sum([self.N[c] for c in self.expanded_children[node]])

        def get_upper_confidence_bound(n: Node) -> float:
            """
            Get the Upper Confidence Bound for Trees (UCT) of a child
            Node.

            Parameters
            ----------
            n : Node
                The child node to get the UCT of.

            Returns
            -------
            float
                The UCT of the child node.
            """
            #return self.C[n] / self.N[n] + self.exploration_weight * math.sqrt(
            #    log_N_vertex / self.N[n]
            #)
            return self.C[n] / (self.N[n] + 1e-10) + self.exploration_weight * math.sqrt(N_sum) / (self.N[n] + 1)

        return max(self.expanded_children[node], key=get_upper_confidence_bound)


In [40]:
x_sample, y_sample = x_train[0], y_train[0]

In [41]:
x_train[0]

array([[[6.43750000e+01, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.76250000e+01, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.71250000e+01, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [5.92500000e+01, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.90000000e+01, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.18750000e+01, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],

       [[6.26666667e+01, 3.47463516e-03, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.85555556e+01, 3.47463516e-03, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.54444444e+01, 

In [42]:
y_pred = spatial_temporal_gnn(scaler.scale(torch.Tensor(x_sample).unsqueeze(0).float().to('cuda'))).squeeze(0)

In [43]:
y_pred = scaler.un_scale(y_pred)

In [44]:
y_pred

tensor([[[63.1521],
         [64.8162],
         [64.1770],
         ...,
         [54.4360],
         [67.1545],
         [60.1670]],

        [[63.4581],
         [65.2643],
         [64.5045],
         ...,
         [54.2093],
         [67.4320],
         [60.1126]],

        [[64.0989],
         [65.7083],
         [64.9389],
         ...,
         [54.9721],
         [68.2278],
         [60.5601]],

        ...,

        [[63.8779],
         [65.3649],
         [64.8893],
         ...,
         [55.0188],
         [67.8865],
         [59.8245]],

        [[63.6194],
         [65.1281],
         [64.7243],
         ...,
         [54.0267],
         [68.1301],
         [60.4343]],

        [[63.1571],
         [64.3761],
         [63.8549],
         ...,
         [53.0313],
         [67.3756],
         [59.3520]]], device='cuda:0', grad_fn=<AddBackward0>)

In [45]:
y_pred[torch.Tensor(y_sample) == 0] = 0

In [46]:
y_sample[y_sample != 0]

array([68.54322 , 68.69324 , 69.817474, 70.409035, 69.05133 , 69.580765,
       68.64413 , 69.151245, 69.29832 , 69.345566, 69.57638 , 68.65461 ],
      dtype=float32)

In [47]:
y_pred[y_pred != 0]

tensor([68.5432, 68.6932, 69.8175, 70.4090, 69.0513, 69.5808, 68.6441, 69.1512,
        69.2983, 69.3456, 69.5764, 68.6546], device='cuda:0',
       grad_fn=<IndexBackward0>)

In [48]:
torch.all(y_pred.to('cuda').float() == torch.Tensor(y_sample).to('cuda').float())

tensor(True, device='cuda:0')

In [51]:
from src.explanation.events import get_largest_event_set


x_sample, y_sample = x_train[0], y_train[0]
x_sample = x_sample.copy()

target_events = [[e[1], e[2], y_sample[e[1], e[2], 0]] 
                 for e in get_largest_event_set(y_sample)]

'''input_events = get_largest_event_set(x_sample)
input_events = [i for i in input_events if i[0] == 0] # TODO: Get all kind of events


input_events_with_correlation_score = []
for i, e in enumerate(input_events):
    encoded_information = x_sample[e[1], e[2], :]
    e_ = (e[1], e[2], *encoded_information)
    target_batch = torch.FloatTensor(target_events).to(navigator.device)
    # Repeat the input event e for each target event to generate a
    # batch of input events corresponding to the target events.
    e_unsqueezed = np.expand_dims(e_, axis=0)
    input_batch = torch.FloatTensor(
        np.repeat(e_unsqueezed, len(target_events), axis=0)
        ).to(navigator.device)

    # Get the correlation score average between the input events and
    # the target events.
    correlation_score_avg = navigator(
        input_batch, target_batch).mean().item()
    input_events_with_correlation_score.append(
        (e_, correlation_score_avg))

for e in sorted(input_events_with_correlation_score, key=lambda x: x[1]):#[-800:]:
    x_sample[e[0][0], e[0][1], 0] = 0.'''

input_events = get_largest_event_set(x_sample)
input_events = [i for i in input_events if i[0] == 0]

# TODO: repeated code
input_events_with_correlation_score = []
for i, e in enumerate(input_events):
    encoded_information = x_sample[e[1], e[2], :]
    e_ = (e[1], e[2], *encoded_information)
    target_batch = torch.FloatTensor(target_events).to(navigator.device)
    # Repeat the input event e for each target event to generate a
    # batch of input events corresponding to the target events.
    e_unsqueezed = np.expand_dims(e_, axis=0)
    input_batch = torch.FloatTensor(
        np.repeat(e_unsqueezed, len(target_events), axis=0)
        ).to(navigator.device)

    # Get the correlation score average between the input events and
    # the target events.
    correlation_score_avg = navigator(
        input_batch, target_batch).mean().item()
    input_events_with_correlation_score.append(
        (e_, correlation_score_avg))
# TODO: end repeated code;

input_events_with_correlation_score = sorted(
    input_events_with_correlation_score, key=lambda x: x[1])[-500:]

print(len(input_events_with_correlation_score))

T = 12
heuristical_steps = np.linspace(.1, .12, T, dtype=np.float32, endpoint=False)[::-1]
#print(heuristical_steps)

for t, s in enumerate(heuristical_steps):
    for i, e in enumerate(input_events_with_correlation_score):
        input_timestep = e[0][0]
        if input_timestep != t:
            continue
        is_out_of_reach = True
        for target_node in np.unique([ e[1] for e in target_events ]):
            input_node = e[0][1]
            if adj_matrix[input_node, target_node] <= s or adj_matrix[target_node, input_node] <= s:
                is_out_of_reach = False
                break
        if is_out_of_reach:
            del input_events_with_correlation_score[i]

print(len(input_events_with_correlation_score))

monte_carlo_tree_search = MonteCarloTreeSearch(
    spatial_temporal_gnn,
    navigator,
    torch.FloatTensor(x_sample).to(device=spatial_temporal_gnn.device),
    torch.FloatTensor(y_sample).to(device=spatial_temporal_gnn.device),
    maximum_leaf_size=50, exploration_weight=500)

root = Node(input_events=[e for e, _ in input_events_with_correlation_score])

for i in range(50):
    print(f'Execution {i+1}/50')
    monte_carlo_tree_search.rollout(root)
    print('mae:', - monte_carlo_tree_search.best_leaf[1])

500
490
Execution 1/30
mae: 5.347222805023193
Execution 2/30
mae: 5.347222805023193
Execution 3/30
mae: 5.347222805023193
Execution 4/30
mae: 5.347222805023193
Execution 5/30
mae: 5.343201160430908
Execution 6/30
mae: 5.343201160430908
Execution 7/30
mae: 5.343201160430908
Execution 8/30
mae: 5.343201160430908
Execution 9/30
mae: 5.343201160430908
Execution 10/30
mae: 5.343201160430908
Execution 11/30
mae: 5.335117340087891
Execution 12/30
mae: 5.335117340087891
Execution 13/30
mae: 5.335117340087891
Execution 14/30
mae: 5.335117340087891
Execution 15/30
mae: 5.335117340087891
Execution 16/30
mae: 5.335117340087891
Execution 17/30
mae: 5.335117340087891
Execution 18/30
mae: 5.335117340087891
Execution 19/30
mae: 5.335117340087891
Execution 20/30
mae: 5.335117340087891
Execution 21/30
mae: 5.335117340087891
Execution 22/30
mae: 5.335117340087891
Execution 23/30
mae: 5.335117340087891
Execution 24/30
mae: 5.335117340087891
Execution 25/30
mae: 5.335117340087891
Execution 26/30
mae: 5.331

In [67]:
# input_events_with_correlation_score[1]

((0, 203, 69.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0), -6.236620903015137)

In [24]:
print(np.unique([ e[0] for e in monte_carlo_tree_search.best_leaf[0].input_events ]))
print(np.unique([ e[1] for e in monte_carlo_tree_search.best_leaf[0].input_events ]))

[ 0  1  2  5  6  7  8  9 10 11]
[ 0  1  2  3  4  6  7  8  9 10 14 16 18 21 22 35 62]


In [None]:
# print(len(monte_carlo_tree_search.C.items()))

8835


In [None]:
#for i in monte_carlo_tree_search.C.items():
#    print(i)

(<__main__.Node object at 0x000001F12369D930>, 3511.9177753829954)
(<__main__.Node object at 0x000001F1918B9ED0>, 3508.9177753829954)
(<__main__.Node object at 0x000001F19193E4D0>, 3505.9177753829954)
(<__main__.Node object at 0x000001F1919749D0>, 3502.9177753829954)
(<__main__.Node object at 0x000001F1D6FE4D90>, 3499.9177753829954)
(<__main__.Node object at 0x000001F1D6FEF6D0>, 3496.9177753829954)
(<__main__.Node object at 0x000001F1D70430D0>, 3493.9177753829954)
(<__main__.Node object at 0x000001F1DA309DB0>, 3490.9177753829954)
(<__main__.Node object at 0x000001F1DA395930>, 3487.9177753829954)
(<__main__.Node object at 0x000001F1DECD3D30>, 3484.9177753829954)
(<__main__.Node object at 0x000001F1DED3A050>, 3481.9177753829954)
(<__main__.Node object at 0x000001F1DED950F0>, 3478.9177753829954)
(<__main__.Node object at 0x000001F1E4DE4970>, 3475.9177753829954)
(<__main__.Node object at 0x000001F1E4E72230>, 3472.9177753829954)
(<__main__.Node object at 0x000001F1E4E84130>, 3469.9177753829

In [55]:
print(np.argmax(y_sample[5]))

96


In [29]:
print(np.unique([ e[0] for e in monte_carlo_tree_search.best_leaf[0].input_events ]))
print(np.unique([ e[1] for e in monte_carlo_tree_search.best_leaf[0].input_events ]))

[ 0  1  2  5  6  7  8  9 10 11]
[ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 16 17 18 19 21 22 23 29 30 35
 50 62]


In [None]:
import json

