In [None]:
import os
import torch as T
import torch.cuda
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
device

# Инициализация нейронной сети

In [None]:
class DeepQNetwork(nn.Module):
    def __init__(self, lr, n_actions, name, input_dims, chkpt_dir):
        super(DeepQNetwork, self).__init__()
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir, name)

        # you may want to play around with this and forward()
        self.fc1 = nn.Linear(input_dims[0], 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, n_actions)

        self.optimizer = optim.RMSprop(self.parameters(), lr=lr)

        self.loss = nn.MSELoss()
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    # you may want to play around with this
    def forward(self, state):
        flat1 = F.relu(self.fc1(state))
        flat2 = F.relu(self.fc2(flat1))
        actions = self.fc3(flat2)
        return actions

    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file))

# Буфер воспроизведения

In [None]:
import numpy as np

class ReplayBuffer(object):
    def __init__(self, max_size, input_shape, n_actions):
        self.mem_size = max_size
        self.mem_cntr = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape),
                                     dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_shape),
                                         dtype=np.float32)

        self.action_memory = np.zeros(self.mem_size, dtype=np.int64)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.terminal_memory[index] = done
        self.mem_cntr += 1

    def sample_buffer(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_mem, batch_size, replace=False)

        states = self.state_memory[batch]
        actions = self.action_memory[batch]
        rewards = self.reward_memory[batch]
        states_ = self.new_state_memory[batch]
        terminal = self.terminal_memory[batch]

        return states, actions, rewards, states_, terminal

# DDQN агент

In [None]:
from CybORG.Agents.SimpleAgents.BaseAgent import BaseAgent

class DQNAgent(BaseAgent):
    def __init__(self, gamma=0.9, epsilon=0, lr=0.1, n_actions=41, input_dims=(52,),
                 mem_size=1000, batch_size=32, eps_min=0.01, eps_dec=5e-7,
                 replace=1000, algo='DDQN', env_name='Scenario1b', chkpt_dir='chkpt', load=False):
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.batch_size = batch_size
        self.eps_min = eps_min
        self.eps_dec = eps_dec
        self.replace_target_cnt = replace
        self.algo = algo
        self.env_name = env_name
        self.chkpt_dir = chkpt_dir
        self.action_space = [i for i in range(n_actions)]
        self.learn_step_counter = 0

        self.memory = ReplayBuffer(mem_size, input_dims, n_actions)

        self.q_eval = DeepQNetwork(self.lr, self.n_actions,
                                        input_dims=self.input_dims,
                                        name=self.env_name+'_'+self.algo+'_q_eval',
                                        chkpt_dir=self.chkpt_dir)
        self.q_next = DeepQNetwork(self.lr, self.n_actions,
                                        input_dims=self.input_dims,
                                        name=self.env_name+'_'+self.algo+'_q_next',
                                        chkpt_dir=self.chkpt_dir)

    # if epsilon=0 it will just use the model
    def get_action(self, observation, action_space=None):
        if np.random.random() > self.epsilon:
            state = T.tensor([observation], dtype=T.float).to(self.q_eval.device)
            actions = self.q_eval.forward(state)
            action = T.argmax(actions).item()
        else:
            action = np.random.choice(self.action_space)

        return action

    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def sample_memory(self):
        state, action, reward, new_state, done = \
                                self.memory.sample_buffer(self.batch_size)

        states = T.tensor(state).to(self.q_eval.device)
        rewards = T.tensor(reward).to(self.q_eval.device)
        dones = T.tensor(done).to(self.q_eval.device)
        actions = T.tensor(action).to(self.q_eval.device)
        states_ = T.tensor(new_state).to(self.q_eval.device)

        return states, actions, rewards, states_, dones

    def replace_target_network(self):
        if self.replace_target_cnt is not None and \
           self.learn_step_counter % self.replace_target_cnt == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def decrement_epsilon(self):
        self.epsilon = self.epsilon - self.eps_dec \
                           if self.epsilon > self.eps_min else self.eps_min

    def train(self):
        if self.memory.mem_cntr < self.batch_size:
            return
        self.q_eval.optimizer.zero_grad()
        self.replace_target_network()
        states, actions, rewards, states_, dones = self.sample_memory()
        indices = np.arange(self.batch_size)
        q_pred = self.q_eval.forward(states)[indices, actions]
        q_next = self.q_next.forward(states_)
        q_eval = self.q_eval.forward(states_)
        max_actions = T.argmax(q_eval, dim=1)
        q_next[dones] = 0.0
        q_target = rewards + self.gamma*q_next[indices, max_actions]
        loss = self.q_eval.loss(q_target, q_pred).to(self.q_eval.device)
        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1
        self.decrement_epsilon()

    def end_episode(self):
        pass

    def set_initial_values(self, action_space, observation):
        pass

    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

