In [2]:
!pip install typed-argument-parser
!pip install stable-baselines3
!pip install sb3_contrib


Collecting typed-argument-parser
  Downloading typed_argument_parser-1.10.1-py3-none-any.whl.metadata (32 kB)
Collecting typing-inspect>=0.7.1 (from typed-argument-parser)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect>=0.7.1->typed-argument-parser)
  Downloading mypy_extensions-1.1.0-py3-none-any.whl.metadata (1.1 kB)
Downloading typed_argument_parser-1.10.1-py3-none-any.whl (30 kB)
Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)
Downloading mypy_extensions-1.1.0-py3-none-any.whl (5.0 kB)
Installing collected packages: mypy-extensions, typing-inspect, typed-argument-parser
Successfully installed mypy-extensions-1.1.0 typed-argument-parser-1.10.1 typing-inspect-0.9.0
Collecting stable-baselines3
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-p

In [4]:
import random
import json
import itertools
from typing import Dict, Literal, Optional, List
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import gymnasium as gym
from gymnasium import spaces
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.env_util import SubprocVecEnv, make_vec_env, DummyVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback

In [5]:
# Configure model and training parameters.

class Args(object):
    no_wandb: bool = False
    total_timesteps: int = 4_000_000
    seed: int = random.randint(0, 2**32 - 1)
    n_envs: int = 32
    device: Literal['cpu', 'mps', 'cuda'] = 'cuda'
    feature_extractor: Literal['none', 'tfh_small', 'tfh_big', 'tf', 'tfh_fast'] = 'tfh_fast'
    env_num_inserts: int = 6
    env_num_deletes: int = 6
    env_max_tree_values: int = 24
    env_max_values_per_node: int = 4
    checkpoint_callback_freq: int = 50_000
    s_net_arch: Dict[str, list[int]] = {'pi': [512, 512], 'vf': [512, 512]}
    s_transformer_features_dim: int = 64
    s_transformer_num_layers: int = 2
    s_transformer_nhead: int = 2
    s_n_epochs: int = 10
    s_learning_rate: float = 1e-4
    s_entropy_coef: int = 0
    s_n_eval_episodes: int = 10_000
    s_eval_freq: int = 50_000
    s_n_seeds: int = 40
    s_batch_size: int = 512
    s_n_steps: int =  1000

args = Args()

In [6]:
# B+Tree adapted from https://gist.github.com/benben233/2c8a2a8ab44a7beabad0df1b6658232e
# In the calculation functions, it contains the logic to compute the cost of an executed operation

class Node(object):
    """Base node object. It should be an index node
    Each node stores keys and children.

    Attributes:
        parent
        cost_dict
    """

    def __init__(self, cost_dict: dict, parent=None):
        """Child nodes are stored in values. Parent nodes simply act as a medium to traverse the tree.
        :type parent: Node"""
        self.keys: list = []
        self.values: list[Node] = []
        self.parent: Node = parent
        self.cost_dict: dict = cost_dict

    def index(self, key):
        """Return the index where the key should be.
        :type key: str
        """
        for i, item in enumerate(self.keys):
            if key < item:
                return i

        return len(self.keys)

    def __getitem__(self, item):
        return self.values[self.index(item)]

    def __setitem__(self, key, value):
        i = self.index(key)
        self.keys[i:i] = [key]
        self.values.pop(i)
        self.values[i:i] = value

    def split(self):
        """Splits the node into two and stores them as child nodes.
        extract a pivot from the child to be inserted into the keys of the parent.
        @:return key and two children
        """
        self.cost_dict["splits"] += 1
        self.cost_dict["parent_splits"] += 1

        left = Node(cost_dict=self.cost_dict, parent=self.parent)

        mid = len(self.keys) // 2

        left.keys = self.keys[:mid]
        left.values = self.values[: mid + 1]
        for child in left.values:
            child.parent = left

        key = self.keys[mid]
        self.keys = self.keys[mid + 1 :]
        self.values = self.values[mid + 1 :]

        return key, [left, self]

    def __delitem__(self, key):
        i = self.index(key)
        del self.values[i]
        if i < len(self.keys):
            del self.keys[i]
        else:
            del self.keys[i - 1]

    def fusion(self):
        self.cost_dict["fusions"] += 1
        self.cost_dict["parent_fusions"] += 1

        index = self.parent.index(self.keys[0])
        # merge this node with the next node
        if index < len(self.parent.keys):
            next_node: Node = self.parent.values[index + 1]
            next_node.keys[0:0] = self.keys + [self.parent.keys[index]]
            for child in self.values:
                child.parent = next_node
            next_node.values[0:0] = self.values
        else:  # If self is the last node, merge with prev
            prev: Node = self.parent.values[-2]
            prev.keys += [self.parent.keys[-1]] + self.keys
            for child in self.values:
                child.parent = prev
            prev.values += self.values

    def borrow_key(self, minimum: int):
        index = self.parent.index(self.keys[0])
        if index < len(self.parent.keys):
            next_node: Node = self.parent.values[index + 1]
            if len(next_node.keys) > minimum:
                self.keys += [self.parent.keys[index]]

                borrow_node = next_node.values.pop(0)
                borrow_node.parent = self
                self.values += [borrow_node]
                self.parent.keys[index] = next_node.keys.pop(0)
                return True
        elif index != 0:
            prev: Node = self.parent.values[index - 1]
            if len(prev.keys) > minimum:
                self.keys[0:0] = [self.parent.keys[index - 1]]

                borrow_node = prev.values.pop()
                borrow_node.parent = self
                self.values[0:0] = [borrow_node]
                self.parent.keys[index - 1] = prev.keys.pop()
                return True

        return False


