# Transformer Model Code

## Load data from pickle, flip all boards, convert to 6x7x2, remove duplicate boards and keep most commonly recommended move only

In [None]:
#not needed for Kaggle
#from google.colab import drive
#drive.mount('/content/drive')

In [24]:
import pickle

In [25]:
with open('/kaggle/input/mcts7500-pool/mcts7500_pool.pickle', 'rb') as f:
    data = pickle.load(f)

In [27]:
#dictionary has three keys
board, move, note_from_dan = data.keys()  # Extract keys

# Assign values to separate objects
board = data[board]
move = data[move]
note = data[note_from_dan] #this is just a note on how Dan created the dataset

In [28]:
import numpy as np

# Suppose 'board' has shape (N, 6, 7) and 'move' has shape (N,), where
#  - board[n, i, j] is +1, -1, or  0 for row i, column j of the nth board
#  - move[n] is the recommended column (0..6) for +1 to play on the nth board

#
# 1) Flip the boards and moves
#
# Flipping a Connect 4 board for augmentation typically means reversing columns
# left-to-right.  We can use np.flip along axis=2 (the columns).
# For the recommended move, flipping column c goes to column (6 - c).
#
# Convert move (Python list) to a NumPy array:
move = np.array(move, dtype=np.int32)

board_flipped = np.flip(board, axis=2)        # Flip each board left-right
move_flipped  = 6 - move                      # Flip the recommended column

# Now double the data by concatenating the original and flipped versions:
board_aug = np.concatenate([board, board_flipped], axis=0)
move_aug  = np.concatenate([move,  move_flipped],  axis=0)

#
# 2) Convert to 6x7x2 format
#
# We create two “channels.”
#   Channel 0 (last dimension = 0) indicates +1 pieces
#   Channel 1 (last dimension = 1) indicates -1 pieces
#

def convert_to_6x7x2(boards_6x7):
    # boards_6x7 has shape (M, 6, 7)
    # We want shape (M, 6, 7, 2)

    result = np.zeros((len(boards_6x7), 6, 7, 2), dtype=np.int8)

    # Wherever the board is +1, set channel 0 to 1
    result[boards_6x7 == 1, 0] = 1

    # Wherever the board is -1, set channel 1 to 1
    result[boards_6x7 == -1, 1] = 1

    return result

board_6x7x2 = convert_to_6x7x2(board_aug)

# board_6x7x2 now has shape (2N, 6, 7, 2)
# move_aug has shape (2N,)


In [29]:
import numpy as np
from collections import Counter

def remove_duplicates_and_keep_common_move(boards_6x7x2, moves):
    """
    boards_6x7x2: shape (N, 6, 7, 2)
    moves: shape (N,)

    Returns:
      unique_boards: shape (M, 6, 7, 2)
      unique_moves: shape (M,)
      where M <= N, with one record per unique board,
      and the move is the most common one among duplicates.
    """
    # Dictionary from a board's bytes representation -> list of moves
    board_dict = {}
    for b, m in zip(boards_6x7x2, moves):
        # Convert b to a bytes object so it can be used as a dict key.
        b_bytes = b.tobytes()
        board_dict.setdefault(b_bytes, []).append(m)

    unique_boards_list = []
    unique_moves_list = []

    # For each unique board, pick the most common move
    for b_bytes, move_list in board_dict.items():
        # Convert bytes back into a (6,7,2) array
        board_arr = np.frombuffer(b_bytes, dtype=boards_6x7x2.dtype).reshape(6, 7, 2)
        # Count how many times each move occurs and pick the most frequent
        move_counter = Counter(move_list)
        most_common_move = move_counter.most_common(1)[0][0]

        unique_boards_list.append(board_arr)
        unique_moves_list.append(most_common_move)

    # Convert lists to NumPy arrays
    unique_boards = np.stack(unique_boards_list, axis=0)  # shape (M, 6, 7, 2)
    unique_moves = np.array(unique_moves_list)            # shape (M,)

    return unique_boards, unique_moves


# Example usage:
# boards_6x7x2 has shape (2N, 6, 7, 2)
# move_aug has shape (2N,)

