In [1]:
from wrapper.turbo_wrapper_greedy import TGraphWrapper

from CybORG import CybORG
from CybORG.Simulator.Scenarios.EnterpriseScenarioGenerator import EnterpriseScenarioGenerator
from CybORG.Agents import SleepAgent, EnterpriseGreenAgent, FiniteStateRedAgent
from collections import deque
import random
import numpy as np
from submission import Submission
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import traceback
from joblib import Parallel, delayed
import os
import torch.multiprocessing as mp
from types import SimpleNamespace

In [2]:
# class DQN(nn.Module):
#     def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, output_dim):
#         super().__init__()
#         self.fc = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim_1),
#             nn.ReLU(),
#             nn.Linear(hidden_dim_1, hidden_dim_2),
#             nn.ReLU(),
#             nn.Linear(hidden_dim_2, output_dim)
#         )

#     def forward(self, x):
#         return self.fc(x)

class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

In [3]:
class DQNAgent:
    def __init__(self, agent, input_dim, hidden_dim_1, output_dim, lr, memory_capacity):
        self.device = device
        self.policy_net = DQN(input_dim, hidden_dim_1, output_dim).to(self.device)
        self.target_net = DQN(input_dim, hidden_dim_1, output_dim).to(self.device)
        self.name = agent
        print(f"agent {self.name} initiation")
        try:
            # raise Exception
            self.policy_net.load_state_dict(torch.load(f"dqn_Turbo_Greedy_policy_{agent}_net_Jul25.pth"))
            self.target_net.load_state_dict(torch.load(f"dqn_Turbo_Greedy_target_{agent}_net_Jul25.pth"))
            self.policy_net.train()
            self.target_net.eval()
            print("model loaded successfully")
        except Exception:
            # traceback.print_exc()
            print("model not loaded")
            self.target_net.load_state_dict(self.policy_net.state_dict())
            self.target_net.eval()

        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr)
        # self.criterion = nn.MSELoss()
        self.criterion = nn.SmoothL1Loss()
        self.replay_buffer = deque(maxlen=memory_capacity)
        
# """ Does this need to be edited? """
    def select_action(self, state, epsilon):
        if state[-1] == 0: #since this action is None
            return 0
        else:
            print(state[-1])
        if random.random() < epsilon:
            return random.randint(0, 6)  # 4 + 2 + 1 actions
        with torch.no_grad():
            q = self.policy_net(state.to(self.device))
            return q.argmax().item()
            
    def store_transition(self, *transition): 
        (state, deception, reward, next_state, done) = transition
        i = f"blue_agent_{self.name}"
        agent_transition = (state[self.name], deception[self.name], reward[i], next_state[self.name], done[self.name])
        self.replay_buffer.append(agent_transition)

    def sample_batch(self, batch_size):
        return random.sample(self.replay_buffer, batch_size)

    def update(self, batch, gamma):
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.cat(states).to(self.device)
        next_states = torch.cat(next_states).to(self.device)
        actions = torch.tensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(self.device)

        q_values = self.policy_net(states).gather(1, actions)
        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0].unsqueeze(1)
            target = -rewards + gamma * next_q * (1-dones) 

        loss = self.criterion(q_values, target)
        # loss = nn.SmoothL1Loss()(current_q_values, target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # print(f"loss is ({loss})")
        return loss.item()

In [4]:
# 0. Initializing with parallelization parameters
HYPER_PARAMS = SimpleNamespace(
    agent_num = 5,
    episode_length = 500,
    episode_num = 2000,
    workers = 1,  # Number of parallel environments
    batch_size = 512,
    memory_capacity = 30000,
    lr = 0.0008,
    tau = 0.0001,
    target_update_freq = 5,
    epsilon = 0.1,
    gamma = 0
)

MAX_THREADS = min(os.cpu_count(), 30)  # Limit to CPU count or 20 (default was 12)
torch.set_num_threads(MAX_THREADS)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 1234

In [5]:

# Create environments in parallel
def create_env(seed):
    sg = EnterpriseScenarioGenerator(
        blue_agent_class=SleepAgent,
        green_agent_class=EnterpriseGreenAgent,
        red_agent_class=FiniteStateRedAgent,
        steps=HYPER_PARAMS.episode_length,
    )
    cyborg = CybORG(sg, "sim", seed=seed)
    return TGraphWrapper(cyborg)

