In [7]:
from copy import deepcopy

# @dataclass(frozen=True)
# class GameStateTree():
#     def __init__(self, game, node_lookup=None):
#         self._game_state = self.game.get_state()
#         self.init_children(game)
#         self.calc_opt_winner()

#     def __getattr__(self, name):
#         """GameStateTree inherits from self.game_state"""
#         return getattr(self._game_state, name)
    
#     def init_children(self, game, node_lookup=None):
#         if not node_lookup:
#             node_lookup = {}
#         node_lookup[self._game_state] = self
        
#         self.children = []
#         for move in self.legal_moves:
#             game_copy = deepcopy(game)
#             game_copy.move(move)
#             if game_copy.get_state() in node_lookup:
#                 self.children.append(node_lookup[game_copy.get_state()])
#             else:
#                 self.children.append(GameTree(game_copy, node_lookup))
    
#     def get_steps_to_outcomes(self):
#         if len(self.children) == 0:
#             return {self.get_winner()*self.get_next_player(): [0]}
#         steps_to_outcomes = {}
#         for child in self.children:
#             child_steps = child.get_steps_to_outcomes()
#             for k, v in child_steps.items():
#                 player_change = child.get_next_player()*self.get_next_player()
#                 steps_to_outcomes[k*player_change] = steps_to_outcomes.get(k*player_change, []) + [steps+1 for steps in v]
#         return steps_to_outcomes

#     def calc_opt_winner(self):
#         if len(self.children) == 0:
#             self.opt_winner = self.get_winner()
#         else:
#             opt_winners = np.array([child.opt_winner for child in self.children])
#             player_turn = self.get_next_player()
#             if np.any(opt_winners == player_turn):
#                 self.opt_winner = player_turn
#             elif np.all(opt_winners == -1*player_turn):
#                 self.opt_winner = -1*player_turn
#             else:
#                 self.opt_winner = 0

#     def get_opt_moves(self):
#         outcomes = self.get_next_player()*np.array([child.opt_winner for child in self.children])
#         max_outcome = np.max(outcomes)
#         return [self.moves[idx] for idx in np.arange(len(outcomes))[outcomes == max_outcome]]
        
#     def critical_point_type(self):
#         opt_winners = np.array([child.opt_winner for child in self.children])
#         player_turn = self.get_next_player()
#         if len(np.unique(opt_winners)) <= 1:
#             return 0 # outcome is guaranteed
#         elif np.any(opt_winners == player_turn):
#             return 1 # there are paths for current player to win or draw/lose
#         return -1 # there are paths for current player to lose or draw

#     def critical_point_difficulty(self):
#         critical_pt_type = self.critical_point_type()
#         if critical_pt_type == 0:
#             return None
#         steps_to_outcomes = self.get_steps_to_outcomes()
#         if critical_pt_type == 1:
#             return np.min(steps_to_outcomes[1])
#         return np.max(steps_to_outcomes[-1])
    
#     def visualize(self, node_lookup=None, id_str="0", viz_fxn=lambda x: x._game_state.visualize()):
#         if not node_lookup:
#             node_lookup = {}
#         print("NODE ID", id_str)
#         viz_fxn(self)
#         print()
#         node_lookup[self.state] = (id_str, self)
#         for child_idx, child in enumerate(self.children):
#             print(f"CHILD OF {id_str}")
#             if child.state in node_lookup:
#                 node_id, node = node_lookup[child.state]
#                 print("NODE ID", node_id)
#                 viz_fxn(node)
#             else:
#                 child.visualize(node_lookup, f"{id_str}.{child_idx}", viz_fxn=viz_fxn)
#             print()

#     def mapfilter_traverse(self, already_seen=None, filter_fxn=lambda x: True, map_fxn=lambda x: x):
#         if not already_seen:
#             already_seen = set()
#         already_seen.add(id(self))
#         filtered = []
#         if filter_fxn(self):
#             filtered.append(map_fxn(self))
#         for child in self.children:
#             if not id(child) in already_seen:
#                  filtered += child.mapfilter_traverse(already_seen=already_seen, filter_fxn=filter_fxn, map_fxn=map_fxn)
#         return filtered

#     def get_size(self):
#         return len(self.mapfilter_traverse(filter_fxn=lambda x: True, map_fxn=lambda x: 1))

In [8]:
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np

