In [221]:
import datetime
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
import connect4
importlib.reload(connect4);
from connect4 import Connect4

def log(message):
    print(f"[{datetime.datetime.now().strftime('%H:%M:%S')}] {message}")


In [241]:
#
# Create the model and optimizer
# 
class Connect4Cnn(nn.Module):
    def __init__(self):
        super(Connect4Cnn, self).__init__()
        self.conv1 = nn.Conv2d(3, 64,  kernel_size=3, padding=1)    
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.out = nn.Linear(128*6*7, 7)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1)
        x = self.out(x)
        return x

model = Connect4Cnn()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)
sum(p.numel() for p in model.parameters() if p.requires_grad)

260871

In [242]:
#
# Load model from checkpoint
#
games = 0
if games > 0:
    cp = torch.load(f'connect4-{games}.nn');
    model.load_state_dict(cp['model_state_dict']);
    optimizer.load_state_dict(cp['optimizer_state_dict']);

In [230]:
@torch.no_grad()
def checkpoint(step):
    train = model.training
    if train: model.eval();
    log(f"{step}: checkpoint...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'connect4-{step}.nn');

    dummy_input = Connect4().state
    torch.onnx.export(model, dummy_input, f"connect4-{step}.onnx");

    log(f"{step}: checkpoint saved.")
    if train: model.train();


In [231]:
#
# Validation
#
@torch.no_grad()
def validate(numberOfGames):
    log('Validation...')
    xw, xd, xl, ow, od, ol  = connect4.validate(model, numberOfGames)
    xtotal = xw + xd + xl
    ototal = ow + od + ol
    log(f"Cross: {100*xw/xtotal:.2f}% of {xtotal} ({xw}/{xd}/{xl})")
    log(f"Circle: {100*(ow+od)/ototal:.2f}% of {ototal} ({ow}/{od}/{ol})")

In [243]:
#
# TRAINING
#

log(f"Starting training at {games} games.")

gamma = 0.9

log_interval = 5000
validation_interval = 50000
validation_games = 10000
checkpoint_interval = 50000
losses = []
model.train()
for _ in range(1000000):
    env = Connect4()
    done = False
    games += 1
    moves = 0

    qstack = []

    while env.winner == 0 and not env.full:
        q = model(env.state)
        validmoves = [a for a in range(7) if env.is_valid(a)]

        if len(validmoves) == 1:
            action = validmoves[0]
        else:
            with torch.no_grad():
                validqs = torch.tensor([q[a] for a in validmoves])
                mean = validqs.mean()
                std = validqs.std() + 1e-8
                normalizedqs = (validqs - mean) / std
                probs = F.softmax(normalizedqs, dim=0)
                action = validmoves[torch.multinomial(probs, num_samples=1)]

        # e = random.uniform(0, 1)
        # if e < epsilon:
        #     action = random.choice(validmoves)
        # else:
        #     action = max(validmoves, key=lambda x: q[x])
        
        qstack.append((q, action, validmoves))
        env.move(action)
        moves += 1

    qlist = []
    targetlist = []

    (q, action, validmoves) = qstack.pop()
    targetq = q.clone().detach()
    targetq[action] = 1 if env.winner != 0 else 0
    qlist.append(q)
    targetlist.append(targetq)

    while len(qstack) > 0:
        next_q = targetq
        nextvalidmoves = validmoves
        (q, action, validmoves) = qstack.pop()
        targetq = q.clone().detach()
        next_max = -max([next_q[a] for a in nextvalidmoves]).item()
        targetq[action] = -0.1 + gamma * next_max
        qlist.append(q)
        targetlist.append(targetq)

    qtensor = torch.stack(qlist)
    targettensor = torch.stack(targetlist)
    loss = F.mse_loss(qtensor, targettensor)
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if games % log_interval == 0:
        log(f'{games}: average loss: {sum(losses)/len(losses)}')
        losses = []
    if games % validation_interval == 0:
        validate(validation_games)
    if games % checkpoint_interval == 0:
        checkpoint(games)


[22:19:35] Starting training at 0 games.
[22:22:22] 5000: average loss: 0.01738047182228911
[22:24:58] 10000: average loss: 0.01573566099146847
[22:27:38] 15000: average loss: 0.012980273686035071
[22:30:20] 20000: average loss: 0.011195772932897672
[22:33:08] 25000: average loss: 0.010129340354594752
[22:35:59] 30000: average loss: 0.00914073964621639
[22:38:57] 35000: average loss: 0.00862481269808195
[22:42:03] 40000: average loss: 0.008505100530570053
[22:45:07] 45000: average loss: 0.008503882083877761
[22:48:12] 50000: average loss: 0.008606729895681202
[22:48:12] Validation...
[22:49:06] Cross: 99.76% of 5000 (4988/0/12)
[22:49:06] Circle: 98.90% of 5000 (4945/0/55)
[22:49:06] 50000: checkpoint...
[22:49:06] 50000: checkpoint saved.
[22:52:19] 55000: average loss: 0.00831319514827701
[22:55:31] 60000: average loss: 0.008230780222194154
[22:58:46] 65000: average loss: 0.007850818548285315
[23:02:04] 70000: average loss: 0.007655467816264718
[23:05:22] 75000: average loss: 0.00705

### Connect4 CNN training 

#### Layers: C64-C128-C128
optimizer: SGD lr: 0.1
validation against 100% random moves  
exploration via softmax  
&gamma;: 0.9  

| Games     | Loss                  | Cross     | Circle    | Remarks
| :-------: | :-----------:         | :----:    | :-----:   | :-------:
| 0         |  .                    | 81.82%    | 73.46%    | 
| 50000     | 0.008606729895681202  | 99.76%    | 98.90%
| 100000    | 0.006011312045413797  | 99.92%    | 99.66%
| 150000    | 0.004590907993056498  | 99.92%    | 99.82%
| 200000    | 0.0038953403555755358 | 99.98%    | 99.96% 
| 250000    | 0.0034589882022273742 | 100.00%   | 100.00%
| 300000    | 0.003244458292566924  | 99.98%    | 99.96% 
| 350000    | 0.0030313615739049056 | 100.00%   | 100.00%
| 400000    | 0.0030451344474326106 | 100.00%   | 99.98% 
| 450000    | 0.0029125683935615596 | 100.00%   | 100.00%
| 500000    | 0.002829687898804059  | 99.98%    | 99.96%
| 550000    | 0.002746553306104215  | 100.00%   | 99.96%
| 600000    | 0.002667246424300356  | 100.00%   | 100.00%
| 650000    | 0.0027019029050018616 | 100.00%   | 100.00%
| 700000    | 0.0024376917252463956 | 99.98%    | 100.00%
| 750000    | 0.0024801731352255955 | 100.00%   | 100.00%
| 800000    | 0.002431928619570681  | 100.00%   | 99.96%
| 850000    | 0.0024910759856210463 | 99.98%    | 99.94%
| 900000    | 0.002337201290548728  | 100.00%   | 100.00%
| 950000    | 0.0023458601327791258 | 100.00%   | 100.00%
| 1000000   | 0.002380479886350804  | 99.98%    | 99.96% 
