In [None]:
# Install dependencies
%pip install tensorflow numpy

In [None]:
# Import necessary libraries
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers
import math
from collections import deque 

In [None]:
# Arguments for training the model
args = {
    'numIterations': 10000, # Total number of training iterations
    'numSimulations': 500, # MCTS simulations per move
    'epochs': 10, # Epochs per training step
    'batchSize': 64, # Batch size for training
    'maxBufferSize': 50000, # Maximum replay buffer (memory) size
    'cPuct': 1.25, # Exploration constant for MCTS
    'temperatureMoves': 12, # Temperature for move selection
    'dirichletAlpha': 0.3, # Dirichlet noise alpha for exploration
    'dirichletEpsilon': 0.25, # Dirichlet noise epsilon for exploration
}

In [None]:
# Main game engine
class Game:
    def __init__(self, width=7, length=7, height=6, winLength=4):
        self.width = width
        self.length = length
        self.height = height
        self.winLength = winLength
        self.directions = directions = [
            (1, 0, 0),
            (0, 1, 0),
            (0, 0, 1),
            (1, 1, 0),
            (1, -1, 0),
            (1, 0, 1),
            (1, 0, -1),
            (0, 1, 1),
            (0, -1, 1),
            (1, 1, 1),
            (1, 1, -1),
            (1, -1, 1),
            (-1, 1, 1),
        ]  # Possible directions to check for winning conditions

    # Initialize the game state
    def getInitialState(self):
        return np.zeros((self.height, self.width, self.length), dtype=np.int8)

    # Get possible moves for a state
    def getValidMoves(self, state):
        height, width, length = state.shape
        valid = []
        for x in range(width):
            for z in range(length):
                # Check the top of the column at (x, z)
                if state[height - 1][x][z] == 0:
                    valid.append((x, z))
        return valid

    # Get the next state based off of a move
    def getNextState(self, state, player, action):
        nState = state.copy()
        x, z = action
        for layer in range(self.height):
            if nState[layer][x][z] == 0:
                nState[layer][x][z] = player
                return nState
        return nState

    # Check if game has ended
    def gameEnded(self, state):
        height, width, length = state.shape

        # Iterate through every cell on the board
        for y in range(height):
            for x in range(width):
                for z in range(length):
                    player = state[y][x][z]

                    if player == 0:
                        continue

                    for dy, dx, dz in self.directions:
                        count = 1
                        for i in range(1, self.winLength):
                            ny, nx, nz = y + i * dy, x + i * dx, z + i * dz
                            if (
                                0 <= ny < height
                                and 0 <= nx < width
                                and 0 <= nz < length
                            ):
                                if state[ny][nx][nz] == player:
                                    count += 1
                                else:
                                    break
                            else:
                                break

                        if count == self.winLength:
                            return player

        if len(self.getValidMoves(state)) == 0:
            return 0

        return None


In [None]:
# Main neural network model for the game
class NeuralNetwork:
    def __init__(self):
        inputs = layers.Input(shape=(6, 7, 7, 3))

        # Bigger Kernel helps to find more useful patterns
        x = layers.Conv3D(
            filters=64, kernel_size=(3, 3, 3), padding="same", name="inputLayer"
        )(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)

        # Start setting up the residual blocks
        for i in range(4):
            x = self.createResidialBlock(x)


        # Sets up the outputs for policy and value heads. We need a policy head to predict the next move, and a value head to predict the game outcome.
        policyHead = layers.Conv3D(filters=2, kernel_size=(1, 1, 1), padding="same", name="policy")(x)
        policyHead = layers.BatchNormalization()(policyHead)
        policyHead = layers.Activation("relu")(policyHead)
        policyHead = layers.Flatten()(policyHead)
        policyOut = layers.Dense(units=49, activation="softmax", name="policyOut")(policyHead) # 49 units for the 49 possible moves

        valueHead = layers.Conv3D(filters=1, kernel_size=(1, 1, 1), padding="same", name="value")(x)
        valueHead = layers.BatchNormalization()(valueHead)
        valueHead = layers.Activation("relu")(valueHead)
        valueHead = layers.Flatten()(valueHead)
        valueHead = layers.Dense(units=64, activation="relu")(valueHead)
        valueOut = layers.Dense(units=1, activation="tanh", name="valueOut")(valueHead) # Single value output

        # Finishing the model
        self.model = keras.Model(inputs=inputs, outputs=[policyOut, valueOut])
        self.model.compile(
            optimizer="adam",
            loss={
                "policyOut": "categorical_crossentropy",
                "valueOut": "mean_squared_error",
            },
        )

    # Creates a single residual block
    def createResidialBlock(self, inputLayer, filters=64):
        residual = inputLayer
        x = layers.Conv3D(filters=filters, kernel_size=(3, 3, 3), padding="same")(
            inputLayer
        )
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)

        x = layers.Conv3D(filters=filters, kernel_size=(3, 3, 3), padding="same")(x)
        x = layers.BatchNormalization()(x)

        # Adds skips (prevents vanishing gradients)
        x = layers.Add()([x, residual])
        x = layers.Activation("relu")(x)

        return x


