In [9]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [10]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [12]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Loading the Data

In [13]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [14]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [15]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.explanation.navigator.model import Navigator
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix, the node indices and the
# matrix itself.
header, node_ids_dict, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the STGNN in evaluation mode.
spatial_temporal_gnn.eval();

# Get the Navigator and load the checkpoints.
navigator = Navigator(DEVICE)

navigator_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                          'navigator_metr_la.pth')

navigator_checkpoints = torch.load(navigator_checkpoints_path)
navigator.load_state_dict(navigator_checkpoints['model_state_dict'])

# Set the Navigator in evaluation mode.
navigator.eval();



In [16]:
import os
import numpy as np

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_train.npy'))
y_train = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_train.npy'))
x_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_val.npy'))
y_val = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_val.npy'))
x_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'x_test.npy'))
y_test = np.load(os.path.join(BASE_DATA_DIR, 'explainable', 'y_test.npy'))

In [None]:
from typing import List, Optional, Set
from abc import ABC, abstractmethod
from collections import defaultdict
import math

from src.spatial_temporal_gnn.metrics import MAE
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.explanation.navigator.model import Navigator
from src.explanation.events import remove_features_by_events


class Node():
    """
    A representation of a single board state.
    MCTS works by constructing a tree of these Nodes.
    Could be e.g. a chess or checkers board state.
    """
    def __init__(
        self, input_events: np.ndarray,
        removed_event: Optional[np.ndarray] = None,
        correlation_score: float = 1.
        ) -> None:
        self.input_events = input_events
        self.removed_event = removed_event
        self.correlation_score = correlation_score

    def find_children(
        self, target_events: List[np.ndarray],
        navigator: Navigator) -> List['Node']:
        """
        Get all possible successors of the current node.

        Returns
        -------
        set of Node
            All possible successors of the current node.
        """
        children = []
        device = navigator.device
        for i, e in enumerate(self.input_events):
            # Get the subset of input events without the event at index i.
            input_events_subset = np.delete(self.input_events, i, axis=0)
            # Generate a batch of target events.
            target_batch = torch.Tensor(target_events).float().to(device)
            # Repeat the input event e for each target event to generate a
            # batch of input events corresponding to the target events.
            input_batch = torch.Tensor(
                np.repeat(e[np.newaxis, :], len(target_events), axis=0)
                ).float().to(device)

            # Get the correlation score average between the input events and
            # the target events.
            correlation_score_avg = navigator(
                input_batch, target_batch).mean().item()
            
            # Add the events set except the one at index i to the children.
            children.append(
                Node(input_events_subset, e, correlation_score_avg))
        return children

    '''@abstractmethod
    def find_random_child(self):
        "Random successor of this board state (for more efficient simulation)"
        return None'''

    def expand_best_child(self, navigator: Navigator) -> 'Node':
        """
        Expand the best child of the current node according to the
        Navigator model.

        Parameters
        ----------
        navigator : Navigator
            The Navigator model.

        Returns
        -------
        Node
            The best child of the current node.
        """
        pass
        # Get the best child of the current node.
        #best_child = self.best_child(spatial_temporal_gnn, navigator)
        # Remove the best child from the children of the current node.
        #self.children.remove(best_child)
        # Return the best child.
        #return best_child

    def is_terminal(self, leaf_size: int) -> bool:
        """
        Returns True if the node has less than or equal to `leaf_size`
        events.

        Parameters
        ----------
        leaf_size : int
            The maximum number of events allowed in a leaf node.

        Returns
        -------
        bool
            Whether or not the node is terminal.
        """
        return len(self.input_events) <= leaf_size

    def reward(
        self, spatial_temporal_gnn: SpatialTemporalGNN,
        x: torch.FloatTensor, y: torch.FloatTensor) -> float:
        """Get the reward of the current node in terms of the negative
        Mean Absolute Error (MAE) between the predicted output data,
        given the subset of input events expressed by the current node,
        and the actual output data.

        Parameters
        ----------
        spatial_temporal_gnn : SpatialTemporalGNN
            The Spatial Temporal GNN model used to predict the output
            events.
        x : FloatTensor
            The input data.
        y : FloatTensor
            The output data.

        Returns
        -------
        float
            The negative MAE between the predicted output data and the
            actual output data.
        """
        # Set the MAE criterion.
        mae_criterion = MAE()
        # Get the device of the spatial temporal GNN.
        device = spatial_temporal_gnn.device

        # Set the input events as a list.
        input_events = list(self.input_events)
        # Remove the features corresponding to the input events in
        # the input data.
        x_subset = remove_features_by_events(x, input_events)

        # Predict the output events.
        y_pred = spatial_temporal_gnn(x_subset.to(device))

        # Compute the reward as the negative MAE between the predicted
        # output events and the actual output events.
        reward = - mae_criterion(y_pred, y.to(device)).item()
        return reward

    def __hash__(self) -> int:
        """Hash the node by the input events.

        Returns
        -------
        int
            The hash of the node.
        """
        return hash(set([e for e in self.input_events]))

    def __eq__(node1: 'Node', node2: 'Node') -> bool:
        """Get whether or not two nodes are equal.
        A node is equal to another node if they have the same input
        events set.

        Parameters
        ----------
        node1 : Node
            The first node to compare.
        node2 : Node
            The second node to compare.

        Returns
        -------
        bool
            Whether or not the two nodes are equal.
        """
        return node1.input_events == node2.input_events