class Leaf(Node):
    def __init__(self, cost_dict: dict, parent=None, prev_node=None, next_node=None):
        """
        Create a new leaf in the leaf link
        :type prev_node: Leaf
        :type next_node: Leaf
        """
        super(Leaf, self).__init__(cost_dict, parent)
        self.next: Leaf = next_node
        if next_node is not None:
            next_node.prev = self
        self.prev: Leaf = prev_node
        if prev_node is not None:
            prev_node.next = self

    def __getitem__(self, item):
        return self.values[self.keys.index(item)]

    def __setitem__(self, key, value):
        i = self.index(key)
        if key not in self.keys:
            self.keys[i:i] = [key]
            self.values[i:i] = [value]
        else:
            self.values[i - 1] = value

    def split(self):
        self.cost_dict["splits"] += 1

        left = Leaf(
            cost_dict=self.cost_dict,
            parent=self.parent,
            prev_node=self.prev,
            next_node=self,
        )
        mid = len(self.keys) // 2

        left.keys = self.keys[:mid]
        left.values = self.values[:mid]

        self.keys: list = self.keys[mid:]
        self.values: list = self.values[mid:]

        # When the leaf node is split, set the parent key to the left-most key of the right child node.
        return self.keys[0], [left, self]

    def __delitem__(self, key):
        i = self.keys.index(key)
        del self.keys[i]
        del self.values[i]

    def fusion(self):
        self.cost_dict["fusions"] += 1

        if self.next is not None and self.next.parent == self.parent:
            self.next.keys[0:0] = self.keys
            self.next.values[0:0] = self.values
        else:
            self.prev.keys += self.keys
            self.prev.values += self.values

        if self.next is not None:
            self.next.prev = self.prev
        if self.prev is not None:
            self.prev.next = self.next

    def borrow_key(self, minimum: int):
        index = self.parent.index(self.keys[0])
        if index < len(self.parent.keys) and len(self.next.keys) > minimum:
            self.keys += [self.next.keys.pop(0)]
            self.values += [self.next.values.pop(0)]
            self.parent.keys[index] = self.next.keys[0]
            return True
        elif index != 0 and len(self.prev.keys) > minimum:
            self.keys[0:0] = [self.prev.keys.pop()]
            self.values[0:0] = [self.prev.values.pop()]
            self.parent.keys[index - 1] = self.keys[0]
            return True

        return False