def to_tuple(lst):
    return tuple(to_tuple(i) if isinstance(i, list) else i for i in lst)

class GameBase(ABC):
    @abstractmethod
    def get_next_player(self):
        pass

    @abstractmethod
    def get_legal_moves(self):
        pass

    @abstractmethod
    def get_winner(self):
        """
        returns None if still in play, 0 if draw, 1 if player 1 wins, -1 if player 1 loses
        """
        pass

class GameState(GameBase, ABC):
    def init_properties(self):
        self.next_player = self.get_next_player()
        self.legal_moves = self.get_legal_moves()
        self.winner = self.get_winner()

    @abstractmethod
    def __eq__(self, other):
         pass

    @abstractmethod
    def __hash__(self):
         pass
    
    @abstractmethod
    def __str__(self):
        pass

    @abstractmethod
    def export_for_prompt(self):
        pass
         
class Game(GameBase, ABC):
    @abstractmethod
    def move(self, move):
        pass

    @abstractmethod
    def get_state(self):
        """
        returns hashable version of game state
        """
        pass

    def __eq__(self, other):
        return self.get_state() == other.get_state()

    def __str__(self):
        return self.get_state().__str__()


class TicTacToeBase(GameBase):
    def __init__(self, dim=3, init_board=None):
        self.turn_idx = 0
        self.dim = dim
    
        if init_board is None:
            self.board = np.array([[0]*self.dim]*self.dim)
        else:
            self.board = np.array(init_board)

    def get_next_player(self):
        # 1 always starts
        return 1 if self.turn_idx % 2 == 0 else -1

    def get_legal_moves(self):
        if self.get_winner():
            return []
        else:
            X, Y = np.mgrid[:self.dim, :self.dim]
            unplayed_squares = self.board == 0
            positions = np.array([X[unplayed_squares], Y[unplayed_squares]]).T
            return [tuple(pos) for pos in positions.tolist()]

    def get_winner(self):
        col_sums = np.sum(self.board, axis=0)
        row_sums = np.sum(self.board, axis=1)
        major_sum = np.sum(self.board[np.arange(self.dim), np.arange(self.dim)], keepdims=True)
        minor_sum = np.sum(self.board[np.arange(self.dim), np.arange(self.dim)[::-1]], keepdims=True)
        all_sums = np.concatenate([col_sums, row_sums, major_sum, minor_sum])
        
        if np.any(all_sums == self.dim):
            return 1
        elif np.any(all_sums == -self.dim):
            return -1
        elif np.any(self.board == 0):
            return None
        else:
            return 0

class TicTacToeState(GameState, TicTacToeBase):
    def __init__(self, board, turn_idx, dim):
        ## state
        self.board = board
        self.turn_idx = turn_idx
        self.dim = dim
        
        ## properties
        self.init_properties()

    def __eq__(self, other):
        return to_tuple(self.board.tolist()) == to_tuple(other.board.tolist())

    def __hash__(self):
        return hash(to_tuple(self.board.tolist()))

    def __str__(self):
        row_strs = []
        for row in self.board:
            row_str = "["
            for pos in row:
                pos_str = str(pos)
                if len(pos_str) == 1:
                    pos_str = " "+pos_str
                pos_str = "  " + pos_str
                row_str += pos_str
            row_str += "  ]"
            row_strs.append(row_str)
        return "\n".join(row_strs)

    def export_for_prompt(self):
        return self.board.tolist(), self.next_player

class TicTacToe(Game, TicTacToeBase):
    def move(self, pos):
        if self.get_winner():
            return
        assert self.board[*pos] == 0
        self.board[*pos] = self.get_next_player()
        self.turn_idx += 1
    
    def get_state(self):
        return TicTacToeState(self.board, self.turn_idx, self.dim)