In [None]:
# Helper function to encode states for the neural network
def encodeState(state, playerPerspective):
    # Creates matrix of where player pieces are
    p = np.zeros_like(state, dtype=np.float32)
    p[state == playerPerspective] = 1

    # Creates matrix of where opponentt pieces are
    o = np.zeros_like(state, dtype=np.float32)
    o[state == -playerPerspective] = 1

    # Creates matrix of whose turn it is
    turn = np.full_like(state, playerPerspective, dtype=np.float32)

    return np.stack([p, o, turn])

In [None]:
# A single node in the MCTS tree
class Node:
    def __init__(self, state, player, parent=None, move=None, priorP=0):
        # Initialize the node. Sets up values required for MCTS calculations.
        self.state = state
        self.player = player
        self.parent = parent
        self.move = move
        self.visits = 0
        self.value = 0.0
        self.priorP = priorP
        self.children = []
        self.expanded = False

    # Gets node's mean value
    def getMeanValue(self):
        return 0 if self.visits == 0 else self.value / self.visits

    # Get a node's score
    def getScore(self, cPuct):
        # Models after the selection equation: Q = c_puct * P(s, a) * sqrt( sum_b N(s, b) / (1 + N(s, a)) )
        qValue = self.getMeanValue()
        parentVisits = self.parent.visits if self.parent else 1
        eT = cPuct * self.priorP * math.sqrt(parentVisits) / (1 + self.visits)

        # qValue is being flipped because it's based off of exploitation (and as the next move would be the opponent, not ours). eT is untouched, as it's the term for exploration.
        return -qValue + eT 

In [None]:
# Class for the Monte Carlo Tree Search algorithm
class MCTS:
    def __init__(self, game, model, args):
        # Initialize the MCTS with the game, model, and arguments.
        self.game = game
        self.model = model
        self.args = args

    # Searches the MCTS tree for the best move
    def search(self, state, player):
        if self.game.gameEnded(state) is not None: # Make sure game isnt already over
            return Node(state, player)

        root = Node(state, player, priorP=1.0)

        # Applies Dirichlet noise to the root node's policy (helps with exploration)
        encodedState = encodeState(root.state, root.player)
        encodedState = np.expand_dims(encodedState, axis=0)
        encodedState = np.transpose(encodedState, (0, 2, 3, 4, 1))

        policy, _ = self.model(encodedState, training=False)
        policy = policy.numpy()[0]

        noise = np.random.dirichlet([self.args["dirichletAlpha"]] * len(policy))
        policy = (1 - self.args["dirichletEpsilon"]) * policy + self.args[
            "dirichletEpsilon"
        ] * noise

        # Initialize children of the root node based on valid moves
        validMoves = self.game.getValidMoves(root.state)
        for move in validMoves:
            moveIndex = move[0] * self.game.length + move[1]
            nextState = self.game.getNextState(root.state, root.player, move)
            child = Node(
                state=nextState,
                player=-root.player,
                parent=root,
                move=move,
                priorP=policy[moveIndex],
            )
            root.children.append(child)
        root.expanded = True  # Mark the root as expanded

        # Perform MCTS simulations
        for simulation in range(self.args["numSimulations"]):
            node = root

            # 1. Selection Phase: Traverse the tree to find a leaf node
            while node.expanded:
                childScores = [
                    child.getScore(self.args["cPuct"]) for child in node.children
                ]
                bestChildIndex = np.argmax(childScores)
                node = node.children[bestChildIndex]

            gameResult = self.game.gameEnded(node.state)

            if gameResult is not None: # Make sure game doesn't end at this node
                value = (gameResult * node.player)  # Calculate the value based on the outcome of the game the player
            else:
                # Encodes the state for the neural network
                encodedState = encodeState(node.state, node.player)
                encodedState = np.expand_dims(encodedState, axis=0)
                encodedState = np.transpose(encodedState, (0, 2, 3, 4, 1))

                # Feeds NN the encoded state to get policy and value
                policy, value = self.model(encodedState, training=False)
                value = value.numpy()[0][0]
                policy = policy.numpy()[0]

                # 2 + 3. Expansion/Simulation Phase: Adds children to the node
                validMoves = self.game.getValidMoves(node.state)
                for move in validMoves:
                    moveIndex = (move[0] * self.game.length + move[1]) # Maps coordinates to a specific column
                    nextState = self.game.getNextState(node.state, node.player, move)
                    child = Node(
                        state=nextState,
                        player=-node.player,
                        parent=node,
                        move=move,
                        priorP=policy[moveIndex],
                    )
                    node.children.append(child)
                node.expanded = True

            # 4. Backpropagation Phase: Update the values of the nodes in the path
            while node:
                node.visits += 1
                node.value += value
                value = -value  # Player -> Opponent (as turns flip each iteration)
                node = node.parent

        # Returns the root (contains the entire tree)
        return root

