In [1]:
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces


In [115]:
from typing import List, Tuple, Dict, Any


class UNOCommunicationEnv(gym.Env):
    """
    Alternative version of the UNO communication environment
    
    This is a multi-agent environment where agents must communicate to play a game.

    The environment's action space is:
    n_agents * (Discrete(game action), Continuous((n_agents - 1) * message_dim))

    The environment's observation space is:
    n_agents * (Box(), Discrete(game action) + 2 positional encodings, Continuous((n_agents - 1) * (message_dim + 2 positional encodings)))

    The environment's state consists of:
    - Game deck

    """
    def __init__(self, n_agents: int, message_dim: int):
        super(UNOCommunicationEnv, self).__init__()

        # Number of agents and dimensions for messages and actions
        self.n_agents = n_agents
        self.message_dim = message_dim

        # Game constants
        self.cards = self.get_deck_names()
        self.actions = self.get_actions_names()
        self.message_routes = self._get_route_map()


        # Action space consists of discrete game action and continuous message vectors for each agent
        self.action_space = spaces.Tuple([
            spaces.Tuple([
                spaces.Discrete(62), # Action encoding
                spaces.Box(low=-1, high=1, shape=(n_agents - 1, message_dim), dtype=np.float32) # Messages
            ])
            for _ in range(n_agents)
        ])

        # Observation space consists of last played action, available actions, cards on hands, and messages for each agent
        self.observation_space = spaces.Tuple([
            spaces.Tuple([
                spaces.Sequence(
                    spaces.Tuple([
                        spaces.Discrete(62), # Action encoding
                        spaces.Discrete(self.n_agents), # Actor's positional encoding
                    ])
                ),
                spaces.MultiBinary(62), # Agent's available actions
                spaces.MultiBinary(108), # Agent's cards on hand
                spaces.Sequence(
                    spaces.Tuple([
                        spaces.Tuple([
                            spaces.Box(low= -1, high=1, shape=(self.message_dim,), dtype=np.float32), # Message
                            spaces.Discrete(self.n_agents) # Sender's positional encoding
                        ])
                        for _ in range(n_agents - 1) # For each sender
                    ])
                )
            ])
            for _ in range(n_agents) # For each receiver
        ])

        # Game state:
        self.deck = []
        self.discard_pile = []
        self.players_hands = []
        self.top_card = None
        self.message_buffers = []
        self.game_history = []
        self.current_agent = None
        self.play_direction = 1

        self.reset()


    def reset(self):
        self.current_agent = 0
        self.deck = self._prepare_deck(np.arange(108))
        self._deal_cards()
        self.discard_pile = []
        self.top_card = self._draw_card()
        self.play_direction = 1
        self.message_buffers = [[] for _ in range(self.n_agents)]
        self.game_history = []
        return self._get_observation()

    
    def _prepare_deck(self, deck: np.ndarray) -> np.ndarray:
        return np.random.permutation(deck)
    

    def _update_current_agent(self):
        if self.current_agent is None:
            self.current_agent = 0
        else:
            self.current_agent = (self.current_agent + self.play_direction) % self.n_agents


    def _draw_card(self) -> np.ndarray:
        if len(self.deck) == 0:
            self.deck = self._prepare_deck(self.discard_pile)
            self.discard_pile = []
        res = np.zeros(108)
        res[self.deck.pop()] = 1
        return res
    

    def _get_route_map(self) -> np.ndarray:
        res = []
        for i in range(self.n_agents):
            routes = np.arange(self.n_agents)
            routes = np.delete(routes, i)
            res.append(routes)
        return res
        

    def _route_messages(self, received_actions: Tuple[Tuple[int, np.ndarray]]) -> Tuple[Tuple[np.ndarray, int]]:
        res = [[] for _ in range(self.n_agents)]
        for sender_idx, ((_, messages), routes) in enumerate(zip(received_actions, self.message_routes)):
            for route_idx, message in zip(routes, messages):
                res[route_idx].append((message, sender_idx))

        return tuple(res)

    
    def _get_available_actions(self, player_idx: int) -> np.ndarray:
        res = np.zeros(62)
        if player_idx != self.current_agent:
            res[61] = 1
            return res
        



    def print_observation(self, obs_sample: tuple):
        for agent_idx, agent_obs in enumerate(obs_sample):
            print(f"Agent {agent_idx}")
            actions_history, available_actions, cards_on_hand, messages_history = agent_obs
            readable_history = [f'{self.actions[action]} {positional_encoding}' for action, positional_encoding in actions_history]
            print(f"Actions history: {readable_history}")
            readable_available_actions = [self.actions[action_id] for action_id, action in enumerate(available_actions) if action == 1]
            print(f"Available actions: {readable_available_actions}")
            readable_cards_on_hand = [self.cards[card_id] for card_id, card in enumerate(cards_on_hand) if card == 1]
            print(f"Cards on hand: {readable_cards_on_hand}")
            print("Messages history:")
            for record_id, messages_record in enumerate(messages_history):
                print(f"Record {record_id}")
                for message, positional_encoding in messages_record:
                    print(f"  {message} {positional_encoding}")
                
            print()

    
    def print_actions(self, actions_sample: tuple):
        for agent_idx, (action, messages) in enumerate(actions_sample):
            print(f"Agent {agent_idx}")
            print(f"Action: {self.actions[action]}")
            print("Messages:")
            for message in messages:
                print(f"  {message}")
            print()


    @staticmethod
    def get_deck_names() -> List[str]:
        # 108 cards in total
        res = []
        for color in ['r', 'g', 'b', 'y']:
            for number in range(10):
                res.append(f'{color}-{number}')
                if number > 0:
                    res.append(f'{color}-{number}')
            res.extend([f'{color}-skip'] * 2)

            res.extend([f'{color}-reverse'] * 2)

            res.extend([f'{color}-draw2'] * 2)

        res.extend(['wild'] * 4)
        res.extend(['wild-draw4'] * 4)

        return res
    
    @staticmethod
    def get_actions_names() -> List[str]:
        # 62 actions in total
        res = []
        for color in ['r', 'g', 'b', 'y']:
            for number in range(10):
                res.append(f'{color}-{number}')
            res.append(f'{color}-skip')
            res.append(f'{color}-reverse')
            res.append(f'{color}-draw2')
            res.append(f'{color}-wild')
            res.append(f'{color}-wild-draw4')
        res.append('draw')
        res.append('pass')
        return res

