**Before we declare the Environment we should shape our Message Object model and environment config model**

In [61]:
"""
Each Message Object contains the following attributes:
    size: size of the message
    source: source task [Task where the message is generated]
    destination: destination task [Task where the message is meant to be sent]
"""


class MessageObject:
    def __init__(self, size, source, destination):
        self.size = size
        self.source = source
        self.destination = destination

    def to_array(self):
        return [self.size, self.source, self.destination]
    
    def __getsize__(self):
        return self.size

In [62]:
'''
  Contains name and scaling factors of all the reward functions
  Negative rewards:
    message_pool_invalid_index_reward: when the message pool is accessed with an invalid index
    comm_bus_invalid_index_reward: when the communication bus is accessed with an invalid index
    comm_bus_overflow_reward: when the communication bus is overflowed
  Positive rewards:
    message_store_reward: when a message is stored in the message pool
    message_send_reward: when a message is sent from the message pool to the communication bus
    message_pool_cost_decrease_reward: when the message pool cost decreases
    comm_bus_cost_decrease_reward: when the communication bus cost decreases
'''
class RewardConfig:
    def __init__(
        self,
        message_pool_invalid_index=-1,
        comm_bus_invalid_index=-1,
        comm_bus_overflow=-1,
        message_pool_no_change_reward = -0.1,
        comm_bus_no_change_reward = -0.1,
        message_store=1,
        message_send=1,
        message_pool_cost_decrease=1,
        comm_bus_cost_decrease=1,
    ):
        # Negative rewards
        self.message_pool_invalid_index_reward = message_pool_invalid_index
        self.message_pool_no_change_reward = message_pool_no_change_reward
        self.comm_bus_invalid_index_reward = comm_bus_invalid_index
        self.comm_bus_overflow_reward = comm_bus_overflow
        self.comm_bus_no_change_reward = comm_bus_no_change_reward
        # Positive rewards
        self.message_store_reward = message_store
        self.message_send_reward = message_send
        self.message_pool_cost_decrease_reward = message_pool_cost_decrease
        self.comm_bus_cost_decrease_reward = comm_bus_cost_decrease

    def load_from_file(self, file_path):
        with open(file_path, "r") as file:
            data = json.load(file)
            self.message_pool_invalid_index_reward = data.get(
                "message_pool_invalid_index_reward",
                self.message_pool_invalid_index_reward,
            )
            self.comm_bus_invalid_index_reward = data.get(
                "comm_bus_invalid_index_reward", self.comm_bus_invalid_index_reward
            )
            self.comm_bus_overflow_reward = data.get(
                "comm_bus_overflow_reward", self.comm_bus_overflow_reward
            )
            self.message_store_reward = data.get(
                "message_store_reward", self.message_store_reward
            )
            self.message_send_reward = data.get(
                "message_send_reward", self.message_send_reward
            )
            self.message_pool_cost_decrease_reward = data.get(
                "message_pool_cost_decrease_reward",
                self.message_pool_cost_decrease_reward,
            )
            self.comm_bus_cost_decrease_reward = data.get(
                "comm_bus_cost_decrease_reward", self.comm_bus_cost_decrease_reward
            )

In [63]:
import json

'''
Environment Configuration
    min_message_pool: minimum number of messages in the message pool
    max_message_pool: maximum number of messages in the message pool
    min_message_obj_size: minimum size of the message object
    max_message_obj_size: maximum size of the message object
    comm_bus_size: size of the communication bus

    Can directly load from a json file or manual user input
    Can be used to generate a random environment configuration
'''
class EnvConfig:
    def __init__(
        self,
        min_message_pool=10,
        max_message_pool=25,
        min_message_obj_size=50,
        max_message_obj_size=300,
        comm_bus_size=1500,
    ):
        self.min_message_pool = min_message_pool
        self.max_message_pool = max_message_pool
        self.min_message_obj_size = min_message_obj_size
        self.max_message_obj_size = max_message_obj_size
        self.comm_bus_size = comm_bus_size

    def load_from_file(self, file_path):
        with open(file_path, "r") as file:
            data = json.load(file)
            self.min_message_pool = data.get("min_message_pool", self.min_message_pool)
            self.max_message_pool = data.get("max_message_pool", self.max_message_pool)
            self.min_message_obj_size = data.get(
                "min_message_obj_size", self.min_message_obj_size
            )
            self.max_message_obj_size = data.get(
                "max_message_obj_size", self.max_message_obj_size
            )
            self.comm_bus_size = data.get("comm_bus_size", self.comm_bus_size)

