In [1]:
import torch
import torch.nn as nn
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

In [2]:
#
# Hyper parameters
# 
alpha = 0.002
gamma = 0.99
entropy_coefficient = 2.5
stabilizer = 1e-8

In [3]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(126, 294),
    nn.ReLU(),
    nn.Linear(294, 294),
    nn.ReLU(),
    nn.Linear(294, 7),
    nn.Softmax(dim=-1))
games = 0

In [4]:
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)

In [5]:
#
# Load model from checkpoint
#
games = 6000000
cp = torch.load(f'connect4-{games}.nn');
model.load_state_dict(cp['model_state_dict']);
optimizer.load_state_dict(cp['optimizer_state_dict']);

In [7]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4.onnx", );

    print(f"{step}: checkpoint saved.")
    if train: model.train();


In [8]:
#
# TRAINING
#
log_interval = 50000
validation_interval = 100000
validation_games = 10000
checkpoint_interval = 100000
target_games = 10000000
losses = []

for _ in range(target_games-games):
    env = Connect4()
    games += 1    

    episode = connect4.generateEpisode(model)

    returns = []
    R = 0
    for _, _, reward in reversed(episode):
        R = reward - gamma * R
        returns.insert(0, R)
    
    states, actions, rewards = zip(*episode)
    states_tensor = torch.stack(states)         # [t,126]
    actions_tensor = torch.LongTensor(actions)  # [t]
    returns_tensor = torch.FloatTensor(returns) # [t]

    # normalize returns
    returns_tensor = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + stabilizer)

    baseline = returns_tensor.mean().detach()

    model.train()
    action_probs = model(states_tensor) # [t, 7]
    chosen_probs = action_probs.gather(1, actions_tensor.unsqueeze(1)) # unsqueeze(1) -> [t,1]
    chosen_probs += stabilizer
    log_probs = torch.log(chosen_probs).squeeze() # [t,1] -> squeeze -> [t]

    policy_loss = -(returns_tensor - baseline) * log_probs

    entropy = -torch.sum(action_probs * torch.log(action_probs), dim=1)

    loss = policy_loss.mean() - entropy_coefficient * entropy.mean()
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []

    if games % checkpoint_interval == 0:
        checkpoint(games)
    if games % validation_interval == 0:
        print(f'{games}: validating...')
        connect4.validate(model, validation_games)


#### 126-294-ReLU-294-ReLU-7
- alpha: 0.002
- gamma: 0.99
- entropy_coefficient:
    - 1.0
    - 2.0 since 4,000,000
    - 2.5 since 6,000,000
- validating against sampling from the model since 6,000,000

#### Results
|Games|Loss|Red|Yellow||
|--:|:--|:-:|:-:|:-:|
| 2,000,000 | -1.9643527344226837 | 99.51 0.00 0.49 | 97.00 0.02 2.98 |
| 4,000,000 | -1.9719070252478124 | 99.53 0.01 0.46 | 99.17 0.01 0.82 | entropy coeff **2.0** |
| 6,000,000 | -3.9055097838521005 | 99.81 0.00 0.19 | 99.41 0.02 0.57 | entropy coeff **2.5**, sampling opponent |
