# Item Collection

## Environment definition

In [16]:
from typing import Set
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import random as rnd
import math
from gymnasium.spaces import Discrete, Box, Sequence, Tuple
import numpy as np
from IPython.display import clear_output
from gymnasium.spaces.utils import flatten_space, flatten

from ray.rllib.utils.typing import AgentID

class EnvironmentConfiguration():
    def __init__(self, width, height, n_agents, n_items, max_steps=None):
        self.width = width
        self.height = height
        self.n_agents = n_agents
        self.n_items = n_items
        self.max_steps = max_steps

class ItemCollectionEnv(MultiAgentEnv):

    actions_encoding = ["UP","DOWN","RIGHT","LEFT"]
    actions_dict = [(0,-1),(0,1),(1,0),(-1,0)]

    def __init__(self, config: EnvironmentConfiguration):
        self.width = config.width
        self.height = config.height
        self.n_agents = config.n_agents
        self.n_items = config.n_items
        self.max_steps = config.max_steps
        self.original_observation_space = Tuple((
            Box(low=0, high=1, shape=(self.width, self.height), dtype=np.bool_), # agent
            Box(low=0, high=1, shape=(self.width, self.height), dtype=np.bool_), # other agents
            Box(low=0, high=1, shape=(self.width, self.height), dtype=np.bool_)  # items
        ))
        self.observation_space = self.observation_space('agent-0')
        self.action_space = self.action_space('agent-0')
        

        self.agents = ['agent-' + str(i) for i in range(self.n_agents)]

    def observation_space(self, agent):
        #coordinates_space = Box(low=np.array([0, 0]), high=np.array([self.width-1, self.height-1]), dtype=np.int32)
        return flatten_space(self.original_observation_space)

    def action_space(self, agent):
        return Discrete(len(self.actions_dict))

    def __get_random_point(self):
        return (rnd.randint(0, self.width-1), rnd.randint(0, self.height-1))
    
    def __get_observation(self, agent):
        myPosition = np.zeros((self.width, self.height), np.bool_)
        otherAgentsPositions = np.zeros((self.width, self.height), np.bool_)
        itemsPositions = np.zeros((self.width, self.height), np.bool_)
        myPosition[self.agent_pos[agent][0], self.agent_pos[agent][1]] = True
        for other_agent, position in self.agent_pos.items():
            if other_agent != agent:
                otherAgentsPositions[position[0], position[1]] = True
        for position in self.item_pos:
            itemsPositions[position[0], position[1]] = True

        return flatten(self.original_observation_space, (myPosition, otherAgentsPositions, itemsPositions))
        return (np.array(self.agent_pos[agent], dtype=np.int32),
                np.array([self.agent_pos[a] for a in self.agents if a != agent], dtype=np.int32),
                np.array(self.item_pos, dtype=np.int32))

    def __get_reward(self, agent):
        return -0.5 + self.__get_global_reward()
        """
        if self.agent_pos[agent] in self.item_pos:
            return 10/self.agent_pos.count(self.agent_pos[agent])
        else:
            return -1
        """
    
    def __get_global_reward(self):
        return (len((self.item_pos)) - len(set(self.item_pos) - set(self.agent_pos.values()))) * 10
    
    def __update_agent_position(self, agent, x, y):
        self.agent_pos[agent] = (max(min(self.agent_pos[agent][0] + x, self.width-1), 0),
                                 max(min(self.agent_pos[agent][1] + y, self.height-1), 0))

    def reset(self, seed=None, options=None):
        self.agent_pos = {agent: self.__get_random_point() for agent in self.agents}
        self.item_pos = [self.__get_random_point() for _ in range(self.n_items)]
        self.steps = 0;
        return {agent: self.__get_observation(agent) for agent in self.agents}, {}
     
    def step(self, actions):
        self.steps += 1
        observations, rewards, terminated, truncated, infos = {}, {}, {}, {}, {}

        for agent, action in actions.items():
            self.__update_agent_position(agent, self.actions_dict[action][0], self.actions_dict[action][1])

        for agent in actions.keys():
            observations[agent] = self.__get_observation(agent)
            rewards[agent] = self.__get_reward(agent)
            terminated[agent] = False
            truncated[agent] = False
            infos[agent] = {}
        
        self.item_pos = list(set(self.item_pos) - set(self.agent_pos.values()))

        terminated['__all__'] = len(self.item_pos) == 0
        if self.max_steps != None and self.steps == self.max_steps and not terminated['__all__']:
            truncated['__all__'] = True
        else:
            truncated['__all__'] = False

        return observations, rewards, terminated, truncated, infos
     
    def render(self, mode='text'):
        str = '_' * (self.width+2) + '\n'
        for i in range(self.height):
            str = str + "|"
            for j in range(self.width):
                if (j,i) in self.agent_pos.values():
                    str = str + 'o'
                elif (j,i) in self.item_pos:
                    str = str + 'x'
                else:
                    str = str + ' '
            str = str + '|\n'
        str = str + '‾' * (self.width+2)
        print(str)

    def get_agent_ids(self):
       return self.agents
    