In [64]:
import gym
from gym import spaces
import numpy as np

"""
The Message Network Environment for CADES-Task 2 (Message Passing)
    Observation Space:
        3 - as each message is represented by 3 values (size, source, destination)
        self.config.max_message_pool + self.config.comm_bus_size - as the message pool
        and communication bus are both in observation space
    Action Space:
        [OPERATION, MESSAGE_POOL_INDEX, COMM_BUS_INDEX]
        OPERATION: 0 - Indicates the node manager/Agent pick message from the message pool and queues in the bus. ,
                   1 - Indicates the node manager/Agent pick message from the communication bus and sends it to destination node.
        MESSAGE_POOL_INDEX: Index of the message object in the 'message pool' to be picked for sending 
        COMM_BUS_INDEX: Index of the message object in the 'communication bus' to be picked for sending to destination  
"""


class MessageNetworkEnv(gym.Env):
    def __init__(self, envconfig, reward_config):
        super(MessageNetworkEnv, self).__init__()

        self.config = envconfig
        self.reward_config = reward_config

        # Observation Space
        self.observation_space = spaces.Dict({
            "message_pool": spaces.Box(low=0, high=self.config.max_message_obj_size, shape=(3, self.config.max_message_pool), dtype=np.float32),
            "comm_bus": spaces.Box(low=0, high=self.config.max_message_obj_size, shape=(3, self.config.comm_bus_size), dtype=np.float32)
        })

        # Action Space
        self.action_space = spaces.MultiDiscrete(
            [2, self.config.max_message_pool, self.config.max_message_pool]
        )

        self.message_pool = []
        self.comm_bus = []

    def reset(self):
        num_messages = np.random.randint(
            self.config.min_message_pool, self.config.max_message_pool + 1
        )
        self.message_pool = [self.generate_message() for _ in range(num_messages)]
        self.comm_bus = []

        observation = self.get_observation()
        
        return observation

    # Generates Message Objects
    # Source and Destination tasks are now assigned at random, but we will have to change this later
    def generate_message(self):
        size = np.random.uniform(
            self.config.min_message_obj_size / self.config.comm_bus_size,
            self.config.max_message_obj_size / self.config.comm_bus_size,
        )
        source = np.random.randint(0, 100)
        destination = np.random.randint(0, 100)

        while (
            destination == source
        ):  # To ensure Source and destination are never the same
            destination = np.random.randint(0, 100)

        return MessageObject(size, source, destination)

    """
        The following functions are used to convert the message pool and communication bus into an observation
    """

    def get_message_pool_observation(self):
        message_pool_observation = np.zeros(
            (3, self.config.max_message_pool), dtype=np.float32
        )
        for i, message in enumerate(self.message_pool):
            message_pool_observation[:, i] = message.to_array()
        return message_pool_observation

    def get_comm_bus_observation(self):
        comm_bus_observation = np.zeros(
            (3, self.config.comm_bus_size), dtype=np.float32
        )
        for i, message in enumerate(self.comm_bus):
            comm_bus_observation[:, i] = message.to_array()
        return comm_bus_observation

    def get_observation(self):
        return {
            "message_pool": self.get_message_pool_observation(),
            "comm_bus": self.get_comm_bus_observation(),
        }

    """
        The following functions are used to convert the action into a message object
    """

    def get_message_from_message_pool(self, message_pool_index):
        if 0 <= message_pool_index < len(self.message_pool):
            return self.message_pool[message_pool_index]
        else:
            return None

    def get_message_from_comm_bus(self, comm_bus_index):
        if 0 <= comm_bus_index < len(self.comm_bus):
            return self.comm_bus[comm_bus_index]
        else: 
            return None

    def get_message_from_action(self, action):
        operation, message_pool_index, comm_bus_index = action
        if operation == 0:
            return self.get_message_from_message_pool(message_pool_index)
        elif operation == 1:
            return self.get_message_from_comm_bus(comm_bus_index)

    """
        The following functions are used to generate appropriate negative rewards based on observations
        - MESSAGE_POOL_INVALID_INDEX_REWARD if it chooses an index that is not in the message pool.
        - COMM_BUS_INVALID_INDEX_REWARD if it chooses an index that is not in the communication bus.

        - COMM_BUS_OVERFLOW_REWARD if it tries to place a message object in a occupied communication bus 
        that makes it surpass its capacity.
        
        - MESSAGE_POOL_NO_CHANGE_REWARD if OPERATION=1 AND message objects are in the message pool 
        AND adequate space is available in the communication bus 

        - COMM_BUS_NO_CHANGE_REWARD if OPERATION=0 AND message objects in the communication bus. 

        ==== The below rewards are to be done later =====
        -? MESSAGE_POOL_INVALID_INDEX_REWARD if it chooses a previously picked index from the message pool. 
        -? COMM_BUS_INVALID_INDEX_REWARD if it chooses a previously picked index from the communication bus.
        - SENDER_TASK_NOT_ALLOCATED_REWARD if the sender task is not allocated to the node manager/Agent.
        - RECEIVER_TASK_NOT_ALLOCATED_REWARD if the receiver task is not allocated to the node manager/Agent. 
    """

    def check_message_pool_invalid_index(self, message_pool_index):
        if message_pool_index >= len(self.message_pool):
            return True
        return False

    def check_comm_bus_invalid_index(self, comm_bus_index):
        if comm_bus_index >= len(self.comm_bus):
            return True
        return False

    def check_comm_bus_overflow(self, message):
        if (
            sum(message.size for message in self.comm_bus) + message.size
            > self.config.comm_bus_size
        ):
            return True
        return False

    def check_message_pool_no_change(self, message):
        if message in self.message_pool:
            return True
        return False

    def check_comm_bus_no_change(self, message):
        if message in self.comm_bus:
            return True
        return False

    """
        The following functions are used to generate appropriate positive rewards based on observations
        - MESSAGE_STORE_REWARD if it chooses OPERATION=0 AND message objects are in the message pool
        - MESSAGE_SEND_REWARD if it chooses OPERATION=1 AND message objects are in the communication bus
        - MESSAGE_POOL_COST_DECREASE_REWARD if it chooses OPERATION=0 AND messages are in the message pool 
        AND the size of MESSAGE_POOL decreases. [Reward assigned based on the size of the message chosen]
        - COMM_BUS_COST_DECREASE_REWARD if it chooses OPERATION=1 AND messages are in the communication bus
        AND the size of COMM_BUS decreases. [Reward assigned based on the size of the message chosen]


        ==== The below rewards are to be done later =====
        -? MESSAGE_TRAVEL_DISTANCE_REWARD if OPERATION=1 AND ...
    """

    def check_message_store_reward(self, message):
        if message in self.message_pool:
            return True
        return False

    def check_message_send_reward(self, message):
        if message in self.comm_bus:
            return True
        return False

    def get_message_pool_cost_decrease_reward(self, message):
        message_size_range = (
            self.config.max_message_obj_size - self.config.min_message_obj_size
        )
        normalized_size = (
            message.size * self.config.comm_bus_size - self.config.min_message_obj_size
        ) / message_size_range
        reward = normalized_size * self.reward_config.message_pool_cost_decrease_reward
        return reward

    def get_comm_bus_cost_decrease_reward(self, message):
        message_size_range = (
            self.config.max_message_obj_size - self.config.min_message_obj_size
        )
        normalized_size = (
            message.size * self.config.comm_bus_size - self.config.min_message_obj_size
        ) / message_size_range
        reward = normalized_size * self.reward_config.comm_bus_cost_decrease_reward
        return reward

    """
        The following functions are for the step function
    """

    def step(self, action):
        operation, message_pool_index, comm_bus_index = action
        reward = 0
        done = False
        info = {}

        if operation == 0:
            message = self.get_message_from_action(action)
            if self.check_message_pool_invalid_index(message_pool_index):
                reward += self.reward_config.message_pool_invalid_index_reward
            elif self.check_comm_bus_overflow(message):
                reward += self.reward_config.comm_bus_overflow_reward
            elif self.check_message_pool_no_change(message):
                reward += self.reward_config.message_pool_no_change_reward
            else:
                self.message_pool.remove(message)
                self.comm_bus.append(message)
                reward = (
                    self.reward_config.message_store_reward
                    + self.get_message_pool_cost_decrease_reward(message)
                )
        elif operation == 1:
            message = self.get_message_from_action(action)
            if self.check_comm_bus_invalid_index(comm_bus_index):
                reward += self.reward_config.comm_bus_invalid_index_reward
            elif self.check_comm_bus_no_change(message):
                reward += self.reward_config.comm_bus_no_change_reward
            else:
                self.comm_bus.remove(message)
                reward = (
                    self.reward_config.message_send_reward
                    + self.get_comm_bus_cost_decrease_reward(message)
                )
        
        if len(self.message_pool) == 0 and len(self.comm_bus) == 0:
            done = True
        
        info = {
            "message_pool": self.message_pool,
            "comm_bus": self.comm_bus,
        }
        return self.get_observation(), reward, done, info
    
    def render(self, mode='console'):
        if mode == 'console':
            print("Message Pool: ", self.message_pool)
            print("Communication Bus: ", self.comm_bus)
        else:
            pass
    
    def close(self):
        pass


