In [None]:
from typing import Any, SupportsFloat

import gymnasium as gym
from gymnasium.core import RenderFrame, ObsType, ActType

from jass.game.const import *
from jass.game.game_util import *
from jass.game.game_state_util import *
from jass.game.game_sim import GameSim
from jass.game.game_state import GameState
from jass.game.game_observation import GameObservation
from jass.game.rule_schieber import RuleSchieber



- https://gymnasium.farama.org/content/basic_usage/
- https://github.com/zmcx16/OpenAI-Gym-Hearts/blob/master/src/Hearts/Hearts.py
- https://gymnasium.farama.org/environments/toy_text/blackjack/#references
- https://www.youtube.com/watch?v=YLa_KkehvGw
- https://github.com/bernhard-pfann/uno-card-game-rl/tree/main
- https://towardsdatascience.com/tackling-uno-card-game-with-reinforcement-learning-fad2fc19355c

Okay so if I understand correctly, the model can be pretty simple. Taking some features like trump, player, cards in trick, played cards, and hand and returning a softmax over the cards to play the next one.

The big question is how to train it. A gym wrapper around the existing framework might help but maybe isn't even necessary because there is already a lot implemented with GameSim etc.

In a first step, this can just be trained with PPO or something like that, which should be possible without any MCTS just by self play. Biggest issue here will be to determine the value of a given time step when the game hasn't ended yet.

For an AlphaZero-like implementation, you'd apparently play games using MCTS and learn from that a value function? But it's still not supervised learning?

The UNO example does a bare re-implementation of Q learning with Belman and epsilon-greedy moves. I think using a PyTorch implementation of such an algorithm would be better, but you have to do some understanding first.

In [None]:
INVALID_CARD_REWARD = -100  # maybe need to be tunable

In [None]:
# can do different reward functions etc. later so it's tunable
def get_rewards(state: GameState):
    # returns our points - their points -> constant sum
    player_team = team[state.player]
    player_team_score = state.points[player_team] - state.points[1 - player_team]
    return {p: (player_team_score if team[p] == player_team else -player_team_score) for p in range(4)}

In [None]:
def obs_to_vector(obs: GameObservation, valid_cards: np.ndarray):
    # todo, not needed yet, but you'll need a vector where all the features are contained.
    # all the cards should be one-hot encoded so the played cards will already take 36x36,
    # as there are 36 cards to play and each one could be one of 36.
    # which card is currently being played could be a number between 0 and 1 as it's also
    # a percentage of how far a game is along. Same for number cards in trick. This is btw. the same as min-max scaling.
    # Trump should be one-hot encoded too though.
    # The number of cards each player has in their hand (4x1, from 0 to 1).
    # Trick winners should also be included in the features -> may have used all good cards
    # Who declared trump should too -> probably has good cards for that suit/trump
    # If player pushed (and who) -> probably has general cards, not good for any specific suit
    # These player things can also be done one-hot so one 4 long vector for who declared, one for who pushed,
    # one for who is me, one for who is my partner, one for who are my enemies, one for who won trick 0, one for who won trick 1, ..
    # for the hand cards have all 9 possible hand cards one hot encoded (9x36) and also one binary encoded (multilabel, 1x36).
    
    # ps. could run a convolution on all the cards so that there are shared weights which extract some information on a specific card.
    # this could be done by concating the trump on top of the cards (6 + 36 = 42), then concating all card stacks (at least 36 existing + 9 hand cards),
    # then do the conv with stride 42. Multiple filters but conv to just 1 number per card stack? So youd get num_filters new features per hand.
    # The extracted features can be concat to the card stacks again before stacking/flattening again.
    v = np.array
    np.array([
        obs.player,
        obs.nr_played_cards,
        obs.nr_cards_in_trick,
        obs.trump,
    ])