In [None]:
from CybORG import CybORG
from CybORG.Agents import RedMeanderAgent, B_lineAgent
from CybORG.Agents.Wrappers import ChallengeWrapper
from CybORG.Agents.Wrappers.TrueTableWrapper import true_obs_to_table
import inspect
import os

In [None]:
PATH = str(inspect.getfile(CybORG))
PATH = PATH[:-10] + '/Shared/Scenarios/Scenario1b.yaml'

In [None]:
def cuda():
    print("CUDA: " + str(torch.cuda.is_available()))

# Получение таблиц состояний среды

In [None]:
def get_tables(eps_len=100,chkpt_dir="model_meander", red_agent=RedMeanderAgent):
    lr=0.0001
    eps_dec=0.000005
    eps_min=0.05
    gamma=0.99
    batch_size=32
    epsilon=0
    mem_size=5000
    replace=1000
    cyborg = CybORG(PATH, 'sim', agents={
        'Red': red_agent
    })
    wrapped_cyborg = ChallengeWrapper(env=cyborg, agent_name="Blue")

    model_dir = os.path.join(os.getcwd(), "Models", chkpt_dir)
    print(model_dir)
    # the default epsilon is 0. we also don't need to define most hyperparamters since all we will do is agent.get_action()
    agent = DQNAgent(gamma=gamma, epsilon=0, lr=lr,
                     input_dims=(wrapped_cyborg.observation_space.shape),
                     n_actions=wrapped_cyborg.action_space.n, mem_size=mem_size, eps_min=eps_min,
                     batch_size=batch_size, replace=replace, eps_dec=eps_dec,
                     chkpt_dir=model_dir, algo='DDQNAgent',
                     env_name='Scenario1b')
    # gets the checkpoint from model_dir
    agent.load_models()

    blue_moves = []
    blue_move_numbers = []
    red_moves = []
    green_moves = []
    table_file = 'visualisation/logs_to_vis/results.txt'
    with open(table_file, 'w+') as table_out:
        table_out.write('\n')
    observation= wrapped_cyborg.reset()
    #print(observation)
    agent_name = 'Blue'
    action_space = wrapped_cyborg.get_action_space(agent_name)
    specialist_agent_names = {0: 'b_lineAgent', 1: 'meanderAgent'}
    count_agent_dist = [0,0]
    controller_moves = []
    moves = []
    successes = []
    tables = []
    total_reward = 0
    actions = []
    rewards = []
    for j in range(100):
        action = agent.get_action(observation, action_space)

        # Sample the agent selected by our hierarchy controller
        observation, rew, done, info = wrapped_cyborg.step(action)

        blue_moves += [info['action'].__str__()]
        blue_move_numbers += [action]
        red_moves += [wrapped_cyborg.get_last_action('Red').__str__()]

        green_moves += [wrapped_cyborg.get_last_action('Green').__str__()]

        red_move = wrapped_cyborg.get_last_action('Red').__str__()
        blue_move = wrapped_cyborg.get_last_action('Blue').__str__()
        green_move = wrapped_cyborg.get_last_action('Green').__str__()
        true_state = cyborg.get_agent_state('True')
        true_table = true_obs_to_table(true_state, cyborg)
        success_observation = wrapped_cyborg.get_attr('environment_controller').observation
        blue_success = success_observation['Blue'].action_succeeded
        red_success = success_observation['Red'].action_succeeded
        green_success = success_observation['Green'].action_succeeded
        #controller_moves.append(agent_selected_name)
        moves.append((blue_move, green_move, red_move))
        successes.append((blue_success, green_success, red_success))
        tables.append(true_table)
        total_reward += rew
        rewards.append(rew)


    with open(table_file, 'a+') as table_out:
        for move in range(len(moves)):
            table_out.write('\n----------------------------------------------------------------------------\n')
            #table_out.write('Agent Selected: {}\n'.format(controller_moves[move]))
            table_out.write('Blue Action: {}\n'.format(moves[move][0]))
            table_out.write('Reward: {}, Episode reward: {}\n'.format(rewards[move], total_reward))
            table_out.write('Network state:\n')
            #table_out.write('Scanned column likely inaccurate.\n')
            table_out.write(str(tables[move]))
            table_out.write('\n.\n\n')

    print('Controller distribution: {} b_lineAgent, {}, RedMeanderAgent'.format(count_agent_dist[0], count_agent_dist[1]))
    return total_reward

