In [158]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

In [159]:
#
# Hyper parameters
# 
alpha = 0.1
gamma = 0.9
epsilon = 0.5

In [160]:
#
# Create the model and optimizer
# 
class Connect4Cnn(nn.Module):
    def __init__(self):
        super(Connect4Cnn, self).__init__()
        self.conv1 = nn.Conv2d(3, 64,  kernel_size=3, padding=1)    
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.out = nn.Linear(128*6*7, 7)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1)
        x = self.out(x)
        return x

model = Connect4Cnn()
optimizer = torch.optim.SGD(model.parameters(), lr = alpha)
sum(p.numel() for p in model.parameters() if p.requires_grad)

260871

In [161]:
#
# Load model from checkpoint
#
games = 0
if games > 0:
    cp = torch.load(f'connect4-{games}.nn');
    model.load_state_dict(cp['model_state_dict']);
    optimizer.load_state_dict(cp['optimizer_state_dict']);

In [162]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4-{step}.onnx");

    print(f"{step}: checkpoint saved.")
    if train: model.train();


In [163]:
#
# Validation
#
@torch.no_grad()
def validate(numberOfGames):
    print('Validation...')
    xw, xd, xl, ow, od, ol  = connect4.validate(model, numberOfGames)
    xtotal = xw + xd + xl
    ototal = ow + od + ol
    print(f"Cross: {100*xw/xtotal:.2f}% of {xtotal} ({xw}/{xd}/{xl})")
    print(f"Circle: {100*(ow+od)/ototal:.2f}% of {ototal} ({ow}/{od}/{ol})")

In [164]:
#
# TRAINING
#
log_interval = 5000
validation_interval = 50000
validation_games = 10000
checkpoint_interval = 50000
losses = []
model.train()
for _ in range(500000):
    env = Connect4()
    done = False
    games += 1
    loss = 0
    moves = 0

    while not done:
        q = model(env.state)
        targetq = q.detach().clone()

        e = random.uniform(0, 1)
        if e < epsilon:
            action = random.choice([a for a in range(7) if env.is_valid(a)])
        else:
            action = max([a for a in range(7) if env.is_valid(a)], key=lambda x: q[x])
        
        env.move(action)

        if env.winner != 0:
            targetq[action] = 1
            done = True
        elif env.full:
            targetq[action] = 0 #if env.player == 2 else 0.1
            done = True
        else:
            model.eval()
            with torch.no_grad():
                next_q = model(env.state)
                next_max = -max([next_q[a] for a in range(7) if env.is_valid(a)])
                reward = -0.1 #if env.player == 2 else 0
                targetq[action] = reward + gamma * next_max
            model.train()
            
        loss += F.mse_loss(q, targetq)
        moves += 1

    loss /= moves
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        print(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []
    if games % validation_interval == 0:
        validate(validation_games)
    if games % checkpoint_interval == 0:
        checkpoint(games)


5000: average loss: 0.010914109992841259
10000: average loss: 0.010192029919289053
15000: average loss: 0.008831977478676709
20000: average loss: 0.008122944925760385
25000: average loss: 0.007297524475122918
30000: average loss: 0.00657030988917104
35000: average loss: 0.006192681926908699
40000: average loss: 0.005694992028763954
45000: average loss: 0.005463707554696885
50000: average loss: 0.005292380097397108
Validation...
Cross: 89.98% of 5000 (4499/0/501)
Circle: 84.68% of 5000 (4230/4/766)
50000: checkpoint...
50000: checkpoint saved.
55000: average loss: 0.005297303976061812
60000: average loss: 0.005194457059772321
65000: average loss: 0.005532179747063492
70000: average loss: 0.005403866619039763
75000: average loss: 0.005405459809898457
80000: average loss: 0.005199022535253607
85000: average loss: 0.004969472597049389
90000: average loss: 0.005013729686437
95000: average loss: 0.004881055744469631
100000: average loss: 0.004793603768121102
Validation...
Cross: 91.16% of 50

### Connect4 CNN training 

#### Layers: C64-C128-C128
&alpha;: 0.1, &gamma;: 0.9, &epsilon;: 0.5

| Games     | Loss                  | Cross     | Circle
| :-------: | :-----------:         | :-------: | :-------: 
|  50000    | 0.005292380097397108  | 89.98%    | 84.68%
| 100000    | 0.004793603768121102  | 91.16%    | 85.10%
| 150000    | 0.004057415647386188  | 86.90%    | 88.80%
| 200000    | 0.0034653012876227876 | 95.20%    | 90.72%
| 250000    | 0.003226012402748529  | 95.66%    | 93.82% 
| 300000    | 0.00299808547905468   | 94.62%    | 94.32%
| 350000    | 0.0028812737085465415 | 96.62%    | 96.18%
| 400000    ||| 
| 450000    ||| 
| 500000    ||| 

#### Layers: C64
&alpha;: 0.1, &gamma;: 0.9, &epsilon;: 0.3

| Games     | Loss                  | Cross     | Circle
| :-------: | :-----------:         | :-------: | :-------: 
| 150000    | 0.0064809466980426805 | 92.98%    | 84.62%
| 200000    | 0.006492313634124002  | 92.90%    | 67.74%
| 250000    | 0.006201987464528065  | 91.10%    | 77.10%
| 300000    | 0.006193165494049026  | 91.12%    | 81.44%
| 350000    | 0.006100989708234556  | 90.32%    | 86.66%
| 400000    | 0.00585501702687834   | 88.36%    | 84.28%
| 450000    | 0.005980592772015371  | 86.68%    | 86.90%
| 500000    | 0.0057320972953020825 | 86.78%    | 81.88%

