In [None]:
import sys
sys.path.append('../')
import datetime
import numpy as np
from board import Connect4Board
from agent import Connect4Agent, createStateTensor
from validation import validate

def log(message):
    print(f"[{datetime.datetime.now().strftime('%H:%M:%S')}] {message}")

In [None]:
#
# Hyper parameters
# 
lr = 0.001
gamma = 0.9
epsilon = 0.5
eps_min = 0.01
eps_dec = 1e-6
batch_size = 512
memory_size = 64000

In [None]:
games = 0
agent = Connect4Agent(
    gamma = gamma, 
    epsilon = epsilon, 
    lr = lr, 
    batch_size = batch_size, 
    memory_size = memory_size,
    epsilon_end = eps_min, 
    epsilon_decay = eps_dec)
agent.numberOfParameters

In [None]:
# load agent from checkpoint
games = 0
agent.loadCheckpoint(f'connect4-{games}')

In [None]:
#
# TRAINING
#
gamesToGo = 100000

log_interval = 5000

validation_interval = 10000
validation_games = 1000
omega = 1

log(f"Starting training at {games} games for {gamesToGo} games.")

for _ in range(gamesToGo):
    games += 1
    env = Connect4Board()
        
    next_state = createStateTensor(env)
    
    while not env.Finished:
        state = next_state
        action = agent.getTrainingAction(state, env.ValidMovesMask)
        env.move(action)
        next_state = createStateTensor(env)
        reward = 1 if env.Winner != Connect4Board.EMPTY else 0 if env.Full else -0.1
        agent.store_transition(state, action, next_state, env.ValidMovesMask, env.Finished, reward)

    agent.learn()

    if games % log_interval == 0:
        log(f'{games} games')
    if games % validation_interval == 0:
        log(f'Validation:')
        agent.saveCheckpoint(f'connect4-{games}')
        validate(agent, validation_games, omega)