class BPlusTree(object):
    """B+ tree object, consisting of nodes.

    Nodes will automatically be split into two once it is full. When a split occurs, a key will
    'float' upwards and be inserted into the parent node to act as a pivot.

    Attributes:
        maximum (int): The maximum number of keys each node can hold.
    """

    root: Node
    cost_dict: dict

    def __init__(self, maximum=4):
        self.cost_dict = {
            "splits": 0,
            "parent_splits": 0,
            "fusions": 0,
            "parent_fusions": 0,
        }
        self.root = Leaf(cost_dict=self.cost_dict)
        self.maximum: int = maximum if maximum > 2 else 2
        self.minimum: int = self.maximum // 2
        self.depth = 0

    def find(self, key) -> Leaf:
        """find the leaf

        Returns:
            Leaf: the leaf which should have the key
        """
        node = self.root
        # Traverse tree until leaf node is reached.
        while type(node) is not Leaf:
            node = node[key]

        return node

    def __getitem__(self, item):
        return self.find(item)[item]

    def query(self, key):
        """Returns a value for a given key, and None if the key does not exist."""
        leaf = self.find(key)
        return leaf[key] if key in leaf.keys else None

    def change(self, key, value):
        """change the value

        Returns:
            (bool,Leaf): the leaf where the key is. return False if the key does not exist
        """
        leaf = self.find(key)
        if key not in leaf.keys:
            return False, leaf
        else:
            leaf[key] = value
            return True, leaf

    def __setitem__(self, key, value, leaf=None):
        """Inserts a key-value pair after traversing to a leaf node. If the leaf node is full, split
        the leaf node into two.
        """
        if leaf is None:
            leaf = self.find(key)
        leaf[key] = value
        if len(leaf.keys) > self.maximum:
            self.insert_index(*leaf.split())

    def insert(self, key, value):
        """
        Returns:
            (bool,Leaf): the leaf where the key is inserted. return False if already has same key
        """
        leaf = self.find(key)
        if key in leaf.keys:
            return False, leaf
        else:
            self.__setitem__(key, value, leaf)
            return True, leaf

    def insert_index(self, key, values: list[Node]):
        """For a parent and child node,
        Insert the values from the child into the values of the parent."""
        parent = values[1].parent
        if parent is None:
            values[0].parent = values[1].parent = self.root = Node(
                cost_dict=self.cost_dict
            )
            self.depth += 1
            self.root.keys = [key]
            self.root.values = values
            return

        parent[key] = values
        # If the node is full, split the  node into two.
        if len(parent.keys) > self.maximum:
            self.insert_index(*parent.split())
        # Once a leaf node is split, it consists of a internal node and two leaf nodes.
        # These need to be re-inserted back into the tree.

    def delete(self, key, node: Node = None):
        if node is None:
            node = self.find(key)
        del node[key]

        if len(node.keys) < self.minimum:
            if node == self.root:
                if len(self.root.keys) == 0 and len(self.root.values) > 0:
                    self.root = self.root.values[0]
                    self.root.parent = None
                    self.depth -= 1
                return

            elif not node.borrow_key(self.minimum):
                node.fusion()
                self.delete(key, node.parent)

    def show(self, node=None, file=None, _prefix="", _last=True):
        """Prints the keys at each level."""
        if node is None:
            node = self.root
        print(_prefix, "`- " if _last else "|- ", node.keys, sep="", file=file)
        _prefix += "   " if _last else "|  "

        if type(node) is Node:
            # Recursively print the key of child nodes (if these exist).
            for i, child in enumerate(node.values):
                _last = i == len(node.values) - 1
                self.show(child, file, _prefix, _last)

    def output(self):
        return tuple(self.cost_dict.values()), self.depth

    def readfile(self, reader):
        i = 0
        for i, line in enumerate(reader):
            s = line.decode().split(maxsplit=1)
            self[s[0]] = s[1]
            if i % 1000 == 0:
                print("Insert " + str(i) + "items")
        return i + 1

    def leftmost_leaf(self) -> Leaf:
        node = self.root
        while type(node) is not Leaf:
            node = node.values[0]
        return node

    def get_obs_space_representation(self, max_depth):
        """
        Returns a 1D array representation of the tree:
        - Keys in each node are padded with zeros to `maximum` keys.
        - The entire structure is padded with zeros to account for the maximum possible nodes at each level.
        """

        max_depth += 1 # Add 1 to account for the root node
        levels = [[] for _ in range(max_depth)]

        def dfs(node: Node, depth: int):
            if depth == max_depth:
                return

            level = levels[depth]

            if node is None:
                level += [0] * self.maximum
                children = []
            else:
                level += node.keys.copy() + [0] * (self.maximum - len(node.keys))
                assert len(level) % self.maximum == 0

                if type(node) is Leaf:
                    children = []
                else:
                    children = node.values.copy()


            while len(children) < self.maximum + 1:
                children.append(None)

            assert len(children) == self.maximum + 1

            for child in children:
                dfs(child, depth + 1)

        # Start traversal from the root
        dfs(self.root, 0)

        # Make sure the layers are filled correctly
        prev_nodes = 1
        for level in levels[1:]:
            cur_nodes = prev_nodes * (self.maximum + 1)
            assert len(level) == cur_nodes * self.maximum
            prev_nodes = cur_nodes

        flattened_representation = list(itertools.chain(*levels))
        return np.array(flattened_representation).flatten()

    def get_obs_space_feature_representation(self, max_depth):
        """
        Returns a 1D array representation of the tree with feature engineering:
        - Each node is represented by its minimum key, maximum key, and fill percentage.
        - The structure is padded with zeros to account for the maximum possible nodes at each level.
        """

        max_depth += 1  # Add 1 to account for the root node
        levels = [[] for _ in range(max_depth)]

        def dfs(node: Node, depth: int):
            if depth == max_depth:
                return

            level = levels[depth]

            if node is None:
                level += [0, 0, 0]
                children = []
            else:
                min_key = min(node.keys) if node.keys else 0
                max_key = max(node.keys) if node.keys else 0
                fill_percentage = len(node.keys) / self.maximum

                level += [min_key, max_key, fill_percentage]

                if isinstance(node, Leaf):
                    children = []
                else:
                    children = node.values.copy()

            while len(children) < self.maximum + 1:
                children.append(None)

            assert len(children) == self.maximum + 1

            for child in children:
                dfs(child, depth + 1)

        dfs(self.root, 0)

        # Make sure the layers are filled correctly
        prev_nodes = 1
        for level in levels[1:]:
            cur_nodes = prev_nodes * (self.maximum + 1)
            assert len(level) == cur_nodes * 3  # **3 features per node (min, max, fill percentage).**
            prev_nodes = cur_nodes

        flattened_representation = list(itertools.chain(*levels))
        return np.array(flattened_representation).flatten()


    def reset_cost_dict(self):
        for key in self.cost_dict.keys():
            self.cost_dict[key] = 0

    def calculate_reward(self):
        cost_factors = {
            "splits": 2,
            "parent_splits": 1,
            "fusions": 2,
            "parent_fusions": 1,
        }
        reward = 0
        for key in self.cost_dict.keys():
            reward += cost_factors[key] * self.cost_dict[key]
        self.reset_cost_dict()
        return reward



