Monte Carlos Search Tree :

In [1]:
import math
import random

# All rolls
rolls = [[]]
# Current number of dice
dice = 8
# Constant for the Upper Confidence Bound
C = 2000
# Number of iterations for the MCST
total_iterations = 100

class MCSTree:
  def __init__(self, t, n):
    # Value of the Node
    self.t = t
    # Number of visits of the node
    self.n = n
    # How far in the tree are we ?
    self.previousRoll = []
    # All the discarded dice
    self.discardedDice = []
    # Parent 
    self.parent = None
    # Valid node
    self.valid = True
    # Child number
    self.childNb = None
    # The six child possibilities for the dice or end turn 
    self.children = [None]*7

  def __str__(self, level=0):
    if level == 5:
      return ""
    ret = "\t"*level+"child : "+repr(self.childNb)+", passed : "+repr(self.n)+", val : "+repr(self.t)+", previous : "+repr(self.previousRoll)+", discarded : "+repr(self.discardedDice)+"\n"
    if self.children[0] != None:
      for child in self.children:
        ret += child.__str__(level+1)
    return ret

def createNode(node, t, n, previousRoll, discardedDice, childNb):
  newNode = MCSTree(t, n)
  newNode.previousRoll.extend(node.previousRoll)
  newNode.previousRoll.append(previousRoll)
  newNode.discardedDice.extend(node.discardedDice)
  newNode.discardedDice.extend(discardedDice)
  newNode.parent = node
  newNode.childNb = childNb+1
  node.children[childNb] = newNode

In [2]:
# Monte Carlo Search Tree 
def MCST(root):
  global total_iterations
  global rolls
  print(rolls)
  while root.n < total_iterations:
    old_rolls = rolls.copy()
    leaf = traverse(root)
    if leaf.valid == False:
      backpropagate(leaf, 0)
      leaf.valid == True
    else:
      simulation_result = rollout(leaf)
      backpropagate(leaf, simulation_result)
    rolls = old_rolls
  return best_child(root)

# Find the total number of simulations done
def root_simulation_number(node):
  while node.parent != None:
    node = node.parent
  return node

# Calculate the value of a node 
def best_ucb(node):
  global C
  old_ucb = -1
  node_num = -1
  myList = list(range(7))
  random.shuffle(myList)
  for i in myList:
    if node.children[i] != None:
      if node.children[i].n == 0:
        return node.children[i]
      N = root_simulation_number(node)
      ucb = node.children[i].t + C * math.sqrt(math.log(N.n)/node.children[i].n)
      if ucb > old_ucb:
        old_ucb = ucb
        node_num = i
  return node.children[node_num]

# Is the node not a leaf node ? 
def non_terminal(node):
  # Is it the root node 
  if node.parent == None:
    return 1
  # Has there already been 6 throws ?
  if len(node.previousRoll) == 6:
    return 0
  # No new numbers in current roll
  if all(x in node.previousRoll for x in rolls[len(node.previousRoll)-1]):
    return 0
  # No dice left
  if len(node.discardedDice) == 8:
    return 0
  # The choice 7 is to stop the turn now (normally already taken care of when checking no new numbers in current roll)
  if node.childNb == 7:
    return 0
  return 1

# Not possible to take this number
def impossible_action(node):
  # Continue if it is the root node
  if node.parent == None:
    return 0
  # Continue if it is the stop turn choice
  if node.childNb == 7:
    return 0
  # If the node number has already been selected previously, then abort
  if node.childNb in node.previousRoll[0:-1]:
    return 1
  # If the node number is not in the rolls, then abort
  if node.childNb not in rolls[len(node.previousRoll)-1]:
    return 1
  return 0

# Is the node a leaf node ? (0: yes, 1 : unvisited child to do, 2: no) 
def fully_expanded(node):
  if impossible_action(node):
    return -1
  if not non_terminal(node):
    return 0
  # Create the children if not already done
  if node.children[0] == None:
    for i in range(6):
      nbDiscarded = [i+1]*rolls[len(node.previousRoll)].count(i+1)
      createNode(node, 0, 0, i+1, nbDiscarded, i)
    createNode(node, 0, 0, None, [], 6)
    return 1
  # If one of the child nodes has not yet been visited
  did_visit = pick_unvisited(node)
  if did_visit > -1:
    return 1
  return 2

# Is there a child that has never been visited ?
def pick_unvisited(node):
  # randomly select one of the seven children
  myList = list(range(7))
  random.shuffle(myList)
  for i in myList:
    if node.children[i].n == 0:
      return i
  return -1

# function for node traversal
def traverse(node):
  isExpanded = fully_expanded(node)
  # While the node is not terminal, not impossible or has been visited
  while (isExpanded == 2):
    node = best_ucb(node)
    rolls.append([random.randint(1, 6) for i in range(8 - len(node.discardedDice))])
    isExpanded = fully_expanded(node)
  # The node is impossible
  if isExpanded == -1:
    node.valid = False
    return node
  # The node is a terminal node
  if isExpanded == 0:
    return node
  # The node has unvisited children
  if isExpanded == 1:
    unVisited = pick_unvisited(node)
    if impossible_action(node.children[unVisited]):
        node.children[unVisited].valid = False
    return node.children[unVisited]