unique_boards, unique_moves = remove_duplicates_and_keep_common_move(
    board_6x7x2,  # shape (2N, 6, 7, 2) after augmentation
    move_aug      # shape (2N,) after augmentation
)

print("Unique boards shape:", unique_boards.shape)
print("Unique moves shape:",  unique_moves.shape)


Unique boards shape: (453376, 6, 7, 2)
Unique moves shape: (453376,)


### Make sure to use unique_boards and unique_moves for training the models

### Optionally, save to csv and reload from csv

In [None]:
# optionally, here are functions save to csv file and reload it correctly from the csv

def load_boards_csv(filename):
    boards = []
    moves = []
    with open(filename, 'r') as f:
        reader = csv.reader(f)
        rows = list(reader)

    i = 0
    while i < len(rows):
        # 1) Grab header row: [ "BOARD_INDEX", idx, "MOVE", move_val ]
        header = rows[i]
        # parse board index if needed: idx = int(header[1])
        move_val = int(header[3])
        moves.append(move_val)
        i += 1

        # 2) Grab 6 lines for the board
        board_rows = rows[i : i + 6]
        i += 6

        # Convert from shape (6 lines × 14 columns) => (6, 7, 2)
        board_6x7x2 = np.zeros((6, 7, 2), dtype=int)
        for row_idx in range(6):
            row_data = board_rows[row_idx]       # 14 entries
            for col_idx in range(7):
                ch0 = int(row_data[col_idx*2])
                ch1 = int(row_data[col_idx*2 + 1])
                board_6x7x2[row_idx, col_idx, 0] = ch0
                board_6x7x2[row_idx, col_idx, 1] = ch1

        boards.append(board_6x7x2)

        # 3) Skip blank line if it exists
        if i < len(rows) and len(rows[i]) == 0:
            i += 1

    boards = np.stack(boards, axis=0)  # shape (N, 6, 7, 2)
    moves = np.array(moves)
    return boards, moves

import csv
import numpy as np

def load_boards_csv(filename):
    """
    Reads a CSV file where each board is stored as:
      1) A header row: ["BOARD_INDEX", i, "MOVE", moves[i]]
      2) 6 rows, each with 14 integers (7 columns × 2 channels)
      3) A blank line (optional) separating each board

    Returns:
      boards: ndarray of shape (N, 6, 7, 2)
      moves: ndarray of shape (N,)
    """
    boards = []
    moves = []

    with open(filename, 'r', newline='') as f:
        reader = csv.reader(f)
        rows = list(reader)

    i = 0
    while i < len(rows):
        # 1) Read the header row (e.g. ["BOARD_INDEX", "0", "MOVE", "3"])
        header = rows[i]
        move_val = int(header[3])   # parse the move
        moves.append(move_val)
        i += 1

        # 2) Read the 6 lines of board data
        board_rows = rows[i : i + 6]
        i += 6

        # Convert from shape (6, 14) -> (6, 7, 2)
        board_6x7x2 = np.zeros((6, 7, 2), dtype=int)
        for row_idx in range(6):
            row_data = board_rows[row_idx]  # 14 entries
            for col_idx in range(7):
                ch0 = int(row_data[col_idx * 2])
                ch1 = int(row_data[col_idx * 2 + 1])
                board_6x7x2[row_idx, col_idx, 0] = ch0
                board_6x7x2[row_idx, col_idx, 1] = ch1

        boards.append(board_6x7x2)

        # 3) Skip the blank line if present
        if i < len(rows) and len(rows[i]) == 0:
            i += 1

    boards = np.stack(boards, axis=0)  # shape (N, 6, 7, 2)
    moves = np.array(moves)
    return boards, moves

# Example usage:
# unique_boards, unique_moves = load_boards_csv("connect4_data.csv")
# print(boards.shape)  # e.g. (N, 6, 7, 2)
# print(moves.shape)   # e.g. (N,)


# Train Transformer
Note: For ease of use, the data loading and preprocessing code is repeated here so the above cells do not need to be run

