In [103]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
import training as training
importlib.reload(training);


In [104]:
#
# Hyper parameter
#
gamma = 0.9 # Q-learning's discount factor, should probably stay constant at 0.9
learning_rate = 0.1 # initial learning rate
actual_learning_rate = 0.1 # used to reset optimizers from file
epsilon = 0.01
validation_games = 5000
omega = 1 # the percentage of opponent random moves during validation


In [105]:
class Connect4Cnn(nn.Module):
    def __init__(self):
        super(Connect4Cnn, self).__init__()

        self.feature_size = 32 * 6 * 7

        self.conv1 = nn.Conv2d(1, 16,  kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.out = nn.Linear(self.feature_size, 7)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        if self.training:
            x = F.dropout2d(x, p = 0.2)
        x = F.relu(self.conv2(x))
        if self.training:
            x = F.dropout2d(x, p = 0.2)
        x = x.view(-1, self.feature_size)
        x = self.out(x)
        return x


In [106]:

#
# Create the model and optimizer
# 
games = 0
model = Connect4Cnn()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
sum(p.numel() for p in model.parameters() if p.requires_grad)

14215

In [93]:
#
# Load model from checkpoint
#
games = 300000
training.loadCheckpoint(model, optimizer, f'connect4-{games}', actual_learning_rate)

In [None]:
#
# TRAINING
#
gamesToGo = 300000
training.train(model, optimizer, gamesToGo, epsilon, omega, gameOffset = games, gamma = gamma)
games += gamesToGo

### Connect4 CNN training 
#### Layers: C3/16-RL-D0.2-C3/32-RL-D0.2
optimizer: SGD initial lr: 0.1  
validation 10000 games against normalized softmax moves and &omega; random moves  
exploration via normalized softmax and &epsilon;  
loss over last 50.000 games

| Games     | lr        | &epsilon; | Loss                  | &omega;   |Cross      | Circle    | Remarks
| :-------: | :-:       | :---:     | :-----------:         | :----:    |:----:     | :-----:   | :-------:
|           | 0.1       | 0.01      |                       | 1         | 76.14%    | 70.90%    |
| 50.000    |           |           |                       |           |           |           |
| 100.000   |           |           |                       |           |           |           |
| 150.000   |           |           |                       |           |           |           |
| 200.000   |           |           |                       |           |           |           |
| 250.000   |           |           |                       |           |           |           |
| 300.000   |           |           | |           | ||