# Connect4 GPU Training (Policy CNN)

Train a CNN to predict the best move from 6x7x2 boards.

In [None]:
# Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')

PROJECT_DIR = '/content/drive/MyDrive/Connect4_Project'
COMBINED_DIR = '/content/drive/MyDrive/Connect4_Combined'
DATASET_DIR = f'{PROJECT_DIR}/datasets'
MODEL_DIR = f'{COMBINED_DIR}/models'  # save CNN & Transformer to same folder
os.makedirs(MODEL_DIR, exist_ok=True)

print('Drive mounted')
print('Dataset dir:', DATASET_DIR)
print('Model dir:', MODEL_DIR)

In [None]:
# Install/Import
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

print('TF version:', tf.__version__)

In [None]:
# Config â€” uses combined unique dataset from Connect4_Combined
DATASET_PATH = '/content/drive/MyDrive/Connect4_Combined/datasets/connect4_combined_unique.npz'
BATCH_SIZE = 256
EPOCHS = 40
VAL_SPLIT = 0.1
TEST_SPLIT = 0.1
SEED = 42

# Mixed precision for GPU speed
USE_MIXED_PRECISION = True

In [None]:
# Mixed precision
if USE_MIXED_PRECISION:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
    print('Mixed precision enabled')

In [None]:
# Load dataset (combined unique has pre-split train/val/test)
npz = np.load(DATASET_PATH)
if 'X_val' in npz:
    X_train = npz['X_train']
    y_train = npz['y_train']
    X_val = npz['X_val']
    y_val = npz['y_val']
    X_test = npz['X_test']
    y_test = npz['y_test']
    print('Loaded pre-split: train/val/test')
else:
    X = npz['X_train']
    y = npz['y_train']
    X_train, X_val, X_test, y_train, y_val, y_test = None, None, None, None, None, None

print('X_train:', X_train.shape if X_train is not None else X.shape)
print('y_train:', y_train.shape if y_train is not None else y.shape)

In [None]:
# Train/val/test split (only if not pre-split)
if X_train is None:
    np.random.seed(SEED)
    idx = np.random.permutation(len(X))
    X, y = X[idx], y[idx]
    n = len(X)
    ntest, nval = int(n * TEST_SPLIT), int(n * VAL_SPLIT)
    X_test, y_test = X[:ntest], y[:ntest]
    X_val, y_val = X[ntest:ntest+nval], y[ntest:ntest+nval]
    X_train, y_train = X[ntest+nval:], y[ntest+nval:]

print('Train:', X_train.shape, y_train.shape)
print('Val:', X_val.shape, y_val.shape)
print('Test:', X_test.shape, y_test.shape)

In [None]:
# tf.data pipeline (one-hot labels for label smoothing)
NUM_CLASSES = 7

def to_one_hot(x, y):
    return x, tf.one_hot(tf.cast(y, tf.int32), NUM_CLASSES)

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_ds = train_ds.shuffle(10000, seed=SEED).map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_ds = val_ds.map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_ds = test_ds.map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
# Model
inputs = keras.Input(shape=(6, 7, 2))

x = layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)

def residual_block(x, filters=64):
    shortcut = x
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

for _ in range(6):
    x = residual_block(x, 64)

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128, activation='relu')(x)
outputs = layers.Dense(7, activation='softmax', dtype='float32')(x)

model = keras.Model(inputs, outputs)
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
    metrics=[keras.metrics.CategoricalAccuracy(name='accuracy'), keras.metrics.TopKCategoricalAccuracy(k=2, name='top2')]
)

model.summary()

In [None]:
# Callbacks
ckpt_path = f"{MODEL_DIR}/connect4_cnn_best.keras"
callbacks = [
    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_accuracy', save_best_only=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-5),
    keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=6, restore_best_weights=True),
]

In [None]:
# Train
history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=callbacks)

In [None]:
# Evaluate
test_metrics = model.evaluate(test_ds)
print('Test metrics:', dict(zip(model.metrics_names, test_metrics)))

In [None]:
# Save final model
final_path = f"{MODEL_DIR}/connect4_cnn_final.keras"
model.save(final_path)
print('Saved:', final_path)

## Evaluation vs Weak MCTS

In [None]:
# Simple Connect4 engine for evaluation
import numpy as np
import random

