In [1]:
import random
import torch
import torch.optim as optim

from PUCT import PUCT
from PUCTTrainer import PUCTTrainer
from TicTacToe import TicTacToe
from TicTacToeNet import TicTacToeNet

seed = 0
device = "cpu"

In [2]:
random.seed(seed)
torch.manual_seed(seed)

game = TicTacToe()
predictor = TicTacToeNet(game.max_num_actions(), game.num_players(), device = device)
optimizer = optim.Adam(predictor.parameters(), lr = 3e-3, weight_decay = 1e-4) # L2 regularization achieved by weight_decay
trainer = PUCTTrainer(game, predictor, optimizer, buffer_size = 1000, batch_size = 32)

In [3]:
trainer.train_predictor(iterations = 3, playouts_per_node = 250, episodes = 50, gradient_updates = 100)

Training Predictor:   0%|          | 0/3 [00:00<?, ?it/s]

Self-Play:   0%|          | 0/50 [00:00<?, ?it/s]

Updating Weights:   0%|          | 0/100 [00:00<?, ?it/s]

[Iteration 1] Samples in Replay Buffer: 361 - Policy Loss: 1.23 - Value Loss: 0.325 - Total Loss: 1.56


Self-Play:   0%|          | 0/50 [00:00<?, ?it/s]

Updating Weights:   0%|          | 0/100 [00:00<?, ?it/s]

[Iteration 2] Samples in Replay Buffer: 769 - Policy Loss: 0.92 - Value Loss: 0.208 - Total Loss: 1.13


Self-Play:   0%|          | 0/50 [00:00<?, ?it/s]

Updating Weights:   0%|          | 0/100 [00:00<?, ?it/s]

[Iteration 3] Samples in Replay Buffer: 1000 - Policy Loss: 0.827 - Value Loss: 0.159 - Total Loss: 0.987


In [4]:
predictor.eval()
with torch.no_grad():
    policy_logits, values = predictor.predict([game.current_player()], [game.get_state()], [game.possible_actions()])
    print("Policy logits for first action in TicTacToe:", [round(policy_logit, 2) for policy_logit in policy_logits.squeeze(dim = 0).tolist()])

Policy logits for first action in TicTacToe: [-0.92, -0.79, -0.72, -0.82, 2.42, -1.84, -1.07, -0.84, 2.73]


In [5]:
tree = PUCT(game, predictor)
tree.self_play(playouts_per_node = 250, training = False, render = True)

 | | 
-----
 | | 
-----
 | | 

 | | 
-----
 |X| 
-----
 | | 

 | | 
-----
 |X| 
-----
 | |O

 | | 
-----
 |X| 
-----
 |X|O

 |O| 
-----
 |X| 
-----
 |X|O

 |O|X
-----
 |X| 
-----
 |X|O

 |O|X
-----
 |X| 
-----
O|X|O

 |O|X
-----
X|X| 
-----
O|X|O

 |O|X
-----
X|X|O
-----
O|X|O

X|O|X
-----
X|X|O
-----
O|X|O



