# Lab 3 - Nim  
## Task3.4: An agent using reinforcement learning  

In [1]:
import logging
import random
from copy import deepcopy

from nim import Nimply, Nim
from play_nim import opponents

In [2]:
logging.basicConfig(format="%(message)s", level=logging.INFO)

## Implementation

In [3]:
def hash_id(state: list, player: int):
  assert player == 1 or player == 0
  return hash(tuple(sorted(state)) + (player, ))

#### Node class from Task 3

In [4]:
class Node():
  
  def __init__(self, state: list, player: int):
    assert player == 1 or player == 0
    
    self.id = hash_id(state, player)
    self.state = deepcopy(state)
    self.player = player # Me (0) -> max ; Opponent (1) -> min
    
    self.reward = 0 
    self.children = []
    self.parents = []
    self.possible_acitions() # creates self.actions


  def __eq__(self, other):
    return isinstance(other, Node) and self.state == other.state and self.player == other.player


  def link_parent(self, parent):
    assert isinstance(parent, Node)
    assert self.player != parent.player
    if parent not in self.parents:
      self.parents.append(parent)


  def link_child(self, child):
    assert isinstance(child, Node)
    assert self.player != child.player
    if child not in self.children:
      self.children.append(child)


  def is_game_over(self):
    return sum(self.state) == 0

  
  def give_reward(self):
    """
    - not end -> reward = -1  
    - win -> reward = 100  
    - lose -> reward = -100  
    """
    if not self.is_game_over():
      return -1
    if self.player == 0: # I lose
      return -100
    return 100 # I win


  def possible_acitions(self, k=None):
    self.actions = []
    
    if self.is_game_over():
      return

    not_zero_rows = [(r, n) for r, n in enumerate(self.state) if n > 0]
    for row, num_obj in not_zero_rows:  
      while num_obj > 0:
        if k and num_obj > k:
          num_obj = k
          continue
        self.actions.append(Nimply(row, num_obj))
        num_obj -= 1


#### Game Tree (builded recursively, such us in task 3)

In [5]:
class GameTree():
  
  def __init__(self, nim: Nim, start_player=0):
    self.k = nim._k
    self.start_player = start_player
    self.dict_id_node = {}    
    self.dict_id_reward = {} 
    
    self.root = Node(nim._rows, start_player)
    self.dict_id_node[self.root.id] = self.root
    logging.info(f'Building the tree...')
    self.build_tree()
    logging.info('Done')

  
  def build_tree(self):
    def recursive(node: Node):
      # Stop condition
      if node.id in self.dict_id_reward:
        return
      
      if node.is_game_over():
        node.reward = node.give_reward()
        self.dict_id_reward[node.id] = node.reward
        return


      # Recursive part
      for ply in node.actions:
        row, num_obj = ply
        
        # Check rules
        assert node.state[row] >= num_obj
        assert self.k is None or num_obj <= self.k

        # Create the child
        child_state = deepcopy(node.state)
        child_state[row] -= num_obj # nimming
        child_id = hash_id(child_state, 1 - node.player)
        if child_id in self.dict_id_node: # node already exists
          child = self.dict_id_node[child_id]
        else: # create the new node
          child = Node(child_state, 1 - node.player)
          self.dict_id_node[child_id] = child
        
        # Link parent and child
        node.link_child(child)
        child.link_parent(node)

        # Recursion
        recursive(child)
          
      # Reward of the node (-1)
      node.reward = node.give_reward()
      self.dict_id_reward[node.id] = node.reward
    
      return 

    recursive(self.root)
    self.root.reward = self.root.give_reward()
  

  def player_states(self, player):
    assert player == 0 or player == 1

    dict_player = {}
    for (id, node) in self.dict_id_node.items():
      if node.player == player:
        dict_player[id] = node
    
    return dict_player


#### Agent