In [9]:
class GameStateTree():
    def __init__(self, game, node_lookup=None, root=False):
        self._game_state = game.get_state()
        self.init_children(game, node_lookup)
        self.size = self.calc_size()
        if root:
            print(self.size)
        self.optimal_outcome, self.move_to_outcome, self.optimal_num_turns = self.calc_optimality()
        self.win_critical, self.lose_critical = self.calc_critical_point_types()
        self.outcomes_to_num_moves = self.calc_num_moves_to_outcomes()
        self.win_difficulty, self.lose_difficulty = self.calc_critical_point_difficulty()

    def __getattr__(self, name):
        """GameStateTree inherits from self.game_state"""
        return getattr(self._game_state, name)

    def init_children(self, game, node_lookup=None):
        if not node_lookup:
            node_lookup = {}
        node_lookup[game.get_state()] = self
        
        self.children = []
        for move in self.legal_moves:
            game_copy = deepcopy(game)
            game_copy.move(move)
            if game_copy.get_state() in node_lookup:
                self.children.append(node_lookup[game_copy.get_state()])
            else:
                self.children.append(GameStateTree(game_copy, node_lookup=node_lookup))

    def calc_optimality(self):
        """Assuming both players play optimally, return the optimal outcome: 1 = current player wins, -1 = current player loses, 0 = draw"""
        if len(self.children) == 0:
            return self.winner*self.next_player, {}, 0
        else:
            child_outcomes = np.array([child.optimal_outcome if child.next_player==self.next_player else -1*child.optimal_outcome for child in self.children])
            move_to_outcome = dict([(move, outcome) for move, outcome in zip(self.legal_moves, child_outcomes)])
            optimal_outcome = np.max(child_outcomes)
            if optimal_outcome >= 0:
                optimal_num_turns = 1+np.min([child.optimal_num_turns for child, child_outcome in zip(self.children, child_outcomes) if child_outcome == optimal_outcome])
            else:
                optimal_num_turns = 1+np.max([child.optimal_num_turns for child, child_outcome in zip(self.children, child_outcomes)  if child_outcome == optimal_outcome])
            return optimal_outcome, move_to_outcome, optimal_num_turns

    def mapfilter_traverse(self, already_seen=None, filter_fxn=lambda x: True, map_fxn=lambda x: x):
        if not already_seen:
            already_seen = set()
        already_seen.add(id(self))
        filtered = []
        if filter_fxn(self):
            filtered.append(map_fxn(self))
        for child in self.children:
            if not id(child) in already_seen:
                 filtered += child.mapfilter_traverse(already_seen=already_seen, filter_fxn=filter_fxn, map_fxn=map_fxn)
        return filtered

    def calc_size(self):
        return len(self.mapfilter_traverse(filter_fxn=lambda x: True, map_fxn=lambda x: 1))

    def calc_critical_point_types(self):
        """
        Returns if this game state is a critical point for winning and if it is a critical point for losing
        A critical point for winning is a point at which there is at least one move that guarantees victory under optimal play and at least one that does not
        A critical point for winning is a point at which there is at least one move that guarantees loss under optimal play and at least one that does not
        """
        child_outcomes = np.array([child.optimal_outcome if child.next_player==self.next_player else -1*child.optimal_outcome for child in self.children])
        if len(np.unique(child_outcomes)) <= 1: # outcome is guaranteed
            return False, False 
        win_critical = np.any(child_outcomes == 1) # there is at least one path that guarantees victory (with optimal play) and one that does not
        lose_critical = np.any(child_outcomes == -1) # there is at least one path that guarantees loss (with optimal play) and one that does not
        return win_critical, lose_critical

    def calc_num_moves_to_outcomes(self):
        """
        r\Returns a dictionary that maps each possible outcome (1 - win, 0 - draw, -1 - loss) 
        to a list of the number of moves in each possible path to achieve that outcome
        """
        if len(self.children) == 0:
            return {self.winner*self.next_player: [0]}
        outcomes_to_num_moves = {}
        for child in self.children:
            child_outcomes_to_moves = child.outcomes_to_num_moves
            for child_outcome, child_moves in child_outcomes_to_moves.items():
                player_change = child.next_player*self.next_player
                outcomes_to_num_moves[child_outcome*player_change] = outcomes_to_num_moves.get(child_outcome*player_change, []) + [steps+1 for steps in child_moves]
        return outcomes_to_num_moves

    def calc_critical_point_difficulty(self):
        """
        If this game state is a critical point return the difficulty of winning (if win_critical) and losing (if lose_critical)
        Difficulty of winning is defined as the minimum number of turns needed to guarantee a win (under optimal play)
        Difficulty of avoiding the loss is defined as the maximum number of turns needed for the opponent to win
        among any of the moves that lead to a guaranteed loss (under optimal play)
        """
        win_difficulty, lose_difficulty = None, None
        if self.win_critical:
            win_difficulty = 1+np.min([child.optimal_num_turns for child in self.children if child.optimal_outcome*child.next_player==self.next_player])
        if self.lose_critical:
            lose_difficulty = 1+np.max([child.optimal_num_turns for child in self.children if -1*child.optimal_outcome*child.next_player==self.next_player])
        return win_difficulty, lose_difficulty
    
    def visualize(self, node_lookup=None, id_str="0", viz_fxn=lambda x: x._game_state.visualize()):
        if not node_lookup:
            node_lookup = {}
        print("NODE ID", id_str)
        viz_fxn(self)
        print()
        node_lookup[self] = (id_str, self)
        for child_idx, child in enumerate(self.children):
            print(f"CHILD OF {id_str}")
            if child in node_lookup:
                node_id, node = node_lookup[child]
                print("NODE ID", node_id)
                viz_fxn(node)
            else:
                child.visualize(node_lookup, f"{id_str}.{child_idx}", viz_fxn=viz_fxn)
            print()