In [65]:
from stable_baselines3.common.env_checker import check_env

env_config = EnvConfig()
# print(env_config.max_message_pool, env_config.comm_bus_size)
reward_config = RewardConfig()
env = MessageNetworkEnv(env_config, reward_config)
check_env(env)

In [66]:
test_action = env.action_space.sample()  # Generates a sample action
observation, reward, done, info = env.step(test_action)

In [69]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

import time
import torch

logdir = f"../logs/{int(time.time())}/"
models_dir = f"../models/{int(time.time())}/"


# Initializes the model
model = PPO(
    "MultiInputPolicy",
    env,
    verbose=1,
    tensorboard_log=logdir,
    batch_size=128,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

# Trains the model
model.learn(total_timesteps=100000)

# Saves the model
model.save(models_dir + f"model_{int(time.time())}")


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ../logs/1700694123/PPO_1
-----------------------------
| time/              |      |
|    fps             | 387  |
|    iterations      | 1    |
|    time_elapsed    | 5    |
|    total_timesteps | 2048 |
-----------------------------




In [71]:
from stable_baselines3.common.callbacks import EvalCallback

eval_callback = EvalCallback(env, best_model_save_path=models_dir,
                             log_path=logdir, eval_freq=10000,
                             deterministic=True, render=False)

In [74]:
is_train = True
EPOCHS=5
if is_train:
    TIMESTEPS = 25000
    iters=0
    while iters<EPOCHS:
        iters = iters+1

        model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"ppo_Test",callback=eval_callback)
        model.save(f"{models_dir}/{iters}")

Logging to ../logs/1700694123/ppo_Test_0
------------------------------
| time/              |       |
|    fps             | 398   |
|    iterations      | 1     |
|    time_elapsed    | 5     |
|    total_timesteps | 24096 |
------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 323        |
|    iterations           | 2          |
|    time_elapsed         | 12         |
|    total_timesteps      | 26144      |
| train/                  |            |
|    approx_kl            | 0.02073974 |
|    clip_fraction        | 0.0827     |
|    clip_range           | 0.2        |
|    entropy_loss         | -6.41      |
|    explained_variance   | -0.000276  |
|    learning_rate        | 0.0003     |
|    loss                 | 0.991      |
|    n_updates            | 100        |
|    policy_gradient_loss | -0.0231    |
|    value_loss           | 6.68       |
----------------------------------------
----

KeyboardInterrupt: 