In [30]:
import pickle
import numpy as np
from collections import Counter
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# ------------------------------------------------------------------
# 0) LOAD AND PRE-PROCESS YOUR DATA (same as before)
# ------------------------------------------------------------------
with open('/kaggle/input/mcts7500-pool/mcts7500_pool.pickle', 'rb') as f:
    data = pickle.load(f)

board_key, move_key, note_from_dan = data.keys()  # Extract keys
board = data[board_key]   # shape (N, 6, 7)
move = data[move_key]     # shape (N,)
note = data[note_from_dan]

move = np.array(move, dtype=np.int32)

board_flipped = np.flip(board, axis=2)
move_flipped = 6 - move
board_aug = np.concatenate([board, board_flipped], axis=0)
move_aug  = np.concatenate([move,  move_flipped],  axis=0)

def convert_to_6x7x2(boards_6x7):
    result = np.zeros((len(boards_6x7), 6, 7, 2), dtype=np.int8)
    result[boards_6x7 == 1, 0] = 1
    result[boards_6x7 == -1, 1] = 1
    return result

board_6x7x2 = convert_to_6x7x2(board_aug)

def remove_duplicates_and_keep_common_move(boards_6x7x2, moves):
    board_dict = {}
    for b, m in zip(boards_6x7x2, moves):
        b_bytes = b.tobytes()
        board_dict.setdefault(b_bytes, []).append(m)

    unique_boards_list = []
    unique_moves_list  = []
    from collections import Counter

    for b_bytes, move_list in board_dict.items():
        board_arr = np.frombuffer(b_bytes, dtype=boards_6x7x2.dtype).reshape(6, 7, 2)
        most_common_move = Counter(move_list).most_common(1)[0][0]
        unique_boards_list.append(board_arr)
        unique_moves_list.append(most_common_move)

    unique_boards = np.stack(unique_boards_list, axis=0)
    unique_moves  = np.array(unique_moves_list)
    return unique_boards, unique_moves

unique_boards, unique_moves = remove_duplicates_and_keep_common_move(
    board_6x7x2,
    move_aug
)

print("Unique boards shape:", unique_boards.shape)
print("Unique moves shape: ", unique_moves.shape)


# ------------------------------------------------------------------
# 1) DEFINE THE TRANSFORMER ENCODER BLOCK
# ------------------------------------------------------------------
class TransformerEncoder(layers.Layer):
    def __init__(self, d_model, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation='relu'),
            layers.Dense(d_model),
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training=False, mask=None):
        # Multi-head self-attention
        attn_output = self.att(inputs, inputs, inputs, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        
        # Feed-forward
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)


# ------------------------------------------------------------------
# 2) POSITIONAL EMBEDDING LAYER (TRAINABLE)
# ------------------------------------------------------------------
class PositionalEmbedding(layers.Layer):
    def __init__(self, maxlen, d_model):
        super().__init__()
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=d_model)
        self.maxlen = maxlen

    def call(self, x):
        seq_len = tf.shape(x)[1]
        positions = tf.range(start=0, limit=seq_len, delta=1)
        pos_embeddings = self.pos_emb(positions)
        return x + pos_embeddings


# ------------------------------------------------------------------
# 3) BUILD THE TRANSFORMER MODEL
# ------------------------------------------------------------------
def create_transformer(
    d_model=256,       # Increased from 128
    num_heads=8,       # Increased from 4
    ff_dim=512,        # Increased from 256
    num_layers=5,      # Increased from 4
    rate=0.05,         # Reduced dropout
    maxlen=42,         # 6x7 = 42 tokens
    num_classes=7
):
    """
    Returns a compiled Transformer model that classifies boards into 7 columns.
    """
    inputs = layers.Input(shape=(6, 7, 2), name="board_input")

    # Flatten the 6x7 grid to 42 tokens; each token has 2 channels
    x = layers.Reshape((maxlen, 2))(inputs)

    # Project each 2-dim token into d_model
    x = layers.Dense(d_model)(x)

    # Add positional embeddings
    x = PositionalEmbedding(maxlen, d_model)(x)

    # Stack multiple TransformerEncoder blocks
    for _ in range(num_layers):
        x = TransformerEncoder(d_model, num_heads, ff_dim, rate)(x)

    # Pool across tokens to get a single vector
    x = layers.GlobalAveragePooling1D()(x)

    # Final classification
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = Model(inputs=inputs, outputs=outputs, name="connect4_transformer")
    
    # Lower initial learning rate (you can also try 3e-4 if stable).
    optimizer = Adam(learning_rate=1e-4)

    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer=optimizer, 
        metrics=["accuracy"]
    )
    return model