In [15]:
import json

def json_save(fpath, data):
    with open(fpath,'w') as f:
        json.dump(data, f, cls=NpEncoder)

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.bool_):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
        
def get_game_states_for_prompts(game, outpath=None):
    tree = GameStateTree(game)
    def tree_to_prompt_state(tree):
        return {
            "board": tree.board.tolist(),
            "next_player": tree.next_player,
            "moves": [
                {
                    'move': move,
                    'outcome': tree.move_to_outcome[move],
                    'optimal_num_turns_after_move': child.optimal_num_turns,
                    'tree_size_after_move': child.size,
                }
                for move, child in zip(tree.legal_moves, tree.children)
            ],
            "win_critical": tree.win_critical,
            "lose_critical": tree.lose_critical,
            "tree_size": tree.size
        }
    states = tree.mapfilter_traverse(filter_fxn=lambda x: x.win_critical != x.lose_critical, map_fxn=tree_to_prompt_state)
    json_save(outpath, states)
    return states

In [16]:
game = TicTacToe(dim=3, )#init_board=[[0,1,0],[-1,0,0],[0,0,0]]) #init_board=[[1,0],[0,0]] init_board=[[1,-1,0],[-1,-1,0],[1,1,-1]]
tree = GameStateTree(game, root=True)

5478


In [17]:
savepath = "../data/critical_points/tic-tac-toe.json"
prompt_states = get_game_states_for_prompts(game, savepath)

In [724]:
print(len(prompt_states))
print(prompt_states[345])

1247
{'board': [[1, 0, 0], [0, -1, 0], [0, 1, -1]], 'next_player': 1, 'moves': [{'move': (0, 1), 'outcome': -1, 'optimal_num_turns_after_move': 3, 'tree_size_after_move': 29}, {'move': (0, 2), 'outcome': 0, 'optimal_num_turns_after_move': 4, 'tree_size_after_move': 31}, {'move': (1, 0), 'outcome': 0, 'optimal_num_turns_after_move': 4, 'tree_size_after_move': 30}, {'move': (1, 2), 'outcome': 0, 'optimal_num_turns_after_move': 4, 'tree_size_after_move': 34}, {'move': (2, 0), 'outcome': 0, 'optimal_num_turns_after_move': 4, 'tree_size_after_move': 30}], 'win_critical': False, 'lose_critical': True, 'tree_size': 87}


In [707]:
# print(tree)
# print("\n".join([str(x) for x in tree.board]))
# print(tree.next_player)
# print(tree.optimal_outcome)
# print(tree.move_to_outcome)
# print(tree.optimal_num_turns)
# print(tree.win_critical, tree.lose_critical)
# print(tree.outcomes_to_num_moves)
# print(tree.win_difficulty, tree.lose_difficulty)

In [708]:
def viz_fxn(x):
    print("CURRENT PLAYER", x.get_next_player(), "OPT WINNER", x.opt_winner, "CRITICAL PT TYPE", x.critical_point_type())
    print("CHILD OPT WINNERS", [child.opt_winner for child in x.children])
    print(x.get_winner())
    x.visualize_state()

