# Minimax Tree Search with Policy/Value Functions

In [77]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
from typing import List
import numpy as np
from domoku.minimax import MinimaxSearch

### Test Stuff

In [52]:
class Policy:
    """
    A trivial policy for testing
    """

    @staticmethod
    def sample(state, _n=0) -> List[int]:
        return list(range(len(state.successors)))

    @staticmethod
    def eval(state):
        return state.value

    @staticmethod
    def is_terminated(state):
        return state.terminal


class State:
    def __init__(self, terminal: bool = False, successors=None, value=0):
        self.terminal = terminal
        self.successors = successors
        self.value = value

    def move(self, move: int) -> 'State':
        assert move in range(len(self.successors)), "Move can only be one of 0 or 1 or 2"
        return self.successors[move]

    def __str__(self):
        if self.terminal:
            return str(self.value)
        else:
            return f"[{str(self.successors)}]"

    __repr__ = __str__


def create_state(depth, min_depth=0, num_successors=2):
    if depth == 0 or (np.random.choice([False]*6+[True]) and min_depth > depth):
        return State(True, value=np.random.choice([1, 2, 3, 4, 5, 6]))

    successors = [create_state(depth - 1, min_depth, num_successors)
                  for _ in range(num_successors)]

    return State(False, successors)

In [53]:
minimax = MinimaxSearch(policy=Policy(), value=Policy(), max_depth=5, max_breadth=3)

In [134]:
state = create_state(9, 2, num_successors=3)

#left = State(True, value=4)
#state = State(False, left=left, right=State(False, left=State(True, value=6), right=State(True, value=2)))

search = MinimaxSearch(Policy(), Policy(), 5, 3)
value, history = search.minimax(state, 10, -float('inf'), float('inf'), True)


other = state
for move in history:
    if not other.terminal:
        other = other.move(move)

assert other.value == value

print(value, history)
# print(state)



4 [0, 0, 0, 2, 0, 2, 2, 1, 2]
