In [5]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
import tictactoe
importlib.reload(tictactoe);
from tictactoe import TicTacToe


In [45]:
#
# Hyper parameters
# 
alpha = 0.2
gamma = 0.9
epsilon = 0.2

In [46]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(27, 81),    
    nn.ReLU(),
    nn.Linear(81, 9)
)
games = 0
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)


In [None]:
#
# Full validation
#
def validategame(env, results):
    state, player, opponent = env.state, env.player, env.opponent
    for action in [a for a in range(9) if env.is_valid(a)]:
        env.move(action)
        if env.is_won():
            results['losses'] += 1
        elif env.is_full():
            results['draws'] += 1
        else:
            q = model(env.stateTensor)
            qa = max([a for a in range(9) if env.is_valid(a)], key = lambda x: q[x])
            env.move(qa)
            if env.is_won():
                results['wins'] += 1
            elif env.is_full():
                results['draws'] += 1
            else:
                validategame(env, results)
        env.board, env.player, env.opponent = list(state), player, opponent

def validate():
    train = model.training
    model.eval()
    
    env = TicTacToe()
    q = model(env.stateTensor)
    qa = max([a for a in range(9) if env.is_valid(a)], key = lambda x: q[x])
    env.move(qa)
    results = {'wins': 0, 'losses': 0, 'draws': 0}
    validategame(env, results)
    nonloss = results['wins'] + results['draws']
    total = results['losses'] + nonloss
    print(f"Cross: {100*nonloss/total:.2f}% of {total} ({results['wins']}/{results['draws']}/{results['losses']})")

    env = TicTacToe()
    results = {'wins': 0, 'losses': 0, 'draws': 0}
    validategame(env, results)
    nonloss = results['wins'] + results['draws']
    total = results['losses'] + nonloss
    print(f"Circle: {100*nonloss/total:.2f}% of {total} ({results['wins']}/{results['draws']}/{results['losses']})")

    if train:
        model.train()

validate()

In [None]:
#
# TRAINING
#
log_interval = 10000
losses = []
model.train()
for _ in range(1000000):
    env = TicTacToe()
    done = False
    games += 1
    loss = 0
    moves = 0

    while not done:
        q = model(env.stateTensor)
        targetq = q.detach().clone()

        e = random.uniform(0, 1)
        if moves == 0:
            e /= 2
        if  e < epsilon:
            action = random.choice([a for a in range(9) if env.is_valid(a)])
        else:
            action = max([a for a in range(9) if env.is_valid(a)], key=lambda x: q[x])
        
        env.move(action)

        if env.is_won():
            targetq[action] = 1
            done = True
        elif env.is_full():
            targetq[action] = 0
            done = True
        else:
            model.eval()
            with torch.no_grad():
                next_q = model(env.stateTensor)
                next_max = -max([next_q[a] for a in range(9) if env.is_valid(a)])
                targetq[action] = -0.1 + gamma * next_max
            model.train()
            
        loss += F.mse_loss(q, targetq)
        moves += 1

    loss /= moves
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        validate()


In [158]:
env = TicTacToe()

In [None]:
env.move(4)
env.render()

In [None]:
state = env.stateTensor
q = model(state)
action = max([a for a in range(9) if env.is_valid(a)], key=lambda x: q[x])
env.move(action)
print(action)
env.render()