# ------------------------------------------------------------------
# 4) TRAINING
# ------------------------------------------------------------------
model = create_transformer(
    d_model=256,
    num_heads=8,
    ff_dim=512,
    num_layers=5,
    rate=0.05,
    maxlen=42,
    num_classes=7
)

model.summary()

Unique boards shape: (453376, 6, 7, 2)
Unique moves shape:  (453376,)


In [None]:
# Define callbacks
early_stopping_cb = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

reduce_lr_cb = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=2,
    verbose=1
)


from tensorflow.keras.callbacks import ModelCheckpoint

# Save the best model based on validation loss
checkpoint_cb = ModelCheckpoint(
    filepath='021025_best_transformer_model.keras',  # File to save the model
    monitor='val_loss',                    # Track validation loss
    save_best_only=True,                    # Save only the best model
    save_weights_only=False,                 # Save the entire model
    mode='min',                              # Lower val_loss is better
    verbose=1
)


# Fit the model
history = model.fit(
    x=unique_boards,
    y=unique_moves,
    batch_size=64,       # Adjust if you run out of memory or have more GPU power
    epochs=30, #30 or 50           # More epochs; model is still improving after 30
    validation_split=0.1,
    callbacks=[early_stopping_cb, reduce_lr_cb, checkpoint_cb],
    shuffle=True
)

# Save the entire model (architecture + weights)
model.save("021025_final_transformer_connect4.keras")