In [709]:
critical_pts, win_critical, lose_critical, win_difficulty, lose_difficulty = zip(*tree.mapfilter_traverse(filter_fxn=lambda x: x.win_critical or x.lose_critical, map_fxn=lambda x: (x, x.win_critical, x.lose_critical, x.win_difficulty, x.lose_difficulty)))
win_critical, lose_critical, win_difficulty, lose_difficulty = np.array(win_critical), np.array(lose_critical), np.array(win_difficulty), np.array(lose_difficulty) 

In [710]:
print(len(critical_pts))

3191


In [711]:
print(np.unique(win_difficulty[win_critical], return_counts=True))
print(np.unique(lose_difficulty[lose_critical], return_counts=True))

(array([1, 3, 5], dtype=object), array([1904,  356,  120]))
(array([2, 4, 6], dtype=object), array([2316,  346,   93]))


In [712]:
mask = ~(win_critical & lose_critical)
print(np.unique(win_difficulty[win_critical & mask], return_counts=True))
print(np.unique(lose_difficulty[lose_critical & mask], return_counts=True))

(array([1, 3, 5], dtype=object), array([324, 100,  12]))
(array([2, 4, 6], dtype=object), array([604, 162,  45]))


In [713]:
for difficulty in np.unique(win_difficulty[win_critical & mask]):
    difficulty_mask = win_difficulty == difficulty
    critical_point_idxs = np.arange(len(critical_pts))[difficulty_mask & win_critical & mask]
    np.random.seed(12)
    np.random.shuffle(critical_point_idxs)
    print("DIFFICULTY", difficulty, "COUNT", len(critical_point_idxs))
    print()

    for idx in critical_point_idxs[:5]:
        pt = critical_pts[idx]
        print("NEXT PLAYER", pt.next_player)
        print(pt._game_state)
        print()
        # print(pt.optimal_outcome)
        print(pt.move_to_outcome)
        print("Optimal num turns", pt.optimal_num_turns)
        # print(pt.win_critical, pt.lose_critical)
        # print(pt.outcomes_to_num_moves)
        print(pt.win_difficulty, pt.lose_difficulty)
        print()
    

DIFFICULTY 1 COUNT 324

NEXT PLAYER 1
[   1  -1   1  ]
[  -1   1   0  ]
[   0   0  -1  ]

{(1, 2): 0, (2, 0): 1, (2, 1): 0}
Optimal num turns 1
1 None

NEXT PLAYER 1
[   1  -1  -1  ]
[   0   0   1  ]
[   1   0  -1  ]

{(1, 0): 1, (1, 1): 0, (2, 1): 0}
Optimal num turns 1
1 None

NEXT PLAYER 1
[  -1   0   1  ]
[   1   0   0  ]
[  -1  -1   1  ]

{(0, 1): 0, (1, 1): 0, (1, 2): 1}
Optimal num turns 1
1 None

NEXT PLAYER 1
[   1  -1   1  ]
[   1   0  -1  ]
[   0   0  -1  ]

{(1, 1): 0, (2, 0): 1, (2, 1): 0}
Optimal num turns 1
1 None

NEXT PLAYER -1
[   0  -1   1  ]
[   1  -1   1  ]
[   0   1  -1  ]

{(0, 0): 1, (2, 0): 0}
Optimal num turns 1
1 None

DIFFICULTY 3 COUNT 100

NEXT PLAYER -1
[   1   1  -1  ]
[   0   0   1  ]
[   0   0  -1  ]

{(1, 0): 0, (1, 1): 0, (2, 0): 1, (2, 1): 0}
Optimal num turns 3
3 None

NEXT PLAYER 1
[   0   1  -1  ]
[   0   0  -1  ]
[   1  -1   1  ]

{(0, 0): 1, (1, 0): 0, (1, 1): 0}
Optimal num turns 3
3 None

