In [39]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""A sample run of the CyberBattle simulation"""

from typing import cast
import gymnasium as gym
import numpy as np
import logging
import sys
import copy
from cyberbattle._env.cyberbattle_env import CyberBattleEnv

In [None]:

"""Entry point if called as an executable"""

root = logging.getLogger()
root.setLevel(logging.DEBUG)

# handler = logging.StreamHandler(sys.stdout)
# handler.setLevel(logging.DEBUG)
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# handler.setFormatter(formatter)
# root.addHandler(handler)

env = cast(CyberBattleEnv, gym.make("SimpleNetwork-v0"))


In [41]:
def greedy_search(env, depth=1):
    if depth == 0:
        return None, 0

    best_action = None
    best_reward = float("-inf")

    print(f"All valid actions in depth {depth} {env.get_all_valid_actions()}")

    for action in env.get_all_valid_actions():
        env_copy = copy.deepcopy(env)
        _, reward, done, truncated, _ = env_copy.step(action)

        if not done and not truncated:
            _, future_reward = greedy_search(env_copy, depth-1)
            reward += future_reward

        if reward > best_reward:
            best_reward = reward
            best_action = action

    return best_action, best_reward

In [42]:
obs, _ = env.reset()
total_reward = 0

done = False
truncated = False

while not done and not truncated:
    best_action, expected_reward = greedy_search(env, depth=1)
    print("Chosen action:", best_action, "| expected cumulative reward:", expected_reward)

    obs, reward, done, truncated, info = env.step(best_action)
    print(reward)
    total_reward += reward

print("Total reward collected:", total_reward)
env.close()

All valid actions in depth 1 [{'local_vulnerability': array([0, 0], dtype=int32)}]
Chosen action: {'local_vulnerability': array([0, 0], dtype=int32)} | expected cumulative reward: 11.0
11.0
All valid actions in depth 1 [{'local_vulnerability': array([0, 0], dtype=int32)}, {'remote_vulnerability': array([0, 1, 0], dtype=int32)}, {'remote_vulnerability': array([0, 1, 1], dtype=int32)}]
Chosen action: {'remote_vulnerability': array([0, 1, 1], dtype=int32)} | expected cumulative reward: 11.0
11.0
All valid actions in depth 1 [{'local_vulnerability': array([0, 0], dtype=int32)}, {'remote_vulnerability': array([0, 1, 0], dtype=int32)}, {'remote_vulnerability': array([0, 1, 1], dtype=int32)}, {'remote_vulnerability': array([0, 2, 0], dtype=int32)}, {'remote_vulnerability': array([0, 2, 1], dtype=int32)}]
Chosen action: {'remote_vulnerability': array([0, 2, 0], dtype=int32)} | expected cumulative reward: 9.0
9.0
All valid actions in depth 1 [{'local_vulnerability': array([0, 0], dtype=int32)},

KeyboardInterrupt: 