In [3]:
import unittest
import numpy as np
from YambEnv import YambEnv, ROW, COL, Action

class TestYambEnv(unittest.TestCase):
    def test_get_next_dolje(self):
        env = YambEnv()
        
        # start of game nothing is filled out
        self.assertEqual(ROW.ONES, env.get_next_dolje())
        
        
        # when we add stuff to other columns nothing should change
        env.grid[ROW.ONES.value, COL.GORE.value] = 1
        env.grid[ROW.ONES.value, COL.SLOBODNO.value] = 1
        self.assertEqual(ROW.ONES, env.get_next_dolje())
        
        # when we fill out the rows in order, check the function works as expected
        rows = list(ROW)
        for row in rows[:-1]:
            env.grid[row.value, COL.DOLJE.value] = 0
            self.assertEqual(rows[row.value+1], env.get_next_dolje())
        
        # once we've filled everything out check that this returns nan
        env.grid[rows[-1].value, COL.DOLJE.value] = 0
        self.assertTrue(np.isnan(env.get_next_dolje()))
        
    def test_get_next_gore(self):
        env = YambEnv()
        
        # start of game nothing is filled out
        self.assertEqual(ROW.YAMB, env.get_next_gore())
        
        
        # when we add stuff to other columns nothing should change
        env.grid[ROW.YAMB.value, COL.DOLJE.value] = 1
        env.grid[ROW.YAMB.value, COL.SLOBODNO.value] = 1
        self.assertEqual(ROW.YAMB, env.get_next_gore())
        
        # when we fill out the rows in order, check the function works as expected
        rows = list(ROW)
        for row in reversed(rows[1:]):
            env.grid[row.value, COL.GORE.value] = 0
            self.assertEqual(rows[row.value-1], env.get_next_gore())
        
        # once we've filled everything out check that this returns nan
        env.grid[rows[0].value, COL.GORE.value] = 0
        self.assertTrue(np.isnan(env.get_next_gore()))
        
    def test_get_score(self):
        env = YambEnv()
        self.assertEqual(0, env.get_score())
        
        env.grid = np.array(
        [[np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan]]
        )
        self.assertEqual(0, env.get_score())
        
        env.grid = np.array(
        [[1     , np.nan, 2     , np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, 10    , np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, 55    , np.nan, np.nan]]
        )
        self.assertEqual(1+55+2, env.get_score())
        
        env.grid = np.array(
        [[1     , np.nan, 2     , np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, 20    , np.nan],
         [np.nan, np.nan, 10    , np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, 55    , np.nan, np.nan]]
        )
        self.assertEqual(1+55+2+2*(20-10), env.get_score())
        
        env.grid = np.array(
        [[1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],
         [1     , 1     , 1     , 1     ],]
        )
        self.assertEqual(6*4 + 6*4, env.get_score())
        
        env.grid = np.array(
        [[np.nan, np.nan, np.nan, 2     ],
         [np.nan, np.nan, np.nan, 4     ],
         [np.nan, np.nan, np.nan, 3     ],
         [np.nan, np.nan, np.nan, 12    ],
         [np.nan, np.nan, np.nan, 15    ],
         [np.nan, np.nan, np.nan, 24    ],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan]]
        )
        self.assertEqual(90, env.get_score())
        
        env.grid = np.array(
        [[1     , 0     , 6     , 0     ],
         [4     , 2     , 12    , 0     ],
         [3     , 3     , 15    , 0     ],
         [12    , 16    , 20    , 0     ],
         [15    , 20    , 25    , 0     ],
         [24    , 6     , 30    , 0     ],
         [20    , 30    , 10    , 0     ],
         [10    , 5     , 5     , 0     ],
         [16    , 0     , 20    , 0     ],
         [33    , 0     , 0     , 0     ],
         [45    , 0     , 0     , 0     ],
         [55    , 0     , 0     , 0     ],
         [54    , 0     , 0     , 0     ],
         [65    , 0     , 0     , 0     ]]
        )
        self.assertEqual(337+47+188+0, env.get_score())
        
    def test_valid_announce_row(self):
        env = YambEnv()
        env.grid = np.array(
        [[np.nan, np.nan, np.nan, 2     ],
         [np.nan, np.nan, np.nan, 4     ],
         [np.nan, np.nan, np.nan, 3     ],
         [np.nan, np.nan, np.nan, 12    ],
         [np.nan, np.nan, np.nan, 15    ],
         [np.nan, np.nan, np.nan, 24    ],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan],
         [np.nan, np.nan, np.nan, np.nan]]
        )
        self.assertFalse(env.valid_announce_row(14))
        self.assertFalse(env.valid_announce_row(ROW.SIXES))
        self.assertTrue(env.valid_announce_row(ROW.MAX))
        
    def test_get_grid_square_value(self):
        self.assertEqual(YambEnv.get_grid_square_value(ROW.ONES, np.array([1,1,1,1,1,0])), 1)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TWOS, np.array([0,2,1,1,1,0])), 2)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.THREES, np.array([5,0,0,0,0,0])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FOURS, np.array([4,0,0,1,0,0])), 4)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FIVES, np.array([0,0,0,0,5,0])), 25)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SIXES, np.array([0,1,0,0,0,3])), 18)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.MAX, np.array([1,1,1,1,1,0])), 15)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.MIN, np.array([4,1,0,0,0,0])), 6)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.DVAPARA, np.array([4,1,0,0,0,0])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.DVAPARA, np.array([2,3,0,0,0,0])), 10+6)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.DVAPARA, np.array([1,0,0,0,2,2])), 10+22)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TRIS, np.array([4,1,0,0,0,0])), 20+3)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TRIS, np.array([2,3,0,0,0,0])), 20+6)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.TRIS, np.array([1,0,0,0,2,2])), 0)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA, np.array([1,1,1,1,1,0])), 45)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA, np.array([1,1,1,1,0,1])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA, np.array([0,1,1,1,1,1])), 50)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.SKALA, np.array([0,1,2,0,1,1])), 0)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL, np.array([0,0,0,0,0,5])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL, np.array([0,0,0,0,2,3])), 40 + 28)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL, np.array([0,0,0,1,2,2])), 0)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.FULL, np.array([0,0,0,0,3,2])), 40 + 27)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.POKER, np.array([0,4,1,0,0,0])), 50+8)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.POKER, np.array([0,5,0,0,0,0])), 50+8)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.POKER, np.array([0,0,0,0,2,3])), 0)
        
        self.assertEqual(YambEnv.get_grid_square_value(ROW.YAMB, np.array([0,0,0,0,0,5])), 60+30)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.YAMB, np.array([0,5,0,0,0,0])), 60+10)
        self.assertEqual(YambEnv.get_grid_square_value(ROW.YAMB, np.array([0,0,0,0,1,4])), 0)
        
        
    def test_step(self):
        env = YambEnv()
        observation = env.reset()
        
        # step should fail because trying to keep more dice than we have
        keep = observation["roll"]
        keep[0] += 1
        action = Action(roll_number=1, keep=keep)
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(reward, -1000)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should pass
        keep[0] -= 1
        action = Action(roll_number=1, keep=keep, announce=True, announce_row=ROW.YAMB)
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 2)
        np.testing.assert_array_equal(observation["roll"], keep)  # should be the same because we kept everything
        self.assertEqual(observation["score"], 0)
        self.assertTrue(np.isnan(observation["grid_square_value"]))
        self.assertEqual(reward, 0)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, False)
        last_roll = observation["roll"]
        
        # step should fail because roll number of action does not match roll number of game state
        action = Action(roll_number=1, keep=keep)
        with self.assertRaises(ValueError):
            observation, reward, terminated, truncated, truncation_reason = env.step(action)
        
        # step should pass
        action = Action(roll_number=2, keep=np.zeros(6))
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(observation["turn_number"], 1)
        self.assertEqual(observation["roll_number"], 3)
        with np.testing.assert_raises(AssertionError):
            np.testing.assert_array_equal(observation["roll"], last_roll) # roll should be different because we didn't keep anything
        self.assertEqual(observation["score"], 0)
        self.assertEqual(reward, 0)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, False)
        last_roll = observation["roll"]
        
        # step should fail because we're trying to fill out column which isn't najava
        action = Action(roll_number=3, keep=np.zeros(6), row_to_fill=ROW.YAMB, col_to_fill=COL.GORE)
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(reward, -1000)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should fail because we're trying to fill out row which isn't the one announced
        action = Action(roll_number=3, keep=np.zeros(6), row_to_fill=ROW.ONES, col_to_fill=COL.NAJAVA)
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(reward, -1000)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should pass - note that keep is completely ignored
        action = Action(roll_number=3, keep=np.full(77, 7), row_to_fill=ROW.YAMB, col_to_fill=COL.NAJAVA)
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(observation["turn_number"], 2)
        self.assertEqual(observation["roll_number"], 1)
        with np.testing.assert_raises(AssertionError):
            np.testing.assert_array_equal(observation["roll"], last_roll) # since we're moving onto next turn roll should differ
        self.assertEqual(observation["score"], 0) # score should be zero because we probs didn't get a yamb
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, False)
        self.assertEqual(observation["grid"][ROW.YAMB.value, COL.NAJAVA.value], observation["score"])
        last_roll = observation["roll"]
        
        # try to announce row which has already been announced - should fail
        action = Action(roll_number=1, keep=observation["roll"], announce=True, announce_row=ROW.YAMB)
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        np.testing.assert_array_equal(observation["roll"], last_roll)  # should be the same because we failed action
        self.assertEqual(reward, -1000)
        self.assertEqual(terminated, False)
        self.assertEqual(truncated, True)
        
        # step should pass
        action = Action(roll_number=1, keep=np.zeros(6))
        observation, reward, terminated, truncated, truncation_reason = env.step(action)
        self.assertEqual(observation["turn_number"], 2)
        self.assertEqual(observation["roll_number"], 2)
        

In [4]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_get_grid_square_value (__main__.TestYambEnv.test_get_grid_square_value) ... ok
test_get_next_dolje (__main__.TestYambEnv.test_get_next_dolje) ... ok
test_get_next_gore (__main__.TestYambEnv.test_get_next_gore) ... ok
test_get_score (__main__.TestYambEnv.test_get_score) ... ok
test_step (__main__.TestYambEnv.test_step) ... ok
test_valid_announce_row (__main__.TestYambEnv.test_valid_announce_row) ... ok

----------------------------------------------------------------------
Ran 6 tests in 0.025s

OK


<unittest.main.TestProgram at 0x15cd8fe0b90>