In [None]:
from dlgo.data.parallel_processor import GoDataProcessor
from dlgo.encoders.simple import SimpleEncoder
from layers import layers

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os

Data Processing

In [None]:
import glob
print(glob.glob('GoBot/data/*_features_*.npy'))


In [None]:
# if .npy files already exist
def colab_safe_map_to_workers(self, data_type, samples):
    print(">> [Colab] Skipping map_to_workers to avoid multiprocessing.")
    return

GoDataProcessor.map_to_workers = colab_safe_map_to_workers


In [None]:
board_size = 19
num_classes = board_size * board_size
num_games = 100

encoder = SimpleEncoder((board_size, board_size))

processor = GoDataProcessor(encoder=encoder.name(), data_directory='data')

generator = processor.load_go_data('train', num_games, use_generator=True)
test_generator = processor.load_go_data('test', num_games, use_generator=True)

In [None]:
x_batch, y_batch = next(generator.generate(batch_size=128))
print(x_batch.shape, y_batch.shape)
x_batch, y_batch = next(test_generator.generate(batch_size=128))
print(x_batch.shape, y_batch.shape)

In [None]:
class GoDatasetWrapper(Dataset):
    def __init__(self, generator, batch_size, num_classes):
        self.generator = generator.generate(batch_size, num_classes)
        self.num_samples = generator.get_num_samples()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # one batch
        X, y = next(self.generator)
        X = torch.tensor(X, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)  # class indices for CrossEntropyLoss
        return X, y


Defining the Model

In [None]:
class BetterGoCNN(nn.Module):
    def __init__(self, board_size=19):
        super(BetterGoCNN, self).__init__()
        input_shape = (11, board_size, board_size)

        self.model = layers(input_shape)

    def forward(self, x):
        return self.model(x)
        

Training Loop

In [None]:
model = BetterGoCNN(board_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

train_dataset = GoDatasetWrapper(generator, batch_size=128, num_classes=num_classes)
test_dataset = GoDatasetWrapper(test_generator, batch_size=128, num_classes=num_classes)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
print(f"Train loader has {len(train_loader)} batches")
print(f"Test loader has {len(test_loader)} batches")

In [None]:
epochs = 20
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.argmax(dim=1)  # convert one-hot to class index

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_correct += (predicted == targets).sum().item()
        train_total += targets.size(0)

    train_accuracy = train_correct / train_total

    # === Validation phase ===
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0

    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc=f"Validation Epoch {epoch+1}"):
            inputs, targets = inputs.to(device), targets.to(device)
            targets = targets.argmax(dim=1)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == targets).sum().item()
            val_total += targets.size(0)

    val_accuracy = val_correct / val_total

    print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}")
    print(f"Val   Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}")
    print("-" * 40)

    torch.save(model.state_dict(), f"small_model_epoch_{epoch+1}.pth")