In [None]:
class JassEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        super().__init__()
        self.game: GameSim = GameSim(RuleSchieber())
        self.current_observation: GameObservation = None
        self.valid_actions: np.ndarray = None

    @property
    def state(self):
        return self.game.state

    @property
    def player(self):
        return self.state.player

    def reset(
        self,
        *,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ):
        super().reset(seed=seed)
        dealer = options['dealer']

        self.game.init_from_cards(dealer=dealer, hands=deal_random_hand(self.np_random))

        self.valid_actions = self.game.rule.get_valid_actions_from_state(self.game.state)

        self.current_observation = self.game.get_observation()

        info = {'invalid': False}
        return self.current_observation, info

    def step(
        self, action: ActType
    ):
        # action is simply an int with the action as specified by jass-kit itself.
        terminated = False
        info = {'invalid': False}
        rewards = {i: 0 for i in range(4)}
        curr_player = self.player
        if self.valid_actions[action] != 1:
            # player played an invalid card, don't update game
            rewards[curr_player] = INVALID_CARD_REWARD
            info['invalid'] = True
        else:
            self.game.action(action)
            self.valid_actions = self.game.rule.get_valid_actions_from_state(self.game.state)
            self.current_observation = self.game.get_observation()
            terminated = self.game.is_done()
            if terminated:
                rewards = get_rewards(self.game.state)

        info['rewards'] = rewards
        truncated = False
        return self.current_observation, rewards[curr_player], terminated, truncated, info

    def render(self) -> None:
        if self.render_mode != 'human':
            raise ValueError("Invalid render mode")

        if self.state.trump == -1:
            print(f"Trump: to be defined, push {'NOT ' if self.state.forehand != -1 else ''}available")
        else:
            print(f"Trump: {self.state.trump}")

        to_str = convert_one_hot_encoded_cards_to_str_encoded_list
        print(f"Current player: {self.player}")
        print(f"Their hand: {to_str(self.current_observation.hand)}")
        print(f"Tricks done so far: {self.current_observation.nr_tricks}")
        print(f"Cards in trick: {self.current_observation.nr_cards_in_trick}")

        partner = partner_player[self.player]
        o1 = next_player[self.player]
        o2 = next_player[partner]
        print(f"Partners hand: {to_str(self.state.hands[partner])}")
        print(f"Opponent hand P{o1}: {to_str(self.state.hands[o1])}")
        print(f"Opponent hand P{o2}: {to_str(self.state.hands[o2])}")
        

In [None]:
def add_dicts(a, b):
    # left join basically
    return {i: a[i] + b[i] for i in a.keys()}

In [None]:
def add_to_dict(a, b):
    # must have the same indices
    for k in a.keys():
        a[k] += b[k]

In [None]:
def run_randoms(n_episodes: int):
    # for best balance, set n_episodes divisible by 4, then everyone starts the same number of times
    from jass.agents.agent_random_schieber import AgentRandomSchieber
    env = JassEnv()
    players = [AgentRandomSchieber()] * 4  # reuses the same instance, which is fine for us

    for episode in range(n_episodes):
        obs, *_ = env.reset(seed=(42 if episode == 0 else None), options=dict(dealer=episode % 4))
        done = False
        scores = {i: 0 for i in range(4)}

        # first ask for the trumps and submit them
        # then really start playing
        # you'll have to see how you can do this for actual training with PPO etc.
        player = players[obs.player]
        trump = player.action_trump(obs)
        assert 0 <= trump <= MAX_TRUMP or trump == PUSH or trump == PUSH_ALT, "Invalid trump selected; fix external trump selection!"
        obs, *_ = env.step(trump_to_full(trump))
        if trump == PUSH or trump == PUSH_ALT:
            player = players[obs.player]
            trump = player.action_trump(obs)
            assert 0 <= trump <= MAX_TRUMP, "Invalid trump selected; fix external trump selection!"
            obs, *_ = env.step(trump_to_full(trump))

        while not done:
            player_id = obs.player
            player = players[player_id]
            action = player.action_play_card(obs)
            obs, reward, done, _, info = env.step(action)
            # scores[player_id] += reward
            add_to_dict(scores, info['rewards'])
            
        print(f"Episode {episode} -> Scores {scores}")
    

In [None]:
run_randoms(10)