In [None]:
# Class to train the model
class Coach:
    def __init__(self, game, network, args):
        # Initialize the coach with the game, neural network, and arguments.
        self.game = game
        self.network = network
        self.args = args
        self.mcts = MCTS(self.game, self.network, self.args)
        self.memory = deque(maxlen=self.args['maxBufferSize'])

    # Main function to run the training process
    def run(self):
        # Loops through the total number of iiterations
        for i in range(1, self.args['numIterations'] + 1):
            print(f"Starting Iteration {i}")

            # Get examples from a training episode
            newExamples = self.executeEpisode()

            # Add it to the memory
            self.memory.extend(newExamples)

            # Trains the model
            self.train()

            # Save the weights
            self.network.save_weights(f"modelIteration_{i}.weights.h5")

    # A single episode of the game
    def executeEpisode(self):
        episodeTrainingData = []
        currentPlayer = 1
        state = self.game.getInitialState()

        turn = 0
        while True: # Game loop
            turn += 1
            print(f"Turn {turn}...")

            root = self.mcts.search(state, currentPlayer) # Runs MCTS search to find the best move

            # Selects moves using visits in the tree (more visits = likely a good move)
            policy = np.zeros(self.game.width * self.game.length)
            for child in root.children:
                moveIndex = child.move[0] * self.game.length + child.move[1]
                policy[moveIndex] = child.visits

            policySum = np.sum(policy) # Total visits
            if policySum > 0:
                policy /= policySum # Turns to probabilities (divides all the visits out of the total)
            else:
                print("Warning: MCTS search resulted in zero visits. Using uniform policy.")

                # Treat everything equally likely if something went wrong
                validMoves = self.game.getValidMoves(state)
                policy = np.zeros_like(policy)
                if len(validMoves) > 0:
                  prob = 1 / len(validMoves)
                  for x, z in validMoves:
                      policy[x * self.game.length + z] = prob 

            # Add the move data to trainiinig data
            episodeTrainingData.append((state, currentPlayer, policy)) #appends it to train later

            # Selects based on temperature (helps with exploration)
            if turn < self.args["temperatureMoves"]:
                # High temperature: sample randomly according to the policy
                actionIndex = np.random.choice(len(policy), p=policy)
            else:
                # Low temperature: pick the best move (most visited)
                actionIndex = np.argmax(policy)
            
            # Encodes move
            move = (actionIndex // self.game.length, actionIndex % self.game.length)
            print(move)

            state = self.game.getNextState(state, currentPlayer, move) # Moves the state
            currentPlayer = -currentPlayer # As it's now the next player's turn

            gameResult = self.game.gameEnded(state)

            # If the game has ended
            if gameResult is not None:
                finalExamples = []
                for histState, histPlayer, histPolicy in episodeTrainingData:
                    value = gameResult * histPlayer

                    # Additionally adds the final output of the game to fit while training
                    finalExamples.append((
                        encodeState(histState, histPlayer),
                        histPolicy,
                        value
                    ))
                return finalExamples

    # Trains the neural network using the collected memory
    def train(self):
        batchSize = self.args['batchSize']
        if len(self.memory) < batchSize: # Not enough memory to train
            return

        # Sample random pieces of memory to avoid sequential patterns
        sampleIds = np.random.randint(len(self.memory), size=batchSize)
        batch = [self.memory[i] for i in sampleIds]

        # Helps to prepare the data for training
        states, policies, values = zip(*batch)
        states = np.array(states)
        policies = np.array(policies)
        values = np.array(values)
        states = np.transpose(states, (0, 2, 3, 4, 1))

        # Fit the model with the prepared data
        self.network.fit(
            states,
            {'policyOut': policies, 'valueOut': values},
            batch_size=batchSize,
            epochs=self.args['epochs'],
            verbose=0
        )

In [None]:
# Running the program

game = Game()
network = NeuralNetwork()

#network.model.load_weights("")

coach = Coach(game, network.model, args)
coach.run()