# Create multiple environments
envs = Parallel(n_jobs=HYPER_PARAMS.workers)(
    delayed(create_env)(seed + i) for i in range(HYPER_PARAMS.workers)
)
print("multiple environments created.")
# Create shared DQN agent
agents = []
for ag_num in range(HYPER_PARAMS.agent_num):
    if ag_num < HYPER_PARAMS.agent_num - 1:
        agent = DQNAgent(ag_num, input_dim=258, hidden_dim_1=512, output_dim=7, 
                         lr=HYPER_PARAMS.lr, memory_capacity=HYPER_PARAMS.memory_capacity)
    else:
        agent = DQNAgent(ag_num, input_dim=494, hidden_dim_1=1024, output_dim=7, 
                 lr=HYPER_PARAMS.lr, memory_capacity=HYPER_PARAMS.memory_capacity) 
    agents.append(agent)
    print(f"Agent {ag_num} Created!")
print(agents)
# Function to collect experiences from one environmen
@torch.no_grad()
def collect_experiences(env_idx, episode, hp):
    # Set thread limits for this process
    torch.set_num_threads(MAX_THREADS // hp.workers)
    
    env = envs[env_idx]
    transitions = []

    states = []
    obs_all, _ = env.reset()
    for agent, ag_num in zip(agents, range(HYPER_PARAMS.agent_num)):
        agent_name = f"blue_agent_{ag_num}"
        dict_obs = env.env.environment_controller.get_last_observation(agent_name).data
        state = env.encode_dict_obs_expnd_inpt(dict_obs, dict_obs, env.gen_obs[agent_name], env.manp_obs[agent_name]).to('cpu')  # Keep on CPU when collecting
        states.append(state)
    
    done = np.array([False, False, False, False, False])
    total_reward = 0
    step_count = 0
    
    epsilon_dyn = max((1 - ((1-hp.epsilon)*episode / (hp.episode_num/2))), hp.epsilon)
    
    while not done.any() and step_count < hp.episode_length:
        # Get action (on CPU to avoid CUDA overhead for small operations)
        deceptions = []
        for agent in agents:
            q_vals = agent.policy_net(states[agent.name].to(device))
            if env.M > 0:
                if random.random() < epsilon_dyn:
                    deception = random.randint(0, 6) #4 + 2 + 1 actions
                else:
                    deception = q_vals.argmax().item()
            else:
                deception = 0 #no action
            deceptions.append(deception)
            
        # Execute in environment
        obs_all, manp_dict_obs, reward, term, _, _, _, obs, manp_obs = env.modified_step(deceptions, obs_all)
        next_state = []; done = []
        for agent in agents:
            agent_name = f"blue_agent_{agent.name}"
            next_dict_obs = env.env.environment_controller.get_last_observation(agent_name).data
            next_state.append(env.encode_dict_obs_expnd_inpt(next_dict_obs, manp_dict_obs[agent_name], obs[agent_name], manp_obs[agent_name]).to('cpu'))
            done.append(term[f"blue_agent_{agent.name}"])
        # print(next_state)
        done = np.array(done)
        """whose reward?"""
        # Store transition
        transitions.append((states, deceptions, reward, next_state, done))
        
        states = next_state
        total_reward += (reward['blue_agent_0'] + reward['blue_agent_1'] + reward['blue_agent_2'] + reward['blue_agent_3'] + reward['blue_agent_4'])/5  
        # print(f"reward blue agent 0: {reward['blue_agent_0']} \n reward blue agent 1: {reward['blue_agent_1']} \nreward blue agent 2: {reward['blue_agent_2']} \nreward blue agent 3: {reward['blue_agent_3']} \n reward blue agent 4: {reward['blue_agent_4']} ")
        # print()
        step_count += 1
    """ADD THIS AS AN EDIT!!!"""
    env.M = 500
    return transitions, total_reward

env initialized
in init: 0
multiple environments created.
agent 0 initiation
model not loaded
Agent 0 Created!
agent 1 initiation
model not loaded
Agent 1 Created!
agent 2 initiation
model not loaded
Agent 2 Created!
agent 3 initiation
model not loaded
Agent 3 Created!
agent 4 initiation
model not loaded
Agent 4 Created!
[<__main__.DQNAgent object at 0x15309c1984d0>, <__main__.DQNAgent object at 0x15309c199bb0>, <__main__.DQNAgent object at 0x15304c508620>, <__main__.DQNAgent object at 0x15304c5089b0>, <__main__.DQNAgent object at 0x15304c508d40>]


In [None]:
# Main training loop
all_loss = np.zeros(HYPER_PARAMS.episode_num)
all_rewards = np.zeros(HYPER_PARAMS.episode_num)

for episode in tqdm(range(0, HYPER_PARAMS.episode_num, HYPER_PARAMS.workers)):
    # Collect experiences in parallel
    results = Parallel(prefer='processes', n_jobs=HYPER_PARAMS.workers)(
        delayed(collect_experiences)(i % len(envs), episode + i, HYPER_PARAMS) 
        for i in range(min(HYPER_PARAMS.workers, HYPER_PARAMS.episode_num - episode))
    )
    
    # Process collected experiences
    all_transitions = []
    for i, (transitions, reward) in enumerate(results):
        if episode + i < HYPER_PARAMS.episode_num:
            all_rewards[episode + i] = reward
            all_transitions.extend(transitions)
    
    # print(len(all_transitions))
    # Add all collected transitions to replay buffer
    for transition in all_transitions:
        for agent in agents:
            agent.store_transition(*transition)
    for agent in agents:
        # Perform learning updates if buffer has enough samples
        # print(len(agent.replay_buffer))
        if len(agent.replay_buffer) >= HYPER_PARAMS.batch_size:
            # Perform multiple updates
            update_count = len(all_transitions) // 10  # Adjust this ratio as needed
            # print(all_transitions)
            # print()
            for _ in range(update_count):
                batch = agent.sample_batch(HYPER_PARAMS.batch_size)
                loss = agent.update(batch, gamma=HYPER_PARAMS.gamma)
                avg_loss = loss / update_count
                for i in range(min(HYPER_PARAMS.workers, HYPER_PARAMS.episode_num - episode)):
                    if episode + i < HYPER_PARAMS.episode_num:
                        all_loss[episode + i] += avg_loss/5 #editted
        
        # Soft update target network
        with torch.no_grad():
            for target_param, policy_param in zip(agent.target_net.parameters(), agent.policy_net.parameters()):
                target_param.data.copy_((1 - HYPER_PARAMS.tau) * target_param.data + HYPER_PARAMS.tau * policy_param.data)
        
    # Save checkpoint occasionally
    if episode % 20 == 0:
        # torch.save(agent.policy_net.state_dict(), f"dqn_policy_{agent.name}_net_checkpoint_{episode}.pth")
        np.save('dqn_training_loss_Jul25_Turbo_Greedy.npy', all_loss[:episode])
        print(all_loss[episode])
        np.save('dqn_training_rewards_Jul25_Turbo_Greedy.npy', all_rewards[:episode])

# Save final model
for agent in agents:
    torch.save(agent.policy_net.state_dict(), f'dqn_Turbo_Greedy_policy_{agent.name}_net_Jul25.pth')
    torch.save(agent.target_net.state_dict(), f'dqn_Turbo_Greedy_target_{agent.name}_net_Jul25.pth')


  0%|          | 1/2000 [00:32<17:52:06, 32.18s/it]

0.0


  1%|          | 21/2000 [12:47<20:27:21, 37.21s/it]

0.18594790119297638


  2%|▏         | 41/2000 [25:25<20:38:09, 37.92s/it]

0.1935249860797609


  3%|▎         | 61/2000 [38:07<20:41:54, 38.43s/it]

0.19758994786106807


  4%|▍         | 81/2000 [50:46<19:51:29, 37.25s/it]

0.21452846910272322


  5%|▌         | 101/2000 [1:03:04<19:21:09, 36.69s/it]

0.21902220370818165


  6%|▌         | 118/2000 [1:13:44<20:00:52, 38.29s/it]

In [None]:
for agent in agents:
    torch.save(agent.policy_net.state_dict(), f'dqn_Turbo_Greedy_policy_{agent.name}_net_Jul25.pth')
    torch.save(agent.target_net.state_dict(), f'dqn_Turbo_Greedy_target_{agent.name}_net_Jul25.pth')


In [None]:
aaa = np.load("dqn_training_loss_Jul25_Turbo_Greedy.npy")
print(aaa)

In [None]:
import matplotlib.pyplot as plt
plt.plot(aaa)

In [None]:
bbb = np.load("dqn_training_rewards_Jul25_Turbo_Greedy.npy")
print(bbb)
plt.plot(bbb)

In [None]:
S = 9
H = 16
i = 1
l = 1 + (3*S + 2*H)*i + 4*8
print(l)

In [None]:
S = 9
H = 16
i = 3
l = 1 + (3*S + 2*H)*i + 4*8
print(l)