In [5]:
import numpy as np
from itertools import combinations
import math
from tqdm.notebook import trange

In [6]:
def are_collinear(p1, p2, p3):
  """Returns True if the three points are collinear."""
  """Adapt from https://github.com/kitft/funsearch"""
  x1, y1 = p1
  x2, y2 = p2
  x3, y3 = p3
  return (y1 - y2) * (x1 - x3) == (y1 - y3) * (x1 - x2)

class N3il: # Class of No-3-In-Line
    def __init__(self, grid_size):
        self.row_count = grid_size[0]
        self.column_count = grid_size[1]
        self.action_size = self.row_count * self.column_count
    
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = 1

        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_collinear(self, state, action):
        row = action // self.column_count
        column = action % self.column_count
        state_next = state.copy()
        state_next[row, column] = 1

        # Get the coordinates of all points with value 1
        coords = np.argwhere(state_next == 1) 
        # Convert to list of tuples (optional)
        # coord_list = [tuple(coord) for coord in coords]
        # Get all combinations of 3 points
        triples = list(combinations(coords, 3))

        number_of_collinear_triples = 0 
        # CONSIDER give configs with more collinear triples more "punishment"
        # But it may not be a good idea, consider in 3x3 case, when we got
        # 6 points, adding a points would results in 3 triples collinear, but
        # it doesn't mean the config of 6 points is not good.
        for triple in triples:
            if are_collinear(triple[0], triple[1], triple[2]):
                number_of_collinear_triples += 1
        
        return number_of_collinear_triples
    
    def get_value_and_terminated(self, state):
        value = np.sum(state.reshape(-1) == 1)
        return value, True # Return TRUE if the configuration involves any 3 points collinear

In [7]:
'''
Adapt from
foersterrobert/AlphaZeroFromScratch
'''

class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = game.get_valid_moves(state)

        self.visit_count = 0
        self.value_sum = 0

    def is_fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb= self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        
        return best_child
    
    def get_ucb(self, child):
        q_value = child.value_sum / child.visit_count
        return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)

    def expand(self):
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0

        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action)

        child = Node(self.game, self.args, child_state, self, action)
        self.children.append(child)
        return child
    
    def simulate(self):
        has_collinear = self.game.check_collinear(self.parent.state, self.action_taken)
        # value, _ = self.game.get_value_and_terminated(self.state)

        if has_collinear > 0:
            value, _ = self.game.get_value_and_terminated(self.parent.state)
            return value
            # return 0 # think about this part
        
        rollout_state = self.state.copy()

        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])

            has_collinear = self.game.check_collinear(rollout_state, action)
            
            if has_collinear > 0:
                value, _ = self.game.get_value_and_terminated(rollout_state)
                return value
            
            rollout_state = self.game.get_next_state(rollout_state, action)

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self, game, args={
        'num_searches': 1000,
        'C': 1.4
    }):
        self.game = game
        self.args = args

    def search(self, state):
        # define root
        root = Node(self.game, self.args, state)

        for search in trange(self.args['num_searches']):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()

            if node.action_taken is not None:
                has_collinear = self.game.check_collinear(node.state, node.action_taken)
                value, _ = self.game.get_value_and_terminated(node.state)

                if has_collinear == 0:
                    node = node.expand()
                    value = node.simulate()
            else:
                node = node.expand()
                value = node.simulate()

            node.backpropagate(value)

        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
            
            # expansion
            # simulation
            # backpropagation

        # return visit_counts

In [9]:
n3il = N3il(grid_size=(5,5))

args = {
    'C': 1.41,
    'num_searches': 1000
}

mcts = MCTS(n3il, args)

state = n3il.get_initial_state()
num_of_points = 0

while True:
    print("---------------------------")
    print(f"Number of points: {num_of_points}")
    print(state)

    mcts_probs = mcts.search(state)
    action = np.argmax(mcts_probs)

    '''
    valid_moves = n3il.get_valid_moves(state)
    list_valid_moves = [i for i in range(n3il.action_size) if valid_moves[i] == 1]
    print("valid moves: ", list_valid_moves)
    
    action = int(input("Please give a point: "))
    if action not in list_valid_moves:
        print("This spot is occupied. Action is not valid.")
        continue
    '''

    n_of_collinear_triples = n3il.check_collinear(state, action)

    if n_of_collinear_triples > 0:
        # value, _ = n3il.get_value_and_terminated(state)
        print("*******************************************************************")
        print(f"Trial Terminated with {num_of_points} points. Final valid configuration:")
        print(state)
        print(f"The point you give causes {n_of_collinear_triples} triples of 3 points collinear:")
        print(n3il.get_next_state(state, action))

        break
    
    num_of_points += 1
    state = n3il.get_next_state(state, action)

---------------------------
Number of points: 0
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 1
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 2
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 3
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 4
[[1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 5
[[1. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 6
[[1. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 7
[[1. 0. 1. 0. 0.]
 [0. 0. 0. 1. 1.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 8
[[1. 0. 1. 0. 0.]
 [0. 0. 0. 1. 1.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 1. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

---------------------------
Number of points: 9
[[1. 0. 1. 0. 0.]
 [0. 0. 0. 1. 1.]
 [1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 1. 0.]]


  0%|          | 0/1000 [00:00<?, ?it/s]

*******************************************************************
Trial Terminated with 9 points. Final valid configuration:
[[1. 0. 1. 0. 0.]
 [0. 0. 0. 1. 1.]
 [1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 1. 0.]]
The point you give causes 2 triples of 3 points collinear:
[[1. 1. 1. 0. 0.]
 [0. 0. 0. 1. 1.]
 [1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 1. 0.]]
