In [1]:
%load_ext autoreload
%autoreload 2
%pylab inline
import tensorflow as tf
import numpy as np
import gym
from tqdm import tqdm, trange
import os, sys
sys.path.append(os.getcwd())

Populating the interactive namespace from numpy and matplotlib


In [2]:
class TicTacToe():
    def __init__(self, state=None):
        self.reset()
        if state is not None:
            self.state = state
        
    def reset(self):
        self.done = False
        self.state = [0]*11
        self.state[-1] = 1
        return self.state
    
    class observation_space():
        shape = (11,)

    class action_space():
        n = 9

    def render(self):
        # print whose turn it is when we render to look at what
        # muzero thinks is a losing game
        print("turn %d" % self.state[-1])
        print(np.array(self.state[0:9]).reshape(3,3))
    
    def value(self, s):
        ret = 0
        for turn in [-1, 1]:
            for i in range(3):
                if all([x==turn for x in s[3*i:3*i+3]]):
                    ret = turn
                if all([x==turn for x in [s[i], s[3+i], s[6+i]]]):
                    ret = turn
                if all([x==turn for x in [s[0], s[4], s[8]]]):
                    ret = turn
                if all([x==turn for x in [s[2], s[4], s[6]]]):
                    ret = turn
        # NOTE: this is not the value, the state may be won
        return ret*s[-1]

    def dynamics(self, s, act):
        rew = 0
        s = s.copy()
        if s[act] != 0 or s[-2] != 0:
          # don't move in taken spots or in finished games
          rew = -10
        else:
            s[act] = s[-1]
            rew += self.value(s)
        if s[-2] != 0:
            rew = 0
        else:
            s[-2] = self.value(s)
        s[-1] = -s[-1]
        return rew, s

    def step(self, act):
        rew, self.state = self.dynamics(self.state, act)
        if rew != 0:
            self.done = True
        if np.all(np.array(self.state[0:9]) != 0):
            self.done = True
        return self.state, rew, self.done, None

In [3]:
env = TicTacToe()
print(env.reset())
print(env.step(4))
print(env.step(0))
print(env.step(3))
print(env.step(2))
print(env.step(6))
print(env.step(1))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, -1], 0, False, None)
([-1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], 0, False, None)
([-1, 0, 0, 1, 1, 0, 0, 0, 0, 0, -1], 0, False, None)
([-1, 0, -1, 1, 1, 0, 0, 0, 0, 0, 1], 0, False, None)
([-1, 0, -1, 1, 1, 0, 1, 0, 0, 0, -1], 0, False, None)
([-1, -1, -1, 1, 1, 0, 1, 0, 0, 1, 1], 1, True, None)


In [4]:
# we have mumodel below this is a mock model to better understand how the 
# MuModel works
class MockModel():
    def ht(self, s):
        return s
    def gt(self, s, a):
        return env.dynamics(s, a)
    def ft(s):
        return [1/9]*9, env.value(s)

In [5]:
obs = [-1, -1, 0, 1, 1, 0, 1, 0, 0, -1]

In [6]:
env.render()

env.reset()
env.render()

turn 1
[[-1 -1 -1]
 [ 1  1  0]
 [ 1  0  0]]
turn 1
[[0 0 0]
 [0 0 0]
 [0 0 0]]


In [7]:
from muzero.model import MuModel
m = MuModel(env.observation_space.shape, env.action_space.n, s_dim=64, K=5, lr=0.001)
print(env.observation_space.shape, env.action_space.n)

from muzero.game import Game, ReplayBuffer
from muzero.mcts import naive_search, mcts_search
replay_buffer = ReplayBuffer(50, 128, m.K)
rews = []

(11,) 9


In [8]:
def play_game(env, m):
    import random
    game = Game(env, discount=0.99)
    while not game.terminal():
    # TODO: Do we need to limit the depth of the MCTS search?
    #policy = naive_search(m, game.observation, T=1)
        policy, _ = mcts_search(m, game.observation, 30)
        game.act_with_policy(policy)
    return game

In [None]:
from muzero.model import reformat_batch
import collections

for j in range(30):
    for i in range(10):
        game = play_game(env, m)
        replay_buffer.save_game(game)
        rew = sum(game.rewards)
        rews.append(rew)
    for i in range(10):
        m.train_on_batch(replay_buffer.sample_batch())
    print(len(game.history), rew, game.history, m.losses[-1][0])

7 1 [4, 6, 2, 1, 8, 3, 0] 120.75796508789062
2 -10 [4, 4] 87.0246810913086
2 -10 [4, 4] 41.70781326293945
2 -10 [4, 4] 191.06112670898438
3 -10 [4, 7, 7] 271.1387023925781
2 -10 [4, 4] 25091.55078125
3 -10 [4, 2, 4] 197.02207946777344
5 1 [0, 7, 6, 4, 3] 203.1337127685547
4 -10 [3, 7, 4, 4] 178.56204223632812
5 1 [6, 3, 4, 0, 2] 161.0956573486328
3 -10 [6, 4, 4] 158.32313537597656
3 -10 [4, 7, 7] 123.99928283691406
4 -10 [2, 7, 5, 7] 138.25648498535156
3 -10 [0, 2, 0] 131.96624755859375
3 -10 [4, 7, 7] 117.0551528930664


In [None]:
plot(rews)
figure()
plt.yscale('log')
plot([x[0] for x in m.losses])
plot([x[1] for x in m.losses])
plot([x[-3] for x in m.losses])

In [None]:
from functools import lru_cache
@lru_cache(maxsize=none)

def minimax(s):
    # for each action, do the step method effectively
    if env.value(s) != 0:
        return s[-1] * env.value(s)
    
    print(s)
    if s[-1] == 1:
        value = -float('inf')
        for a in range(9):
            rew, ns = env.dynamics(s, a)
            value = max(value, minimax(ns))
        if s[-1] == -1:
            value = float('inf')
            for a in range(9):
            rew, ns = env.dynamics(s, a)
            if s != ns:
                value = min(value, minimax(ns))

In [None]:
s = env.reset()
minimax(s)