class MCTS:
    "Monte Carlo Tree Search. First rollout the tree then choose a move."

    def __init__(
        self, spatial_temporal_gnn: SpatialTemporalGNN, navigator: Navigator,
        maximum_leaf_size: int = 20, exploration_weight: int = 1) -> None:
        """Initialize the MCTS.

        Parameters
        ----------
        spatial_temporal_gnn : SpatialTemporalGNN
            The Spatial Temporal Graph Neural Network used to get the
            reward of a leaf node.
        navigator : Navigator
            The Navigator used to select which node to expand during
            the tree search.
        exploration_weight : int, optional
            The exploration weight used in the Upper Confidence Bound
            for Trees (UCT) formula, by default 1.
        """
        # Set dictionary of total reward of each node.
        self.C = defaultdict(int)
        # Set dictionary of total visit count for each node.
        self.N = defaultdict(int)
        # Set dictionary of children of each node.
        self.children = dict()
        # Set dictionary of expanded children of each node.
        self.expanded_children = dict()
        # Set the best found leaf node along with its error.
        self.best_leaf = ( None, -math.inf )
        # Set the exploration weight.
        self.exploration_weight = exploration_weight
        # Set the Spatial Temporal Graph Neural Network.
        self.spatial_temporal_gnn = spatial_temporal_gnn
        # Set the Navigator.
        self.navigator = navigator
        # Set the maximum leaf size.
        self.maximum_leaf_size = maximum_leaf_size

    '''
    def choose(self, node: Node):
        "Choose the best successor of node. (Choose a move in the game)"
        if node.is_terminal():
            raise RuntimeError(f"choose called on terminal node {node}")

        if node not in self.children:
            return node.find_random_child()

        def score(n):
            if self.N[n] == 0:
                return float("-inf")  # avoid unseen moves
            return self.Q[n] / self.N[n]  # average reward

        return max(self.children[node], key=score)
    ''';

    def rollout(self, node: Node) -> None:
        """Do a Monte Carlo Tree Search rollout starting from the given
        root node and reaching a leaf node. After the rollout, the leaf node
        is saved as the best leaf node if it has a lower error than the
        current best leaf node. Moreover, the reward is backpropogated
        from the leaf node to the root node in order to update the
        total reward and total visit count of each node.

        Parameters
        ----------
        node : Node
            The root node of the tree search.
        """
        # Get the path from the root node to the leaf node.
        # Apply node expansion through the navigator and node selection
        # through the Upper Confidence Bound applied to Trees (UCT).
        path = self._select(node)
        # Get the leaf node.
        leaf = path[-1]
        # self._expand(leaf)
        # Get the reward of the leaf node.
        reward = self._simulate(leaf)
        # Backpropogate the reward from the leaf node to the root node.
        self._backpropagate(path, reward)

    def _select(self, node: Node):
        "Find an unexplored descendent of `node`"
        # Set the rollout path.
        path = []
        while True:
            # Append the node to the path.
            path.append(node)

            # If the node is a terminal node, return the path.
            if node.is_terminal(self.maximum_leaf_size):
                return path
            # Expand the node children.
            self._expand(node)
            # Explore the child node that maximizes the Upper Confidence
            # Bound applied to Trees (UCT) formula.
            node = self._get_node_by_upper_confidence_bound(node)

            '''# If the node has not been expanded yet, expand it.
            if node not in self.children:


            if node not in self.children or not self.children[node]:
                # node is either unexplored or terminal
                return path
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            node = self._get_node_by_upper_confidence_bound(node)  # descend a layer deeper
            ''';

    def _expand(self, node: Node) -> None:
        "Update the `children` dict with the children of `node`"
        if node not in self.children:
            if node not in self.expanded_children:
                # The node has never been expanded yet.
                self.children[node] = node.find_children()
                self.expanded_children[node] = []
            else:
                # The node has been fully expanded.
                return
        # Get the best child of the node according to the correlation score.
        best_child_idx = np.argmax(
            self.children[node], key=lambda x: x.correlation_score)
        # Add the best child to the expanded children of the node.
        self.expanded_children[node] += [self.children[node][best_child_idx]]
        # Delete the expanded child from the children of the node.
        del self.children[node][best_child_idx]

    def _simulate(self, node: Node) -> float:
        "Returns the reward for a random simulation (to completion) of `node`"
        # Assumes node is terminal
        reward = node.reward()
        if reward > self.best_leaf[1]:
            self.best_leaf = (node, reward)
        return reward
        '''while True:
            if node.is_terminal():
                reward = node.reward()
                return reward
            node = node.find_random_child()''';

    def _backpropagate(self, path: List[Node], reward: float) -> None:
        """Backpropagate the reward from the node to its ancestors.

        Parameters
        ----------
        path : list of Node
            The path from the root node to the leaf node.
        reward : float
            The reward score of the leaf node to backpropagate.
        """
        for node in reversed(path):
            self.N[node] += 1
            self.C[node] += reward
            reward += 1  # Add 1 to the reward for the parent node.

    def _get_node_by_upper_confidence_bound(self, node: Node) -> Node:
        """
        Get a child node of the given node by the Upper Confidence Bound
        for Trees (UCT) algorithm, balancing exploration & exploitation.

        Parameters
        ----------
        node : Node
            The parent node to get a child of to explore.

        Returns
        -------
        Node
            The child node to explore.
        """
        #"Select a child of node, balancing exploration & exploitation"

        # All children of node should already be expanded:
        assert all(n in self.children for n in self.expanded_children[node])

        #log_N_vertex = math.log(self.N[node])
        # Get the sum of total visit count of each children.
        N_sum = sum([self.N[c] for c in self.expanded_children[node]])

        def get_upper_confidence_bound(n: Node) -> float:
            """
            Get the Upper Confidence Bound for Trees (UCT) of a child
            Node.

            Parameters
            ----------
            n : Node
                The child node to get the UCT of.

            Returns
            -------
            float
                The UCT of the child node.
            """
            #return self.C[n] / self.N[n] + self.exploration_weight * math.sqrt(
            #    log_N_vertex / self.N[n]
            #)
            return self.C[n] / self.N[n] + self.exploration_weight * (
                math.sqrt(N_sum) / (1 + self.N[n]))

        return max(self.expanded_children[node],
                   key=get_upper_confidence_bound)
