In [166]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

In [141]:
#
# Hyper parameters
# 
alpha = 0.4
gamma = 0.9
epsilon = 0.3

In [135]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(126, 882),
    nn.ReLU(),
    nn.Linear(882, 7)
)
games = 0

In [142]:
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)

In [None]:
#
# Load model from checkpoint
#
checkpoint = torch.load(f'connect4-{games}.nn');
model.load_state_dict(checkpoint['model_state_dict']);
optimizer.load_state_dict(checkpoint['optimizer_state_dict']);

In [145]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4-{step}.onnx");

    print(f"{step}: checkpoint saved.")
    if train: model.train();


In [167]:
#
# Validation
#
@torch.no_grad()
def validate(games):
    print('Validation...')
    wins, draws, losses = connect4.validate(model, games)
    total = wins + draws + losses
    print(f"Result: {100*(wins+draws)/total:.2f}% of {total} ({wins}/{draws}/{losses})")

@torch.no_grad()
def validate_full(loginterval = None):
    print('Full validation...')
    cross, circle = connect4.validate_full(model, loginterval=loginterval)
    nonloss = cross['wins'] + cross['draws']
    total = cross['losses'] + nonloss
    print(f"Cross: {100*nonloss/total:.2f}% of {total} ({cross['wins']}/{cross['draws']}/{cross['losses']})")

    nonloss = circle['wins'] + circle['draws']
    total = circle['losses'] + nonloss
    print(f"Circle: {100*nonloss/total:.2f}% of {total} ({circle['wins']}/{circle['draws']}/{circle['losses']})")

In [170]:
#
# TRAINING
#
log_interval = 10000
validation_interval = 20000
validation_games = 20000
checkpoint_interval = 50000
losses = []
model.train()
for _ in range(500000):
    env = Connect4()
    done = False
    games += 1
    loss = 0
    moves = 0

    while not done:
        q = model(env.state)
        targetq = q.detach().clone()

        e = random.uniform(0, 1)
#        if moves == 0:
#           e /= 2
        if e < epsilon:
            action = random.choice([a for a in range(7) if env.is_valid(a)])
        else:
            action = max([a for a in range(7) if env.is_valid(a)], key=lambda x: q[x])
        
        env.move(action)

        if env.winner != 0:
            targetq[action] = 1
            done = True
        elif env.full:
            targetq[action] = 0
            done = True
        else:
            model.eval()
            with torch.no_grad():
                next_q = model(env.state)
                next_max = -max([next_q[a] for a in range(9) if env.is_valid(a)])
                targetq[action] = -0.1 + gamma * next_max
            model.train()
            
        loss += F.mse_loss(q, targetq)
        moves += 1

    loss /= moves
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []
    if games % validation_interval == 0:
        validate(validation_games)
    if games % checkpoint_interval == 0:
        checkpoint(games)


In [None]:
validate_full(100000)