# Connect 4 - CNN-Only Training (v2)

**Project deliverables**: Train a **CNN (v2)** for 7-class move classification.
- Path auto-detect: Colab (Drive) or local `Connect4_Combined`
- On-the-fly horizontal flip augmentation
- CNN architecture + warmup/cosine LR + optional Phase 2 fine-tune
- Saves: `connect4_cnn_v2_best.keras`, `connect4_cnn_v2_final.keras`, and compatibility alias `connect4_cnn_final.keras`


## 1. Setup, Paths & Config

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Mount Drive (Colab only)
try:
    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    pass

# Path auto-detect: Colab vs local
if os.path.exists('/content/drive/MyDrive/Connect4_Combined'):
    BASE = '/content/drive/MyDrive/Connect4_Combined'
else:
    BASE = os.path.join(os.getcwd(), 'Connect4_Combined')

COMBINED_DATASET = f'{BASE}/datasets/connect4_combined_unique.npz'
MODEL_DIR = f'{BASE}/models'
os.makedirs(MODEL_DIR, exist_ok=True)

# Configurable parameters
TARGET_VAL_ACC = 0.72
EPOCHS_PHASE1 = 60
EPOCHS_FINETUNE = 10
BATCH_SIZE = 256
INIT_LR = 1e-3
FINETUNE_LR = 1e-5
WARMUP_EPOCHS = 3
EARLY_STOP_PATIENCE = 10
AUG_FLIP_PROB = 0.5
SEED = 42
NUM_CLASSES = 7

np.random.seed(SEED)
tf.random.set_seed(SEED)

print('Loading dataset...')
npz = np.load(COMBINED_DATASET)
X_train = npz['X_train']
y_train = npz['y_train']
X_val = npz['X_val']
y_val = npz['y_val']
X_test = npz['X_test']
y_test = npz['y_test']
print(f'Train: {X_train.shape[0]:,} | Val: {X_val.shape[0]:,} | Test: {X_test.shape[0]:,}')
print(f'Model dir: {MODEL_DIR}')

## 2. Data Pipeline (augmentation + one-hot)

In [None]:
def augment_and_one_hot(x, y):
    """On-the-fly horizontal flip (AUG_FLIP_PROB) + one-hot labels."""
    if tf.random.uniform(()) < AUG_FLIP_PROB:
        x = tf.image.flip_left_right(x)
        y = 6 - y
    return x, tf.one_hot(tf.cast(y, tf.int32), NUM_CLASSES)

def to_one_hot(x, y):
    return x, tf.one_hot(tf.cast(y, tf.int32), NUM_CLASSES)

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_ds = train_ds.shuffle(20000, seed=SEED).map(augment_and_one_hot, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_ds = val_ds.map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_ds = test_ds.map(to_one_hot, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
print('Data pipeline ready (augmentation on train only)')

## 3. Mixed Precision & LR Schedule

In [None]:
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
print('Mixed precision enabled')

def warmup_cosine_lr(epoch, lr):
    if epoch < WARMUP_EPOCHS:
        return INIT_LR * (epoch + 1) / WARMUP_EPOCHS
    progress = (epoch - WARMUP_EPOCHS) / max(1, EPOCHS_PHASE1 - WARMUP_EPOCHS)
    return FINETUNE_LR + 0.5 * (INIT_LR - FINETUNE_LR) * (1 + np.cos(np.pi * progress))

## 4. Build CNN (10 res blocks, 256 filters, BatchNorm)

In [None]:
L2_REG = 1e-4

def res_block(x, filters=256):
    shortcut = x
    x = layers.Conv2D(filters, 3, padding='same', kernel_regularizer=keras.regularizers.l2(L2_REG))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding='same', kernel_regularizer=keras.regularizers.l2(L2_REG))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([x, shortcut])
    return layers.ReLU()(x)

inputs = keras.Input(shape=(6, 7, 2))
x = layers.Conv2D(256, 3, padding='same', kernel_regularizer=keras.regularizers.l2(L2_REG))(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
for _ in range(10):
    x = res_block(x, 256)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(L2_REG))(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation='relu', kernel_regularizer=keras.regularizers.l2(L2_REG))(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation='softmax', dtype='float32')(x)

cnn_model = keras.Model(inputs, outputs)
cnn_model.compile(
    optimizer=keras.optimizers.Adam(INIT_LR),
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
    metrics=[
        keras.metrics.CategoricalAccuracy(name='accuracy'),
        keras.metrics.TopKCategoricalAccuracy(k=2, name='top2')
    ]
)
cnn_model.summary()

## 5. Train CNN — Phase 1 + Phase 2 (if needed)

In [None]:
cnn_ckpt = f'{MODEL_DIR}/connect4_cnn_v2_best.keras'
cnn_callbacks = [
    keras.callbacks.ModelCheckpoint(cnn_ckpt, monitor='val_accuracy', save_best_only=True),
    keras.callbacks.LearningRateScheduler(warmup_cosine_lr),
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=EARLY_STOP_PATIENCE,
        min_delta=0.002, restore_best_weights=True
    ),
]

print('Training CNN Phase 1...')
cnn_history = cnn_model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS_PHASE1, callbacks=cnn_callbacks)
cnn_val_acc = max(cnn_history.history['val_accuracy'])
print(f'CNN Phase 1 best val_accuracy: {cnn_val_acc:.4f}')

if cnn_val_acc < TARGET_VAL_ACC:
    print(f'Phase 2: Fine-tuning (val_acc {cnn_val_acc:.2%} < {TARGET_VAL_ACC:.0%})...')
    cnn_model.load_weights(cnn_ckpt)
    cnn_model.compile(
        optimizer=keras.optimizers.Adam(FINETUNE_LR),
        loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
        metrics=[keras.metrics.CategoricalAccuracy(name='accuracy'), keras.metrics.TopKCategoricalAccuracy(k=2, name='top2')]
    )
    cnn_model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS_FINETUNE)
    cnn_model.save(cnn_ckpt)

cnn_metrics = cnn_model.evaluate(test_ds)


## 6. Save Final CNN v2 Model & Summary


In [None]:
cnn_v2_final = f'{MODEL_DIR}/connect4_cnn_v2_final.keras'
cnn_compat_final = f'{MODEL_DIR}/connect4_cnn_final.keras'
cnn_model.save(cnn_v2_final)
cnn_model.save(cnn_compat_final)

print('='*60)
print('DELIVERABLES SAVED')
print('='*60)
print(f'CNN v2 best:  {cnn_ckpt}')
print(f'CNN v2 final: {cnn_v2_final}')
print(f'Compatibility alias: {cnn_compat_final}')
print('='*60)
