In [1]:
import os
import numpy as np
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict, Optional, List
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium import spaces, ObservationWrapper, RewardWrapper
from gymnasium.wrappers import NormalizeReward
import unittest
import time
import pygame
from stable_baselines3.common.env_checker import check_env
from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete, InvalidActionEnvMultiDiscrete
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks



### Which spaces to use?

I think you can use any observation space, but the action space cannot be a dict or a tuple. `Box` space might be better as there seems to be more algorithms that support it. Though `MultiDiscrete` seems to be easier to describe.

In [15]:
# keep: np.array # action type 1 and 2, array of length 6 saying which dice we keep
# announce: bool = False # roll_number / action type 1
# announce_row: ROW = ROW.YAMB # roll_number / action type 1
# row_to_fill: ROW = ROW.YAMB # roll_number / action type 3 
# col_to_fill: COL = COL.DOLJE # roll_number / action type 3

# the action space can't be a tuple or dictionary, which is tricky
number_of_ones_to_keep_range = {"low" : 0, "high": 5}
number_of_twos_to_keep_range = {"low" : 0, "high": 5}
number_of_threes_to_keep_range = {"low" : 0, "high": 5}
number_of_fours_to_keep_range = {"low" : 0, "high": 5}
number_of_fives_to_keep_range = {"low" : 0, "high": 5}
number_of_sixes_to_keep_range = {"low" : 0, "high": 5}
announce_range = {"low" : 0, "high": 1}
announce_row_range = {"low" : 0, "high": 13}
row_to_fill_range = {"low": 0, "high": 13}
col_to_fill_range = {"low": 0, "high": 3}

low = np.array(
[
    number_of_ones_to_keep_range["low"],
    number_of_twos_to_keep_range["low"],
    number_of_threes_to_keep_range["low"],
    number_of_fours_to_keep_range["low"],
    number_of_fives_to_keep_range["low"],
    number_of_sixes_to_keep_range["low"],
    announce_range["low"],
    announce_row_range["low"],
    row_to_fill_range["low"],
    col_to_fill_range["low"],
]
)

high = np.array(
[
    number_of_ones_to_keep_range["high"],
    number_of_twos_to_keep_range["high"],
    number_of_threes_to_keep_range["high"],
    number_of_fours_to_keep_range["high"],
    number_of_fives_to_keep_range["high"],
    number_of_sixes_to_keep_range["high"],
    announce_range["high"],
    announce_row_range["high"],
    row_to_fill_range["high"],
    col_to_fill_range["high"],
]
)

action_space = spaces.Box(low=low, high=high, dtype=int)
action_space.sample()

array([ 5,  2,  4,  5,  0,  2,  1, 11, 10,  0])

In [2]:
action_space = spaces.MultiDiscrete(np.array([6, 6, 6, 6, 6, 6, 2, 14, 14, 4]))
action_space.sample()

array([ 3,  4,  1,  4,  0,  0,  1, 11,  9,  2], dtype=int64)

### Maskable action space
Might need this for Maskable PPO from `sb3_contrib`

In [11]:
num1s = np.array([1, 1, 1, 0, 0, 0], dtype=np.int8)
num2s = np.array([1, 1, 1, 0, 0, 0], dtype=np.int8)
num3s = np.array([1, 1, 1, 0, 0, 0], dtype=np.int8)
num4s = np.array([1, 1, 1, 0, 0, 0], dtype=np.int8)
num5s = np.array([1, 1, 1, 0, 0, 0], dtype=np.int8)
num6s = np.array([1, 1, 1, 0, 0, 0], dtype=np.int8)
announce = np.array([1, 0], dtype=np.int8)
announce_row = np.array([1]*14, dtype=np.int8)
row_to_fill = np.array([1] + 13 * [0], dtype=np.int8)
col_to_fill = np.array([1, 0, 0, 0], dtype=np.int8)
mask = (num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_to_fill, col_to_fill)
action_space.sample(mask=mask)

array([0, 1, 0, 2, 1, 2, 0, 1, 0, 0], dtype=int64)

In [41]:
spaces.Dict({
            "turn_number": spaces.Discrete(14*3,start=0),
            "roll_number": spaces.Discrete(3,start=0),
            "grid": spaces.Box(low=-145, high=145, shape=(14, 3), dtype=int),
            "roll": spaces.Box(low=0, high=5, shape=(6,), dtype=int),
            "announced": spaces.Discrete(2,start=0),
            "announced_row": spaces.Discrete(14, start=0),
}).sample()
        

OrderedDict([('announced', 0),
             ('announced_row', 5),
             ('grid',
              array([[  92,  113,  106],
                     [   3,  -34,   57],
                     [ 106,  -35,  -79],
                     [  49,  -12,  -74],
                     [ 107,   -8,   67],
                     [ -70,  133,  127],
                     [  47,  113,  -79],
                     [ -10,  -70,  132],
                     [ -57,  120,  -64],
                     [ 132,  121,  -93],
                     [  57,    5,  128],
                     [  99,  130, -143],
                     [ -42,  136, -136],
                     [-112,   33,  127]])),
             ('roll', array([5, 1, 1, 4, 1, 5])),
             ('roll_number', 1),
             ('turn_number', 2)])

# `YambEnv`

In [2]:
class ROW(Enum):
    """Enum representing each row in yamb
    """
    ONES=0
    TWOS=1
    THREES=2
    FOURS=3
    FIVES=4
    SIXES=5
    MAX=6
    MIN=7
    DVAPARA=8
    TRIS=9
    SKALA=10
    FULL=11
    POKER=12
    YAMB=13
    
class COL(Enum):
    """Enum representing each col in yamb
    """
    DOLJE=0
    GORE=1
    SLOBODNO=2
    NAJAVA=3
    

