In [108]:
import torch
import torch.nn as nn
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

In [131]:
#
# Hyper parameters
# 
alpha = 0.005
gamma = 0.99
entropy_coefficient = 0.01

In [137]:
#
# Create the model and optimizer
# 
model = nn.Sequential(
    nn.Linear(126, 294),
#    nn.ReLU(),
#    nn.Linear(294, 294),
    nn.LayerNorm(294),
    nn.Tanh(),
    nn.Linear(294, 7),
    nn.Softmax(dim=-1))
games = 0

In [138]:
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)

In [114]:
#
# Load model from checkpoint
#
games = 250000
cp = torch.load(f'connect4-{games}.nn');
model.load_state_dict(cp['model_state_dict']);
optimizer.load_state_dict(cp['optimizer_state_dict']);

In [135]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4.onnx", );

    print(f"{step}: checkpoint saved.")
    if train: model.train();


In [139]:
#
# TRAINING
#
log_interval = 10000
validation_interval = 50000
validation_games = 5000
checkpoint_interval = 50000
losses = []

for _ in range(2000000):
    env = Connect4()
    games += 1    

    episode = connect4.generateEpisode(model)

    returns = []
    R = 0
    for _, _, reward in reversed(episode):
        R = reward - gamma * R
        returns.insert(0, R)
    
    states, actions, rewards = zip(*episode)
    states_tensor = torch.stack(states)         # [t,126]
    actions_tensor = torch.LongTensor(actions)  # [t]
    returns_tensor = torch.FloatTensor(returns) # [t]

    # normalize returns
    returns_tensor = (returns_tensor - returns_tensor.mean()) / returns_tensor.std()

    baseline = returns_tensor.mean().detach()

    model.train()
    action_probs = model(states_tensor) # [t, 7]
    chosen_probs = action_probs.gather(1, actions_tensor.unsqueeze(1)) # unsqueeze(1) -> [t,1]
    chosen_probs += 1e-8 # for numerical stability
    log_probs = torch.log(chosen_probs).squeeze() # [t,1] -> squeeze -> [t]

    policy_loss = -(returns_tensor - baseline) * log_probs

    entropy = -torch.sum(action_probs * torch.log(action_probs), dim=1)

    loss = policy_loss.mean() - entropy_coefficient * entropy.mean()
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []

    if games % checkpoint_interval == 0:
        checkpoint(games)
    if games % validation_interval == 0:
        print(f'{games}: validating...')
        connect4.validate(model, validation_games)


## Policy gradient training connect 4
### one hidden layer with 294 neurons and (norm)tanh

| Games     | Mode      | Player    | Wins  | Draws | Losses
| :-------: | :-------: | :-------: | :---: | :---: | :-----:
| 200,000   | prob      | Red       | -     | -     | -     
|           |           | Yellow    | -     | -     | -     
|           | determ    | Red       | -     | -     | -     
|           |           | Yellow    | -     | -     | -    