def calculate_length_max_depth_of_tree(max_tree_values, max_keys):
    """
    Calculate the length of the observation space representation.
    """
    # max_depth = calculate_max_depth(num_inserts,num_values,max_keys)

    max_depth = 1 + np.log(max_tree_values) / (np.log(max_keys + 1))
    max_depth = int(max_depth)
    #print("Max Depth:", max_depth)

    total_keys = 3
    prev_nodes = 1
    for level in range(1, max_depth+1):
        cur_nodes = prev_nodes * (max_keys + 1)
        total_keys += cur_nodes * 3

        values_in_level = cur_nodes * 3
        prev_nodes = cur_nodes
    #print("Total Keys:", total_keys)
    #print("Values in Level:", values_in_level)
    return  total_keys, max_depth


def printTree(tree):
    current_node = tree.root
    if current_node is not None:
        print(current_node.values)
        print(current_node.keys)
        print(current_node.nextKey)
        print(current_node.parent)
        print(current_node.check_leaf)
        print("\n")
        if not current_node.check_leaf:
            for i, item in enumerate(current_node.keys):
                printTree(current_node.keys[i])

In [7]:
# The training environment which uses the previously defined B-tree
# The important pieces are contained in the step function.

class BScheduler(gym.Env):
    def __init__(self, args: Args = Args(), render_mode: Optional[str] = None):
        self.num_inserts = args.env_num_inserts
        self.num_deletes =  args.env_num_deletes
        self.num_operations = self.num_inserts + self.num_deletes
        self.max_tree_values = args.env_max_tree_values
        self.max_values_per_node = args.env_max_values_per_node
        self.action_space = spaces.Discrete(self.num_operations)
        self.low = -2
        self.high = self.max_tree_values + self.num_operations
        self.len_tree_obs_space, self.max_possible_tree_depth = calculate_length_max_depth_of_tree(
            self.max_tree_values, self.max_values_per_node
        )
        self.observation_space = spaces.Box(
            low=self.low,
            high=self.high,
            shape=(self.len_tree_obs_space + self.num_operations,),
            dtype=np.float32,
        )
        self.tree = None
        self.tree_representation = None
        self.inserts = None
        self.deletes = None
        self.rng = np.random.default_rng(None)


    def _get_obs(self):
        self.tree_representation = self.tree.get_obs_space_feature_representation(
            self.max_possible_tree_depth
        )
        assert len(self.tree_representation) == self.len_tree_obs_space
        assert self.tree_representation[-1] == 0
        return np.concatenate([self.operations, self.tree_representation], dtype=np.float32)

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        if seed is not None:
            self.rng = np.random.default_rng(seed=seed)
        self.tree = BPlusTree(maximum=self.max_values_per_node)
        self.tree_numbers = self.rng.choice(
            a=np.arange(1, self.high), size=self.max_tree_values + self.num_inserts, replace=False
        )
        inserts = self.tree_numbers[self.max_tree_values:]
        for i in range(self.max_tree_values - self.num_inserts):
            self.tree.insert(self.tree_numbers[i], self.tree_numbers[i])

        deletes = self.rng.choice(
            self.tree_numbers[:self.num_deletes], self.num_deletes, replace=False
        )

        sorted_deletes = np.sort(deletes)
        sorted_inserts = np.sort(inserts)
        self.tree.calculate_reward()  # to reset counters
        self.operations = np.concatenate([sorted_inserts, sorted_deletes])
        return self._get_obs(), {}

    def step(self, action):
        info = {}
        truncated = False
        terminated = False
        reward = 0
        # Check for valid action choices
        operation = self.operations[action]
        if operation == -1:
            print("operation", operation)
            print("operations", self.operations)
            raise "This should not happen if you use MaskablePPO"

        # If an action is contained in the first half of the actions, it is an insert operation
        elif action < self.num_operations // 2:
            self.tree.insert(operation, operation)
        else:
            self.tree.delete(operation)
        self.operations[action] = -1

        # Once the todo-list is empty, the process finishes
        if (self.operations == -1).all():
            terminated = True

        observation = self._get_obs()

        # The tree library contains the logic to compute the cost of the last executed operation
        reward = -1 * self.tree.calculate_reward()
        return observation, reward, terminated, truncated, info

    def action_masks(self) -> List[bool]:
        ret = self.operations != -1
        return ret


