In [4]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""A sample run of the CyberBattle simulation"""

from typing import cast
from cyberbattle.mcts.node import Node
import gymnasium as gym
import numpy as np
import logging
import sys
import math
import copy
from cyberbattle._env.cyberbattle_env import CyberBattleEnv

In [5]:

"""Entry point if called as an executable"""

root = logging.getLogger()
root.setLevel(logging.DEBUG)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
root.addHandler(handler)

# Disable logging
handler.setLevel(logging.CRITICAL + 1)
root.setLevel(logging.CRITICAL + 1)


In [None]:
class MCTS:
    def __init__(self, env):
        self.env = env
        self.root = Node(
            env=env,
            parent=None
        )

    def random_playout(self, node):
        print("Starting random playout")
        env = copy.deepcopy(node.env)
        reward = 0
        counter = 0
        while not env.is_done():
            counter += 1
            action = env.sample_valid_action()
            _, reward, _, _, _ = env.step(action)


            print(f"Action: {action}, reward: {reward}")
            if counter >= 10:
                return -100
        return reward

    def step(self, current_node):
        if current_node == None:
            print("BUG")
            return
        if len(current_node.children) == 0:
            current_node.expand()
            value = self.random_playout(current_node)
            current_node.backpropagate(value)
            return
        else:
            children_scores = [c.eval() for c in current_node.children]
            best_children = np.random.choice(
                np.argwhere(children_scores == np.amax(children_scores)).flatten()
            )
            current_node = current_node.children[best_children]
            self.step(current_node)
        return


In [21]:
# for _ in range(1):
# env = cast(CyberBattleEnv, gym.make("CyberBattleToyCtf-v0"))
env = cast(CyberBattleEnv, gym.make("SimpleNetwork-v1"))
obs, _ = env.reset()
mcts = MCTS(env=env)
current_node = mcts.root
# while True:
if env.is_done():
    env.render()
else:
    # mcts = MCTS(env=env)
    for _ in range(1):
        mcts.step(current_node)
    # pi = current_node.get_node_probabilities()
    # chosen_child = np.random.choice(len(current_node.children),p=pi)
    # env.step(current_node.children[chosen_child].action)
    # print(current_node.children[chosen_child].action)
    # env.render()
    # current_node = current_node.children[chosen_child]
    # print(pi)

Starting random playout
Action: {'local_vulnerability': array([0, 1], dtype=int32)}, reward: 0.0
Action: {'connect': array([0, 1, 0, 0], dtype=int32)}, reward: 0.0
Action: {'local_vulnerability': array([0, 1], dtype=int32)}, reward: 0.0
Action: {'connect': array([0, 0, 0, 0], dtype=int32)}, reward: 0.0
Action: {'connect': array([0, 0, 0, 0], dtype=int32)}, reward: 0.0
Action: {'connect': array([0, 1, 0, 0], dtype=int32)}, reward: 0.0
Action: {'local_vulnerability': array([0, 0], dtype=int32)}, reward: 6
Action: {'local_vulnerability': array([0, 1], dtype=int32)}, reward: 0.0
Action: {'connect': array([0, 1, 0, 0], dtype=int32)}, reward: 0.0
Action: {'connect': array([0, 0, 0, 0], dtype=int32)}, reward: 0.0


In [8]:
print(current_node.children)
print(mcts.root.children[0].get_node_probabilities())
print(mcts.root.accumulated_reward)
print(mcts.root.children[0].accumulated_reward)
print(mcts.root.children[1].accumulated_reward)

[<cyberbattle.mcts.node.Node object at 0x75e32f0ff2b0>, <cyberbattle.mcts.node.Node object at 0x75e32f0ffe50>]
[]
5000.0
0
0