class Connect4Eval:
    def __init__(self):
        self.board = np.zeros((6, 7), dtype=np.int8)
        self.heights = np.zeros(7, dtype=np.int8)
        self.current_player = 1
        self.winner = None
        self.move_count = 0

    def copy(self):
        g = Connect4Eval()
        g.board = self.board.copy()
        g.heights = self.heights.copy()
        g.current_player = self.current_player
        g.winner = self.winner
        g.move_count = self.move_count
        return g

    def legal_moves(self):
        return [c for c in range(7) if self.heights[c] < 6]

    def make_move(self, col):
        if self.heights[col] >= 6:
            return False
        row = self.heights[col]
        self.board[row, col] = self.current_player
        self.heights[col] += 1
        self.move_count += 1
        if self._check_win(row, col):
            self.winner = self.current_player
        elif self.move_count >= 42:
            self.winner = 0
        self.current_player *= -1
        return True

    def _check_win(self, row, col):
        player = self.board[row, col]
        # Vertical
        if row <= 2 and np.sum(self.board[row:row+4, col]) == 4 * player:
            return True
        # Horizontal
        for c in range(max(0, col-3), min(4, col+1)):
            if np.sum(self.board[row, c:c+4]) == 4 * player:
                return True
        # Diagonals
        for dr, dc in [(1, 1), (1, -1)]:
            count = 1
            for sign in [1, -1]:
                r, c = row + sign*dr, col + sign*dc
                while 0 <= r < 6 and 0 <= c < 7 and self.board[r, c] == player:
                    count += 1
                    r += sign*dr
                    c += sign*dc
            if count >= 4:
                return True
        return False

    def is_terminal(self):
        return self.winner is not None

    def encode(self, perspective=1):
        b = self.board if perspective == 1 else -self.board
        enc = np.zeros((6, 7, 2), dtype=np.float32)
        enc[:, :, 0] = (b == 1)
        enc[:, :, 1] = (b == -1)
        return enc

def find_winning_move(game, player):
    for col in game.legal_moves():
        test = game.copy()
        old_p = test.current_player
        test.current_player = player
        test.make_move(col)
        if test.winner == player:
            return col
        test.current_player = old_p
    return None

def policy_move_with_rules(game, model, perspective=1):
    win_col = find_winning_move(game, game.current_player)
    if win_col is not None:
        return win_col
    block_col = find_winning_move(game, -game.current_player)
    if block_col is not None:
        return block_col
    x = game.encode(perspective=perspective)[None, ...]
    probs = model.predict(x, verbose=0)[0]
    legal = game.legal_moves()
    mask = np.full(7, -1e9, dtype=np.float32)
    for c in legal:
        mask[c] = 0.0
    scores = probs + mask
    return int(np.argmax(scores))

class WeakMCTS:
    def __init__(self, sims=100):
        self.sims = sims

    def get_move(self, game):
        win = find_winning_move(game, game.current_player)
        if win is not None:
            return win
        block = find_winning_move(game, -game.current_player)
        if block is not None:
            return block
        return self._mcts(game, self.sims)

    def _mcts(self, game, sims):
        root_player = game.current_player
        stats = {}
        for _ in range(sims):
            node = game.copy()
            path = []
            while not node.is_terminal():
                state = hash(node.board.tobytes())
                path.append(state)
                if state not in stats:
                    stats[state] = [0, 0.0]
                moves = node.legal_moves()
                if not moves:
                    break
                best_move = None
                best_ucb = -1e9
                parent_visits = stats[state][0]
                for col in moves:
                    test = node.copy()
                    test.make_move(col)
                    child = hash(test.board.tobytes())
                    if child not in stats:
                        stats[child] = [0, 0.0]
                    visits, value = stats[child]
                    if visits == 0:
                        ucb = 1e9
                    else:
                        exploit = value / visits
                        explore = 1.4 * np.sqrt(np.log(parent_visits + 1) / visits)
                        ucb = exploit + explore
                    if ucb > best_ucb:
                        best_ucb = ucb
                        best_move = col
                node.make_move(best_move)
                if stats[hash(node.board.tobytes())][0] == 0:
                    path.append(hash(node.board.tobytes()))
                    break
            # rollout
            depth = 0
            while not node.is_terminal() and depth < 12:
                moves = node.legal_moves()
                if not moves:
                    break
                node.make_move(random.choice(moves))
                depth += 1
            result = 1.0 if node.winner == root_player else (-1.0 if node.winner == -root_player else 0.0)
            for st in path:
                if st in stats:
                    stats[st][0] += 1
                    stats[st][1] += result
        # choose best
        best_move, best_val = None, -1e9
        for col in game.legal_moves():
            test = game.copy()
            test.make_move(col)
            st = hash(test.board.tobytes())
            if st in stats and stats[st][0] > 0:
                val = stats[st][1] / stats[st][0]
            else:
                val = -1e9
            if val > best_val:
                best_val, best_move = val, col
        return best_move if best_move is not None else random.choice(game.legal_moves())

weak_mcts = WeakMCTS(sims=120)

In [None]:
# Evaluate policy model vs weak MCTS (policy_move_with_rules: win/block + legal masking)
EVAL_GAMES = 50

def find_winning_move(game, player):
    for col in game.legal_moves():
        test = game.copy()
        old_p = test.current_player
        test.current_player = player
        test.make_move(col)
        if test.winner == player:
            return col
        test.current_player = old_p
    return None

