In [1]:
import torch
import torch.nn as nn
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

In [2]:
#
# Hyper parameters
# 
alpha = 0.001
gamma = 0.9
entropy_coefficient = 0.3
stabilizer = 1e-8

In [11]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(126, 294),
    nn.ReLU(),
#    nn.Linear(294, 294),
#    nn.ReLU(),
    nn.Linear(294, 7),
    nn.Softmax(dim=-1))
games = 0

In [12]:
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)

In [13]:
#
# Load model from checkpoint
#
games = 100000
cp = torch.load(f'connect4-{games}.nn');
model.load_state_dict(cp['model_state_dict']);
optimizer.load_state_dict(cp['optimizer_state_dict']);

In [5]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4.onnx", );

    print(f"{step}: checkpoint saved.")
    if train: model.train();


In [14]:
#
# TRAINING
#
log_interval = 20000
validation_interval = 100000
validation_games = 10000
checkpoint_interval = 50000
losses = []

for _ in range(1000000):
    env = Connect4()
    games += 1    

    episode = connect4.generateEpisode(model)

    returns = []
    R = 0
    for _, _, reward in reversed(episode):
        R = reward - gamma * R
        returns.insert(0, R)
    
    states, actions, rewards = zip(*episode)
    states_tensor = torch.stack(states)         # [t,126]
    actions_tensor = torch.LongTensor(actions)  # [t]
    returns_tensor = torch.FloatTensor(returns) # [t]

    # normalize returns
    returns_tensor = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + stabilizer)

    baseline = returns_tensor.mean().detach()

    model.train()
    action_probs = model(states_tensor) # [t, 7]
    chosen_probs = action_probs.gather(1, actions_tensor.unsqueeze(1)) # unsqueeze(1) -> [t,1]
    chosen_probs += stabilizer
    log_probs = torch.log(chosen_probs).squeeze() # [t,1] -> squeeze -> [t]

    policy_loss = -(returns_tensor - baseline) * log_probs

    entropy = -torch.sum(action_probs * torch.log(action_probs), dim=1)

    loss = policy_loss.mean() - entropy_coefficient * entropy.mean()
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []

    if games % checkpoint_interval == 0:
        checkpoint(games)
    if games % validation_interval == 0:
        print(f'{games}: validating...')
        connect4.validate(model, validation_games)


120000: average loss: -0.6108668301939965
140000: average loss: -0.6124337519198656
150000: checkpoint...
verbose: False, log level: Level.ERROR

150000: checkpoint saved.
160000: average loss: -0.6142979606285691
180000: average loss: -0.6143333837628364
200000: average loss: -0.6140183655276894
200000: checkpoint...
verbose: False, log level: Level.ERROR

200000: checkpoint saved.
200000: validating...
Red: 93.32 0.00 6.68
Yellow: 82.26 0.02 17.72
220000: average loss: -0.6165918832048773
240000: average loss: -0.6180396802261472
250000: checkpoint...
verbose: False, log level: Level.ERROR

250000: checkpoint saved.
260000: average loss: -0.6178141894131899
280000: average loss: -0.6186146073177456
300000: average loss: -0.6196055913552642
300000: checkpoint...
verbose: False, log level: Level.ERROR

300000: checkpoint saved.
300000: validating...
Red: 93.58 0.01 6.41
Yellow: 82.11 0.03 17.86
320000: average loss: -0.618423656925559
340000: average loss: -0.6195675715789198
350000: c

#### 126-294-ReLU-7
- alpha: 0.001
- gamma: 0.9
- entropy_coefficient: **0.3**

#### Results
|Games|Loss|Red|Yellow|
|--:|:--|:-:|:-:|
|   100,000 | -0.6103017311692238 | 88.98 0.00 11.02 | 79.90 0.00 20.10 |
|   200,000 |-0.6140183655276894|93.32 0.00 6.68|82.26 0.02 17.72|
|   300,000 |-0.6196055913552642|93.58 0.01 6.41|82.11 0.03 17.86|
|   400,000 |-0.6199633308008313|95.98 0.00 4.02|86.15 0.01 13.84|
|   500,000 |-0.6231693591535091|95.60 0.00 4.40|86.97 0.00 13.03|
|   600,000 |-0.6263493340745568|96.74 0.00 3.26|87.94 0.03 12.03|
|   700,000 |-0.626819719453156|97.11 0.00 2.89|93.06 0.00 6.94|
|   800,000 |-0.6338658374458551|98.59 0.00 1.41|93.58 0.00 6.42|
|   900,000 |-0.6380351661786438|98.05 0.00 1.95|95.94 0.01 4.05|
| 1,000,000 |-0.6418091414019466|98.53 0.00 1.47|95.95 0.01 4.04|
| 1,100,000 ||||
| 1,200,000 ||||
| 1,300,000 ||||
| 1,400,000 ||||
| 1,500,000 ||||
| 1,600,000 ||||
| 1,700,000 ||||
| 1,800,000 ||||
| 1,900,000 ||||
| 2,000,000 ||||


#### 126-294-ReLU-294-ReLU-7
- alpha: 0.0005
- gamma: 0.9
- entropy_coefficient: 0.05 (0.1 since 750000)

#### Results
|Games|Loss|Red|Yellow||
|:-:|:-:|:-:|:-:|:-:|
|550000| -0.1869731299161911 | 94.10 0.00 5.90 | 86.04 0.02 13.94|
|600000| -0.18033885388188065 | 93.48 0.02 6.50 | 87.14 0.02 12.84|
|650000| -0.1750567243643105 | 94.20 0.00 5.80 | 86.08 0.00 13.92 |
|700000| -0.1734218504589051 | 93.38 0.00 6.62 | 86.76 0.04 13.20 |
|750000| -0.17464535967707634 |96.34 0.02 3.64|86.84 0.04 13.12|
|800000|-0.23563749998174607|96.24 0.00 3.76|90.04 0.00 9.96| _doubled entropy coefficient to 0.1_ |
|850000|-0.24879778488092125|97.34 0.00 2.66|92.16 0.00 7.84|
|900000|-0.24367169531211258|98.50 0.00 1.50|92.62 0.00 7.38|
|950000|-0.25199465897753837|97.78 0.00 2.22|92.62 0.02 7.36|
|1000000|-0.25451473476402464|98.24 0.00 1.76|92.98 0.04 6.98|