# Result from the simulation
def result(node):
  # Player needs dice 6 to be able to win and that the sum of his points is bigger or equal to an available token
  if 6 in node.previousRoll:
    values = [node.discardedDice[j] if node.discardedDice[j] != 6 else 5 for j in range(len(node.discardedDice))]
    score = sum(values)
    i = 0
    while score-i  > 20:
      if score-i in [i[0] for i in available_pieces]:
        return available_pieces[[y[0] for y in available_pieces].index(score-i)][1]
      i += 1
  return 0

# Is the node not a leaf node ?
def rollout_non_terminal(node):
  # Has there already been 6 throws ?
  if len(node.previousRoll) == 6:
    return 0
  # No new numbers in current roll
  if all(x in node.previousRoll for x in rolls[len(node.previousRoll)-1]):
    return 0
  # No dice left
  if len(node.discardedDice) == 8:
    return 0
  # Decided to not continue playing
  if node.childNb == 7:
    return 0
  return 1

# function for the result of the simulation
def rollout(node):
  # Save values that will be changed for the rollout
  global rolls
  oldRolls = rolls.copy()
  oldNodeRoll = node.previousRoll.copy()
  oldNodeDis = node.discardedDice.copy()
  # While the node is not terminal and not impossible, navigate through the children with random rolls
  impossible = impossible_action(node)
  while not impossible and rollout_non_terminal(node):
    rollout_policy(node)
    rolls.append([random.randint(1, 6) for i in range(8 - len(node.discardedDice))])
    impossible = impossible_action(node)
  # get the result of the final node
  myResult = 0 if impossible else result(node)
  # Load values that have been changed 
  node.previousRoll = oldNodeRoll
  node.discardedDice = oldNodeDis
  rolls = oldRolls
  return myResult

# function for randomly selecting a child node
def rollout_policy(node):
  chosen = -1
  # choose a dice that has not yet been selected
  while chosen == -1 or rolls[len(node.previousRoll)-1][chosen] in node.previousRoll:
    chosen = random.randint(0, len(rolls[len(node.previousRoll)-1])-1)
  chosenValue = rolls[len(node.previousRoll)-1][chosen]
  node.discardedDice.extend([chosenValue]*rolls[len(node.previousRoll)-1].count(chosenValue))
  node.previousRoll.append(chosenValue)

# Update the state of the current node 
def update_stats(node, result):
  node.t += result
  node.n += 1

# function for backpropagation
def backpropagate(node, result):
  update_stats(node, result)
  if (node.parent == None): return
  backpropagate(node.parent, result)

# function for selecting the best child
# node with highest number of visits
def best_child(node):
  best = [-1, -1]
  for i in range(6):
    if node.children[i] != None:
      if best[1] < node.children[i].n:
        best = [i, node.children[i].n]
  return best[0] + 1

In [3]:
random.seed(0)
dice = 8
C = 500
total_iterations = 100000
rolls = [[random.randint(1, 6) for i in range(dice)]]
available_pieces =  [(21, 1), (22, 1), (23, 1), (24, 1), (25, 2), (26, 2), (27, 2), (28, 2), (29, 3), (30, 3), (31, 3), (32, 3), (33, 4), (34, 4), (35, 4), (36, 4)]
tree = MCSTree(0,0)
print(MCST(tree))
print(tree)

[[4, 4, 1, 3, 5, 4, 4, 3]]
4
child : None, passed : 100000, val : 48296, previous : [], discarded : []
	child : 1, passed : 557, val : 0, previous : [1], discarded : [1]
		child : 1, passed : 79, val : 0, previous : [1, 1], discarded : [1, 1]
		child : 2, passed : 79, val : 0, previous : [1, 2], discarded : [1, 2, 2]
			child : 1, passed : 8, val : 0, previous : [1, 2, 1], discarded : [1, 2, 2]
			child : 2, passed : 7, val : 0, previous : [1, 2, 2], discarded : [1, 2, 2, 2]
			child : 3, passed : 8, val : 0, previous : [1, 2, 3], discarded : [1, 2, 2, 3]
				child : 1, passed : 0, val : 0, previous : [1, 2, 3, 1], discarded : [1, 2, 2, 3, 1]
				child : 2, passed : 1, val : 0, previous : [1, 2, 3, 2], discarded : [1, 2, 2, 3, 2]
				child : 3, passed : 0, val : 0, previous : [1, 2, 3, 3], discarded : [1, 2, 2, 3]
				child : 4, passed : 0, val : 0, previous : [1, 2, 3, 4], discarded : [1, 2, 2, 3, 4]
				child : 5, passed : 0, val : 0, previous : [1, 2, 3, 5], discarded : [1, 2, 2, 3,