In [None]:
import sys
sys.path.append('../')
import datetime

import torch as T

from board import Connect4Board
from board2dqn import createStateTensor
from agent import Connect4Agent, calculateReward
from validation import validate
from dqn import exportOnnx

def log(message):
    print(f"[{datetime.datetime.now().strftime('%H:%M:%S')}] {message}")

In [2]:
#
# Hyper parameters
# 
lr = 0.1
gamma = 0.9
epsilon = 0.6
eps_min = 0.2
eps_dec = 1e-7
batch_count = 1
batch_size = 128
memory_size = 128000

target_update_interval = 0

In [3]:
agent = Connect4Agent(
    lr = lr, 
    epsilon = epsilon, 
    epsilon_end = eps_min, 
    epsilon_decay = eps_dec,
    batch_size = batch_size, 
    batch_count = batch_count,
    memory_size = memory_size,
    gamma = gamma,
    targetUpdateInterval=target_update_interval
)
agent.numberOfParameters

4082183

In [4]:
# load agent from checkpoint
agent.loadCheckpoint(f'connect4')

Loaded checkpoint connect4.


In [7]:
#
# TRAINING
#
gamesToGo = 250000

log_interval = 5000

validation_interval = 10000
validation_gamesPerPlayer = 1000
validation_procsPerPlayer = 8
validation_strength = 50

lastLoggedGame = 0
games = set()
allGames = set()

log(f"Starting training for {gamesToGo} games.")

for game in range(1, gamesToGo+1):
    env = Connect4Board()
        
    next_state = createStateTensor(env)
    
    while not env.Finished:
        state = next_state
        action = agent.getTrainingAction(env)
        env.move(action)
        next_state = createStateTensor(env)
        validMovesMask = T.zeros(7, dtype=bool)
        validMovesMask[env.ValidMoves] = True
        reward = calculateReward(env)
        agent.store_transition(state, action, next_state, validMovesMask, env.Finished, reward)

    games.add(env.gameKey)
    allGames.add(env.gameKey)
    
    agent.learn()

    if game % log_interval == 0:
        log(f'{game} games, div: {100*len(games)/(game+1-lastLoggedGame):.2f} / {100*len(allGames)/(game+1):.2f}')
        games.clear()
        lastLoggedGame = game
        agent.printStats()
    if game % validation_interval == 0:
        agent.saveCheckpoint(f'connect4-{game}')
        log(f'Validation:')
        validate(agent.evaluationModel, validation_gamesPerPlayer, validation_procsPerPlayer, validation_strength)

[00:11:36] Starting training for 250000 games.
[00:14:41] 5000 games, div: 99.82 / 99.82
Average loss (last 5000): 0.005971367624751292, last: 0.00495483260601759, epsilon: 0.5745007000134217
[00:17:52] 10000 games, div: 99.76 / 99.70
Average loss (last 5000): 0.006183126469305716, last: 0.012912790291011333, epsilon: 0.5740007000136849
Checkpoint 'connect4-10000' saved.
[00:17:52] Validation:
Validation with 1000 games per player on 8 processes each, MCTS with 50 games.
Player 1: 958 won, 39 lost, 3 draws -> 95.80%, div: 73.70%
Player 2: 929 won, 67 lost, 4 draws -> 92.90%, div: 95.80%
[00:24:11] 15000 games, div: 99.82 / 99.59
Average loss (last 5000): 0.006220744311483577, last: 0.0033492110669612885, epsilon: 0.573500700013948
[00:27:19] 20000 games, div: 99.92 / 99.55
Average loss (last 5000): 0.006057008034735918, last: 0.00430223299190402, epsilon: 0.5730007000142112
Checkpoint 'connect4-20000' saved.
[00:27:19] Validation:
Validation with 1000 games per player on 8 processes ea

In [5]:
validate(agent.evaluationModel, 1000, 8, 50)

Validation with 1000 games per player on 8 processes each, MCTS with 50 games.
Player 1: 963 won, 35 lost, 2 draws -> 96.30%, div: 77.50%
Player 2: 931 won, 68 lost, 1 draws -> 93.10%, div: 96.20%


In [None]:
exportOnnx(agent.evaluationModel, 'connect4')