In [70]:
import logging
from collections import namedtuple
import random
from copy import deepcopy
from itertools import accumulate
from operator import xor


Nimply = namedtuple("Nimply", "row, num_objects")

class Nim:
    def __init__(self, num_rows: int, k: int = None):
        self._rows = [i * 2 + 1 for i in range(num_rows)]
        #self._rows = [0,1,3]
        self._k = k

    def __bool__(self):
        return sum(self._rows) > 0

    def __str__(self):
        return "<" + " ".join(str(_) for _ in self._rows) + ">"

    @property
    def rows(self) -> tuple:
        return tuple(self._rows)

    @property
    def k(self):
        return self._k
    
    #apply the chosen move by removing num_objects from the row
    def nimming(self, ply: Nimply):
        row, num_objects = ply
        assert self._rows[row] >= num_objects
        assert self._k is None or num_objects <= self._k
        self._rows[row] -= num_objects


In [71]:
def dumb_PCI(state: Nim):
    """Pick always the minimum(maximum) possible number of the lowest row"""
    possible_moves = [(r, o) for r, c in enumerate(state.rows) for o in range(1, c + 1)]
    return Nimply(*max(possible_moves, key=lambda m: (-m[0], -m[1])))

def pure_random(state: Nim):
    row = random.choice([r for r, c in enumerate(state.rows) if c > 0])
    num_objects = random.randint(1, state.rows[row])
    return Nimply(row, num_objects)

"""optimal strategy"""
def nim_sum(state: Nim):
    *_, result = accumulate(state.rows, xor)
    return result

def cook_status(state: Nim):
    cooked = dict()
    cooked["possible_moves"] = [
        (r, o) for r, c in enumerate(state.rows) for o in range(1, c + 1) if state.k is None or o <= state.k
    ]
    cooked["active_rows_number"] = sum(o > 0 for o in state.rows)
    cooked["shortest_row"] = min((x for x in enumerate(state.rows) if x[1] > 0), key=lambda y: y[1])[0]
    cooked["longest_row"] = max((x for x in enumerate(state.rows)), key=lambda y: y[1])[0]
    cooked["nim_sum"] = nim_sum(state)

    brute_force = list()
    for m in cooked["possible_moves"]:
        tmp = deepcopy(state)
        tmp.nimming(m)
        brute_force.append((m, nim_sum(tmp)))
    cooked["brute_force"] = brute_force

    return cooked

def optimal_startegy(state: Nim):
    data = cook_status(state)
    return next((bf for bf in data["brute_force"] if bf[1] == 0), random.choice(data["brute_force"]))[0]


In [72]:
import numpy as np
logging.getLogger().setLevel(logging.DEBUG)

alpha=0.15
random_factor=0.2  # 80% explore, 20% exploit
state_history = []
rewards = {}
NUM_MATCHES = 100

def give_rewards(state: Nim):
      # if at end give 0 reward
      # if not at end give -1 reward
    if not state:
        return 0
    else:
        return -1

def init_reward(nim):  
    if not nim:
        return
    
    allowedMoves = allowed_moves(nim)
    for move in allowedMoves:
        tmp = deepcopy(nim)
        tmp.nimming(move)
        new_state = tuple(tmp.rows)
        rewards[new_state] = np.random.uniform(low=1.0, high=0.1)
        init_reward(tmp)
    

def choose_action(nim, allowedMoves, rewards):
        maxG = -10e15
        next_move = None
        randomN = np.random.random()
        if randomN < random_factor:
            # if random number below random factor, choose random action
            index = np.random.randint(0, len(allowedMoves))
            next_move = allowedMoves[index]
        else:
            # if exploiting, gather all possible actions and choose one with the highest G (reward)
            for action in allowedMoves:
                tmp = deepcopy(nim)
                tmp.nimming(action)
                new_state = tmp.rows
                if rewards[new_state] >= maxG:
                    next_move = action
                    maxG = rewards[new_state]
            
        return next_move
    
def allowed_moves(nim):
    return [
        (r, o) for r, c in enumerate(nim.rows) for o in range(1, c + 1)
    ]

def learn(rewards, state_history, random_factor):
        target = 0
        for prev, reward in reversed(state_history):
            rewards[tuple(prev)] = rewards[tuple(prev)] + alpha * (target - rewards[tuple(prev)])
            target += reward

        state_history = []
        random_factor -= 10e-5  # decrease random factor each episode of play
        
def reinforcement_learning(nim):
    allowedMoves = allowed_moves(nim)
    action = choose_action(nim,allowedMoves, rewards)
    nim.nimming(action)
    rew = give_rewards(nim)
    state_history.append((nim.rows, rew))
    learn(rewards, state_history, random_factor)
    
    return action

def test(nim):
    allowedMoves = allowed_moves(nim)
    action = choose_action(nim,allowedMoves, rewards)
    nim.nimming(action)
    
    return action

    
nim = Nim(3)
rewards[nim.rows] =  0
init_reward(nim)
print(f"pesi iniziali = {rewards}")

"""Training"""
for i in range(2000):
    player = 0
    while nim:
        if player == 0:
            ply = reinforcement_learning(nim)
        else:
            ply = optimal_startegy(nim)
            nim.nimming(ply)
        
        player = 1 - player
    winner = 1 - player
    nim = Nim(3)
    
print(f"pesi finali = {rewards}")

"""Test"""
nim = Nim(3)
won = 0

for i in range(NUM_MATCHES):
    player = 0
    while nim:
        if player == 0:
            ply = test(nim)
        else:
            ply = optimal_startegy(nim)
            nim.nimming(ply)
        player = 1 - player
        
    if player == 1:
        won += 1
        
    nim = Nim(3)
        
win_rate = won / NUM_MATCHES
print(f"win rate = {win_rate}")

    

pesi iniziali = {(1, 3, 5): 0, (0, 3, 5): 0.7441410996553026, (0, 2, 5): 0.27376311550086085, (0, 1, 5): 0.8161496032366387, (0, 0, 5): 0.35097846058824744, (0, 0, 4): 0.929291075463907, (0, 0, 3): 0.12184547700462944, (0, 0, 2): 0.5319505679423366, (0, 0, 1): 0.7705283343137482, (0, 0, 0): 0.9271047527123472, (0, 1, 4): 0.34334942269621227, (0, 1, 3): 0.7962722039718825, (0, 1, 2): 0.214872950006053, (0, 1, 1): 0.19918964175686193, (0, 1, 0): 0.559542601513439, (0, 2, 4): 0.6302571285264795, (0, 2, 3): 0.2033292254583421, (0, 2, 2): 0.7367409253165338, (0, 2, 1): 0.4385844070344892, (0, 2, 0): 0.6654522557020943, (0, 3, 4): 0.8105791077245333, (0, 3, 3): 0.8977353201699955, (0, 3, 2): 0.2878516006740728, (0, 3, 1): 0.7744985442015562, (0, 3, 0): 0.8607690405367707, (1, 2, 5): 0.5728419082650789, (1, 1, 5): 0.6088933713126825, (1, 0, 5): 0.41824472928933143, (1, 0, 4): 0.1316374242384587, (1, 0, 3): 0.45822864144617037, (1, 0, 2): 0.24654157823266942, (1, 0, 1): 0.2049426865776841, (1,