In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
import tictactoe
importlib.reload(tictactoe);
from tictactoe import TicTacToe


In [2]:
#
# Hyper parameters
# 
alpha = 0.2
gamma = 0.9
epsilon = 0.3

In [3]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(27, 81),    
    nn.ReLU(),
    nn.Linear(81, 81),    
    nn.ReLU(),
    nn.Linear(81, 9)
)

optimizer = torch.optim.AdamW(model.parameters(), lr=alpha)
games = 0

In [12]:
#
#validation
#
def validate(iterations):
    train = model.training
    model.eval()
    wins = 0
    draws = 0
    losses = 0
    for _ in range(iterations):
        env = TicTacToe()
        qplayer = random.choice([env.player, env.opponent])
        done = False
        while not done:
            state = env.stateTensor
            if qplayer == env.player:
                q = model(state)
                action = max([a for a in range(9) if env.is_valid(a)], key = lambda x: q[x])
            else:
                action = random.choice([a for a in range(9) if env.is_valid(a)])
            env.move(action)
            if env.is_won():
                if qplayer == env.opponent:
                    wins += 1
                else:
                    losses += 1
                done = True
            elif env.is_full():
                draws += 1
                done = True

    print(f'Result: {100*(wins+draws)/iterations:.2f}% ({wins}/{draws}/{losses})')
    if train:
        model.train()

In [None]:
#
# TRAINING
#
log_interval = 5000
eval_interval = 10000
eval_iterations = 10000
losses = []
model.train()
for _ in range(2000000):
    env = TicTacToe()
    done = False
    games += 1
    loss = 0
    moves = 0

    while not done:
        state = env.stateTensor
        q = model(state)
        targetq = q.detach().clone()

        if random.uniform(0, 1) < epsilon:
            action = random.choice([a for a in range(9) if env.is_valid(a)])
        else:
            action = max([a for a in range(9) if env.is_valid(a)], key=lambda x: q[x])
        
        env.move(action)

        if env.is_won():
            targetq[action] = 1
            done = True
        elif env.is_full():
            targetq[action] = 0
            done = True
        else:
            model.eval()
            with torch.no_grad():
                next_q = model(env.stateTensor)
                next_max = -max([next_q[a] for a in range(9) if env.is_valid(a)])
            model.train()
            targetq[action] = -0.1 + gamma * next_max
            
        loss += F.mse_loss(q, targetq)
        moves += 1

    loss /= moves
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []
    if games % eval_interval == 0:
        print(f'{games}: validating...')
        validate(eval_iterations)


In [158]:
env = TicTacToe()

In [None]:
env.move(4)
env.render()

In [None]:
state = env.stateTensor
q = model(state)
action = max([a for a in range(9) if env.is_valid(a)], key=lambda x: q[x])
env.move(action)
print(action)
env.render()