In [5]:
# Calculate grundy values for a chomp board

class Chomp:
    def __init__(self):
        self.pos_grundy = {}
        self.pos_lose_ratio = {}

    def best_move(self, pos):
        best_move = None
        best_lose_ratio = 0

        for follower in followers(pos):
            if self.get_grundy(follower) == 0:
                return follower
            lose_ratio = self.get_lose_ratio(follower)
            if lose_ratio >= best_lose_ratio:
                best_move = follower
                best_lose_ratio = lose_ratio
        return best_move

    def get_grundy(self, pos):
        # Get grundy value for this position

        pos_index = pos_ix(pos)
        if pos_index in self.pos_grundy:
            return self.pos_grundy[pos_index]
        else:
            result = self.compute_grundy(pos)
            self.pos_grundy[pos_index] = result
            return result
        
    def compute_grundy(self, pos):
        foll_grundy = [self.get_grundy(follower) for follower in followers(pos)]
        return mex(foll_grundy)
    
    def get_lose_ratio(self, pos):
        # Get lose ratio for this position

        pos_index = pos_ix(pos)
        if pos_index in self.pos_lose_ratio:
            return self.pos_lose_ratio[pos_index]
        else:
            result = self.compute_lose_ratio(pos)
            self.pos_lose_ratio[pos_index] = result
            return result
        
    def compute_lose_ratio(self, pos):
        # Computes the ratio of followers of a position lead to a loss
        
        foll_grundy = [self.get_grundy(follower) for follower in followers(pos)]
        follower_win = foll_grundy.count(0)
        follower_lose = len(foll_grundy) - follower_win
        return 1 if follower_win == 0 else float(follower_lose) / (follower_win + follower_lose)
    
# Helper fns
def mex(vals):
    mex = 0
    for val in sorted(vals):
        if val is mex:
            mex += 1
    return mex

def pos_ix(pos):
    # Get the string used to index this position in a dict
    return str(tuple([x for x in pos if x != 0]))

def pretty_print_grundy(grundy_pos):
    for pos, g in grundy_pos.items():
        print(f'{pos}: {g}\n')

def visualize_board(pos):
    for i in range(len(pos)-1,-1,-1):
        str = ""
        for j in range(0,pos[i]):
            str += "X "
        print(str)
            
def followers(pos):
    # Find all followers of a given chomp position

    if sum(pos) == 1:
        # Base case, terminal position of board
        return []

    next_pos = []
    for i in range(0,len(pos)):
        n = pos[i]
        if n == 0:
            continue
        for j in range(0,n):
            if j == 0 and i == 0:
                # Illegal move, removing last piece
                continue
            new_pos = (*pos[0:i], *tuple(min(j,pos[_]) for _ in range(i,len(pos))))
            next_pos.append(new_pos);
    return next_pos

def valid_move(pos, move):
    return pos_ix(move) in [pos_ix(x) for x in followers(pos)]

def is_terminal(pos):
    return sum(pos) == 1

def parse_move(input):
    return tuple(int(x) for x in input.split(","))