In [3]:
import torch
import torch.nn as nn
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

In [1]:
#
# Hyper parameters
# 
alpha = 0.0005
gamma = 0.9
entropy_coefficient = 0.1
stabilizer = 1e-8

In [4]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(126, 294),
    nn.ReLU(),
    nn.Linear(294, 294),
    nn.ReLU(),
    nn.Linear(294, 7),
    nn.Softmax(dim=-1))
games = 0

In [5]:
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)

In [6]:
#
# Load model from checkpoint
#
games = 750000
cp = torch.load(f'connect4-{games}.nn');
model.load_state_dict(cp['model_state_dict']);
optimizer.load_state_dict(cp['optimizer_state_dict']);

In [7]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4.onnx", );

    print(f"{step}: checkpoint saved.")
    if train: model.train();


In [None]:
#
# TRAINING
#
log_interval = 10000
validation_interval = 50000
validation_games = 5000
checkpoint_interval = 50000
losses = []

for _ in range(250000):
    env = Connect4()
    games += 1    

    episode = connect4.generateEpisode(model)

    returns = []
    R = 0
    for _, _, reward in reversed(episode):
        R = reward - gamma * R
        returns.insert(0, R)
    
    states, actions, rewards = zip(*episode)
    states_tensor = torch.stack(states)         # [t,126]
    actions_tensor = torch.LongTensor(actions)  # [t]
    returns_tensor = torch.FloatTensor(returns) # [t]

    # normalize returns
    returns_tensor = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + stabilizer)

    baseline = returns_tensor.mean().detach()

    model.train()
    action_probs = model(states_tensor) # [t, 7]
    chosen_probs = action_probs.gather(1, actions_tensor.unsqueeze(1)) # unsqueeze(1) -> [t,1]
    chosen_probs += stabilizer
    log_probs = torch.log(chosen_probs).squeeze() # [t,1] -> squeeze -> [t]

    policy_loss = -(returns_tensor - baseline) * log_probs

    entropy = -torch.sum(action_probs * torch.log(action_probs), dim=1)

    loss = policy_loss.mean() - entropy_coefficient * entropy.mean()
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []

    if games % checkpoint_interval == 0:
        checkpoint(games)
    if games % validation_interval == 0:
        print(f'{games}: validating...')
        connect4.validate(model, validation_games)


In [None]:
connect4.validate(model, 10000)

#### 126-294-ReLU-294-ReLU-7
- alpha: 0.0005
- gamma: 0.9
- entropy_coefficient: 0.05 (0.1 since 750000)

#### Results
|Games|Loss|Red|Yellow||
|:-:|:-:|:-:|:-:|:-:|
|550000| -0.1869731299161911 | 94.10 0.00 5.90 | 86.04 0.02 13.94|
|600000| -0.18033885388188065 | 93.48 0.02 6.50 | 87.14 0.02 12.84|
|650000| -0.1750567243643105 | 94.20 0.00 5.80 | 86.08 0.00 13.92 |
|700000| -0.1734218504589051 | 93.38 0.00 6.62 | 86.76 0.04 13.20 |
|750000| -0.17464535967707634 |96.34 0.02 3.64|86.84 0.04 13.12|
|800000|-0.23563749998174607|96.24 0.00 3.76|90.04 0.00 9.96| _doubled entropy coefficient to 0.1_ |
|850000|-0.24879778488092125|97.34 0.00 2.66|92.16 0.00 7.84|
|900000|-0.24367169531211258|98.50 0.00 1.50|92.62 0.00 7.38|
|950000|-0.25199465897753837|97.78 0.00 2.22|92.62 0.02 7.36|
|1000000|-0.25451473476402464|98.24 0.00 1.76|92.98 0.04 6.98|