In [116]:
test_env = UNOCommunicationEnv(4, 5)

In [117]:
print(test_env._get_route_map())

[array([1, 2, 3]), array([0, 2, 3]), array([0, 1, 3]), array([0, 1, 2])]


In [118]:
act_sample = test_env.action_space.sample()
obs_sample = test_env.observation_space.sample()

In [119]:
test_env.print_actions(act_sample)

Agent 0
Action: y-wild
Messages:
  [-0.57202464 -0.83732456 -0.6949524   0.45171896 -0.7756912 ]
  [ 0.74331266  0.60827845 -0.11344    -0.64093965 -0.26389503]
  [-0.9234362   0.5203247   0.87453264  0.81395936  0.30688593]

Agent 1
Action: r-wild
Messages:
  [ 0.27936065 -0.324237    0.09290658 -0.6854672  -0.34026396]
  [ 0.8034983  -0.8687007  -0.9539338  -0.7147616   0.31307963]
  [-0.80043995 -0.6029465   0.74307257 -0.84994894  0.27505732]

Agent 2
Action: r-5
Messages:
  [-0.81831783 -0.79052     0.2338349  -0.24349132 -0.881049  ]
  [-0.9475309  0.6310857 -0.7854138  0.5884438  0.5646693]
  [-0.9893165   0.80585235 -0.56080025 -0.18466997  0.6398543 ]

Agent 3
Action: g-3
Messages:
  [ 0.30545318 -0.7121463   0.6975308  -0.63745224 -0.65897626]
  [ 0.5329095   0.75880843  0.4431483   0.04047064 -0.50628066]
  [ 0.720603   -0.5019339  -0.71893126 -0.49008444  0.6764827 ]



In [120]:
for receiver_idx, received_messages in enumerate(test_env._route_messages(act_sample)):
    print(f"Receiver {receiver_idx}")
    for message, sender_idx in received_messages:
        print(f"  Sender {sender_idx}: {message}")

Receiver 0
  Sender 1: [ 0.27936065 -0.324237    0.09290658 -0.6854672  -0.34026396]
  Sender 2: [-0.81831783 -0.79052     0.2338349  -0.24349132 -0.881049  ]
  Sender 3: [ 0.30545318 -0.7121463   0.6975308  -0.63745224 -0.65897626]
Receiver 1
  Sender 0: [-0.57202464 -0.83732456 -0.6949524   0.45171896 -0.7756912 ]
  Sender 2: [-0.9475309  0.6310857 -0.7854138  0.5884438  0.5646693]
  Sender 3: [ 0.5329095   0.75880843  0.4431483   0.04047064 -0.50628066]
Receiver 2
  Sender 0: [ 0.74331266  0.60827845 -0.11344    -0.64093965 -0.26389503]
  Sender 1: [ 0.8034983  -0.8687007  -0.9539338  -0.7147616   0.31307963]
  Sender 3: [ 0.720603   -0.5019339  -0.71893126 -0.49008444  0.6764827 ]
Receiver 3
  Sender 0: [-0.9234362   0.5203247   0.87453264  0.81395936  0.30688593]
  Sender 1: [-0.80043995 -0.6029465   0.74307257 -0.84994894  0.27505732]
  Sender 2: [-0.9893165   0.80585235 -0.56080025 -0.18466997  0.6398543 ]


In [107]:
test_env.print_observation(obs_sample)

Agent 0
Actions history: ['y-7 0', 'g-4 3']
Available actions: ['r-0', 'r-4', 'r-6', 'r-7', 'r-8', 'r-9', 'r-skip', 'r-reverse', 'r-draw2', 'r-wild-draw4', 'g-1', 'g-3', 'g-4', 'g-9', 'g-wild-draw4', 'b-1', 'b-2', 'b-3', 'b-4', 'b-6', 'b-reverse', 'b-draw2', 'y-0', 'y-4', 'y-9', 'y-wild-draw4', 'draw', 'pass']
Cards on hand: ['r-1', 'r-1', 'r-2', 'r-3', 'r-3', 'r-4', 'r-5', 'r-6', 'r-7', 'r-skip', 'g-0', 'g-2', 'g-3', 'g-3', 'g-4', 'g-7', 'g-9', 'g-reverse', 'g-draw2', 'g-draw2', 'b-1', 'b-3', 'b-3', 'b-5', 'b-7', 'b-7', 'b-8', 'b-skip', 'b-reverse', 'b-reverse', 'b-draw2', 'y-1', 'y-2', 'y-3', 'y-3', 'y-4', 'y-5', 'y-6', 'y-7', 'y-8', 'y-skip', 'y-reverse', 'y-reverse', 'y-draw2', 'wild', 'wild-draw4', 'wild-draw4', 'wild-draw4']
Messages history:
Record 0
  [ 0.62668407  0.42656392 -0.88180476 -0.5814796  -0.98033136] 3
  [-0.7860557  -0.10057011  0.3483353   0.9526899   0.17233883] 0
  [-0.7399801  -0.05013981  0.8436834   0.45365527 -0.83310074] 1

Agent 1
Actions history: ['g-6 3'