def policy_move_with_rules(game, model, perspective=1):
    win_col = find_winning_move(game, game.current_player)
    if win_col is not None:
        return win_col
    block_col = find_winning_move(game, -game.current_player)
    if block_col is not None:
        return block_col
    x = game.encode(perspective=perspective)[None, ...]
    probs = model.predict(x, verbose=0)[0]
    legal = game.legal_moves()
    mask = np.full(7, -1e9, dtype=np.float32)
    for c in legal:
        mask[c] = 0.0
    scores = probs + mask
    return int(np.argmax(scores))

wins = 0
losses = 0
ties = 0

for _ in range(EVAL_GAMES):
    game = Connect4Eval()
    while not game.is_terminal():
        if game.current_player == 1:
            col = policy_move_with_rules(game, model, perspective=1)
        else:
            col = weak_mcts.get_move(game)
        game.make_move(col)
    if game.winner == 1:
        wins += 1
    elif game.winner == -1:
        losses += 1
    else:
        ties += 1

print(f"Policy vs weak MCTS: W {wins}, L {losses}, T {ties} (games={EVAL_GAMES})")

## Fine-Tune on High-Depth Dataset (Optional)

In [None]:
# Optional high-depth dataset
HIGH_DEPTH_DATASET_PATH = f'{DATASET_DIR}/connect4_high_depth.npz'
FINE_TUNE_EPOCHS = 10
FINE_TUNE_LR = 5e-4

import os
if os.path.exists(HIGH_DEPTH_DATASET_PATH):
    npz2 = np.load(HIGH_DEPTH_DATASET_PATH)
    X_hd = npz2['X_train']
    y_hd = npz2['y_train']

    idx = np.random.permutation(len(X_hd))
    X_hd = X_hd[idx]
    y_hd = y_hd[idx]

    n = len(X_hd)
    nval = int(n * 0.1)
    X_hd_val, y_hd_val = X_hd[:nval], y_hd[:nval]
    X_hd_train, y_hd_train = X_hd[nval:], y_hd[nval:]

    hd_train_ds = tf.data.Dataset.from_tensor_slices((X_hd_train, y_hd_train))
    hd_train_ds = hd_train_ds.shuffle(10000, seed=SEED).map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    hd_val_ds = tf.data.Dataset.from_tensor_slices((X_hd_val, y_hd_val))
    hd_val_ds = hd_val_ds.map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    model.compile(
        optimizer=keras.optimizers.Adam(FINE_TUNE_LR),
        loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
        metrics=[keras.metrics.CategoricalAccuracy(name='accuracy'), keras.metrics.TopKCategoricalAccuracy(k=2, name='top2')]
    )

    model.fit(hd_train_ds, validation_data=hd_val_ds, epochs=FINE_TUNE_EPOCHS)
else:
    print('High-depth dataset not found, skipping fine-tune')

## Self-Play Improvement Loop

In [None]:
SELFPLAY_ITERATIONS = 3
SELFPLAY_GAMES_PER_ITER = 300
SELFPLAY_EPOCHS = 3
SELFPLAY_EPSILON = 0.1


def generate_selfplay_data(model, num_games):
    X_sp = []
    y_sp = []
    for _ in range(num_games):
        game = Connect4Eval()
        while not game.is_terminal():
            if random.random() < SELFPLAY_EPSILON:
                col = random.choice(game.legal_moves())
            else:
                col = policy_move_with_rules(game, model, perspective=1 if game.current_player == 1 else -1)
            # store from current player as +1 perspective
            board = game.encode(perspective=game.current_player)
            X_sp.append(board)
            y_sp.append(col)
            game.make_move(col)
    return np.array(X_sp, dtype=np.float32), np.array(y_sp, dtype=np.int8)

for it in range(SELFPLAY_ITERATIONS):
    print(f"Self-play iteration {it+1}/{SELFPLAY_ITERATIONS}")
    X_sp, y_sp = generate_selfplay_data(model, SELFPLAY_GAMES_PER_ITER)

    sp_ds = tf.data.Dataset.from_tensor_slices((X_sp, y_sp))
    sp_ds = sp_ds.shuffle(10000, seed=SEED).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    model.fit(sp_ds, epochs=SELFPLAY_EPOCHS)

    # quick eval vs weak MCTS
    wins = 0
    losses = 0
    ties = 0
    for _ in range(20):
        game = Connect4Eval()
        while not game.is_terminal():
            if game.current_player == 1:
                col = policy_move_with_rules(game, model, perspective=1)
            else:
                col = weak_mcts.get_move(game)
            game.make_move(col)
        if game.winner == 1:
            wins += 1
        elif game.winner == -1:
            losses += 1
        else:
            ties += 1
    print(f"Eval vs weak MCTS: W {wins}, L {losses}, T {ties}")