env = ItemCollectionEnv(EnvironmentConfiguration(width=50, height=10, n_agents=2, n_items=10))
print(env.reset()[0]['agent-1'])
env.render()


[False False False ... False False False]
____________________________________________________
|                          x        x              |
|                                            x     |
|      x     x                      o              |
|                      x              x            |
|                         x                        |
|                                      x           |
|                                                  |
|                                              x   |
|                                                  |
|                                                o |
‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾


In [22]:
import ray
ray.shutdown()
ray.init()

2024-05-15 15:42:11,853	INFO worker.py:1740 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.11.9
Ray version:,2.21.0
Dashboard:,http://127.0.0.1:8265


## Single Agent

In [28]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
from gymnasium.wrappers.time_limit import TimeLimit


register_env("my_env", lambda _: ItemCollectionEnv(EnvironmentConfiguration(width=10, height=10, n_agents=1, n_items=1, max_steps=30)))

algo = (
    PPOConfig()
    .training(gamma=0.99, lr=0.001, kl_coeff=0.5, train_batch_size=4096, sgd_minibatch_size=256, num_sgd_iter=10)
    .env_runners(num_env_runners=1)
    .resources(num_gpus=0)
    .environment(env="my_env")
    .build()
)

for i in range(20):
    result = algo.train()
    print(result["sampler_results"])
    #print(result["info"]["learner"]["default_policy"]["learner_stats"]["total_loss"])
    print(f"[{i}]")
    if i % 5 == 0:
        checkpoint_dir = algo.save().checkpoint.path
        print(f"Checkpoint saved in directory {checkpoint_dir}")



[33m(raylet)[0m [2024-05-15 15:50:11,822 E 17720 17720] (raylet) node_manager.cc:3002: 17 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: 5577bdd9e19e4822eb7190edbd5fdec7a5e58167cd6cf69f465a8ec5, IP: 172.18.183.133) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 172.18.183.133`
[33m(raylet)[0m 
[33m(raylet)[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.
[33m(raylet)[0m [2024-05-15 15:51:11,823 E 17720 17720] (raylet) node_manager.cc:30

KeyboardInterrupt: 

In [24]:
from IPython.display import clear_output
import time
import torch
from gymnasium.spaces.utils import flatdim

env = ItemCollectionEnv(EnvironmentConfiguration(width=50, height=10, n_agents=3, n_items=10, max_steps=100))
obs, _ = env.reset()
env.render()

for i in range(100):
    actions = algo.compute_actions(obs)
    print(actions, "\n")
    
    obs, reward, terminated, truncated, info = env.step(actions)
    clear_output()
    print(f"[{i}]")
    env.render()
    print(obs)
    print(reward)
    time.sleep(0.5)

    if terminated['__all__'] or truncated['__all__']:
        break


[30]
____________________________________________________
|            x                                     |
|                          x                       |
|                                                  |
|                              x           x       |
|                                     x            |
|                     o                            |
|                              x               x   |
|                    o    x                 o      |
|                          x                       |
|                                            x     |
‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
{'agent-0': array([False, False, False, ..., False, False, False]), 'agent-1': array([False, False, False, ..., False, False, False]), 'agent-2': array([False, False, False, ..., False, False, False])}
{'agent-0': -0.5, 'agent-1': -0.5, 'agent-2': -0.5}


KeyboardInterrupt: 