class YambEnv(gym.Env):
    """
    :param turn_number: This tells us which turn we are on. There are 14 * 4 turns in yamb each consisting of 3 rolls.
    :param roll_number: Each round in yamb consists of three rolls. This tells you which roll we are on.
    :param grid: This is the 14 * 4 grid in yamb which needs to be filled out. -145 indicates not filled.
    :param roll: This tells us the roll we just had in multinomial format. This means roll[2] is the number of 3s.
    :param announced: This tells us whether we have announced in our current turn.
    :param announced_row: This tells us the row we have announced in our current turn.
    """
    RENDER_FPS = 10
    NAN = -145
    ACTION_ANNOUNCE_IDX = 6
    ACTION_ANNOUNCE_ROW_IDX = 7
    ACTION_ROW_COL_FILL_IDX = 8
    SCREEN_WIDTH = 640
    SCREEN_HEIGHT = 480
    
    def __init__(self):
        super().__init__()
        self.turn_number = 0
        self.roll_number = 0
        self.grid = np.full((len(ROW), len(COL)), self.NAN)
        self.roll = np.array([0, 0, 0, 0, 0, 0])
        self.announced = 0
        self.announced_row = 0
        
        self.observation_space = spaces.Dict({
            "turn_number": spaces.Discrete(len(ROW)*len(COL),start=0),
            "roll_number": spaces.Discrete(3,start=0),
            "grid": spaces.Box(low=-145, high=145, shape=(len(ROW), len(COL)), dtype=int),
            "roll": spaces.Box(low=0, high=5, shape=(6,), dtype=int),
            "announced": spaces.Discrete(2,start=0),
            "announced_row": spaces.Discrete(len(ROW), start=0),
        })
        
        # [num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_col_fill]
        self.action_space = spaces.MultiDiscrete(np.array([6, 6, 6, 6, 6, 6, 2, len(ROW), len(ROW) * len(COL)]))
        
        self.truncation_penalty = -1000
        
        # pygame parameters and objects
        self.screen = None
        self.clock = None
    
    def reset(self, seed=None, options=None) -> Tuple[dict, dict]:
        """Reset environment to initial state - remember this also includes rolling the dice
        
        :return: observation of the initial state along with auxiliary information
        """
        np.random.seed(seed)
        self.turn_number = 0
        self.roll_number = 0
        for row in ROW:
            for col in COL:
                self.grid[row.value, col.value] = self.NAN
        self.roll = np.random.multinomial(5, [1/6.]*6)
        self.announced = 0
        self.announced_row = 0
        return self.get_observation(), {}
    
    def step(self, action : NDArray[np.int64]) -> Tuple[dict, float, bool, bool, dict]:
        """ Run one timestep of the environment's dynamics - in total there are 56 turns, and 3 rolls within each turn
        
        :param action: numpy array of length 9
            [num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_col_fill]
        
        :return: observation, reward, terminated, truncated, info
            observation:dict
            reward:float score - prev score or self.truncation_penalty if the game finishes because of a bad action
            terminated:bool true when the game finished properly
            truncated:bool true when the game finished due to unforseen circumstances; action was out of bounds
            info:dict other relevant information for example the score / why the game truncated?
        """
        
        prev_score = self.get_score()
        
        valid, info = True, {}
        if self.roll_number == 0:
            valid = self.step_1_valid(action, info)
        
        if self.roll_number == 1:
            valid = self.step_2_valid(action, info)
        
        if self.roll_number == 2:
            valid = self.step_3_valid(action[YambEnv.ACTION_ROW_COL_FILL_IDX], info)
        
        if not valid:
            return self.get_observation(), self.truncation_penalty, False, True, info
        
        # if the action is valid, we can mutate the state
        if self.roll_number == 0:
            self.roll_number += 1
            keep = action[:self.ACTION_ANNOUNCE_IDX]
            number_of_dice_to_roll = sum(self.roll - keep)
            self.roll = np.random.multinomial(number_of_dice_to_roll, [1/6.]*6) + keep
            self.announced = action[self.ACTION_ANNOUNCE_IDX]
            self.announced_row = action[self.ACTION_ANNOUNCE_ROW_IDX]
        elif self.roll_number == 1:
            self.roll_number += 1
            keep = action[:self.ACTION_ANNOUNCE_IDX]
            number_of_dice_to_roll = sum(self.roll - keep)
            self.roll = np.random.multinomial(number_of_dice_to_roll, [1/6.]*6) + keep
        elif self.roll_number == 2:
            # we are moving on to the next turn
            self.roll_number = 0
            self.turn_number += 1
            self.announced = 0
            self.announced_row = 0
            r, c = YambEnv.convert_row_col_fill(action[YambEnv.ACTION_ROW_COL_FILL_IDX])
            self.grid[r, c] = self.get_grid_square_value(r, self.roll)
            self.roll = np.random.multinomial(5, [1/6.]*6)
        
        info["score"] = self.get_score()
        reward = info["score"] - prev_score
        terminated = True if self.turn_number >= len(ROW)*len(COL) else False
        
        if self.render_mode == "human":
            self.render()
            
        return self.get_observation(), reward, terminated, False, info
    
    def render(self):
        """Displays the state of the game - there is no interaction with the user here
        """
        if self.screen is None:
            pygame.init()
            self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT), pygame.RESIZABLE)
            
        if self.clock is None:
            self.clock = pygame.time.Clock()
            
        cell_size = 32
        dot_radius = 2
        black = (0, 0, 0)
        white = (255, 255, 255)
        self.screen.fill(white)
        
        font = pygame.font.Font(None, 24)
        text = font.render(f"TURN: {self.turn_number}, ROLL: {self.roll_number}", True, black)
        text_rect = text.get_rect()
        text_rect.topleft = (self.SCREEN_WIDTH//2, cell_size)
        self.screen.blit(text, text_rect)
        
        text = font.render(f"SCORE: {self.get_score()}", True, black)
        text_rect = text.get_rect()
        text_rect.topleft = (self.SCREEN_WIDTH//2, 2*cell_size)
        self.screen.blit(text, text_rect)
        
        text = font.render(f"ANNOUNCED: {bool(self.announced)}", True, black)
        text_rect = text.get_rect()
        text_rect.topleft = (self.SCREEN_WIDTH//2, 4*cell_size)
        self.screen.blit(text, text_rect)
        
        if self.announced == 1:
            text = font.render(f"ANNOUNCED ROW: {ROW(self.announced_row)}", True, black)
            text_rect = text.get_rect()
            text_rect.topleft = (self.SCREEN_WIDTH//2, 5*cell_size)
            self.screen.blit(text, text_rect)
        
        
        
        # draw rows
        font = pygame.font.Font(None, 12)
        for row in ROW:
            rect = pygame.Rect(0, (1+row.value)*cell_size, 2*cell_size, cell_size)
            text = font.render(str(row.value) + ". " + row.name, True, black)
            pygame.draw.rect(self.screen, black, rect, 1)
            self.screen.blit(text, rect.move(cell_size//4, cell_size//4))
            
        # draw cols
        for col in COL:
            rect = pygame.Rect((2+2*col.value)*cell_size, 0, 2*cell_size, cell_size)
            text = font.render(str(col.value) + ". " + col.name, True, black)
            pygame.draw.rect(self.screen, black, rect, 1)
            self.screen.blit(text, rect.move(cell_size//4, cell_size//4))
            
        # fill out the grid
        font = pygame.font.Font(None, 20)
        for row in ROW:
            for col in COL:
                rect = pygame.Rect((2+2*col.value)*cell_size, (1+row.value)*cell_size, 2*cell_size, cell_size)
                s = "" if self.grid[row.value, col.value] == self.NAN else str(self.grid[row.value, col.value])
                text = font.render(s, True, black)
                pygame.draw.rect(self.screen, black, rect, 1)
                self.screen.blit(text, rect.move(cell_size//4, cell_size//4))
                
        # draw roll
        def draw_dot(position):
            pygame.draw.circle(self.screen, black, position, dot_radius)

        def draw_one(x, y):
            draw_dot((x + cell_size//2, y + cell_size//2))

        def draw_two(x, y):
            draw_dot((x + cell_size//4, y + cell_size//4))
            draw_dot((x + cell_size*3//4, y + cell_size*3//4))

        def draw_three(x, y):
            draw_dot((x + cell_size//4, y + cell_size//4))
            draw_dot((x + cell_size//2, y + cell_size//2))
            draw_dot((x + cell_size*3//4, y + cell_size*3//4))

        def draw_four(x, y):
            draw_dot((x + cell_size//4, y + cell_size//4))
            draw_dot((x + cell_size*3//4, y + cell_size//4))
            draw_dot((x + cell_size//4, y + cell_size*3//4))
            draw_dot((x + cell_size*3//4, y + cell_size*3//4))

        def draw_five(x, y):
            draw_dot((x + cell_size//4, y + cell_size//4))
            draw_dot((x + cell_size*3//4, y + cell_size//4))
            draw_dot((x + cell_size//2, y + cell_size//2))
            draw_dot((x + cell_size//4, y + cell_size*3//4))
            draw_dot((x + cell_size*3//4, y + cell_size*3//4))

        def draw_six(x, y):
            draw_dot((x + cell_size//4, y + cell_size//4))
            draw_dot((x + cell_size*3//4, y + cell_size//4))
            draw_dot((x + cell_size//4, y + cell_size//2))
            draw_dot((x + cell_size*3//4, y + cell_size//2))
            draw_dot((x + cell_size//4, y + cell_size*3//4))
            draw_dot((x + cell_size*3//4, y + cell_size*3//4))
        
        top_left_x, top_left_y = YambEnv.SCREEN_WIDTH//2, YambEnv.SCREEN_HEIGHT//2
        for i, count in enumerate(self.roll):
            for die in range(count):
                rect = pygame.Rect(top_left_x, top_left_y, cell_size, cell_size)
                pygame.draw.rect(self.screen, (0, 0, 0), rect, 1)
                
                if (i+1)==1: draw_one(top_left_x, top_left_y)
                if (i+1)==2: draw_two(top_left_x, top_left_y)
                if (i+1)==3: draw_three(top_left_x, top_left_y)
                if (i+1)==4: draw_four(top_left_x, top_left_y)
                if (i+1)==5: draw_five(top_left_x, top_left_y)
                if (i+1)==6: draw_six(top_left_x, top_left_y)
                    
                top_left_x += cell_size
                
            
        self.clock.tick(YambEnv.RENDER_FPS)
        pygame.display.flip()
        
    
    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()
        
    def get_observation(self) -> dict:
        observation = {
            "turn_number": self.turn_number,
            "roll_number": self.roll_number,
            "grid": self.grid,
            "roll": self.roll,
            "announced": self.announced,
            "announced_row": self.announced_row,
        }
        return observation
    
    def action_masks(self) -> List[bool]:
        """Returns a one hot encoded list whether an action is valid
        
        :return: list of size sum([6, 6, 6, 6, 6, 6, 2, 14, 56])
        """
        num1s = [False] * 6
        num2s = [False] * 6
        num3s = [False] * 6
        num4s = [False] * 6
        num5s = [False] * 6
        num6s = [False] * 6
        announce = [False] * 2
        announce_row = [False] * len(ROW)
        row_col_fill = [False] * len(ROW) * len(COL)
        
        # can only keep dice that you have
        if (self.roll_number == 0) or (self.roll_number == 1):
            num1s = [True] * (self.roll[0] + 1) + [False] * (5 - self.roll[0])
            num2s = [True] * (self.roll[1] + 1) + [False] * (5 - self.roll[1])
            num3s = [True] * (self.roll[2] + 1) + [False] * (5 - self.roll[2])
            num4s = [True] * (self.roll[3] + 1) + [False] * (5 - self.roll[3])
            num5s = [True] * (self.roll[4] + 1) + [False] * (5 - self.roll[4])
            num6s = [True] * (self.roll[5] + 1) + [False] * (5 - self.roll[5])
        
        if self.roll_number == 0:
            announce_row = [self.valid_announce_row(i) for i in range(len(ROW))]
            announce = [not self.need_to_announce(), any(announce_row)]
            
        if self.roll_number == 2:
            for i in range(len(COL)*len(ROW)):
                row_col_fill[i] = self.step_3_valid(i, {})
                
        mask = num1s + num2s + num3s + num4s + num5s + num6s + announce + announce_row + row_col_fill
        return mask
        
    def step_1_valid(self, action: NDArray[np.int64], info: dict) -> bool:
        """Checks whether an action of type 1 is valid
        
        :param action: numpy array of length 9
            [num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_col_fill]
        
        :return: whether the action was valid or not, truncation reason will be added to info dict
        
        unused: row_col_fill
        """
        keep = action[:self.ACTION_ANNOUNCE_IDX]
        if any( (self.roll - keep) < 0 ):
            info["truncation_reason"] = f"Can't keep {keep} when you only have {self.roll}"
            return False
        
        if action[self.ACTION_ANNOUNCE_IDX] == 1:
            if not self.valid_announce_row(action[self.ACTION_ANNOUNCE_ROW_IDX]):
                info["truncation_reason"] = f"Announce row {ROW(action[self.ACTION_ANNOUNCE_ROW_IDX])} not valid"
                return False
            
        if action[self.ACTION_ANNOUNCE_IDX] == 0:
            if self.need_to_announce():
                info["truncation_reason"] = "Only najava column left so must use it"
                return False
            
        return True
    
    def step_2_valid(self, action: NDArray[np.int64], info: dict) -> bool:
        """Checks whether an action of type 2 is valid
        
        :param action: numpy array of length 9
            [num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_col_fill]
        
        :return: whether the action was valid or not, truncation reason will be added to info dict
        
        unused: announce, announce_row, row_col_fill
        """
        keep = action[:self.ACTION_ANNOUNCE_IDX]
        if any( (self.roll - keep) < 0 ):
            info["truncation_reason"] = f"Can't keep {keep} when you only have {self.roll}"
            return False
        
        return True
    
    def step_3_valid(self, row_col_fill: int, info: dict) -> bool:
        """Checks whether an action of type 3 is valid
        
        :param row_col_fill: int indicating which row and col of the grid to fill out
        
        :return: whether the action was valid or not, truncation reason will be added to info dict
        """
        r, c = YambEnv.convert_row_col_fill(row_col_fill)
        if self.grid[r, c] != self.NAN:
            info["truncation_reason"] = f"{r}, {c} already filled in "
            return False
        
        if (c == COL.GORE.value) and (r != self.get_next_gore()):
            info["truncation_reason"] = f"Gore needed {ROW(self.get_next_gore())} but trying {ROW(r)}"
            return False
        
        if (c == COL.DOLJE.value) and (r != self.get_next_dolje()):
            info["truncation_reason"] = f"Dolje needed {ROW(self.get_next_dolje())} but trying {ROW(r)}"
            return False
        
        if self.announced and ((c != COL.NAJAVA.value) or (r != self.announced_row)):
            info["truncation_reason"] = f"Announced {ROW(self.announced_row)} but trying to fill {ROW(r)}, {COL(c)}"
            return False
        
        if (self.announced==0) and (c == COL.NAJAVA.value):
            info["truncation_reason"] = f"Have not announced so cannot fill out najava column"
            return False
        
        return True
    
    def valid_announce_row(self, row: int) -> bool:
        """
        :param row: a row which you which to announce
        
        :return: bool indicated whether you can actually announce that row
        """
        if row not in ROW:
            return False
        
        if self.grid[row, COL.NAJAVA.value] == self.NAN:
            return True
        else:
            return False
        
    def need_to_announce(self) -> bool:
        """You must announce on your first roll when the rest of the grid has been filled
        :return: whether you need to announce on your first roll
        """
        for col in [COL.DOLJE, COL.GORE, COL.SLOBODNO]:
            for row in ROW:
                if self.grid[row.value, col.value] == self.NAN:
                    # there's something you can fill out
                    return False
                
        return True
    
    def get_score(self) -> int:
        """
        :return: game score thus far, anything with an nan will be assigned zero
        """
        result = 0
        for col in COL:
            A = 0
            for row in [ROW.ONES, ROW.TWOS, ROW.THREES, ROW.FOURS, ROW.FIVES, ROW.SIXES]:
                A += 0 if self.grid[row.value, col.value] == self.NAN else self.grid[row.value, col.value]
                
            if (A >= 60):
                A += 30
            
            if (self.grid[ROW.MAX.value, col.value] == self.NAN) or \
            (self.grid[ROW.MIN.value, col.value] == self.NAN) or \
            (self.grid[ROW.ONES.value, col.value] == self.NAN):
                B = 0
            else:
                B = self.grid[ROW.MAX.value, col.value] - self.grid[ROW.MIN.value, col.value]
                B *= self.grid[ROW.ONES.value, col.value] 
            
            C = 0
            for row in [ROW.DVAPARA, ROW.TRIS, ROW.SKALA, ROW.FULL, ROW.POKER, ROW.YAMB]:
                C += 0 if self.grid[row.value, col.value] == self.NAN else self.grid[row.value, col.value]
                
            result = result + A + B + C
                
        return result
    
    def get_next_dolje(self) -> Optional[int]:
        """
        Gets the next row we need to fill out in the dolje column, if we've completed it return nan
        """
        for row in ROW:
            if self.grid[row.value, COL.DOLJE.value] == self.NAN: return row.value
        
        return np.nan
    
    def get_next_gore(self) -> Optional[int]:
        """
        Gets the next row we need to fill out in the gore column, if we've completed it return nan
        """
        for row in reversed(ROW):
            if self.grid[row.value, COL.GORE.value] == self.NAN: return row.value
        
        return np.nan
    
    @staticmethod
    def get_grid_square_value(row: int, cnts: np.array) -> int:
        """
        :param row: which row do you want the grid square value for
        :param cnts: array of size six which tells you tells you mapping of face value to how many dice
        :return: grid square value
        """
        
        if row == ROW.ONES.value:
            return 1 * cnts[0]
        elif row == ROW.TWOS.value:
            return 2 * cnts[1]
        elif row == ROW.THREES.value:
            return 3 * cnts[2]
        elif row == ROW.FOURS.value:
            return 4 * cnts[3]
        elif row == ROW.FIVES.value:
            return 5 * cnts[4]
        elif row == ROW.SIXES.value:
            return 6 * cnts[5]
        elif row == ROW.MAX.value:
            return sum( (i+1)*item for i, item in enumerate(cnts) )
        elif row == ROW.MIN.value:
            return sum( (i+1)*item for i, item in enumerate(cnts) )
        elif row == ROW.DVAPARA.value:
            return YambEnv.dvapara(cnts)
        elif row == ROW.TRIS.value:
            return YambEnv.tris(cnts)
        elif row == ROW.SKALA.value:
            return YambEnv.skala(cnts)
        elif row == ROW.FULL.value:
            return YambEnv.full(cnts)
        elif row == ROW.POKER.value:
            return YambEnv.poker(cnts)
        elif row == ROW.YAMB.value:
            return YambEnv.yamb(cnts)
        else:
            raise IndexError(f"Row {row} not found in possible rows")
    
    @staticmethod
    def dvapara(cnts : np.array) -> int:
        if not (sum(cnts >= 2) >= 2): return 0
        s = 0
        for i, cnt in enumerate(cnts):
            if cnt >= 2:
                s += 2 * (i+1)
        return s + 10
    
    @staticmethod
    def tris(cnts : np.array) -> int:
        if not any(cnts >= 3): return 0
        s = 0
        for i, cnt in enumerate(cnts):
            if cnt >= 3:
                s += 3 * (i+1)
                
        return s + 20
    
    @staticmethod
    def skala(cnts : np.array) -> int:
        if all(np.array([1,1,1,1,1,0]) == cnts):
            return 45
        elif all(np.array([0,1,1,1,1,1]) == cnts):
            return 50
        else:
            return 0
    
    @staticmethod
    def full(cnts : np.array) -> int:
        if not (any(cnts == 3) * any(cnts == 2)): return 0
        s = sum( (i+1)*item for i, item in enumerate(cnts) )
        return s + 40
    
    @staticmethod
    def poker(cnts : np.array) -> int:
        if not any(cnts >= 4): return 0
        s = 0
        for i, cnt in enumerate(cnts):
            if cnt >= 4:
                s += 4 * (i+1)
        
        return s + 50
    
    @staticmethod
    def yamb(cnts : np.array) -> int:
        if not any(cnts >= 5): return 0
        s = sum( (i+1)*item for i, item in enumerate(cnts) )
        return s + 60
    
    @staticmethod
    def convert_row_col_fill(row_col_to_fill: int) -> Tuple[int, int]:
        """Converts a single index representing a grid square to fill in into two indices
        representing a row and column to fill
        
        :param row_col_to_fill: single index representing a grid square we want to fill
        
        :return: row we want to fill in, col we want to fill in
        """
        assert 0 <= row_col_to_fill < len(ROW) * len(COL)
        col_to_fill, row_to_fill = divmod(row_col_to_fill, len(ROW))
        return row_to_fill, col_to_fill
    
    @staticmethod
    def convert_row_fill_col_fill(row_to_fill: int, col_to_fill: int) -> int:
        """Converts two indices representing a row and column to fill into a single
        index representing a grid square to fill
        
        :param row_to_fill: row we want to fill in
        :param col_to_fill: col we want to fill in
        
        :return: single index representing a grid square we want to fill
        """
        assert 0 <= row_to_fill < len(ROW)
        assert 0 <= col_to_fill < len(COL)
        return row_to_fill + len(ROW) * col_to_fill
    

In [3]:
class TestYambEnv(unittest.TestCase):
    def test_convert_row_col_fill(self):
        row, col = YambEnv.convert_row_col_fill(0)
        self.assertEqual(0, row)
        self.assertEqual(0, col)
        
        row, col = YambEnv.convert_row_col_fill(14)
        self.assertEqual(0, row)
        self.assertEqual(1, col)
        
        row, col = YambEnv.convert_row_col_fill(55)
        self.assertEqual(13, row)
        self.assertEqual(3, col)
        
    def test_convert_row_fill_col_fill(self):
        idx = YambEnv.convert_row_fill_col_fill(1, 0)
        self.assertEqual(1, idx)
        
        idx = YambEnv.convert_row_fill_col_fill(0, 1)
        self.assertEqual(14, idx)
        
        idx = YambEnv.convert_row_fill_col_fill(13, 3)
        self.assertEqual(55, idx)
    
    def test_get_next_dolje(self):
        env = YambEnv()
        
        # start of game nothing is filled out
        self.assertEqual(ROW.ONES, ROW(env.get_next_dolje()))
        
        
        # when we add stuff to other columns nothing should change
        env.grid[ROW.ONES.value, COL.GORE.value] = 1
        env.grid[ROW.ONES.value, COL.SLOBODNO.value] = 1
        self.assertEqual(ROW.ONES, ROW(env.get_next_dolje()))
        
        # when we fill out the rows in order, check the function works as expected
        rows = list(ROW)
        for row in rows[:-1]:
            env.grid[row.value, COL.DOLJE.value] = 0
            self.assertEqual(rows[row.value+1], ROW(env.get_next_dolje()))
        
        # once we've filled everything out check that this returns nan
        env.grid[rows[-1].value, COL.DOLJE.value] = 0
        self.assertTrue(np.isnan(env.get_next_dolje()))
        
    def test_get_next_gore(self):
        env = YambEnv()
        
        # start of game nothing is filled out
        self.assertEqual(ROW.YAMB, ROW(env.get_next_gore()))
        
        
        # when we add stuff to other columns nothing should change
        env.grid[ROW.YAMB.value, COL.DOLJE.value] = 1
        env.grid[ROW.YAMB.value, COL.SLOBODNO.value] = 1
        self.assertEqual(ROW.YAMB, ROW(env.get_next_gore()))
        
        # when we fill out the rows in order, check the function works as expected
        rows = list(ROW)
        for row in reversed(rows[1:]):
            env.grid[row.value, COL.GORE.value] = 0
            self.assertEqual(rows[row.value-1], ROW(env.get_next_gore()))
        
        # once we've filled everything out check that this returns nan
        env.grid[rows[0].value, COL.GORE.value] = 0
        self.assertTrue(np.isnan(env.get_next_gore()))
        
    def test_get_score(self):
        env = YambEnv()
        self.assertEqual(0, env.get_score())
        
        env.grid = np.array(
        [[-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145]]
        )
        self.assertEqual(0, env.get_score())
        
        env.grid = np.array(
        [[1     , -145, 2     , -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, 10    , -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, 55    , -145, -145]]
        )
        self.assertEqual(1+55+2, env.get_score())
        
        env.grid = np.array(
        [[1     , -145, 2     , -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, 20    , -145],
         [-145, -145, 10    , -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, 55    , -145, -145]]
        )
        self.assertEqual(1+55+2+2*(20-10), env.get_score())
        
        env.grid = np.array(
        [[1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],]
        )
        self.assertEqual(6*4 + 6*4, env.get_score())
        
        env.grid = np.array(
        [[-145, -145, -145, 2     ],
         [-145, -145, -145, 4     ],
         [-145, -145, -145, 3     ],
         [-145, -145, -145, 12    ],
         [-145, -145, -145, 15    ],
         [-145, -145, -145, 24    ],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145]]
        )
        self.assertEqual(90, env.get_score())
        
        env.grid = np.array(
        [[1     , 0     , 6     , 0     ],
         [4     , 2     , 12    , 0     ],
         [3     , 3     , 15    , 0     ],
         [12    , 16    , 20    , 0     ],
         [15    , 20    , 25    , 0     ],
         [24    , 6     , 30    , 0     ],
         [20    , 30    , 10    , 0     ],
         [10    , 5     , 5     , 0     ],
         [16    , 0     , 20    , 0     ],
         [33    , 0     , 0     , 0     ],
         [45    , 0     , 0     , 0     ],
         [55    , 0     , 0     , 0     ],
         [54    , 0     , 0     , 0     ],
         [65    , 0     , 0     , 0     ]]
        )
        self.assertEqual(337+47+188+0, env.get_score())
        
    def test_valid_announce_row(self):
        env = YambEnv()
        env.grid = np.array(
        [[-145, -145, -145, 2     ],
         [-145, -145, -145, 4     ],
         [-145, -145, -145, 3     ],
         [-145, -145, -145, 12    ],
         [-145, -145, -145, 15    ],
         [-145, -145, -145, 24    ],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145],
         [-145, -145, -145, -145]]
        )
        self.assertFalse(env.valid_announce_row(14))
        self.assertFalse(env.valid_announce_row(ROW.SIXES.value))
        self.assertTrue(env.valid_announce_row(ROW.MAX.value))
        self.assertTrue(env.valid_announce_row(10))
        
    def test_need_to_announce(self):
        env = YambEnv()
        env.grid = np.array(
        [[1     , 0     , 6     , 0     ],
         [4     , 2     , 12    , 0     ],
         [3     , 3     , 15    , 0     ],
         [12    , 16    , 20    , 0     ],
         [15    , 20    , 25    , 0     ],
         [24    , 6     , 30    , 0     ],
         [20    , 30    , 10    , 0     ],
         [10    , 5     , 5     , 0     ],
         [16    , 0     , 20    , 0     ],
         [33    , 0     , 0     , 0     ],
         [45    , 0     , 0     , 0     ],
         [55    , 0     , 0     , 0     ],
         [54    , 0     , 0     , 0     ],
         [65    , 0     , 0     , -145     ]]
        )
        self.assertTrue(env.need_to_announce())
        env.grid = np.array(
        [[1     , 0     , 6     , 0     ],
         [4     , 2     , 12    , 0     ],
         [3     , 3     , 15    , 0     ],
         [12    , 16    , 20    , 0     ],
         [15    , 20    , 25    , 0     ],
         [24    , 6     , 30    , 0     ],
         [20    , 30    , 10    , 0     ],
         [10    , 5     , 5     , 0     ],
         [16    , 0     , 20    , 0     ],
         [33    , 0     , 0     , 0     ],
         [45    , 0     , 0     , 0     ],
         [55    , 0     , -145     , 0     ],
         [54    , 0     , 0     , 0     ],
         [65    , 0     , 0     , -145     ]]
        )
        self.assertFalse(env.need_to_announce())
        
    def test_get_grid_square_value(self):
        self.assertEqual(YambEnv.get_grid_square_value(ROW.ONES.value, np.array([1,1,1,1,1,0])), 1)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TWOS.value, np.array([0,2,1,1,1,0])), 4)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.THREES.value, np.array([5,0,0,0,0,0])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FOURS.value, np.array([4,0,0,1,0,0])), 4)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FIVES.value, np.array([0,0,0,0,5,0])), 25)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SIXES.value, np.array([0,1,0,0,0,3])), 18)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.MAX.value, np.array([1,1,1,1,1,0])), 15)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.MIN.value, np.array([4,1,0,0,0,0])), 6)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.DVAPARA.value, np.array([4,1,0,0,0,0])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.DVAPARA.value, np.array([2,3,0,0,0,0])), 10+6)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.DVAPARA.value, np.array([1,0,0,0,2,2])), 10+22)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TRIS.value, np.array([4,1,0,0,0,0])), 20+3)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TRIS.value, np.array([2,3,0,0,0,0])), 20+6)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TRIS.value, np.array([1,0,0,0,2,2])), 0)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA.value, np.array([1,1,1,1,1,0])), 45)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA.value, np.array([1,1,1,1,0,1])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA.value, np.array([0,1,1,1,1,1])), 50)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA.value, np.array([0,1,2,0,1,1])), 0)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL.value, np.array([0,0,0,0,0,5])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL.value, np.array([0,0,0,0,2,3])), 40 + 28)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL.value, np.array([0,0,0,1,2,2])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL.value, np.array([0,0,0,0,3,2])), 40 + 27)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.POKER.value, np.array([0,4,1,0,0,0])), 50+8)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.POKER.value, np.array([0,5,0,0,0,0])), 50+8)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.POKER.value, np.array([0,0,0,0,2,3])), 0)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.YAMB.value, np.array([0,0,0,0,0,5])), 60+30)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.YAMB.value, np.array([0,5,0,0,0,0])), 60+10)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.YAMB.value, np.array([0,0,0,0,1,4])), 0)
        
        
    def test_step(self):
        env = YambEnv()
        observation, _ = env.reset()
        
        # step should fail because trying to keep more dice than we have
        action = np.array([
            observation["roll"][0]+1,
            observation["roll"][1],
            observation["roll"][2],
            observation["roll"][3],
            observation["roll"][4],
            observation["roll"][5],
            0,
            0,
            0 + 0*14,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(reward, env.truncation_penalty)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should pass
        action = np.array([
            observation["roll"][0],
            observation["roll"][1],
            observation["roll"][2],
            observation["roll"][3],
            observation["roll"][4],
            observation["roll"][5],
            1,
            ROW.YAMB.value,
            0 + 0*14,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(observation["turn_number"], 0)
        self.assertEqual(observation["roll_number"], 1)
        for i in range(6):
            self.assertEqual(observation["roll"][i], action[i]) # should be the same because we kept everything
        
        self.assertEqual(info["score"], 0)
        self.assertEqual(reward, 0)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, False)
        last_roll = np.copy(observation["roll"])
        
        # step should pass
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            -145,
            -145,
            -145,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(observation["turn_number"], 0)
        self.assertEqual(observation["roll_number"], 2)
        with np.testing.assert_raises(AssertionError):
            np.testing.assert_array_equal(observation["roll"], last_roll) # roll should be different because we didn't keep anything
        self.assertEqual(info["score"], 0)
        self.assertEqual(reward, 0)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, False)
        last_roll = np.copy(observation["roll"])
        
        # step should fail because we're trying to fill out column which isn't najava
        action = np.array([
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            ROW.YAMB.value + 14 * COL.GORE.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(reward, env.truncation_penalty)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should fail because we're trying to fill out row which isn't the one announced
        action = np.array([
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            ROW.ONES.value + 14 * COL.NAJAVA.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(reward, env.truncation_penalty)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should pass
        action = np.array([
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            -145,
            ROW.YAMB.value + 14 * COL.NAJAVA.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 0)
        with np.testing.assert_raises(AssertionError):
            np.testing.assert_array_equal(observation["roll"], last_roll) # since we're moving onto next turn roll should differ
        self.assertEqual(info["score"], 0) # score very likely to be zero because we probs didn't get a yamb
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, False)
        self.assertEqual(observation["grid"][ROW.YAMB.value, COL.NAJAVA.value], info["score"])
        last_roll = np.copy(observation["roll"])
        
        # try to announce row which has already been announced - should fail
        action = np.array([
            observation["roll"][0],
            observation["roll"][1],
            observation["roll"][2],
            observation["roll"][3],
            observation["roll"][4],
            observation["roll"][5],
            1,
            ROW.YAMB.value,
            0 + 14 * 0,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        np.testing.assert_array_equal(observation["roll"], last_roll)  # should be the same because we failed action
        self.assertEqual(reward, env.truncation_penalty)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should pass
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0 + 14 * 0,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 1)
        
        # step should pass
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0 + 14 * 0,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 2)
        
        # step should fail as we are trying to fill out halfway down the dolje col
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            ROW.DVAPARA.value + 14 * COL.DOLJE.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertTrue(truncated)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 2)
        
        # step should fail as we are trying to fill out halfway up the gore col
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            ROW.DVAPARA.value + 14 * COL.GORE.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertTrue(truncated)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 2)
        
        # step should fail as we are trying to fill out najava but we didn't announce
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            ROW.DVAPARA.value + 14 * COL.NAJAVA.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertTrue(truncated)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 2)
        
        # step should pass
        action = np.array([
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            ROW.DVAPARA.value + 14 * COL.SLOBODNO.value,
        ], dtype=np.int64)
        observation, reward, terminated, truncated, info = env.step(action)
        self.assertFalse(truncated)
        self.assertEqual(observation["turn_number"], 2)
        self.assertEqual(observation["roll_number"], 0)
        

In [4]:
unittest.main(argv=[""], verbosity=2, exit=False)

test_convert_row_col_fill (__main__.TestYambEnv.test_convert_row_col_fill) ... ok
test_convert_row_fill_col_fill (__main__.TestYambEnv.test_convert_row_fill_col_fill) ... ok
test_get_grid_square_value (__main__.TestYambEnv.test_get_grid_square_value) ... ok
test_get_next_dolje (__main__.TestYambEnv.test_get_next_dolje) ... ok
test_get_next_gore (__main__.TestYambEnv.test_get_next_gore) ... ok
test_get_score (__main__.TestYambEnv.test_get_score) ... ok
test_need_to_announce (__main__.TestYambEnv.test_need_to_announce) ... ok
test_step (__main__.TestYambEnv.test_step) ... ok
test_valid_announce_row (__main__.TestYambEnv.test_valid_announce_row) ... ok

----------------------------------------------------------------------
Ran 9 tests in 0.039s

OK


<unittest.main.TestProgram at 0x2852f3a58b0>

In [14]:
def process_text(s: str) -> NDArray[np.int64]:
    num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_to_fill, col_to_fill = 0,0,0,0,0,0,0,0,0,0
    
    for c in s:
        if c.isnumeric():
            if int(c)==1: num1s += 1
            if int(c)==2: num2s += 1
            if int(c)==3: num3s += 1
            if int(c)==4: num4s += 1
            if int(c)==5: num5s += 1
            if int(c)==6: num6s += 1
        else:
            break
            
    try:
        index_a = s.index("a")
        if s[index_a+1:index_a+3].isnumeric():
            announce_row = int(s[index_a+1:index_a+3])
        else:
            announce_row = int(s[index_a+1:index_a+2])
        announce = 1
    except ValueError as e:
        announce = 0
        announce_row = 0
        
    try:
        index_r = s.index("r")
        if s[index_r+1:index_r+3].isnumeric():
            row_to_fill = int(s[index_r+1:index_r+3])
        else:
            row_to_fill = int(s[index_r+1:index_r+2])
    except ValueError as e:
        row_to_fill = 0
        
    try:
        index_c = s.index("c")
        if s[index_c+1:index_c+3].isnumeric():
            col_to_fill = int(s[index_c+1:index_c+3])
        else:
            col_to_fill = int(s[index_c+1:index_c+2])
    except ValueError as e:
        col_to_fill = 0
    row_col_fill = YambEnv.convert_row_fill_col_fill(row_to_fill, col_to_fill)
    result = np.array([num1s, num2s, num3s, num4s, num5s, num6s, announce, announce_row, row_col_fill])
    return result

try:
    env = YambEnv()
    env.reset()
    env.render()
    run = True
    input_box = pygame.Rect(env.SCREEN_WIDTH//2, env.SCREEN_HEIGHT//2+100, 250, 40)
    input_box_color = (200, 200, 200)
    text = ""
    text_color = (0, 0, 0)
    font = pygame.freetype.Font(None, 20)
    while run:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                run = False
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_RETURN:
                    observation, reward, terminated, truncated, info = env.step(process_text(text))
                    env.render()
                    print(f"Reward:{reward}")
                    text = ""
                elif event.key == pygame.K_BACKSPACE:
                    text = text[:-1]
                else:
                    text += event.unicode
                    
            pygame.draw.rect(env.screen, input_box_color, input_box, 40)
            font.render_to(env.screen, (input_box.x+5, input_box.y+5), text, text_color)
            env.clock.tick(YambEnv.RENDER_FPS)
            pygame.display.flip()
            
    env.close()
except Exception as e:
    print(f"An error occurred: {e}")
    pygame.quit()

Reward:0
Reward:0
Reward:62
Reward:0
Reward:0
Reward:2
Reward:0
Reward:0
Reward:-1000


In [5]:
class FlattenGrid(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Dict({
            "turn_number": spaces.Discrete(len(ROW)*len(COL),start=0),
            "roll_number": spaces.Discrete(3,start=0),
            "grid": spaces.Box(low=-1, high=1, shape=(len(ROW)*len(COL),), dtype=float),
            "roll": spaces.Box(low=-1, high=1, shape=(6,), dtype=float),
            "announced": spaces.Discrete(2,start=0),
            "announced_row": spaces.Discrete(len(ROW), start=0),
        })

    def observation(self, obs):
        obs["grid"] = obs["grid"].flatten() / 145.0
        obs["roll"] = (obs["roll"] - 1.0) / 5.0 
        return obs
    
env = YambEnv()
env = FlattenGrid(env)
check_env(env)

# Training

In [7]:
seed = 1

model_name = f"klinac_{seed}"
model_path = f"models/{model_name}.zip"
log_path = f"logs/{model_name}_0"

# clean up log directory before training
if os.path.exists(log_path):
    for filename in os.listdir(log_path):
        file_path = os.path.join(log_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.remove(file_path)
                print(f"Deleted {file_path}")
            elif os.path.isdir(file_path):
                print(f"Skipping directory {file_path}")
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")
    os.rmdir(log_path)
    
# load model if it has already been trained
if os.path.exists(model_path):
    print("Loading previous model...")
    model = MaskablePPO.load(model_path)
    model.set_env(env)
else:
    print("Creating new model...")
    model = MaskablePPO("MultiInputPolicy", env, seed=seed, verbose=1, tensorboard_log="logs/") 
    
# train
while True:
    model.learn(total_timesteps=100_000, reset_num_timesteps=False, tb_log_name=model_name)
    model.save(model_path)

Deleted logs/klinac_1_0\events.out.tfevents.1720162020.DT-Aled.13704.0
Deleted logs/klinac_1_0\events.out.tfevents.1720162809.DT-Aled.13704.1
Deleted logs/klinac_1_0\events.out.tfevents.1720163232.DT-Aled.13704.2
Deleted logs/klinac_1_0\events.out.tfevents.1720163657.DT-Aled.13704.3
Deleted logs/klinac_1_0\events.out.tfevents.1720164080.DT-Aled.13704.4
Deleted logs/klinac_1_0\events.out.tfevents.1720164504.DT-Aled.13704.5
Deleted logs/klinac_1_0\events.out.tfevents.1720164930.DT-Aled.13704.6
Deleted logs/klinac_1_0\events.out.tfevents.1720165353.DT-Aled.13704.7
Deleted logs/klinac_1_0\events.out.tfevents.1720165776.DT-Aled.13704.8
Deleted logs/klinac_1_0\events.out.tfevents.1720166198.DT-Aled.13704.9
Deleted logs/klinac_1_0\events.out.tfevents.1720166623.DT-Aled.13704.10
Deleted logs/klinac_1_0\events.out.tfevents.1720167047.DT-Aled.13704.11
Deleted logs/klinac_1_0\events.out.tfevents.1720167471.DT-Aled.13704.12
Deleted logs/klinac_1_0\events.out.tfevents.1720167895.DT-Aled.13704.13
De

  logger.warn(


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 168      |
|    ep_rew_mean     | 748      |
| time/              |          |
|    fps             | 145      |
|    iterations      | 1        |
|    time_elapsed    | 14       |
|    total_timesteps | 54757376 |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 168         |
|    ep_rew_mean          | 754         |
| time/                   |             |
|    fps                  | 111         |
|    iterations           | 2           |
|    time_elapsed         | 36          |
|    total_timesteps      | 54759424    |
| train/                  |             |
|    approx_kl            | 0.009520028 |
|    clip_fraction        | 0.065       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.337      |
|    explained_variance   | 0.958       |
|    learning_rate        | 0.

KeyboardInterrupt: 

In [8]:
import torch
print(torch.cuda.is_available())


False


# Load model and test

In [7]:
# del model
model = MaskablePPO.load("models/klinac_1")

try:
    env = YambEnv()
    env = FlattenGrid(env)
    obs, _ = env.reset()
    terminated, truncated = False, False
    env.render()
    while not (terminated or truncated):
        time.sleep(1)
        action_masks = get_action_masks(env)
        action, _states = model.predict(obs, action_masks=action_masks)
        obs, reward, terminated, truncated, info = env.step(action)
        env.render()
        
    # This will pause the notebook and wait for the user to press Enter
    input("Press Enter to continue...")
    env.close()
except Exception as e:
    print(f"An error occurred: {e}")
    pygame.quit()

Press Enter to continue...


In [20]:
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
print(f"Mean reward: {mean_reward}, Std reward: {std_reward}")

Mean reward: 274.85, Std reward: 77.888558209791


# Reward shaping

In [6]:
class AddStepsToReward(RewardWrapper):
    def __init__(self, env):
        super().__init__(env)

    def reward(self, reward):
        return 100*(self.unwrapped.turn_number*3 + self.unwrapped.roll_number) + reward / 1000.0
    
env = AddStepsToReward(env)