Epoch 1/30
[1m 284/6376[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m9:54[0m 98ms/step - accuracy: 0.2076 - loss: 1.9290

# Play the Transformer Against MCTS

In [20]:
import numpy as np
import tensorflow as tf

def board_6x7_to_6x7x2(board_6x7):
    """
    Convert a 6x7 board of +1 (plus), -1 (minus), 0 (empty)
    into a 6x7x2 representation for the Transformer.
      channel 0 = +1 cells (plus)
      channel 1 = -1 cells (minus)
    """
    plus_channel  = (board_6x7 == 1).astype(np.float32)
    minus_channel = (board_6x7 == -1).astype(np.float32)
    # Stack channels along the last axis => shape (6,7,2)
    return np.stack([plus_channel, minus_channel], axis=-1)

def transformer_move(model, board_6x7, color='plus'):
    """
    Given a (6,7) board with +1/-1/0, convert to (6,7,2) 
    and have the Transformer's model pick a column [0..6].
    """
    # Convert from (6,7) => (6,7,2)
    input_2ch = board_6x7_to_6x7x2(board_6x7)
    
    # Reshape to (1, 42, 2) for the model
    #input_2ch = input_2ch.reshape(1, 42, 2)
    input_2ch = np.expand_dims(input_2ch, axis=0)

    
    # model.predict => shape (1,7)
    probs = model.predict(input_2ch, verbose=0)[0]  # => (7,)
    
    # Pick the column with the highest predicted probability
    chosen_col = np.argmax(probs)  # int in [0..6]
    return chosen_col


In [21]:
import numpy as np
from IPython.display import clear_output
import time
import random
# https://www.youtube.com/watch?v=UXW2yZndl7U

def update_board(board_temp,color,column):
    # this is a function that takes the current board status, a color, and a column and outputs the new board status
    # columns 0 - 6 are for putting a checker on the board: if column is full just return the current board...this should be forbidden by the player

    # the color input should be either 'plus' or 'minus'

    board = board_temp.copy()
    ncol = board.shape[1]
    nrow = board.shape[0]

    # this seems silly, but actually faster to run than using sum because of overhead!
    colsum = abs(board[0,column])+abs(board[1,column])+abs(board[2,column])+abs(board[3,column])+abs(board[4,column])+abs(board[5,column])
    row = int(5-colsum)
    if row > -0.5:
        if color == 'plus':
            board[row,column] = 1
        else:
            board[row,column] = -1
    return board


# in this code the board is a 6x7 numpy array.  Each entry is +1, -1 or 0.  You WILL be able to do a better
# job training your neural network if you rearrange this to be a 6x7x2 numpy array.  If the i'th row and j'th
# column is +1, this can be represented by board[i,j,0]=1.  If it is -1, this can be represented by
# board[i,j,1]=1. It's up to you how you represent your board.

def check_for_win_slow(board):
    # this function checks to see if anyone has won on the given board
    nrow = board.shape[0]
    ncol = board.shape[1]
    winner = 'nobody'
    for col in range(ncol):
        for row in reversed(range(nrow)):
            if abs(board[row,col]) < 0.1: # if this cell is empty, all the cells above it are too!
                break
            # check for vertical winners
            if row <= (nrow-4): # can't have a column go from rows 4-7...
                tempsum = board[row,col]+board[row+1,col]+board[row+2,col]+board[row+3,col] # this is WAY faster than np.sum!!!
                if tempsum==4:
                    winner = 'v-plus'
                    return winner
                elif tempsum==-4:
                    winner = 'v-minus'
                    return winner
            # check for horizontal winners
            if col <= (ncol-4):
                tempsum = board[row,col]+board[row,col+1]+board[row,col+2]+board[row,col+3]
                if tempsum==4:
                    winner = 'h-plus'
                    return winner
                elif tempsum==-4:
                    winner = 'h-minus'
                    return winner
            # check for top left to bottom right diagonal winners
            if (row <= (nrow-4)) and (col <= (ncol-4)):
                tempsum = board[row,col]+board[row+1,col+1]+board[row+2,col+2]+board[row+3,col+3]
                if tempsum==4:
                    winner = 'd-plus'
                    return winner
                elif tempsum==-4:
                    winner = 'd-minus'
                    return winner
            # check for top right to bottom left diagonal winners
            if (row <= (nrow-4)) and (col >= 3):
                tempsum = board[row,col]+board[row+1,col-1]+board[row+2,col-2]+board[row+3,col-3]
                if tempsum==4:
                    winner = 'd-plus'
                    return winner
                elif tempsum==-4:
                    winner = 'd-minus'
                    return winner
    return winner


def check_for_win(board,col):
    # this code is faster than the above code, but it requires knowing where the last checker was dropped
    # it may seem extreme, but in MCTS this function is called more than anything and actually makes up
    # a large portion of total time spent finding a good move.  So every microsecond is worth saving!
    nrow = 6
    ncol = 7
    # take advantage of knowing what column was last played in...need to check way fewer possibilities
    colsum = abs(board[0,col])+abs(board[1,col])+abs(board[2,col])+abs(board[3,col])+abs(board[4,col])+abs(board[5,col])
    row = int(6-colsum)
    if row+3<6:
        vert = board[row,col] + board[row+1,col] + board[row+2,col] + board[row+3,col]
        if vert == 4:
            return 'v-plus'
        elif vert == -4:
            return 'v-minus'
    if col+3<7:
        hor = board[row,col] + board[row,col+1] + board[row,col+2] + board[row,col+3]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if col-1>=0 and col+2<7:
        hor = board[row,col-1] + board[row,col] + board[row,col+1] + board[row,col+2]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if col-2>=0 and col+1<7:
        hor = board[row,col-2] + board[row,col-1] + board[row,col] + board[row,col+1]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if col-3>=0:
        hor = board[row,col-3] + board[row,col-2] + board[row,col-1] + board[row,col]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if row < 3 and col < 4:
        DR = board[row,col] + board[row+1,col+1] + board[row+2,col+2] + board[row+3,col+3]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row-1>=0 and col-1>=0 and row+2<6 and col+2<7:
        DR = board[row-1,col-1] + board[row,col] + board[row+1,col+1] + board[row+2,col+2]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row-2>=0 and col-2>=0 and row+1<6 and col+1<7:
        DR = board[row-2,col-2] + board[row-1,col-1] + board[row,col] + board[row+1,col+1]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row-3>=0 and col-3>=0:
        DR = board[row-3,col-3] + board[row-2,col-2] + board[row-1,col-1] + board[row,col]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row+3<6 and col-3>=0:
        DL = board[row,col] + board[row+1,col-1] + board[row+2,col-2] + board[row+3,col-3]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if row-1 >= 0 and col+1 < 7 and row+2<6 and col-2>=0:
        DL = board[row-1,col+1] + board[row,col] + board[row+1,col-1] + board[row+2,col-2]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if row-2 >=0 and col+2<7 and row+1<6 and col-1>=0:
        DL = board[row-2,col+2] + board[row-1,col+1] + board[row,col] + board[row+1,col-1]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if row-3>=0 and col+3<7:
        DL = board[row-3,col+3] + board[row-2,col+2] + board[row-1,col+1] + board[row,col]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    return 'nobody'

def find_legal(board):
    legal = [i for i in range(7) if abs(board[0,i]) < 0.1]
    return legal

def look_for_win(board_,color):
    board_ = board_.copy()
    legal = find_legal(board_)
    winner = -1
    for m in legal:
        bt = update_board(board_.copy(),color,m)
        wi = check_for_win(bt,m)
        if wi[2:] == color:
            winner = m
            break
    return winner

def find_all_nonlosers(board,color):
    if color == 'plus':
        opp = 'minus'
    else:
        opp = 'plus'
    legal = find_legal(board)
    poss_boards = [update_board(board,color,l) for l in legal]
    poss_legal = [find_legal(b) for b in poss_boards]
    allowed = []
    for i in range(len(legal)):
        wins = [j for j in poss_legal[i] if check_for_win(update_board(poss_boards[i],opp,j),j) != 'nobody']
        if len(wins) == 0:
            allowed.append(legal[i])
    return allowed

def back_prop(winner,path,color0,md):
    for i in range(len(path)):
        board_temp = path[i]

        md[board_temp][0]+=1
        if winner[2]==color0[0]:
            if i % 2 == 1:
                md[board_temp][1] += 1
            else:
                md[board_temp][1] -= 1
        elif winner[2]=='e': # tie
            # md[board_temp][1] += 0
            pass
        else:
            if i % 2 == 1:
                md[board_temp][1] -= 1
            else:
                md[board_temp][1] += 1

def rollout(board,next_player):
    winner = 'nobody'
    player = next_player
    while winner == 'nobody':
        legal = find_legal(board)
        if len(legal) == 0:
            winner = 'tie'
            return winner
        move = random.choice(legal)
        board = update_board(board,player,move)
        winner = check_for_win(board,move)

        if player == 'plus':
            player = 'minus'
        else:
            player = 'plus'
    return winner

def mcts(board_temp,color0,nsteps):
    # nsteps is a parameter that determines the skill (and slowness) of the player
    # bigger values of nsteps means the player is better, but also slower to figure out a move.
    board = board_temp.copy()
    ##############################################
    winColumn = look_for_win(board,color0) # check to find a winning column
    if winColumn > -0.5:
        return winColumn # if there is one - play that!
    legal0 = find_all_nonlosers(board,color0) # find all moves that won't immediately lead to your opponent winning
    if len(legal0) == 0: # if you can't block your opponent - just find the 'best' losing move
        legal0 = find_legal(board)
    ##############################################
    # the code above, in between the hash rows, is not part of traditional MCTS
    # but it makes it better and faster - so I included it!
    # MCTS occasionally makes stupid mistakes
    # like not dropping the checker on a winning column, or not blocking an obvious opponent win
    # this avoids a little bit of that stupidity!
    # we could also add this logic to the rest of the MCTS and rollout functions - I just haven't done that yet...
    # feel free to experiment!
    mcts_dict = {tuple(board.ravel()):[0,0]}
    for ijk in range(nsteps):
        color = color0
        winner = 'nobody'
        board_mcts = board.copy()
        path = [tuple(board_mcts.ravel())]
        while winner == 'nobody':
            legal = find_legal(board_mcts)
            if len(legal) == 0:
                winner = 'tie'
                back_prop(winner,path,color0,mcts_dict)
                break
            board_list = []
            for col in legal:
                board_list.append(tuple(update_board(board_mcts,color,col).ravel()))
            for bl in board_list:
                if bl not in mcts_dict.keys():
                    mcts_dict[bl] = [0,0]
            ucb1 = np.zeros(len(legal))
            for i in range(len(legal)):
                num_denom = mcts_dict[board_list[i]]
                if num_denom[0] == 0:
                    ucb1[i] = 10*nsteps
                else:
                    ucb1[i] = num_denom[1]/num_denom[0] + 2*np.sqrt(np.log(mcts_dict[path[-1]][0])/mcts_dict[board_list[i]][0])
            chosen = np.argmax(ucb1)

            board_mcts = update_board(board_mcts,color,legal[chosen])
            path.append(tuple(board_mcts.ravel()))
            winner = check_for_win(board_mcts,legal[chosen])
            if winner[2]==color[0]:
                back_prop(winner,path,color0,mcts_dict)
                break
            if color == 'plus':
                color = 'minus'
            else:
                color = 'plus'
            if mcts_dict[tuple(board_mcts.ravel())][0] == 0:
                winner = rollout(board_mcts,color)
                back_prop(winner,path,color0,mcts_dict)
                break

    maxval = -np.inf
    best_col = -1
    for col in legal0:
        board_temp = tuple(update_board(board,color0,col).ravel())
        num_denom = mcts_dict[board_temp]
        if num_denom[0] == 0:
            compare = -np.inf
        else:
            compare = num_denom[1] / num_denom[0]
        if compare > maxval:
            maxval = compare
            best_col = col
    return (best_col)




In [22]:
def play_game_mcts_vs_transformer(model, nsteps=1000, first_player="transformer"):
    """
    Plays one full Connect 4 game on an empty board:
      - The Transformer model vs. MCTS(nsteps).
      - first_player: "transformer" or "mcts"
      - nsteps: how many MCTS iterations we use on each move.

    Returns: (winner, move_count)
       winner = "transformer" or "mcts" or "tie"
       move_count = number of moves made before game ended
    """

    # Initialize an empty 6x7x2 board (if that's your representation).
    # OLD: board = np.zeros((6,7,2), dtype=np.float32)
    board = np.zeros((6,7), dtype=np.float32)

    # If your code uses string-based colors, let's define them:
    # We'll say: color "plus" = first player, color "minus" = second.
    # (The rest of your MCTS code expects color in {"plus", "minus"}).
    current_color = "plus"  # start with plus

    # We also track a mapping from color -> who is playing.
    # If first_player = "transformer", then "plus" is transformer, "minus" is MCTS
    # else vice versa.
    if first_player == "transformer":
        plus_player = "transformer"
        minus_player = "mcts"
    else:
        plus_player = "mcts"
        minus_player = "transformer"

    move_count = 0
    winner = None

    while True:
        move_count += 1

        # Determine which "agent" is about to move
        if current_color == "plus":
            agent = plus_player
        else:
            agent = minus_player

        # Agent picks a column
        if agent == "transformer":
            col = transformer_move(model, board, current_color)
        else:
            col = mcts(board, current_color, nsteps=nsteps)

        # Update the board
        board = update_board(board, current_color, col)

        # Check if this move caused a win
        check = check_for_win(board, col)
        if check[2] == current_color[0]:
            # current_color is "plus" or "minus", check[2] is 'p' or 'm'
            # If we have a winner, see if it's the transformer or mcts
            if agent == "transformer":
                winner = "transformer"
            else:
                winner = "mcts"
            break

        # Check if board is full or no moves left
        legal = find_legal(board)
        if len(legal) == 0:
            # It's a tie
            winner = "tie"
            break

        # Switch color
        if current_color == "plus":
            current_color = "minus"
        else:
            current_color = "plus"

    return winner, move_count


In [23]:
from collections import Counter

class GetItem(tf.keras.layers.Layer):
    def __init__(self, *args, **kwargs):
        super(GetItem, self).__init__(**kwargs)
        # We store args just in case Keras tries to pass them
        # but we won't necessarily use them.
        self._args = args

    def call(self, inputs, *call_args, **call_kwargs):
        # Hard-code the slice that your model used: mid[:, 0, :]
        # If your slice was different, adjust here:
        return inputs[:, 0, :]

class PositionalIndex(tf.keras.layers.Layer):
    def call(self, x):
        bs = tf.shape(x)[0]
        seq_len = tf.shape(x)[1]
        indices = tf.range(seq_len)
        indices = tf.expand_dims(indices, 0)
        return tf.tile(indices, [bs, 1])

class ClassTokenIndex(tf.keras.layers.Layer):
    def call(self, x):
        bs = tf.shape(x)[0]
        indices = tf.range(1)
        indices = tf.expand_dims(indices, 0)
        return tf.tile(indices, [bs, 1])

class ExtractClassToken(tf.keras.layers.Layer):
    def call(self, inputs):
        # 'inputs' is shape (batch_size, seq_len, hidden_dim)
        # return just the [0]-th token
        return inputs[:, 0, :]

class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, maxlen, d_model, **kwargs):
        # Pass all keyword arguments on to the base Layer
        super(PositionalEmbedding, self).__init__(**kwargs)
        self.maxlen = maxlen
        self.d_model = d_model
        self.pos_emb = tf.keras.layers.Embedding(
            input_dim=maxlen, 
            output_dim=d_model
        )

    def call(self, x):
        seq_len = tf.shape(x)[1]
        positions = tf.range(start=0, limit=seq_len, delta=1)
        # (seq_len,) --> shape (1, seq_len)
        pos_embeddings = self.pos_emb(positions)
        return x + pos_embeddings

class TransformerEncoder(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, ff_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation='relu'),
            tf.keras.layers.Dense(d_model),
        ])
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, inputs, training=False, mask=None):
        # Multi-head self-attention
        attn_output = self.att(inputs, inputs, inputs, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        
        # Feed-forward
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

# 2) Include 'PositionalEmbedding' in custom_objects
#custom_objects = {"GetItem": GetItem,"PositionalIndex": PositionalIndex,"ClassTokenIndex": ClassTokenIndex,"ExtractClassToken": ExtractClassToken,"PositionalEmbedding": PositionalEmbedding}

def test_transformer_vs_mcts(model_path, mcts_steps_list, games_per_setting=5):
    """
    Loads the Transformer model from disk (with custom layers),
    then plays multiple games for each nsteps, tracking win/loss.
    """
    custom_objects = {
    "GetItem": GetItem,
    "PositionalIndex": PositionalIndex,
    "ClassTokenIndex": ClassTokenIndex,
    "ExtractClassToken": ExtractClassToken,
    "PositionalEmbedding": PositionalEmbedding,
    "TransformerEncoder": TransformerEncoder}    
    model = tf.keras.models.load_model("my_transformer_connect4.h5", custom_objects=custom_objects)
   
    results_summary = {}
    for steps in mcts_steps_list:
        wins = Counter()
        for i in range(games_per_setting):
            winner, moves = play_game_mcts_vs_transformer(
                model,
                nsteps=steps,
                first_player="transformer"
            )
            wins[winner] += 1
        results_summary[steps] = wins
        print(f"Results for MCTS(nsteps={steps}): {wins}")

    return results_summary

# Now call it:
mcts_steps_list = [100, 1000, 5000]
test_transformer_vs_mcts("my_transformer_connect4.h5", mcts_steps_list, games_per_setting=5)


Results for MCTS(nsteps=100): Counter({'mcts': 3, 'transformer': 2})
Results for MCTS(nsteps=1000): Counter({'mcts': 5})
Results for MCTS(nsteps=5000): Counter({'mcts': 5})


{100: Counter({'transformer': 2, 'mcts': 3}),
 1000: Counter({'mcts': 5}),
 5000: Counter({'mcts': 5})}