In [27]:
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax import traverse_util
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import matplotlib.pyplot as plt
import networkx as nx

def init_board():
  return np.zeros((3, 3), dtype=int).flatten()

def flip_board(board):
  return -board

def get_valid_mask(board):
  return board == 0

# #Always assume action is made by player 1
# def get_next_state(board, action):
#   return board.flatten().at[action].add(1).reshape(board.shape)

#Always assume action is made by player 1
#Assume action is valid
def get_next_state(board, action):
  next_board = board.copy()
  next_board[action] = 1
  return next_board

def sample_action(action_dist, rng):
  action_dist = action_dist.flatten() / action_dist.sum()
  return jax.random.choice(rng, jnp.arange(action_dist.shape[0]), p=action_dist)

def disp_board(board):
  plt.imshow(board)

reward_conv = nn.Conv(features=4, kernel_size=(3, 3), use_bias=False, padding='SAME')
stripe_filter = jnp.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]])
reward_kernel = jnp.expand_dims(jnp.stack([jnp.eye(3), jnp.eye(3)[::-1, :], stripe_filter, stripe_filter.T], axis=2), 2)

reward_conv_param = {'params': {'kernel': reward_kernel}}
# print(jnp.expand_dims(jnp.eye(4), (0, 3)))

def get_reward(board):
    board_score = reward_conv.apply(reward_conv_param, jnp.expand_dims(jnp.array(board).reshape((3, 3)), (0, 3)))
    # print(jnp.transpose(board_score[0]))
    is_win = (jnp.max(board_score) >= 3).astype(int)
    is_loss = (jnp.min(board_score) <= -3).astype(int)
    score = float(is_win - is_loss)
    return score, score or (board == 0).sum() == 0

In [30]:
board = init_board()
print(board.reshape((3, 3)))

rng = jax.random.PRNGKey(40)
for i in range(10):
  _, rng = jax.random.split(rng, 2)
  action_dist = jnp.ones(board.shape) * get_valid_mask(board)
  next_action = sample_action(action_dist, rng)
  board = flip_board(get_next_state(board, next_action))
  reward, game_over = get_reward(board)

print(board, reward, game_over)

[[0 0 0]
 [0 0 0]
 [0 0 0]]
[-1  1 -1 -1  1 -1 -1  1  1] 0.0 True


In [3]:
def toy_model(state):
    return np.ones(state.shape), 0.1

In [31]:

STATE_DIM = 9 #Dimension of 4x4 tic-tac-toe board
ACTION_DIM = 9
MAX_SIZE = int(1e3)
C_BASE, C_INIT = 1.0, 1.0
class MCTS:
    def __init__(self):
        self.state = np.zeros((MAX_SIZE, STATE_DIM))
        self.state_lookup = {} #Maps state representation to index
        self.expanded = []

        self.visit_count = np.zeros(MAX_SIZE)
        self.action_visits  = np.zeros((MAX_SIZE, ACTION_DIM), dtype=int)
        self.action_total_value = np.zeros((MAX_SIZE, ACTION_DIM))
        self.action_mean_value = np.zeros((MAX_SIZE, ACTION_DIM))
        self.action_prior = np.zeros((MAX_SIZE, ACTION_DIM))
        # self.action_children = np.zeros((MAX_SIZE, ACTION_DIM), dtype=int) - 1
    
    #Assumes state is already expanded, and uses MCTS info to pick best action
    def select_action(self, state, state_index):
        state_visits = self.visit_count[state_index]
        exp_rate = np.log((1+state_visits + C_BASE)/C_BASE) + C_INIT
        model_prior = self.action_prior[state_index]
        sa_visits = self.action_visits[state_index]
        sa_mean_value = self.action_mean_value[state_index]
        action_distr = (sa_mean_value + exp_rate*np.sqrt(state_visits)*model_prior/(1+sa_visits))*get_valid_mask(state)
        return np.argmax(action_distr)

    def get_action_prob(self, state_index, temperature=1):
        """
        Select action according to the visit count distribution and the temperature.
        """
        action_visits = self.action_visits[state_index]
        if temperature == 0:
            a = np.argmax(action_visits)
            r = np.zeros(action_visits.shape)
            r[a] = 1.0
            return r
        elif temperature == float("inf"):
            return np.ones(action_visits.shape)/action_visits.shape[0]
        else:
            # See paper appendix Data Generation
            visit_count_distribution = np.power(action_visits, 1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
            return visit_count_distribution
            
    def expand_node(self, state, action_probs):
        state_index = len(self.expanded)
        # print("state index", state_index)
        self.expanded.append(True)
        self.state_lookup[state.tobytes()] = state_index

        self.state[state_index] = state
        # print(self.action_prior[state_index].shape, action_probs.shape)
        self.action_prior[state_index] = action_probs
        return state_index

    def search_iter(self, state_index, model):
        search_path = []
        path_actions = []

        curr_index = state_index
        curr_state = self.state[state_index]

        # Loop until reaching an untracked state
        while curr_index >= 0:
            search_path.append(curr_index)
            action = self.select_action(curr_state, curr_index)
            path_actions.append(action)

            curr_state = flip_board(get_next_state(curr_state, action))
            nsr = curr_state.tobytes()
            if(nsr in self.state_lookup):
                curr_index = self.state_lookup[nsr]
            else:
                curr_index = -1

        # The value of the new state from the perspective of the other player
        next_state = curr_state
        value, game_over = get_reward(next_state)
        value = -value
        if not game_over:
            # If the game has not ended:
            # EXPAND
            action_probs, value = model(next_state)
            valid_moves = get_valid_mask(next_state)
            action_probs = action_probs * valid_moves  # mask invalid moves
            action_probs /= np.sum(action_probs)
            self.expand_node(next_state, action_probs)
        
        #Backpropagate MCTS search path
        for i in range(len(search_path)-1, -1, -1):
            si, a = search_path[i], path_actions[i]
            self.visit_count[si] += 1
            self.action_visits[si, a] += 1
            self.action_total_value[si, a] += value
            self.action_mean_value[si, a] = self.action_total_value[si, a] / self.action_visits[si, a]
            value *= -1

        # print("search path", search_path)
        # print("search visits after update", [self.visit_count[s] for s in search_path])
        
    def mcts_eval(self, state, model, num_sims):
        # print("Expanded len", len(self.expanded))
        root_state = state
        action_prior, value_est = model(root_state)
        root_index = self.expand_node(root_state, action_prior)
        # print("Root index", root_index)
        for _ in range(num_sims):
            self.search_iter(root_index, model)
        return root_index
    
    def print_tree(self, root_index=0):
        for state_index in self.state_lookup.values():
            state = self.state[state_index]
            print(state.reshape((3, 3)))
            visited_actions = self.action_visits[state_index].nonzero()[0]
            action_values = self.action_mean_value[state_index, visited_actions]
            print("Action", visited_actions)
            print("Action value", action_values)

    def visualize_tree(self):
        G = nx.Graph()
        node_labels = {}
        edge_labels = {}
        for state_index in self.state_lookup.values():
            visited_actions = self.action_visits[state_index].nonzero()[0]
            action_values = self.action_mean_value[state_index, visited_actions]
            state = self.state[state_index]
            for action, value in zip(visited_actions, action_values):
                child_state = flip_board(get_next_state(state, action))
                # print(child_state.reshape((4,4)))
                child_index = self.state_lookup[child_state.tobytes()]
                G.add_edge(state_index, child_index)
                edge_labels[(state_index, child_index)] = value
            node_labels[state_index] = f"{state_index}, {self.action_visits[state_index].sum()}" 
        
        pos = nx.spring_layout(G)
        plt.figure()
        nx.draw(
            G, pos, edge_color='black', width=1, linewidths=1, node_size=10, 
            node_color='pink', alpha=0.9,
            # labels=node_labels
        )
        # nx.draw_networkx_edge_labels(
        #     G, pos,
        #     edge_labels=edge_labels,
        #     font_color='red'
        # )
        plt.axis('off')
        plt.show()

root_state = init_board()
root_state[1] = 1
root_state[7] = 1
print(root_state.reshape((3, 3)))
# root_state[:2] = 1
mcts = MCTS()
mcts.mcts_eval(root_state, toy_model, 50)
print("Finished mcts")
# mcts.visualize_tree()

[[0 1 0]
 [0 0 0]
 [0 1 0]]
Finished mcts


In [59]:
def self_play_episode(model, num_sims=50, temp_threshold=6):
    train_examples = []
    board = init_board()
    step = 0
    while True:
        mcts = MCTS()
        root_index = mcts.mcts_eval(board, model, num_sims=num_sims)
        # print(root_index, mcts.state[root_index].reshape((3,3)))
        # print(board.reshape((3,3)))
        # print(mcts.state[0].reshape((3,3)))
        # print(mcts.action_visits[0].reshape((3,3)))
        
        temp = int(step < temp_threshold)
        pi = mcts.get_action_prob(root_index, temperature=temp)


        train_examples.append((board, pi, step))

        action = np.random.choice(ACTION_DIM, p=pi)
        print("action", action)

        board = flip_board(get_next_state(board, action))
        #Reward is always negative because board is flipped after move. 
        r, game_over = get_reward(board) 
        if(r != 0):
            return [(b, p, (-1)**(step-s)) for b, p, s in train_examples]
        step += 1

def batch_examples(train_examples):
    state_batch = jnp.stack([t[0] for t in train_examples])
    pa_batch = jnp.stack([t[1] for t in train_examples])
    r_batch = jnp.stack([t[2] for t in train_examples]).reshape((-1, 1))
    return state_batch, pa_batch, r_batch

train_examples = self_play_episode(toy_model)
state_b, pa_b, r_b = batch_examples(train_examples)

action 3
action 2
action 8
action 6
action 7
action 0
action 1
action 4


In [80]:
from jax import jit

class TTTModel(nn.Module):
  """A simple MLP model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=64)(x)
    body = nn.Dense(features=32)(x)
    x = nn.Dense(features=9)(body)
    value = nn.tanh(nn.Dense(features=1)(body)) #Value estimate between -1 and 1
    return x, value

model = TTTModel()
rng = jax.random.PRNGKey(42)

model.apply(params, state_b)[1].shape

@jit
def model_agent(x, params):
  logits, value = TTTModel().apply(params, x)
  return nn.softmax(logits), value

model_agent(state_b, params)


(DeviceArray([[0.11111111, 0.11111111, 0.11111111, 0.11111111, 0.11111111,
               0.11111111, 0.11111111, 0.11111111, 0.11111111],
              [0.07653027, 0.09736775, 0.07992052, 0.13066143, 0.15686639,
               0.10830156, 0.10120884, 0.16185613, 0.08728711],
              [0.08857683, 0.12516937, 0.10361782, 0.09504247, 0.12469787,
               0.12033694, 0.11011213, 0.11450323, 0.11794329],
              [0.13931678, 0.117563  , 0.10381063, 0.1656441 , 0.08276459,
               0.05339085, 0.10845144, 0.11083571, 0.11822291],
              [0.07295639, 0.13741463, 0.09114759, 0.05642284, 0.19417347,
               0.15284966, 0.11017194, 0.10711964, 0.07774387],
              [0.12512428, 0.09865574, 0.10199723, 0.23405725, 0.04898063,
               0.07368465, 0.08672816, 0.13338219, 0.09738986],
              [0.08199129, 0.11197416, 0.11482912, 0.05512401, 0.2190445 ,
               0.13658142, 0.11735156, 0.05933076, 0.10377317],
              [0.12995099, 

In [71]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  model = TTTModel()
  params = model.init(rng, jnp.ones([1, 9]))['params']

  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def train_step(state, state_b, pa_b, r_b):
  """Train for a single step."""
  def loss_fn(params):
    logits, exp_value = TTTModel().apply({'params': params}, state_b)
    loss = optax.softmax_cross_entropy(logits, pa_b).mean() + jnp.square(r_b - exp_value).mean()
    return loss, (logits, exp_value)
  grad_fn = jax.grad(loss_fn, has_aux=True)
  grads, aux = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  # metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state#, metrics

state = create_train_state(rng, 0.1, 0.1)
state_b, pa_b, r_b = batch_examples(train_examples)
state = train_step(state, state_b, pa_b, r_b)

In [75]:
model = TTTModel()

model.apply({'params': state.params}, board)

(DeviceArray([ 0.11060315,  0.39404774,  0.11323913,  0.15138267,
               0.03048529, -0.2978174 ,  0.09324761,  0.09060495,
               0.19323306], dtype=float32),
 DeviceArray([0.5425823], dtype=float32))

In [None]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    return x

cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']