In [2]:
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from typing import List, Tuple, Dict, Any, Optional

In [136]:
class UNOGameEnv(gym.Env):
    def __init__(self, n_agents: int):
        # Environment configuration parameters
        self.n_agents = n_agents
        self.n_cards_total = 108
        self.n_actions = 62
        self.color_map = {'r': 0, 'g': 1, 'b': 2, 'y': 3, 'unknown': 4}
        self.card_types_map = {
            **{f'{i}': i for i in range(10)},
            "skip": 10,
            "reverse": 11,
            "draw2": 12,
            "wild-default": 13,
            "wild-draw4": 14
        }
        self.action_types_map = {
            **self.card_types_map,
            "draw": 15,
            "pass": 16
        }
        self.card_specs = self._get_card_specs()
        self.actions_specs = self._get_actions_specs(self.card_specs)

        # Rewards
        self.invalid_action_penalty = -1
        self.winning_reward = 1

        # Action space
        self.action_space = spaces.Tuple([
            spaces.Discrete(self.n_actions) for _ in range(self.n_agents)
        ])


        # Observation space
        self.public_observation_space = spaces.Tuple((
            spaces.Discrete(self.n_agents), # number of players
            spaces.Discrete(4),  # game color
            spaces.Discrete(15),  # played card type
            spaces.Discrete(2),   # game direction (0 or 1)
            spaces.Discrete(self.n_agents),  # current player
            spaces.Discrete(self.n_cards_total + 1),  # number of cards in deck
            spaces.MultiDiscrete([self.n_cards_total + 1] * self.n_agents)  # cards in hand for each player
        ))
        self.private_observation_space = spaces.Tuple((
            spaces.MultiBinary(self.n_cards_total),    # cards in hand
            spaces.Discrete(self.n_agents)            # agent's index
        ))
        self.observation_space = spaces.Tuple([
            spaces.Tuple((
                self.public_observation_space,
                self.private_observation_space
            )) for _ in range(self.n_agents)
        ])

        self.reset()

    
    def reset(self) -> List[Tuple]:
        """
        Reset the environment.
        
        :returns: List of observations
        """
        self.game_state = self.default_game_state()
        return self._get_observations()

    
    def default_game_state(self) -> Dict:
        """
        Create a default game state.
        
        :returns: Default game state
        """
        deck = np.random.permutation(self.n_cards_total).tolist()

        hands = [np.zeros(self.n_cards_total) for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            hands[i][deck[:7]] = 1
            deck = deck[7:]
            
        current_agent = np.random.randint(self.n_agents)
        game_direction = 1

        discard_pile = [deck.pop()]
        top_card_info = self.card_specs[discard_pile[-1]]
        game_color = top_card_info[1] if top_card_info[1] != self.color_map['unknown'] else np.random.randint(4)
        played_card_type = top_card_info[2]
        if played_card_type == self.card_types_map["skip"]:
            current_agent = (current_agent + 1) % self.n_agents
        elif played_card_type == self.card_types_map["reverse"]:
            game_direction *= -1
        elif played_card_type == self.card_types_map["draw2"]:
            hands[current_agent][deck[:2]] = 1
            deck = deck[2:]
        elif played_card_type == self.card_types_map["wild-draw4"]:
            hands[current_agent][deck[:4]] = 1
            deck = deck[4:]
            game_color = np.random.randint(4)

        return {
            "deck": deck,
            "discard_pile": discard_pile,
            "current_agent": current_agent,
            "game_color": game_color,
            "game_direction": game_direction,
            "played_card_type": played_card_type,
            "hands": hands
        }


    def _get_observations(self) -> List[Tuple]:
        """
        Get the observations for all agents.
        
        :returns: List of observations
        """
        public_observation = self._get_public_observation()
        return [
            (public_observation, self._get_private_observation(i))
            for i in range(self.n_agents)
        ]
    

    def _get_public_observation(self) -> tuple:
        """
        Get the public observation.
        
        :returns: Public observation
        """
        return (
            self.n_agents, # number of players
            self.game_state["game_color"],  # game color
            self.game_state["played_card_type"],  # played card type
            1 if self.game_state["game_direction"] == 1 else 0,  # game direction
            self.game_state["current_agent"],  # current player
            len(self.game_state["deck"]),  # number of cards in deck
            [sum(hand) for hand in self.game_state["hands"]]  # cards in hand for each player
        )

    
    def _get_private_observation(self, agent_index: int) -> tuple:
        """
        Get the private observation for an agent.
        
        :param agent_index: Index of the agent
        :returns: Private observation
        """
        return (
            self.game_state["hands"][agent_index],  # cards in hand
            agent_index  # agent's index
        )


    def step(self, actions: List[int]) -> Tuple[List[Tuple], List[float], bool, Dict]:
        """
        Executes a step in the environment.

        :param actions: List of action IDs from all agents.
        :return: A tuple containing observations, rewards, done flag, and info dictionary.
        """
        rewards = [0 for _ in range(self.n_agents)]
        done = False
        info = {}

        active_agent = self.game_state["current_agent"]
        active_action = actions[active_agent]

        # Inactive agents must pass
        for i, action in enumerate(actions):
            if i != active_agent:
                if self.actions_specs[action][2] != self.action_types_map["pass"]:
                    rewards[i] += self.invalid_action_penalty
                
        action_name, action_color, action_type, associated_cards_ids = self.actions_specs[active_action]
        if action_type == self.action_types_map["pass"]:
            rewards[active_agent] += self.invalid_action_penalty
        elif action_type == self.action_types_map["draw"]:
            if self.can_draw(1):
                self.game_state["hands"][active_agent][self.draw_cards(1)] = 1
            else:
                rewards[active_agent] += self.invalid_action_penalty
                done = True
        elif action_type == self.action_types_map["wild-draw4"]:
            next_agent = (active_agent + self.game_state["game_direction"]) % self.n_agents
            self.game_state["hands"][next_agent][self.draw_cards(4)] = 1
            self.game_state["game_color"] = action_color
        elif action_type == self.action_types_map["wild-default"]:
            self.game_state["game_color"] = action_color
        else:
            if (self.game_state["game_color"] == action_color or
                self.game_state["played_card_type"] == action_type):
                card_to_play = -1
                for card_id in associated_cards_ids:
                    if self.game_state["hands"][active_agent][card_id] == 1:
                        card_to_play = card_id
                        break
                if card_to_play != -1:
                    self.game_state["discard_pile"].append(card_to_play)
                    self.game_state["hands"][active_agent][card_to_play] = 0
                    self.game_state["played_card_type"] = action_type
                    self.game_state["game_color"] = action_color
                    if action_type == self.card_types_map["skip"]:
                        self.game_state["current_agent"] = (active_agent + self.game_state["game_direction"]) % self.n_agents
                    elif action_type == self.card_types_map["reverse"]:
                        self.game_state["game_direction"] *= -1
                    elif action_type == self.card_types_map["draw2"]:
                        next_agent = (active_agent + self.game_state["game_direction"]) % self.n_agents
                        self.game_state["hands"][next_agent][self.draw_cards(2)] = 1
                else:
                    rewards[active_agent] += self.invalid_action_penalty
            else:
                rewards[active_agent] += self.invalid_action_penalty

        if self.game_state["hands"][active_agent].sum() == 0:
            done = True
            rewards[active_agent] += self.winning_reward

        self.game_state["current_agent"] = (active_agent + self.game_state["game_direction"]) % self.n_agents

        return self._get_observations(), rewards, done, info


    def can_draw(self, n: int) -> bool:
        """
        Check if the deck has enough cards to draw.

        :param n: Number of cards to draw
        :return: True if the deck has enough cards, False otherwise.
        """
        return len(self.game_state["deck"]) + len(self.game_state["discard_pile"]) >= n
    

    def draw_cards(self, n: int) -> List[int]:
        """
        Draw n cards from the deck.

        :param n: Number of cards to draw
        :return: List of card IDs
        """
        if not self.can_draw(n):
            ret = self.game_state["deck"] + self.game_state["discard_pile"]
            self.game_state["deck"] = []
            self.game_state["discard_pile"] = []
            return ret
        if len(self.game_state["deck"]) < n:
            self.game_state["deck"] += np.random.permutation(self.game_state["discard_pile"]).tolist()
            self.game_state["discard_pile"] = []
        ret = self.game_state["deck"][:n]
        self.game_state["deck"] = self.game_state["deck"][n:]
        return ret
    
    
    def _get_card_specs(self) -> List[tuple]:
        """ Get the card specifications. """
        card_specs = []
        for color, color_id in self.color_map.items():
            if color == 'unknown':
                continue
            for number in range(10):
                card_specs.append((
                    f"{color}-{number}",
                    color_id,
                    self.card_types_map[str(number)]
                ))
                if number > 0:
                    card_specs.append((
                        f"{color}-{number}",
                        color_id,
                        self.card_types_map[str(number)]
                    ))
            card_specs.extend([(f"{color}-skip", color_id, self.card_types_map["skip"])] * 2)
            card_specs.extend([(f"{color}-reverse", color_id, self.card_types_map["reverse"])] * 2)
            card_specs.extend([(f"{color}-draw2", color_id, self.card_types_map["draw2"])] * 2)

        card_specs.extend([("wild-default", 4, self.card_types_map["wild-default"])] * 4)
        card_specs.extend([("wild-draw4", 4, self.card_types_map["wild-draw4"])] * 4)
        return card_specs


    def _get_actions_specs(self, cards_specs: List[tuple]) -> List[tuple]:
        """ Get the action specifications. """
        actions = [] # [(action_name, action_color, action_type, associated_cards_ids)]
        for color, color_id in self.color_map.items():
            if color == 'unknown':
                actions.extend([
                    ("draw", color_id, self.action_types_map["draw"], []),
                    ("pass", color_id, self.action_types_map["pass"], [])
                ])
                break
            for number in range(10):
                actions.append((
                    f"{color}-{number}",
                    color_id,
                    self.card_types_map[str(number)],
                    [i for i, card in enumerate(cards_specs) if card[1] == color_id and card[2] == number]
                ))

            actions.extend([
                (f"{color}-skip", color_id, self.action_types_map["skip"], [i for i, card in enumerate(cards_specs) if card[1] == color_id and card[2] == self.card_types_map["skip"]]),
                (f"{color}-reverse", color_id, self.action_types_map["reverse"], [i for i, card in enumerate(cards_specs) if card[1] == color_id and card[2] == self.card_types_map["reverse"]]),
                (f"{color}-draw2", color_id, self.action_types_map["draw2"], [i for i, card in enumerate(cards_specs) if card[1] == color_id and card[2] == self.card_types_map["draw2"]]),
                (f"{color}-wild-default", color_id, self.action_types_map["wild-default"], [i for i, card in enumerate(cards_specs) if card[2] == self.card_types_map["wild-default"]]),
                (f"{color}-wild-draw4", color_id, self.action_types_map["wild-draw4"], [i for i, card in enumerate(cards_specs) if card[2] == self.card_types_map["wild-draw4"]])
            ])

        

        return actions                


In [138]:
env = UNOGameEnv(n_agents=5)
num_episodes = 100

for episode in range(num_episodes):
    observations = env.reset()
    done = False
    while not done:
        actions = []
        for obs in observations:
            # Each agent selects a random valid action
            agent_actions = list(range(env.n_actions))
            action = np.random.choice(agent_actions)
            actions.append(action)
        observations, rewards, done, info = env.step(actions)
    print(f"Episode {episode+1} finished. Rewards: {rewards}")

Episode 1 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 2 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 3 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 4 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 5 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 6 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 7 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 8 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 9 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 10 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 11 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 12 finished. Rewards: [-1, -1, 0, -1, -1]
Episode 13 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 14 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 15 finished. Rewards: [-1, 0, -1, -1, -1]
Episode 16 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 17 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 18 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 19 finished. Rewards: [-1, -1, -1, -1, -1]
Episode 20 finished. Rewards: [-1, -1, -1,

In [185]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Dict, Any
import torch.nn.functional as F


# Helper Functions
def relative_position_encoding(current_player, agents, n_agents):
    distances = agents - current_player
    theta = (distances / n_agents) * 2 * np.pi
    return np.stack((np.sin(theta), np.cos(theta)), axis=-1)  # [n_agents, 2]

def scale_card_count(card_count, max_card_count=108):
    return np.log1p(card_count) / np.log1p(max_card_count)

class FeatureExtractor(nn.Module):
    def __init__(self, n_cards_total, embed_dim):
        super(FeatureExtractor, self).__init__()
        self.embed_dim = embed_dim
        self.n_cards_total = n_cards_total

        # Embedding layers for categorical features
        self.color_embedding = nn.Embedding(4, embed_dim)  # Game color
        self.card_type_embedding = nn.Embedding(15, embed_dim)  # Played card type
        self.direction_embedding = nn.Embedding(2, embed_dim)  # Game direction

        # Linear layers for numerical and binary features
        self.numeric_embedding = nn.Linear(1, embed_dim)  # Deck count
        self.card_hand_embedding = nn.Linear(n_cards_total, embed_dim)  # Private hand cards
        self.card_counts_embedding = nn.Linear(1, embed_dim)  # Cards in hand count
        self.rel_pos_fc = nn.Linear(2, embed_dim)  # Relative position encoding
        self.agent_order_fc = nn.Linear(2, embed_dim)  # Agent order encoding

        self.attention_layer = nn.Linear(embed_dim, 1)  # Compute attention scores

    
    def forward(self, public_obs, private_obs):
        # Unpack public observations
        n_agents = public_obs[0]
        game_color = public_obs[1]
        played_card_type = public_obs[2]
        game_direction = public_obs[3]
        current_player = public_obs[4]
        deck_count = public_obs[5].float()
        cards_in_hand_counts = public_obs[6].float()

        # Unpack private observations
        hand_cards = private_obs[0]
        agent_index = private_obs[1]

        # Process categorical features
        color_emb = self.color_embedding(game_color)  # [embed_dim]
        card_type_emb = self.card_type_embedding(played_card_type)  # [embed_dim]
        direction_emb = self.direction_embedding(game_direction)  # [embed_dim]

        # Process numerical features
        deck_count_scaled = scale_card_count(deck_count)
        deck_emb = self.numeric_embedding(deck_count_scaled.unsqueeze(-1))  # [embed_dim]

        card_counts_scaled = scale_card_count(cards_in_hand_counts)
        card_counts_emb = self.card_counts_embedding(card_counts_scaled.unsqueeze(-1))  # [n_agents, embed]

        # Process private hand cards
        hand_card_emb = self.card_hand_embedding(hand_cards)  # [embed_dim]

        # Relative position encoding
        agent_indices = torch.arange(n_agents)
        rel_pos_enc = relative_position_encoding(agent_index, agent_indices, n_agents)  # [n_agents, 2]
        rel_pos_emb = self.rel_pos_fc(torch.from_numpy(rel_pos_enc))  # [n_agents, embed_dim]
        agent_order_enc = rel_pos_enc[current_player] # [2]
        agent_order_emb = self.agent_order_fc(torch.from_numpy(agent_order_enc))  # [embed_dim]

        # Combine all features
        opponents_emb = card_counts_emb + rel_pos_emb # [n_agents, embed_dim]
        # [n_agents + 6, embed_dim]
        
        all_features = torch.cat([
            color_emb.unsqueeze(0), 
            card_type_emb.unsqueeze(0), 
            direction_emb.unsqueeze(0), 
            deck_emb.unsqueeze(0), 
            hand_card_emb.unsqueeze(0), 
            opponents_emb, 
            agent_order_emb.unsqueeze(0)
            ])

        
        # Apply attention mechanism
        attention_scores = self.attention_layer(all_features).squeeze(-1)  # [n_agents + 6]
        attention_weights = F.softmax(attention_scores, dim=0)  # Normalize scores
        attended_features = torch.sum(attention_weights.unsqueeze(-1) * all_features, dim=0)  # [embed_dim]

        return attended_features


In [None]:
def obs_to_tensor(obs, device):
    public_obs, private_obs = obs
    public_obs = (
        torch.tensor(public_obs[0], dtype=torch.long, device=device),
        torch.tensor(public_obs[1], dtype=torch.long, device=device),
        torch.tensor(public_obs[2], dtype=torch.long, device=device),
        torch.tensor(public_obs[3], dtype=torch.long, device=device),
        torch.tensor(public_obs[4], dtype=torch.long, device=device),
        torch.tensor(public_obs[5], dtype=torch.float, device=device),
        torch.tensor(public_obs[6], dtype=torch.float, device=device)
    )
    private_obs = (
        torch.tensor(private_obs[0], dtype=torch.float, device=device),
        torch.tensor(private_obs[1], dtype=torch.long, device=device)
    )
    return public_obs, private_obs



In [None]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.feature_extractor = FeatureExtractor(n_cards_total=108, embed_dim=self.state_dim)
        

