## Transformer

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import random

class FixedSubblockExtractor(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.coords = [
            (0, 0), (0, 2), (0, 4),
            (2, 0), (2, 2), (2, 4)]
        self.flat_dim = 18  # 3 * 3 * 2

    def call(self, x):
        blocks = []
        for r, c in self.coords:
            block = x[:, r:r+3, c:c+3, :]
            block = tf.reshape(block, (-1, self.flat_dim))
            blocks.append(block)
        out = tf.stack(blocks, axis=1)
        return tf.ensure_shape(out, (None, 6, self.flat_dim))


class AdditivePositionalEncoding(layers.Layer):
    def __init__(self, num_tokens, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.pos_emb = self.add_weight(
            shape=(1, num_tokens, embed_dim),
            initializer="random_normal",
            trainable=True
        )

    def call(self, x):
        return x + self.pos_emb


class ClassToken(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.cls = self.add_weight(
            shape=(1, 1, embed_dim),
            initializer="zeros",
            trainable=True
        )

    def call(self, x):
        batch = tf.shape(x)[0]
        cls = tf.broadcast_to(self.cls, [batch, 1, tf.shape(x)[-1]])
        return tf.keras.layers.Concatenate(axis=1)([cls, x])
    

class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout, **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            dropout=dropout)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = tf.keras.Sequential([
            layers.Dense(mlp_dim, activation="gelu"),
            layers.Dropout(dropout),
            layers.Dense(embed_dim),
            layers.Dropout(dropout)])

    def call(self, x, training=False):
        attn_out = self.attn(self.norm1(x), self.norm1(x), training=training)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x), training=training)
        return x

def build_transformer(config):
    inputs = tf.keras.Input(shape=(6, 7, 2))

    x = FixedSubblockExtractor()(inputs)  # (B, 6, 18)

    x = layers.Dense(config.embed_dim, use_bias=False)(x)

    x = AdditivePositionalEncoding(num_tokens = 6, embed_dim = config.embed_dim)(x)
    
    x = ClassToken(config.embed_dim)(x)

    for _ in range(config["num_layers"]):
        x = TransformerEncoder(
            embed_dim=config.embed_dim,
            num_heads=config.num_heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout)(x)

    cls_out = x[:, 0]  # CLS token

    cls = layers.LayerNormalization(epsilon=1e-6)(cls_out)

    cls = layers.Dense(config.embed_dim,activation="relu")(cls)
    output = layers.Dense(7, activation="softmax")(cls)

    model = tf.keras.Model(inputs,output)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(config.learning_rate),loss='sparse_categorical_crossentropy',metrics=['accuracy'])
    return model

  if not hasattr(np, "object"):


## Hyperparameter Sweep

In [17]:
X = np.load('connect4_X_clean.npy')
y = np.load('connect4_Y_clean.npy')

In [None]:
import wandb
from wandb.integration.keras import WandbMetricsLogger
from tensorflow.keras.callbacks import EarlyStopping

def train_test_split_np(X, y, test_size=0.2, seed=42, shuffle=True):
    assert len(X) == len(y)

    N = len(X)
    indices = np.arange(N)

    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(indices)

    split = int(N * (1 - test_size))

    train_idx = indices[:split]
    test_idx = indices[split:]

    return (
        X[train_idx],
        X[test_idx],
        y[train_idx],
        y[test_idx],
    )

TRANSFORMER_SWEEP_CONFIG = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'num_layers': {'values': [3, 4, 6]},
        'embed_dim': {'values': [64, 128]},
        'num_heads': {'values': [4, 8]},
        'mlp_dim': {'values': [128, 256]},
        'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.3},
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-4, 'max': 1e-3},
        'batch_size': {'values': [128, 256]},
        'epochs': {'value': 50},
        'early_stopping_patience': {'value': 10},
    }
}

def sweep(config = None):
    with wandb.init(config=config):
        config = wandb.config

        print("Config keys:", list(config.keys()))
        x_train, x_val, y_train, y_val = train_test_split_np(
            X, y, test_size=0.3)

        x_test, x_val, y_test, y_val= train_test_split_np(
            X, y, test_size=0.5)

        model = build_transformer(config)

        early_stop = EarlyStopping(
            monitor='val_policy_accuracy',
            mode='max',
            min_delta=1e-4,
            restore_best_weights=True,
            patience=config.early_stopping_patience)

        model.fit(
            x_train,y_train,
            validation_data=(x_val,y_val),
            batch_size=config.batch_size,
            epochs=config.epochs,
            callbacks=[WandbMetricsLogger(), early_stop],
            verbose=1
        )

        results  = model.evaluate(x_test,y_test verbose=0,return_dict=True)
        wandb.log(results)

        print(f"\nTest Accuracy: {results}%)")