([[0,
   ((0, 0, 0), (0, 0, 0), (0, 0, 0)),
   [0, 1, 2, 3, 4, 5, 6, 7, 8],
   ([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 1, 2, 5, 239, 0, 0, 0, 0])],
  [1,
   ((0, 0, 0), (0, -1, 0), (0, 0, 0)),
   [0, 1, 2, 3, 5, 6, 7, 8],
   ([0, 1, 2, 3, 5, 6, 7, 8], [2, 2, 3, 13, 10, 9, 1, 448])],
  [0,
   ((0, 0, 0), (0, 1, 0), (0, 0, -1)),
   [0, 1, 2, 3, 5, 6, 7],
   ([0, 1, 2, 3, 5, 6, 7], [8, 52, 52, 15, 97, 8, 465])],
  [1,
   ((0, 0, 0), (0, -1, 0), (0, -1, 1)),
   [0, 1, 2, 3, 5, 6],
   ([0, 1, 2, 3, 5, 6], [12, 689, 3, 5, 3, 2])],
  [0,
   ((0, -1, 0), (0, 1, 0), (0, 1, -1)),
   [0, 2, 3, 5, 6],
   ([0, 2, 3, 5, 6], [23, 447, 220, 236, 12])],
  [1,
   ((0, 1, -1), (0, -1, 0), (0, -1, 1)),
   [0, 3, 5, 6],
   ([0, 3, 5, 6], [10, 4, 13, 669])],
  [0,
   ((0, -1, 1), (0, 1, 0), (-1, 1, -1)),
   [0, 3, 5],
   ([0, 3, 5], [177, 563, 178])],
  [1, ((0, 1, -1), (-1, -1, 0), (1, -1, 1)), [0, 5], ([0, 5], [4, 808])],
  [0, ((0, -1, 1), (1, 1, -1), (-1, 1, -1)), [0], ([0], [1057])]],
 [0, 0])

In [6]:
tree = PUCT(game, predictor)
tree.human_play(human_player = 0, playouts_per_node = 250, training = False, render = True)

 | | 
-----
 | | 
-----
 | | 

Possible actions: [0, 1, 2, 3, 4, 5, 6, 7, 8]


>  4


 | | 
-----
 |X| 
-----
 | | 

 | | 
-----
 |X| 
-----
 | |O

Possible actions: [0, 1, 2, 3, 5, 6, 7]


>  2


 | |X
-----
 |X| 
-----
 | |O

 | |X
-----
 |X| 
-----
O| |O

Possible actions: [0, 1, 3, 5, 7]


>  7


 | |X
-----
 |X| 
-----
O|X|O

 |O|X
-----
 |X| 
-----
O|X|O

Possible actions: [0, 3, 5]


>  3


 |O|X
-----
X|X| 
-----
O|X|O

 |O|X
-----
X|X|O
-----
O|X|O

Possible actions: [0]


>  0


X|O|X
-----
X|X|O
-----
O|X|O



([[0,
   ((0, 0, 0), (0, 0, 0), (0, 0, 0)),
   [0, 1, 2, 3, 4, 5, 6, 7, 8],
   ([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 1, 2, 10, 3, 1, 230, 0, 0])],
  [1,
   ((0, 0, 0), (0, -1, 0), (0, 0, 0)),
   [0, 1, 2, 3, 5, 6, 7, 8],
   ([0, 1, 2, 3, 5, 6, 7, 8], [2, 1, 5, 6, 1, 11, 1, 225])],
  [0,
   ((0, 0, 0), (0, 1, 0), (0, 0, -1)),
   [0, 1, 2, 3, 5, 6, 7],
   ([0, 1, 2, 3, 5, 6, 7], [7, 21, 257, 5, 44, 8, 132])],
  [1,
   ((0, 0, -1), (0, -1, 0), (0, 0, 1)),
   [0, 1, 3, 5, 6, 7],
   ([0, 1, 3, 5, 6, 7], [3, 6, 3, 13, 479, 2])],
  [0,
   ((0, 0, 1), (0, 1, 0), (-1, 0, -1)),
   [0, 1, 3, 5, 7],
   ([0, 1, 3, 5, 7], [1, 1, 1, 12, 713])],
  [1,
   ((0, 0, -1), (0, -1, 0), (1, -1, 1)),
   [0, 1, 3, 5],
   ([0, 1, 3, 5], [3, 940, 4, 15])],
  [0,
   ((0, -1, 1), (0, 1, 0), (-1, 1, -1)),
   [0, 3, 5],
   ([0, 3, 5], [179, 472, 538])],
  [1, ((0, 1, -1), (-1, -1, 0), (1, -1, 1)), [0, 5], ([0, 5], [14, 707])],
  [0, ((0, -1, 1), (1, 1, -1), (-1, 1, -1)), [0], ([0], [956])]],
 [0, 0])