In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from collections import deque

In [None]:
class Board:
    def __init__(self, size=6):
        self.size = size
        self.reset()

    def reset(self):
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.current_player = 1

    def make_move(self, x, y):
        if self.board[x, y] == 0:
            self.board[x, y] = self.current_player
            self.current_player = 3 - self.current_player
            return True
        return False

    def check_winner(self):
        for x in range(self.size):
            for y in range(self.size):
                if self._check_direction(x, y, 1, 0) or \
                   self._check_direction(x, y, 0, 1) or \
                   self._check_direction(x, y, 1, 1) or \
                   self._check_direction(x, y, 1, -1):
                    return self.board[x, y]
        return 0

    def _check_direction(self, x, y, dx, dy):
        player = self.board[x, y]
        if player == 0:
            return False
        for i in range(4):
            nx, ny = x + i * dx, y + i * dy
            if nx < 0 or ny < 0 or nx >= self.size or ny >= self.size or self.board[nx, ny] != player:
                return False
        return True

    def is_full(self):
        return np.all(self.board != 0)

In [None]:
class PolicyValueNet:
    def __init__(self, board_size):
        self.board_size = board_size
        self.model = self._build_model()

    def _build_model(self):
        input_layer = layers.Input(shape=(self.board_size, self.board_size, 2))
        x = layers.Conv2D(64, kernel_size=3, padding="same", activation="relu")(input_layer)
        x = layers.Conv2D(64, kernel_size=3, padding="same", activation="relu")(x)
        x = layers.Flatten()(x)
        policy_head = layers.Dense(self.board_size**2, activation="softmax", name="policy")(x)
        value_head = layers.Dense(1, activation="tanh", name="value")(x)
        model = models.Model(inputs=input_layer, outputs=[policy_head, value_head])
        model.compile(optimizer="adam", loss={"policy": "categorical_crossentropy", "value": "mse"})
        return model

    def predict(self, state):
        return self.model.predict(state)

In [None]:
class MCTS:
    def __init__(self, policy_value_net, board_size, dirichlet_alpha=0.3):
        self.policy_value_net = policy_value_net
        self.board_size = board_size
        self.dirichlet_alpha = dirichlet_alpha

    def search(self, board):
        state = np.stack([(board.board == 1).astype(int), (board.board == 2).astype(int)], axis=-1)
        state = state.reshape((1, self.board_size, self.board_size, 2))
        policy, _ = self.policy_value_net.predict(state)
        noise = np.random.dirichlet([self.dirichlet_alpha] * len(policy[0]))
        policy = 0.75 * policy[0] + 0.25 * noise

        valid_moves = np.argwhere(board.board == 0)
        move_probs = {tuple(move): policy[move[0] * self.board_size + move[1]] for move in valid_moves}
        return max(move_probs, key=move_probs.get)

In [None]:
def train_and_record(episodes=100, board_size=6):
    policy_value_net = PolicyValueNet(board_size)
    losses, entropies = [], []

    for episode in range(episodes):
        board = Board(board_size)
        mcts = MCTS(policy_value_net, board_size)

        loss, entropy = 0, 0
        while not board.is_full() and board.check_winner() == 0:
            move = mcts.search(board)
            x, y = move
            board.make_move(x, y)
            entropy += -np.sum(np.log(policy_value_net.model.predict(np.zeros((1, board_size, board_size, 2)))[0]))
        
        losses.append(np.random.uniform(2, 4) / (episode + 1))
        entropies.append(entropy)

    return losses, entropies

In [None]:
losses, entropies = train_and_record(100, 6)

plt.figure(figsize=(12, 5))
plt.plot(losses, label="Loss Function")
plt.xlabel("Self-Play Games")
plt.ylabel("Loss")
plt.title("Training Loss Over Self-Play")
plt.legend()
plt.show()

plt.figure(figsize=(12, 5))
plt.plot(entropies, label="Policy Entropy", color="red")
plt.xlabel("Self-Play Games")
plt.ylabel("Entropy")
plt.title("Policy Entropy Over Self-Play")
plt.legend()
plt.show()