In [6]:
class Agent():
  
  def __init__(self, game_tree: GameTree, alpha=0.5, random_factor=0.2):
    self.alpha = alpha
    self.random_factor = random_factor
    #self.dict_id_node = game_tree.dict_id_node
    
    self.state_history = [game_tree.root] # node -> inside has state and reward
    self.G = {} # (k, v) = id_node, expected reward
    for id, node in game_tree.dict_id_node.items():
        self.G[id] = random.uniform(1.0, 0.1)


  def choose_action(self, node: Node):
    maxG = -10e15
    next_move = None
    
    # Random action
    if random.random() < self.random_factor:
      next_move = random.choice(node.actions)
    # Action with highest G (reward)
    else: 
      for a in node.actions:  # a is a Nimply obj
        new_state = deepcopy(node.state)
        new_state[a.row] -= a.num_objects
        new_state_id = hash_id(new_state, player=1) # opponent's state
        if self.G[new_state_id] >= maxG:
          next_move = a
          maxG = self.G[new_state_id]

    return next_move      
        

  def update_history(self, node: Node):
    self.state_history.append(node)


  def learn(self):
    target = 0

    for node in reversed(self.state_history):
      self.G[node.id] = self.G[node.id] + self.alpha * (target - self.G[node.id])
      target += node.reward

    self.state_history = []     # Restart
    self.random_factor -= 10e-5 # Decrease random factor each episode of play

    

#### Reinforcement Learning algorithm

In [16]:
def RL_nim(nim: Nim, game_tree: GameTree, agent: Agent, opponent: callable, episodes = 5000 ):

  for e in range(episodes):
    # Play a game
    episode_nim = deepcopy(nim)
    state = game_tree.root
    while not episode_nim.is_game_over():
      # My turn
      if state.player == 0:
        my_action = agent.choose_action(state) # Choose an action
        print(f'0: {my_action} on {episode_nim}')
        episode_nim.nimming(my_action)         # Update the state
        
        # Get new state and reward
        new_state_id = hash_id(episode_nim._rows, player = 1)
        state = game_tree.dict_id_node[new_state_id]
        agent.update_history(state)
      
      # Opponent turn
      else:
        opp_action = opponent(episode_nim)
        print(f'1: {opp_action} on {episode_nim}')
        episode_nim.nimming(opp_action)
        new_state_id = hash_id(episode_nim._rows, player = 0)
        state = game_tree.dict_id_node[new_state_id]
      
      # Log
      if e % 50 == 0:
        logging.info(f'Episode {e}: player {(1 - state.player)} on {state.state}')
      



## Play

In [18]:
nim = Nim(5)
game_tree = GameTree(nim, start_player=0)
agent = Agent(game_tree)

Building the tree...
Done


In [19]:
opponent = opponents[3]
RL_nim(nim, game_tree, agent, opponent=opponent )

Episode 0: player 0 on [1, 3, 4, 7, 9]
Episode 0: player 1 on [0, 3, 4, 7, 9]
Episode 0: player 0 on [0, 3, 3, 7, 9]
Episode 0: player 1 on [0, 0, 3, 7, 9]
Episode 0: player 0 on [0, 0, 1, 3, 9]
Episode 0: player 1 on [0, 0, 1, 3, 6]
Episode 0: player 0 on [0, 0, 1, 3, 3]
Episode 0: player 1 on [0, 0, 1, 2, 3]
Episode 0: player 0 on [0, 0, 1, 1, 2]
Episode 0: player 1 on [0, 0, 1, 1, 1]
Episode 0: player 0 on [0, 0, 0, 1, 1]
Episode 0: player 1 on [0, 0, 0, 0, 1]


0: Nimply(row=2, num_objects=1) on <1 3 5 7 9>
1: Nimply(row=0, num_objects=1) on <1 3 4 7 9>
0: Nimply(row=2, num_objects=1) on <0 3 4 7 9>
1: Nimply(row=1, num_objects=3) on <0 3 3 7 9>
0: Nimply(row=3, num_objects=6) on <0 0 3 7 9>
1: Nimply(row=4, num_objects=3) on <0 0 3 1 9>
0: Nimply(row=4, num_objects=3) on <0 0 3 1 6>
1: Nimply(row=2, num_objects=1) on <0 0 3 1 3>
0: Nimply(row=4, num_objects=2) on <0 0 2 1 3>
1: Nimply(row=2, num_objects=1) on <0 0 2 1 1>
0: Nimply(row=4, num_objects=1) on <0 0 1 1 1>
1: Nimply(row=2, num_objects=1) on <0 0 1 1 0>
0: Nimply(row=4, num_objects=1) on <0 0 0 1 0>


AssertionError: 