In [19]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""A sample run of the CyberBattle simulation"""

from typing import cast
import gymnasium as gym
import logging
import sys
from cyberbattle._env.cyberbattle_env import CyberBattleEnv

In [20]:
"""Entry point if called as an executable"""

from cyberbattle._env import cyberbattle_env


root = logging.getLogger()
root.setLevel(logging.CRITICAL)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.CRITICAL)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
root.addHandler(handler)


env = cast(
    CyberBattleEnv,
    gym.make(
        "SimpleNetwork-v1",
        attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0),
    ).unwrapped,
)

In [21]:
env.get__owned_nodes_indices()

[0]

In [22]:
env.render()

Unnamed: 0_level_0,status,properties,local_attacks,remote_attacks
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Starting_Client,owned,[Win11],"[ReadSourceCode_LeakedNode, Trap, GetBackUpCre...",[]


In [23]:
print(f"valid actions: {env.get_all_valid_actions()}")

valid actions: [{'local_vulnerability': array([0, 0], dtype=int32)}, {'local_vulnerability': array([0, 1], dtype=int32)}, {'local_vulnerability': array([0, 2], dtype=int32)}]


In [24]:
import copy
import math
import numpy as np


class Node:
    def __init__(self, env, parent, action=None, reward=0, done=False):
        self.env = env
        self.parent = parent
        self.children = []
        self.visits = 0
        self.action = action
        self.accumulated_reward = 0
        self.mean_reward = 0
        self.tried_actions = []
        self.value = 0.0
        self.reward = reward
        self.done = done

    def get_valid_untried_action_mask(self):
        valid_actions = self.env.compute_action_mask()
        for action in self.tried_actions:
            for key, val in action.items():
                if key == "local_vulnerability":
                    valid_actions["local_vulnerability"][tuple(val)] = 0
                if key == "remote_vulnerability":
                    valid_actions["remote_vulnerability"][tuple(val)] = 0
                if key == "connect":
                    valid_actions["connect"][tuple(val)] = 0
        return valid_actions

    def is_fully_expanded(self):
        actions = self.get_valid_untried_action_mask()
        return not (
            actions["local_vulnerability"].any()
            or actions["remote_vulnerability"].any()
            or actions["connect"].any()
        )

    def is_terminal(self):
        return self.env.is_done()

    def backpropagate(self, result):
        self.visits += 1
        self.accumulated_reward = self.accumulated_reward + result
        self.mean_reward = self.accumulated_reward / self.visits
        if self.parent:
            self.parent.backpropagate(result)

    def get_all_valid_untried_actions(self):
        mask = self.get_valid_untried_action_mask()
        actions = []

        lv_idx = np.argwhere(mask["local_vulnerability"] == 1)
        for idx in lv_idx:
            actions.append({"local_vulnerability": idx})

        rv_idx = np.argwhere(mask["remote_vulnerability"] == 1)
        for idx in rv_idx:
            actions.append({"remote_vulnerability": idx})

        conn_idx = np.argwhere(mask["connect"] == 1)
        for idx in conn_idx:
            actions.append({"connect": idx})

        return actions

    def expand(self):
        for action in self.get_all_valid_untried_actions():
            new_state = copy.deepcopy(self.env)
            observation, reward, done, truncated, info = new_state.step(action)
            self.tried_actions.append(action)
            child = Node(
                new_state, parent=self, action=action, reward=reward, done=done
            )
            if new_state.is_done():
                child.done = True
            self.children.append(child)

    def get_random_untried_action(self):
        untried_action = self.env.sample_valid_action()
        while not self.env.apply_mask(
            untried_action, self.get_valid_untried_action_mask()
        ):
            untried_action = self.env.sample_valid_action()
        return untried_action

    def eval(self):
        if self.visits == 0:
            return float("inf")
        exploit = self.accumulated_reward / self.visits
        explore = math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploit  # + explore

    def get_node_probabilities(self):
        temperature = 1.0
        visits = np.array([x.visits if x else 0 for x in self.children])
        if len(visits) == 0:
            return np.array([])
        visits = np.array([x ** (1.0 / temperature) for x in visits])
        probabilities = visits / np.sum(visits)
        probabilities[-1] = max(0, 1 - np.sum(probabilities[0:-1]).item())
        return probabilities

In [25]:
class MCTS:
    def __init__(self, env):
        self.env = env
        self.root = Node(env=env, parent=None)

    def random_playout(self, node):
        print("Starting random playout")
        env = copy.deepcopy(node.env)
        reward = 0
        counter = 0
        action_t0 = None
        if node.done:
            print(
                f"Starting random playout from terminal node, backpropagating {node.reward}"
            )
            return node.reward

        while not env.is_done():
            action = np.random.choice(env.get_all_valid_actions())
            print(f"\t Played: {action}, mean_r: {reward/(counter+1)}")
            _, rt, _, _, _ = env.step(action)
            reward += rt
            if counter == 0:
                action_t0 = action
            counter += 1
            if counter >= 20:
                break
        if counter > 0:
            return reward / counter
        else:
            return reward

    def step(self, current_node: Node):
        print("Enter step")
        if current_node == None:
            print("BUG")
            return
        if current_node.done:
            print("Node is done ", current_node.accumulated_reward)
            current_node.backpropagate(current_node.accumulated_reward)
            return
        if len(current_node.children) == 0:
            current_node.expand()
            random_child = np.random.choice(current_node.children)

            value = self.random_playout(random_child)
            random_child.backpropagate(value)
            return
        else:
            print("going down one level")
            children_scores = [c.eval() for c in current_node.children]
            print(children_scores)
            best_children = np.argmax(children_scores)
            # np.random.choice(
            #     np.argwhere(children_scores == np.amax(children_scores)).flatten()
            # )
            print(f"picked action {best_children}")
            current_node = current_node.children[best_children]
            self.step(current_node)
        return

In [26]:
mcts = MCTS(env=env)
current_node = mcts.root

In [27]:
mcts.step(current_node)

Enter step
Starting random playout
Starting random playout from terminal node, backpropagating 13


In [28]:
env.is_done()

False

In [29]:
mcts.root.accumulated_reward

13

In [35]:
mcts.step(current_node)

Enter step
going down one level
[26.0, 0.0, 0.0]
picked action 0
Enter step
Node is done  104