wandb.login()
sweep_id = wandb.sweep(TRANSFORMER_SWEEP_CONFIG, project="connect4-transformer")
wandb.agent(sweep_id, function=sweep, count=40)


## Train Final Model

In [None]:
import matplotlib.pyplot as plt
def model_train(config=None):
    model_name = 'connect4_transformer'
    config = wandb.config
    print("Config keys:", list(config.keys()))
    model = build_transformer(config)
    history = model.fit(
        X, y,
        batch_size=config.batch_size,
        epochs=config.epochs,
        verbose=1)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].plot(history.history['accuracy'], label='Train', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Accuracy', fontsize=12)
    axes[0].set_title(f'{model_name} - Accuracy', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(history.history['loss'], label='Train', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].set_title(f'{model_name} - Loss', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_name}_history.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return model

TRANSFORMER_BEST_CONFIG = {
    'num_layers': 4,
    'embed_dim': 128,
    'num_heads': 4,
    'mlp_dim': 128,
    'dropout': 0.15,
    'learning_rate': 0.00045,
    'batch_size': 256,
    'epochs': 500,

}
wandb.login()
wandb.init(config=TRANSFORMER_BEST_CONFIG,name='transformer_train', project="connect4-transformer_full_train")
model = model_train(config=wandb.config)
model.save('/content/drive/MyDrive/connect4_transformer_best.keras')
wandb.save('connect4_transformer1.keras')
wandb.finish()


## Load Model

In [None]:
model = tf.keras.models.load_model(
    "connect4_transformer_best.keras",
    custom_objects={
        "FixedSubblockExtractor": FixedSubblockExtractor,
        "AdditivePositionalEncoding": AdditivePositionalEncoding,
        "ClassToken": ClassToken,
        "TransformerEncoder": TransformerEncoder
    }
)






## Helper Functions for Running Connect4


In [3]:
def predict_move(model, board):
    policy, value = model.predict(board[None, ...])
    probs = tf.nn.softmax(policy).numpy()[0]
    return probs

def board_2d_to_3d(board_2d: np.ndarray) -> np.ndarray:
    """
    Convert a 6x7 board with values in {-1, 0, +1}
    to a 6x7x2 one-hot board.

    Channel 0 -> +1 player
    Channel 1 -> -1 player
    """
    if board_2d.shape != (6, 7):
        raise ValueError("Input board must have shape (6, 7)")

    board_3d = np.zeros((6, 7, 2), dtype=int)

    board_3d[:, :, 0] = (board_2d == 1).astype(int)
    board_3d[:, :, 1] = (board_2d == -1).astype(int)

    return board_3d

def board_3d_to_2d(board_3d: np.ndarray) -> np.ndarray:
    """
    Convert a 6x7x2 one-hot board
    back to a 6x7 board with values in {-1, 0, +1}.
    """
    if board_3d.shape != (6, 7, 2):
        raise ValueError("Input board must have shape (6, 7, 2)")

    board_2d = np.zeros((6, 7), dtype=int)

    board_2d[board_3d[:, :, 0] == 1] = 1
    board_2d[board_3d[:, :, 1] == 1] = -1

    return board_2d

def normalize_board_perspective(board_6x7x2):
    """
    Normalize a Connect-4 board so that the current player
    is always in channel 0.

    Args:
        board_6x7x2 : np.ndarray of shape (6, 7, 2)

    Returns:
        norm_board : np.ndarray of shape (6, 7, 2)
    """
    assert board_6x7x2.shape == (6, 7, 2)

    plus_count  = np.sum(board_6x7x2[:, :, 0])
    minus_count = np.sum(board_6x7x2[:, :, 1])

    # minus player's turn
    if plus_count == minus_count + 1:
        # swap channels
        norm_board = board_6x7x2[:, :, ::-1]
        return norm_board
    # plus player's turn
    return board_6x7x2.copy()

def update_board(board_temp, color, column):
    """Update board with a move"""
    board = board_temp.copy()
    colsum = abs(board[0,column])+abs(board[1,column])+abs(board[2,column])+\
             abs(board[3,column])+abs(board[4,column])+abs(board[5,column])
    row = int(5-colsum)
    if row > -0.5:
        if color == 'plus':
            board[row,column] = 1
        else:
            board[row,column] = -1
    return board

def find_legal(board):
    """Find legal moves"""
    legal = [i for i in range(7) if abs(board[0,i]) < 0.1]
    return legal

def check_for_win(board, col):
    """Check if the last move resulted in a win"""
    nrow = 6
    ncol = 7

    colsum = abs(board[0,col])+abs(board[1,col])+abs(board[2,col])+\
             abs(board[3,col])+abs(board[4,col])+abs(board[5,col])
    row = int(6-colsum)

    # Vertical
    if row+3<6:
        vert = board[row,col] + board[row+1,col] + board[row+2,col] + board[row+3,col]
        if vert == 4:
            return 'v-plus'
        elif vert == -4:
            return 'v-minus'

    # Horizontal
    for start_col in [col-3, col-2, col-1, col]:
        if 0 <= start_col <= 3:
            hor = board[row,start_col] + board[row,start_col+1] + \
                  board[row,start_col+2] + board[row,start_col+3]
            if hor == 4:
                return 'h-plus'
            elif hor == -4:
                return 'h-minus'

    # Diagonal down-right
    for offset in range(-3, 1):
        r, c = row + offset, col + offset
        if 0 <= r <= 2 and 0 <= c <= 3:
            DR = board[r,c] + board[r+1,c+1] + board[r+2,c+2] + board[r+3,c+3]
            if DR == 4:
                return 'd-plus'
            elif DR == -4:
                return 'd-minus'

    # Diagonal down-left
    for offset in range(-3, 1):
        r, c = row + offset, col - offset
        if 0 <= r <= 2 and 3 <= c <= 6:
            DL = board[r,c] + board[r+1,c-1] + board[r+2,c-2] + board[r+3,c-3]
            if DL == 4:
                return 'd-plus'
            elif DL == -4:
                return 'd-minus'

    return 'nobody'

def display_board(board):
    # this function displays the board as ascii using X for +1 and O for -1
    # For the project, this should be a better picture of the board...
    clear_output()
    horizontal_line = '-'*(7*5+8)
    blank_line = '|'+' '*5
    blank_line *= 7
    blank_line += '|'
    print('   0     1     2     3     4     5     6')
    print(horizontal_line)
    for row in range(6):
        print(blank_line)
        this_line = '|'
        for col in range(7):
            if board[row,col] == 0:
                this_line += ' '*5 + '|'
            elif board[row,col] == 1:
                this_line += '  X  |'
            else:
                this_line += '  O  |'
        print(this_line)
        print(blank_line)
        print(horizontal_line)
    print('   0     1     2     3     4     5     6')

def is_board_full(board):
    return np.all(board != 0)
            

## Example Classification

In [4]:
from IPython.display import clear_output

games = 1
transformer_wins = 0
for i in range(games):
    winner = 'nobody'
    board = np.zeros((6,7))

    display_board(board)

    player = 'plus'
    tied = False
    while winner == 'nobody':
        
        if player == 'plus':
            move = predict_move(model,board_2d_to_3d(board)).argmax()
            board = update_board(board,player,move)
            display_board(board)
            player = 'minus'
        else:
            #move = mcts(board,player,1000)
            move = int(input('0-6:'))
            board = update_board(board,player,move)
            display_board(board)
            player = 'plus'
        if is_board_full(board):
            tied = True
            break    
        winner = check_for_win(board,move)
        
    if not tied:
        print('The winner is '+winner)
        if 'plus' in winner:
            transformer_wins += 1
transformer_wins/games

   0     1     2     3     4     5     6
-------------------------------------------
|     |     |     |     |     |     |     |
|     |     |     |     |     |     |     |
|     |     |     |     |     |     |     |
-------------------------------------------
|     |     |     |     |     |     |     |
|     |     |  O  |  X  |  O  |     |     |
|     |     |     |     |     |     |     |
-------------------------------------------
|     |     |     |     |     |     |     |
|  O  |     |  X  |  O  |  X  |     |  X  |
|     |     |     |     |     |     |     |
-------------------------------------------
|     |     |     |     |     |     |     |
|  X  |     |  O  |  X  |  O  |     |  O  |
|     |     |     |     |     |     |     |
-------------------------------------------
|     |     |     |     |     |     |     |
|  X  |     |  X  |  O  |  O  |  O  |  X  |
|     |     |     |     |     |     |     |
-------------------------------------------
|     |     |     |     |     |    

0.0