gym.register(
    id="BScheduler-v0",
    entry_point=BScheduler,
)

In [9]:
class OptimizedHierarchicalBPlusFeatureExtractor(BaseFeaturesExtractor):
    """
    The feature extractor, which contains the main logic of the hierarchical model described in the text.

    """
    def __init__(self,
                 observation_space: gym.spaces.Box,
                 feature_dim: int = 256,
                 values_per_node: int = 4,
                 num_ops: int = 6,
                 num_heads: int = 4,
                 dropout: float = 0.1,
                 max_levels: int = 10):
        super(OptimizedHierarchicalBPlusFeatureExtractor, self).__init__(
            observation_space,
            feature_dim + num_ops
        )
        self.features_per_node = 3
        self.values_per_node = values_per_node
        self.children_per_node = values_per_node + 1
        self.num_ops = num_ops
        self.feature_dim = feature_dim
        self.max_levels = max_levels
        self.debug = False

        self.level_structure = self._compute_level_structure(observation_space.shape)

        # By default, a TransformerEncoderLayer is used for the computation. It is shared among all levels of the tree.
        self.transformer = nn.TransformerEncoderLayer(
            d_model=feature_dim,
            nhead=num_heads,
            dim_feedforward=feature_dim,
            dropout=dropout,
            batch_first=True
        )
        self.linear = nn.Linear(feature_dim, feature_dim)
        self.leaf_embedding = nn.Linear(self.features_per_node, feature_dim)
        self.node_combiner = nn.Linear(feature_dim * self.children_per_node + self.features_per_node, feature_dim)
        self.level_embeddings = nn.Parameter(torch.randn(max_levels, 1, feature_dim))
        self.level_norm = nn.LayerNorm(feature_dim)
        self.node_norm = nn.LayerNorm(feature_dim * self.children_per_node + self.features_per_node)

    def _compute_level_structure(self, obs):
        """
        Pre-compute the structure of each level in the tree.
        Calculate leaf start index and propagate upwards to get all index ranges for all nodes.
        """
        num_levels = 0
        idx = 1
        obs_without_actions = obs[0] - self.num_ops
        # Determine the number of levels and leaf start index
        while idx < obs_without_actions:
            if idx * self.children_per_node > obs_without_actions:
                break
            idx *= self.children_per_node
            num_levels += 1

        level_structure = []
        current_value_end = obs_without_actions

        for level in range(num_levels, -1, -1):
            nodes_this_level = idx if level == num_levels else idx // (self.children_per_node ** (num_levels - level))
            parent_nodes = nodes_this_level // self.children_per_node

            value_start_idx = current_value_end - nodes_this_level * self.features_per_node

            level_info = {
                'num_nodes': nodes_this_level,
                'num_parents': parent_nodes,
                'value_start_idx': value_start_idx,
                'values_per_level': nodes_this_level * self.features_per_node,
                'value_end_idx': current_value_end
            }

            level_structure.append(level_info)
            current_value_end = value_start_idx

        return level_structure


    def _get_empty_mask(self, node_values):
        empty = node_values[..., 0] == 0
        return empty

    def process_level(self, level_info, current_embeddings, tree_data, level_idx):
        # Parse one level of the tree using the data stored in tree_data
        batch_size = current_embeddings.shape[0]
        num_parents = level_info['num_parents']

        if num_parents == 0:  # root node have to pass
            return current_embeddings


        parent_values_start = level_info['value_start_idx']
        parent_values = tree_data[:, parent_values_start - num_parents * self.features_per_node:parent_values_start]
        parent_values = parent_values.view(batch_size, num_parents, self.features_per_node)
        empty_mask = self._get_empty_mask(parent_values)

        output_embeddings = torch.zeros(
            batch_size, num_parents, self.feature_dim,
            device=current_embeddings.device
        )

        if (~empty_mask).any():
            non_empty_indices = torch.nonzero(~empty_mask)
            non_empty_parents = parent_values[~empty_mask]


            grouped_children = current_embeddings.view(batch_size, -1, self.children_per_node, self.feature_dim)
            non_empty_children = grouped_children[non_empty_indices[:, 0], non_empty_indices[:, 1]]

            children_flat = non_empty_children.view(-1, self.children_per_node * self.feature_dim)

            combined = torch.cat([children_flat, non_empty_parents], dim=1)
            #combined = self.node_norm(combined)
            parent_embeddings = self.node_combiner(combined)

            level_embedding = self.level_embeddings[level_idx].expand(len(parent_embeddings), -1)
            parent_embeddings = parent_embeddings + level_embedding
            parent_embeddings = self.level_norm(parent_embeddings)

            transformed = self.transformer(parent_embeddings.unsqueeze(1)).squeeze(1)
            #transformed = self.linear(parent_embeddings.unsqueeze(1)).squeeze(1) # linear version
            output_embeddings[non_empty_indices[:, 0], non_empty_indices[:, 1]] = transformed

        return output_embeddings

    def forward(self, obs: torch.Tensor) -> torch.Tensor:

        batch_size = obs.shape[0]
        ops = obs[:, :self.num_ops]
        tree_data = obs[:, self.num_ops:]
        tree_data = (tree_data - tree_data.min(dim=1, keepdim=True).values) / (tree_data.max(dim=1, keepdim=True).values - tree_data.min(dim=1, keepdim=True).values + 1e-8)

        # Process leaf nodes separately
        num_leaf_nodes = self.level_structure[0]['num_nodes'] if self.level_structure else (tree_data.shape[1] // self.features_per_node)
        leaf_values = tree_data[:, -num_leaf_nodes * self.features_per_node:].view(
            batch_size, num_leaf_nodes, self.features_per_node
        )

        leaf_embeddings = torch.zeros(batch_size, num_leaf_nodes, self.feature_dim, device=obs.device)

        empty_mask = self._get_empty_mask(leaf_values)
        if (~empty_mask).any():
            non_empty_leaves = leaf_values[~empty_mask]
            non_empty_indices = torch.nonzero(~empty_mask)

            embeddings = self.leaf_embedding(non_empty_leaves)

            level_embedding = self.level_embeddings[0].expand(len(embeddings), -1)
            embeddings = embeddings + level_embedding
            embeddings = self.level_norm(embeddings)

            leaf_embeddings[non_empty_indices[:, 0], non_empty_indices[:, 1]] = embeddings

        current_embeddings = leaf_embeddings

        for level_idx, level_info in enumerate(self.level_structure):
            current_embeddings = self.process_level(
                level_info,
                current_embeddings,
                tree_data,
                level_idx
            )

        root_embedding = current_embeddings.squeeze(1)
        return torch.cat([root_embedding, ops], dim=1)

In [10]:
# A few helper functions for initialization
def get_env(args: Args):
    env = get_env_func(args)()
    env.reset(seed=args.seed)
    return env


def get_env_func(args: Args):
    def env_func():
        env = gym.make("BScheduler-v0", args=args, render_mode=None)
        return env

    return env_func


def get_vec_env(args: Args, n_envs_override: int = None):
    n_envs = n_envs_override or args.n_envs
    env_func = get_env_func(args)
    env = make_vec_env(env_func, n_envs=n_envs, vec_env_cls=DummyVecEnv)
    return env


def create_model(args: Args, env: BScheduler, eval_env: BScheduler):
    print(f"Creating model with args: {args}")
    random.seed(args.seed)
    np.random.seed(args.seed)

    policy_kwargs = dict(
        net_arch=args.s_net_arch,
    )

    if args.feature_extractor == "tfh_fast":  # Fast Hierarchical Transformer Feature Extractor
        policy_kwargs['features_extractor_class'] = OptimizedHierarchicalBPlusFeatureExtractor
        policy_kwargs['features_extractor_kwargs'] = dict(
            feature_dim=args.s_transformer_features_dim,
            values_per_node=args.env_max_values_per_node,
            num_ops=args.env_num_inserts + args.env_num_deletes,
            num_heads=args.s_transformer_nhead,
            dropout=0.1,
            max_levels=5,
        )

    elif args.feature_extractor == 'none':
        print("Training without special feature extractor")

    else:
        raise ValueError(f"Unknown feature_extractor: {args.feature_extractor}")

    model = MaskablePPO(
        "MlpPolicy",
        env,
        n_epochs=args.s_n_epochs,
        learning_rate=args.s_learning_rate,
        verbose=1,
        gamma=0.999,
        device=args.device,
        policy_kwargs=policy_kwargs,
        seed=args.seed,
        batch_size=args.s_batch_size,
    )

    print("Model policy:", model.policy)

    return model


In [11]:
# Create the environment
vec_env = get_vec_env(args)
eval_env = get_vec_env(args, 1)

# Initialize the model
model = create_model(args, vec_env, eval_env)

# Start the training process
model.learn(total_timesteps=args.total_timesteps)

# Compute the mean reward of the model for the evaluation, which can be compared to the table in the text.
eval = evaluate_policy(model, vec_env, n_eval_episodes=1000, warn=False)
print("Evaluation (mean/std):", eval)

Creating model with args: <__main__.Args object at 0x7a625d172dd0>
Using cpu device
Model policy: MaskableActorCriticPolicy(
  (features_extractor): OptimizedHierarchicalBPlusFeatureExtractor(
    (transformer): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (linear1): Linear(in_features=64, out_features=64, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=64, out_features=64, bias=True)
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (linear): Linear(in_features=64, out_features=64, bias=True)
    (leaf_embedding): Linear(in_features=3, out_features=64, bias=True)
    (node_combiner): Linear(in_features=323, out_features=64, 

KeyboardInterrupt: 