# Item Collection

## Environment definition

In [11]:
from typing import Set
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import random as rnd
import math
from gymnasium.spaces import Discrete, Box, Sequence, Tuple
import numpy as np
from IPython.display import clear_output

from ray.rllib.utils.typing import AgentID

class EnvironmentConfiguration():
    def __init__(self, width, height, n_agents, n_items, max_steps=None):
        self.width = width
        self.height = height
        self.n_agents = n_agents
        self.n_items = n_items
        self.max_steps = max_steps

class ItemCollectionEnv(MultiAgentEnv):

    actions_encoding = ["UP","DOWN","RIGHT","LEFT"]
    actions_dict = [(0,-1),(0,1),(1,0),(-1,0)]

    def __init__(self, config: EnvironmentConfiguration):
        self.width = config.width
        self.height = config.height
        self.n_agents = config.n_agents
        self.n_items = config.n_items
        self.max_steps = config.max_steps
        self.observation_space = self.observation_space('')
        self.action_space = self.action_space('')
        self.agents = ['agent-' + str(i) for i in range(self.n_agents)]

    def observation_space(self, agent):
        coordinates_space = Box(low=np.array([0, 0]), high=np.array([self.width-1, self.height-1]), dtype=np.int32)
        return Tuple((
            coordinates_space, # observing agent
            Sequence(coordinates_space, stack=True), # other agents
            Sequence(coordinates_space, stack=True), # targets
        ))

    def action_space(self, agent):
        return Discrete(len(self.actions_dict))

    def __get_random_point(self):
        return (rnd.randint(0, self.width-1), rnd.randint(0, self.height-1))
    
    def __get_observation(self, agent):
        return (np.array(self.agent_pos[agent], dtype=np.int32),
                np.array([self.agent_pos[a] for a in self.agents if a != agent], dtype=np.int32),
                np.array(self.item_pos, dtype=np.int32))

    def __get_reward(self, agent):
        return -0.5
        """
        if self.agent_pos[agent] in self.item_pos:
            return 10/self.agent_pos.count(self.agent_pos[agent])
        else:
            return -1
        """
    
    def __get_global_reward(self):
        return (len((self.item_pos)) - len(set(self.item_pos) - set(self.agent_pos.values()))) * 10
    
    def __update_agent_position(self, agent, x, y):
        self.agent_pos[agent] = (max(min(self.agent_pos[agent][0] + x, self.width-1), 0),
                                 max(min(self.agent_pos[agent][1] + y, self.height-1), 0))

    def reset(self, seed=None, options=None):
        self.agent_pos = {agent: self.__get_random_point() for agent in self.agents}
        self.item_pos = [self.__get_random_point() for _ in range(self.n_items)]
        self.steps = 0;
        return {agent: self.__get_observation(agent) for agent in self.agents}, {}
     
    def step(self, actions):
        self.steps += 1
        observations, rewards, terminated, truncated, infos = {}, {}, {}, {}, {}

        for agent, action in actions.items():
            self.__update_agent_position(agent, self.actions_dict[action][0], self.actions_dict[action][1])

        for agent in actions.keys():
            observations[agent] = self.__get_observation(agent)
            rewards[agent] = self.__get_reward(agent)
            terminated[agent] = False
            truncated[agent] = False
            infos[agent] = {}
        
        self.item_pos = list(set(self.item_pos) - set(self.agent_pos.values()))

        terminated['__all__'] = len(self.item_pos) == 0
        if self.max_steps != None and self.steps == self.max_steps and not terminated['__all__']:
            truncated['__all__'] = True
        else:
            truncated['__all__'] = False

        return observations, rewards, terminated, truncated, infos
     
    def render(self, mode='text'):
        str = '_' * (self.width+2) + '\n'
        for i in range(self.height):
            str = str + "|"
            for j in range(self.width):
                if (j,i) in self.agent_pos.values():
                    str = str + 'o'
                elif (j,i) in self.item_pos:
                    str = str + 'x'
                else:
                    str = str + ' '
            str = str + '|\n'
        str = str + '‾' * (self.width+2)
        print(str)

    def get_agent_ids(self):
       return self.agents
    
env = ItemCollectionEnv(EnvironmentConfiguration(width=50, height=10, n_agents=2, n_items=10))
env.reset()
env.render()

____________________________________________________
|                                                  |
|                                                  |
|       x                         o                |
|                                      x           |
|                    x                             |
|    x                                           x |
|                                      x           |
|               x                  x        x      |
|                     o                x           |
|                                                  |
‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