# Пример получения таблицы стратегии с желаемой наградой

In [None]:
a = get_tables()
while a<-50:
    a = get_tables()

In [None]:
a

# Визуализация таблиц состояний

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib inline
import numpy as np
from PIL import Image

In [None]:
os.chdir('./visualisation/')
print(os.getcwd())

In [None]:
fig, ax = plt.subplots(figsize=(20, 8))
img = mpl.image.imread("./img/figure1.png")
plt.imshow(img)

In [None]:
im = Image.open('./img/figure1.png')

# Create figure and axes
fig, ax = plt.subplots(figsize=(20, 8))

# Display the image
ax.imshow(im)

known_shade = (1,0,0,0.5) # facecolor
unknown_shade = (0,0,1,0.5) # facecolor
access_none = 'b' # edgecolor
access_user = 'y' # edgecolor
access_priv = 'r' # edgecolor
e0_loc = (460, 55)
e1_loc = (532, 55)
e2_loc = (603, 55)
ops_loc = (968, 55)
op_host0_loc = (908, 315)
op_host1_loc = (973, 315)
op_host2_loc = (1035, 315)
user0_loc = (0, 317)
user1_loc = (67, 317)
user2_loc = (135, 317)
user3_loc = (204, 317)
user4_loc = (273, 317)
defender_loc = (535, 315)
server_shape_w = 65
server_shape_h = 85
host_shape_w = 62
host_shape_h = 62


# Create a Rectangle patch
rect = patches.Rectangle(defender_loc, host_shape_w, host_shape_h, linewidth=3, 
                         edgecolor=access_none, facecolor=unknown_shade)

# Add the patch to the Axes
ax.add_patch(rect)

plt.show()

In [None]:
known_shade = (1,0,0,0.5) # facecolor
unknown_shade = (0,0,1,0.5) # facecolor
access_none = 'b' # edgecolor
access_user = 'y' # edgecolor
access_priv = 'r' # edgecolor
# e0_loc = (460, 55)
# e1_loc = (532, 55)
# e2_loc = (603, 55)
# ops_loc = (968, 55)
# op_host0_loc = (908, 315)
# op_host1_loc = (973, 315)
# op_host2_loc = (1035, 315)
# user0_loc = (0, 317)
# user1_loc = (67, 317)
# user2_loc = (135, 317)
# user3_loc = (204, 317)
# user4_loc = (273, 317)
# defender_loc = (535, 315)
server_shape_w = 65
server_shape_h = 85
host_shape_w = 62
host_shape_h = 62

def get_loc(hostname):
    if hostname == "Enterprise0":
        loc = (460, 55)
    elif hostname == "Enterprise1":
        loc = (532, 55)
    elif hostname == "Enterprise2":
        loc = (603, 55)
    elif hostname == "Op_Server0":
        loc = (968, 55)
    elif hostname == "Op_Host0":
        loc = (908, 315)
    elif hostname == "Op_Host1":
        loc = (973, 315)
    elif hostname == "Op_Host2":
        loc = (1035, 315)
    elif hostname == "User0":
        loc = (0, 317)
    elif hostname == "User1":
        loc = (67, 317)
    elif hostname == "User2":
        loc = (135, 317)
    elif hostname == "User3":
        loc = (204, 317)
    elif hostname == "User4":
        loc = (273, 317)
    elif hostname == "Defender":
        loc = (535, 315) 
    return loc

