In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader

In [14]:
class GoCNN(nn.Module):
    def __init__(self, board_size=9):
        super(GoCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 48, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(48, 48, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(48 * (board_size // 2)**2, 512)
        self.fc2 = nn.Linear(512, board_size * board_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

In [15]:
X = np.load('../generated_games/features-40k.npy').astype(np.float32)
Y = np.load('../generated_games/labels-40k.npy').astype(np.float32)

X = X.reshape(-1, 1, 9, 9)  # PyTorch expects [N, C, H, W]
dataset = TensorDataset(torch.tensor(X), torch.tensor(Y))
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [16]:
model = GoCNN()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(100):
    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        outputs = model(inputs)
        labels = torch.argmax(targets, dim=1)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        predicted = torch.argmax(outputs, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}, Accuracy = {accuracy:.2f}%")

torch.save(model.state_dict(), 'go_cnn.pt')

Epoch 1: Loss = 2802.2307, Accuracy = 2.15%
Epoch 2: Loss = 2769.5065, Accuracy = 2.37%
Epoch 3: Loss = 2756.2215, Accuracy = 2.41%
Epoch 4: Loss = 2735.1461, Accuracy = 2.32%
Epoch 5: Loss = 2701.7996, Accuracy = 2.62%
Epoch 6: Loss = 2675.8367, Accuracy = 2.76%
Epoch 7: Loss = 2660.6933, Accuracy = 2.69%
Epoch 8: Loss = 2652.6607, Accuracy = 2.91%
Epoch 9: Loss = 2646.1720, Accuracy = 3.11%
Epoch 10: Loss = 2640.3110, Accuracy = 3.09%
Epoch 11: Loss = 2636.9563, Accuracy = 3.03%
Epoch 12: Loss = 2633.3259, Accuracy = 3.26%
Epoch 13: Loss = 2627.5433, Accuracy = 3.25%
Epoch 14: Loss = 2624.9889, Accuracy = 3.30%
Epoch 15: Loss = 2619.7612, Accuracy = 3.48%
Epoch 16: Loss = 2619.0295, Accuracy = 3.27%
Epoch 17: Loss = 2615.7024, Accuracy = 3.43%
Epoch 18: Loss = 2610.1059, Accuracy = 3.34%
Epoch 19: Loss = 2608.4570, Accuracy = 3.40%
Epoch 20: Loss = 2606.2626, Accuracy = 3.61%
Epoch 21: Loss = 2603.4620, Accuracy = 3.60%
Epoch 22: Loss = 2599.6278, Accuracy = 3.51%
Epoch 23: Loss = 25