NEXT PLAYER 1
[   0   0  -1  ]
[   0   1   0  ]
[   1 

In [715]:
for difficulty in np.unique(lose_difficulty[lose_critical & mask]):
    difficulty_mask = lose_difficulty == difficulty
    critical_point_idxs = np.arange(len(critical_pts))[difficulty_mask & lose_critical & mask]
    np.random.seed(12)
    np.random.shuffle(critical_point_idxs)
    print("DIFFICULTY", difficulty, "COUNT", len(critical_point_idxs))
    print()

    for idx in critical_point_idxs[:5]:
        pt = critical_pts[idx]
        print("NEXT PLAYER", pt.next_player)
        print(pt._game_state)
        print()
        # print(pt.optimal_outcome)
        print(pt.move_to_outcome)
        print("Optimal num turns", pt.optimal_num_turns)
        # print(pt.win_critical, pt.lose_critical)
        # print(pt.outcomes_to_num_moves)
        print(pt.win_difficulty, pt.lose_difficulty)
        print()
    

DIFFICULTY 2 COUNT 604

NEXT PLAYER 1
[  -1   1  -1  ]
[   0  -1   1  ]
[   0   0   1  ]

{(1, 0): -1, (2, 0): 0, (2, 1): -1}
Optimal num turns 3
None 2

NEXT PLAYER -1
[   0  -1   1  ]
[   0   0   0  ]
[  -1   1   1  ]

{(0, 0): -1, (1, 0): -1, (1, 1): -1, (1, 2): 0}
Optimal num turns 4
None 2

NEXT PLAYER -1
[  -1   0   1  ]
[   1   0   0  ]
[   0  -1   1  ]

{(0, 1): -1, (1, 1): -1, (1, 2): 0, (2, 0): -1}
Optimal num turns 4
None 2

NEXT PLAYER -1
[   1  -1   1  ]
[   0  -1   1  ]
[  -1   1   0  ]

{(1, 0): -1, (2, 2): 0}
Optimal num turns 2
None 2

NEXT PLAYER -1
[   0   1   1  ]
[   1  -1  -1  ]
[  -1   0   1  ]

{(0, 0): 0, (2, 1): -1}
Optimal num turns 2
None 2

DIFFICULTY 4 COUNT 162

NEXT PLAYER -1
[   1  -1   0  ]
[   1   0   0  ]
[  -1   1   0  ]

{(0, 2): -1, (1, 1): 0, (1, 2): 0, (2, 2): 0}
Optimal num turns 4
None 4

NEXT PLAYER 1
[   0   0   0  ]
[  -1   1   1  ]
[   0   0  -1  ]

{(0, 0): 0, (0, 1): 0, (0, 2): -1, (2, 0): 0, (2, 1): 0}
Optimal num turns 5
None 4

NEXT P

In [607]:
for critical_score

pts = [idx for idx in np.arange(len(critical_pts))[(win_difficulty == 3) & ~lose_critical]]
for idx in pts:
    print(win_critical[idx], win_difficulty[idx])
    print(lose_critical[idx], lose_difficulty[idx])
    print()
    print(critical_pt.next_player)
    print("\n".join([str(x) for x in critical_pt.board]))
    print()
    print(critical_pt.move_to_outcome)
    for child in critical_pt.children:
        print(type(child))
        print(child.optimal_num_turns)
    print()
    print()
    break

In [608]:
print()




In [609]:
print(np.unique(lose_difficulty[lose_critical], return_counts=True))

(array([1], dtype=object), array([2755]))


In [606]:
# print(pos_critical_pts)
np.random.seed(12)
for idx in np.random.choice(np.arange(len(critical_pts)), (10,)):
    critical_pt = critical_pts[idx]
    print(win_critical[idx], win_difficulty[idx])
    print(lose_critical[idx], lose_difficulty[idx])
    print()
    print(critical_pt.next_player)
    print("\n".join([str(x) for x in critical_pt.board]))
    print()
    print(critical_pt.move_to_outcome)
    print()
    print()
    break

True 1
True 1

1
[-1  0  1]
[ 1 -1  0]
[ 0 -1  1]

{(0, 1): 0, (1, 2): 1, (2, 0): -1}




In [361]:
count = tree.count_unique_nodes()
print(count)

5478


In [363]:
# tree.visualize(viz_fxn=viz_fxn)

In [85]:
game = TicTacToe(one_starts=False, dim=2, init_state=[[1,0],[-1,0]])

legal_moves = game.get_legal_moves()
while not len(legal_moves) == 0:
    move_idx = np.random.choice(np.arange(len(legal_moves)))
    print("ACTION", legal_moves[move_idx])
    game.move(legal_moves[move_idx])
    print("STATE")
    game.visualize_state()
    print("WINNER?", game.get_winner())
    print("NEXT MOVES", game.get_legal_moves())
    legal_moves = game.get_legal_moves()

ACTION [1, 1]
STATE
[1 0]
[-1 -1]
WINNER? -1
NEXT MOVES []