def add_patch(hostname, known_state, access_state, scanned=False):
    # Translate variables to colors
    if known_state:
        facecolor_shade = known_shade
    else:
        facecolor_shade = unknown_shade
    if access_state == "None":
        edgecolor_shade = access_none
    elif access_state == "User":
        edgecolor_shade = access_user
    else: # access_state == "Privileged"
        edgecolor_shade = access_priv
        
    if "Enterprise" in hostname or "Server" in hostname:
        shape_w = server_shape_w
        shape_h = server_shape_h
    else: # Host
        shape_w = host_shape_w
        shape_h = host_shape_h
    loc = get_loc(hostname)
    
    # Create a Rectangle patch
    rect = patches.Rectangle(loc, shape_w, shape_h, linewidth=3, 
                             edgecolor=edgecolor_shade, facecolor=facecolor_shade)
    # Add the patch to the Axes
    ax.add_patch(rect)

In [None]:
def parse_results(file_name, dest_file_name=""):
    with open(file_name, 'r') as fp:
        lines = fp.readlines()
    results_json = []
    step = {"hosts":[]}
    header = False
    for i in range(len(lines)):
        l = lines[i]
        if "Blue Action: " in l:
            step["blue_action"] = l.lstrip('Blue Action: ').strip()
            #print( l.lstrip('Blue Action: ').strip())
        elif "Reward: " in l:
            rewards = l.split(",")
            step["reward"] = rewards[0].lstrip("Reward: ")
            step["ep_reward"] = rewards[1].strip().lstrip("Episode reward: ")
            #print(f"parsed: {step['reward']} and {step['ep_reward']}")
        elif "+" in l:
            if "Subnet" in lines[i+1]:
                header = True
            else: # Header started so this is end of header
                header = False
        elif not header and "|" in l:
            row = l.split("|")
            attr = {} # attributes listed by hostname
            attr["subnet"] = row[1].strip()
            attr["ip_addr"] = row[2].strip()
            attr["known"] = "True" in row[4].strip()
    #         print("True" in attr["known"])
            attr["scanned"] = "True" in row[5].strip()
            attr["access"] = row[6].strip()
            attr["hostname"] = row[3].strip()
            step["hosts"].append(attr)
    #         print(f"{hostname} ({ip_addr} with subnet {subnet}) is known ({known}), scanned ({scanned}), and access ({access})")
    #         print(row)
            #print(step['hosts'][0])
        elif '----------------------------------------------------------------------------' in l and not step == {"hosts":[]}:
            #print('я работаю')
            results_json.append(step)
            step = {"hosts":[]}
        else:
            continue
        

    if dest_file_name == "":
        dest_file_name = file_name.rstrip(".txt") + ".json"
    with open(dest_file_name, 'w') as fp:
        fp.write(str(results_json))
    return results_json

In [None]:
results_json = parse_results(file_name='./logs_to_vis/-30.1.txt')

In [None]:
results_json

# Получение картинок из таблиц

In [None]:
r = 0.0
for i in range(len(results_json)):
    img = results_json[i]
    base = Image.open('./img/figure1.png')
    # Create figure and axes
    fig, ax = plt.subplots(figsize=(20, 8))
    # Display the image
    ax.imshow(base)
    
    for host in img["hosts"]:
    #     def add_patch(hostname, known_state, access_state, scanned=False):
#         print(host)
        add_patch(host["hostname"], host["known"], host["access"], host["scanned"])
    r+=float(img['reward'])
    if i == 0:
        plt.title("Starting...")
    else:
        plt.title(f"Blue Action: {img['blue_action']} \nReward: {img['reward']} \nEp Reward: {r} \nStep:{i}")
    plt.axis('off')
    plt.savefig(f"./img/strategy//img{i}.png")
    plt.show()

# сохранить картинки в GIF

In [None]:
import glob
import re
# filepaths
fp_in = "img/strategy/img*.png"
fp_out = "results.gif"

def atoi(text):
    return int(text) if text.isdigit() else text

# Функция для сортировки с учетом чисел в названии файлов
def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]

# Сортируем файлы по числовому порядку
file_list = sorted(glob.glob(fp_in), key=natural_keys)

img, *imgs = [Image.open(f) for f in file_list]
img.save(fp=fp_out, format='GIF', append_images=imgs,
         save_all=True, duration=400, loop=0)
