In [112]:
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces


In [115]:
from typing import List, Tuple, Dict, Any


class UNOCommunicationEnv(gym.Env):
    """
    Alternative version of the UNO communication environment
    
    This is a multi-agent environment where agents must communicate to play a game.

    The environment's action space is:
    n_agents * (Discrete(game action), Continuous((n_agents - 1) * message_dim))

    The environment's observation space is:
    n_agents * (Box(), Discrete(game action) + 2 positional encodings, Continuous((n_agents - 1) * (message_dim + 2 positional encodings)))

    The environment's state consists of:
    - Game deck

    """
    def __init__(self, n_agents: int, message_dim: int):
        super(UNOCommunicationEnv, self).__init__()

        # Number of agents and dimensions for messages and actions
        self.n_agents = n_agents
        self.message_dim = message_dim

        # Game constants
        self.cards = self.get_deck_names()
        self.actions = self.get_actions_names()


        # Action space consists of discrete game action and continuous message vectors for each agent
        self.action_space = spaces.Tuple([
            spaces.Tuple([
                spaces.Discrete(62), # Action encoding
                spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) # Positional encoding
            ])
            for _ in range(n_agents)
        ])

        # Observation space consists of last played action, available actions, cards on hands, and messages for each agent
        self.observation_space = spaces.Tuple([
            spaces.Tuple([
                spaces.Sequence(
                    spaces.Tuple([
                        spaces.Discrete(62), # Action encoding
                        spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) # Positional encoding
                    ])
                ),
                spaces.MultiBinary(62), # Available actions
                spaces.MultiBinary(108), # Cards on hands
                spaces.Sequence(
                    spaces.Box(low=-1, high=1, shape=(n_agents - 1, message_dim + 2), dtype=np.float32) # Messages with positional encodings
                )
            ])
            for _ in range(n_agents)
        ])

        # Precompute positional encodings
        self.sin_matrix, self.cos_matrix = self._precompute_positional_encodings()

        # Game state:
        self.deck = []
        self.discard_pile = []
        self.players_hands = []
        self.top_card = None
        self.message_buffers = []
        self.game_history = []
        self.current_agent = 0
        self.play_direction = 1

        self.reset()


    def reset(self):
        self.deck = self._prepare_deck(np.arange(108))

    
    def _prepare_deck(self, deck: np.ndarray) -> np.ndarray:
        return np.random.permutation(deck)


    def _precompute_positional_encodings(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Precomputes sine and cosine positional encodings for all agent pairs.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Sine and cosine encoding matrices.
        """
        sin_matrix = np.zeros((self.n_agents, self.n_agents))
        cos_matrix = np.zeros((self.n_agents, self.n_agents))

        for i in range(self.n_agents):
            for j in range(self.n_agents):
                delta = (j - i) / self.n_agents
                sin_matrix[i, j] = np.sin(2 * np.pi * delta)
                cos_matrix[i, j] = np.cos(2 * np.pi * delta)

        return sin_matrix, cos_matrix


    def print_observation(self, obs: tuple):
        for agent_idx, agent_obs in enumerate(obs):
            print(f"Agent {agent_idx}")
            actions_history, available_actions, cards_on_hand, messages = agent_obs
            readable_history = [f'{self.actions[action]} {positional_encoding}' for action, positional_encoding in actions_history]
            print(f"Actions history: {readable_history}")
            readable_available_actions = [self.actions[action_id] for action_id, action in enumerate(available_actions) if action == 1]
            print(f"Available actions: {readable_available_actions}")
            readable_cards_on_hand = [self.cards[card_id] for card_id, card in enumerate(cards_on_hand) if card == 1]
            print(f"Cards on hand: {readable_cards_on_hand}")
            print("Messages:")
            for message in messages:
                print(f"  {message}")
            print()


    @staticmethod
    def get_deck_names() -> List[str]:
        # 108 cards in total
        res = []
        for color in ['r', 'g', 'b', 'y']:
            for number in range(10):
                res.append(f'{color}-{number}')
                if number > 0:
                    res.append(f'{color}-{number}')
            res.extend([f'{color}-skip'] * 2)

            res.extend([f'{color}-reverse'] * 2)

            res.extend([f'{color}-draw2'] * 2)

        res.extend(['wild'] * 4)
        res.extend(['wild-draw4'] * 4)

        return res
    
    @staticmethod
    def get_actions_names() -> List[str]:
        # 62 actions in total
        res = []
        for color in ['r', 'g', 'b', 'y']:
            for number in range(10):
                res.append(f'{color}-{number}')
            res.append(f'{color}-skip')
            res.append(f'{color}-reverse')
            res.append(f'{color}-draw2')
            res.append(f'{color}-wild')
            res.append(f'{color}-wild-draw4')
        res.append('draw')
        res.append('pass')
        return res
        

In [116]:
test_env = UNOCommunicationEnv(4, 18)
print(test_env.action_space.sample())

((26, array([0.28266993, 0.53734666], dtype=float32)), (59, array([-0.55601156,  0.38795793], dtype=float32)), (24, array([-0.50963885, -0.91017157], dtype=float32)), (60, array([0.21615107, 0.1241584 ], dtype=float32)))


In [117]:
obs = test_env.observation_space.sample()
for some_id, something in enumerate(obs):
    print(f"Agent {some_id}")
    print(something)
    print()
    print()

Agent 0
(((58, array([-0.53107333,  0.45572653], dtype=float32)), (20, array([0.10348   , 0.39306742], dtype=float32))), array([0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,
       1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
       0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1], dtype=int8), array([1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0,
       0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0,
       0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0],
      dtype=int8), (array([[-0.9689777 , -0.6355152 , -0.9103987 ,  0.08724684,  0.75862974,
         0.91940486, -0.8890176 ,  0.78201973,  0.3410355 ,  0.6226186 ,
        -0.32704422, -0.36437106, -0.15431525, -0.3527881 ,  0.24976806,
         0.1264584 , -0.42189786, -0.27713928,  0.88984984,

In [118]:
actions_history, available_actions, cards_on_hand, messages = obs[0]
actions_history

((58, array([-0.53107333,  0.45572653], dtype=float32)),
 (20, array([0.10348   , 0.39306742], dtype=float32)))

In [111]:
test_env.print_observation(obs)

Agent 0
Actions history: ['r-6 [0.34373745 0.72838646]', 'g-8 [0.79613036 0.4092018 ]', 'y-wild [ 0.2359576 -0.9290162]', 'g-7 [ 0.74114054 -0.63838166]', 'g-5 [0.41278863 0.7221082 ]', 'b-wild-draw4 [-0.10241742  0.57645077]', 'b-skip [-0.87075406  0.01867758]', 'g-2 [ 0.5081785 -0.6232958]', 'g-9 [ 0.3342837  -0.13040084]', 'r-skip [0.23299184 0.8285388 ]', 'b-3 [ 0.48386756 -0.45380464]']
Available actions: ['r-0', 'r-1', 'r-2', 'r-3', 'r-5', 'r-6', 'r-7', 'r-8', 'r-skip', 'r-reverse', 'g-0', 'g-1', 'g-4', 'g-5', 'g-6', 'g-7', 'g-9', 'g-draw2', 'b-0', 'b-2', 'b-3', 'b-4', 'b-5', 'b-skip', 'b-wild-draw4', 'y-1', 'y-3', 'y-5', 'y-7', 'y-9', 'y-reverse', 'y-draw2', 'y-wild', 'pass']
Cards on hand: ['r-0', 'r-1', 'r-2', 'r-3', 'r-3', 'r-4', 'r-5', 'r-5', 'r-6', 'r-8', 'r-8', 'r-9', 'r-9', 'r-skip', 'r-reverse', 'r-draw2', 'g-0', 'g-1', 'g-1', 'g-2', 'g-3', 'g-3', 'g-4', 'g-4', 'g-5', 'g-5', 'g-6', 'g-7', 'g-8', 'g-9', 'g-skip', 'g-draw2', 'b-0', 'b-3', 'b-4', 'b-4', 'b-5', 'b-6', 'b-6',