Implementierung von AlphaZero nach

In [1]:
import numpy as np
print(np.__version__)

import torch
print(torch.__version__)

import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

from tqdm.notebook import trange

import random
import math
import pandas as pd
from games import ConnectFour

1.24.3
2.0.1+cpu


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Definition of the Game Environment

In [3]:
class ConnectFour:
    def __init__(self):
        self.row_count = 6
        self.column_count = 7
        self.action_size = self.column_count
        self.in_a_row = 4
        
    def __repr__(self):
        return "ConnectFour"
        
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    # def get_next_state(self, state, action, player):
    #     row = np.max(np.where(state[:, action] == 0))
    #     state[row, action] = player
    #     return state

    def get_next_state(self, state, action, player):
        """
        Get the game state after a given action.
        The player drops a piece in the specified column (action).
        """
        # Find the lowest available row in the column
        row = np.max(np.where(state[:, action] == 0), initial=-1)
        if row == -1:
            # If the column is full, return the current state
            return state

        next_state = state.copy()
        next_state[row, action] = player
        return next_state


    
    def get_valid_moves(self, state):
        if len(state.shape) == 3:
            return (state[:, 0] == 0).astype(np.uint8)
        return (state[0] == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if action == None:
            return False
        
        row = np.min(np.where(state[:, action] != 0))
        column = action
        player = state[row][column]

        def count(offset_row, offset_column):
            for i in range(1, self.in_a_row):
                r = row + offset_row * i
                c = action + offset_column * i
                if (
                    r < 0 
                    or r >= self.row_count
                    or c < 0 
                    or c >= self.column_count
                    or state[r][c] != player
                ):
                    return i - 1
            return self.in_a_row - 1

        return (
            count(1, 0) >= self.in_a_row - 1 # vertical
            or (count(0, 1) + count(0, -1)) >= self.in_a_row - 1 # horizontal
            or (count(1, 1) + count(-1, -1)) >= self.in_a_row - 1 # top left diagonal
            or (count(1, -1) + count(-1, 1)) >= self.in_a_row - 1 # top right diagonal
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player
    
    def get_encoded_state(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)
        
        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)
        
        return encoded_state
    
    def get_single_array_state(self, state):
        """
        Return the state as a single 2D array. MCTS player's pieces are marked as -1, 
        the random actor's pieces as 1, and all other fields as 0.
        """
        single_array_state = np.zeros_like(state[0])
        single_array_state += state[0] * -1
        single_array_state += state[1]
        return single_array_state


# Definition of Neural Network for AlphaZero

In [4]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        
        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
        self.to(device)
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
        
        
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

# MCTS Implementation for AlphaZero

In [5]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior
    
    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)
                
        return child
            
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)  


class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, state):
        root = Node(self.game, self.args, state, visit_count=1)
        
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
        
        valid_moves = self.game.get_valid_moves(state)
        policy *= valid_moves
        policy /= np.sum(policy)
        root.expand(policy)
        
        for search in range(self.args['num_searches']):
            node = root
            
            while node.is_fully_expanded():
                node = node.select()
                
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)
            
            if not is_terminal:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                valid_moves = self.game.get_valid_moves(node.state)
                policy *= valid_moves
                policy /= np.sum(policy)
                
                value = value.item()
                
                node.expand(policy)
                
            node.backpropagate(value)    
            
            
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs

# AlphaZero Model

In [6]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.policy_losses = list()
        self.value_losses = list()
        self.losses = list()
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)

    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()

        while True:
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state)

            memory.append((neutral_state, action_probs, player))

            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            action = np.random.choice(self.game.action_size, p=action_probs) # change to temperature_action_probs

            state = self.game.get_next_state(state, action, player)

            value, is_terminal = self.game.get_value_and_terminated(state, action)

            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory

            player = self.game.get_opponent(player)

    def train(self, memory):
        random.shuffle(memory)
        self.train_results = pd.DataFrame(columns=['policy_loss', 'value_loss', 'total_loss'])
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            state, policy_targets, value_targets = zip(*sample)

            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(state)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.policy_losses.append(policy_loss.item())
            self.value_losses.append(value_loss.item())
            self.losses.append(loss.item())

            self.train_results.loc[len(self.train_results)] = [policy_loss.item(), value_loss.item(), loss.item()]

            optimizer.zero_grad() # change to self.optimizer
            loss.backward()
            optimizer.step() # change to self.optimizer

        return loss.item()


    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()

            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                loss = self.train(memory)  # modified to get the loss

            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

## AlphaZero Training

In [7]:
game = ConnectFour()

model = ResNet(game, 9, 128, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

# alphaZero = AlphaZero(model, optimizer, game, args)
# alphaZero.learn()

## Alpha Zero Loss Evaluation

In [8]:
# df = pd.DataFrame({
#     'Policy Loss': alphaZero.policy_losses,
#     'Value Loss': alphaZero.value_losses,
#     'Total Loss': alphaZero.losses
# })

In [9]:
# def visualize_loss():
#         sns.set(style="darkgrid", rc={'figure.figsize':(10, 6)})
#         ax = sns.lineplot(data=df)# alphaZero.train_results)
#         ax.set(xlabel='Batch', ylabel='Loss', title='Training Loss')
#         ax.legend(title='Loss Type', loc='upper right', labels=['Policy Loss', 'Value Loss', 'Total Loss'])
#         plt.show()

# visualize_loss()

# MCTS Implementation for MCTS Standalone

In [10]:
# Define the Node class which will be used in MCTS
class StandaloneNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.children = {}
        self.visits = 0
        self.rewards = 0
        self.unvisited_actions = list(range(self.state.shape[1]))  # initially all actions are unvisited
        self.action = action  # action that led to this state

# Define the MCTS class
class StandaloneMCTS:
    def __init__(self, game, exploration_weight=1):
        self.game = game
        self.exploration_weight = exploration_weight
        self.root = StandaloneNode(game.get_initial_state())
        self.node_lookup = {}

    def uct(self, node):
        """
        Calculate UCT value for a node.
        """
        if node.visits == 0:
            return np.inf  # encourage exploration of unvisited nodes
        else:
            return node.rewards / node.visits + self.exploration_weight * np.sqrt(
                2 * np.log(node.parent.visits) / node.visits)

    def select(self, node):
        """
        Selection phase of MCTS.
        Traverse the tree until we reach a leaf node.
        """
        while len(node.unvisited_actions) == 0 and len(node.children) > 0:
            node = max(node.children.values(), key=self.uct)
        return node

    def expand(self, node):
        """
        Expansion phase of MCTS.
        Add a new child for the current node.
        """
        valid_moves = self.game.get_valid_moves(node.state)
        node.unvisited_actions = [action for action in node.unvisited_actions if valid_moves[action] == 1]
        if len(node.unvisited_actions) == 0:
            return node  # No valid moves left, return the current node
        action = node.unvisited_actions.pop()
        next_state = self.game.get_next_state(node.state.copy(), action, 1)
        child_node = StandaloneNode(next_state, parent=node, action=action)
        node.children[action] = child_node
        return child_node


    def simulate(self, node):
        """
        Simulation phase of MCTS.
        Play out a random game from the current node.
        """
        player = 1
        state = node.state.copy()
        while True:
            valid_moves = self.game.get_valid_moves(state)
            if valid_moves.sum() == 0:
                return 0  # draw
            action = np.random.choice(valid_moves.nonzero()[0])
            state = self.game.get_next_state(state, action, player)
            if self.game.check_win(state, action):
                return player
            player = self.game.get_opponent(player)

    def backpropagate(self, node, reward):
        """
        Backpropagation phase of MCTS.
        Propagate the results of the simulation back up the tree.
        """
        while node is not None:
            node.visits += 1
            node.rewards += reward
            node = node.parent

    def get_best_move(self):
        """
        Return the best move according to the MCTS.
        """
        return max(self.root.children.values(), key=lambda node: node.visits).action
    
    def run(self, state, num_iterations):
        """
        Run the MCTS algorithm for a certain number of iterations.
        """
        self.root = StandaloneNode(state)
        for _ in range(num_iterations):
            leaf_node = self.select(self.root)
            if leaf_node.visits == 0 or len(leaf_node.unvisited_actions) > 0:
                new_node = self.expand(leaf_node)
                reward = self.simulate(new_node)
                self.backpropagate(new_node, reward)
            else:
                reward = 0 if leaf_node.rewards == 0 else 1 if leaf_node.rewards > 0 else -1
                self.backpropagate(leaf_node, reward)


# Gameplay

In [11]:
# Load the model and initialize the AlphaZero MCTS
model.load_state_dict(torch.load("org_ConnectFour.pt", map_location=device))
model = ResNet(game, 9, 128, device)
model.eval()

# initialize the standalone MCTS
standalone_mcts = StandaloneMCTS(game)

## AlphaZero vs User

In [None]:
game = ConnectFour()
player = 1

args = {
    'C': 2,
    'num_searches': 100,
    'dirichlet_epsilon': 0.,
    'dirichlet_alpha': 0.3
}


mcts = MCTS(game, args, model)

state = game.get_initial_state()

while True:
    print(state)
    
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("valid_moves", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}:"))

        if valid_moves[action] == 0:
            print("action not valid")
            continue
            
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)
        
    state = game.get_next_state(state, action, player)
    
    value, is_terminal = game.get_value_and_terminated(state, action)
    
    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break
        
    player = game.get_opponent(player)

## AlphaZero vs MCTS (100 games)

In [None]:
# initialize the game
game = ConnectFour()

# initialize the standalone MCTS
standalone_mcts = StandaloneMCTS(game)

# instantiate the model with the same architecture as during training
model = ResNet(game, 9, 128, device)

# load the saved state into the model
model.load_state_dict(torch.load("org_ConnectFour.pt", map_location=device))

# list to store the results
results = []

# play 100 games
for _ in range(100):
    state = game.get_initial_state()
    player = 1  # 1 for AlphaZero, -1 for standalone MCTS

    while True:
        if player == 1:  # AlphaZero's turn
            with torch.no_grad():
                state_tensor = torch.tensor(game.get_encoded_state(state), device=device).unsqueeze(0)
                action_probs, _ = model(state_tensor)
                action_probs = torch.softmax(action_probs, dim=1).squeeze(0).cpu().numpy()
            valid_moves = game.get_valid_moves(state)
            action_probs *= valid_moves
            action_probs /= np.sum(action_probs)
            action = np.random.choice(game.action_size, p=action_probs)
            
            # apply the chosen action to the game state
            state = game.get_next_state(state, action, player)
            
            # check if this action leads to a win
            if game.check_win(state, action):
                results.append("AlphaZero model")
                break
        else:  # standalone MCTS's turn
            # run MCTS and get best action
            standalone_mcts.run(state, num_iterations=500)
            action = standalone_mcts.get_best_move()
            
            # apply the chosen action to the game state
            state = game.get_next_state(state, action, player)
            
            # check if this action leads to a win
            if game.check_win(state, action):
                results.append("Standalone MCTS")
                break

        player = game.get_opponent(player)  # switch player
    else:
        # game has ended with a draw
        results.append("Draw")

    # print the final state of each game
    print("Final state of game", len(results), ":")
    print(state)

# print the results
print("Results of 100 games:")
print("AlphaZero model wins:", results.count("AlphaZero model"))
print("Standalone MCTS wins:", results.count("Standalone MCTS"))
print("Draws:", results.count("Draw"))


## MCTS vs RandomPlayer (100 games)

In [14]:
# initialize the game
game = ConnectFour()

# initialize the standalone MCTS
standalone_mcts = StandaloneMCTS(game)

# list to store the results
results = []

# play 100 games
for _ in range(100):
    state = game.get_initial_state()
    player = 1  # 1 for MCTS, -1 for random actor

    while True:
        # # check if game has ended
        # _, is_terminal = game.get_value_and_terminated(state, None)
        # if is_terminal:
        #     break

        if player == 1:  # MCTS's turn
            # run MCTS and get best action
            standalone_mcts.run(state, num_iterations=1000)
            action = standalone_mcts.get_best_move()

             # apply the chosen action to the game state
            state = game.get_next_state(state, action, player)

            # check if this action leads to a win
            if game.check_win(state, action):
                results.append("MCTS")
                break
        else:  # random actor's turn
            # get valid moves
            valid_moves = game.get_valid_moves(state)
            # choose a random action
            action = np.random.choice(np.where(valid_moves == 1)[0])

            # apply the chosen action to the game state
            state = game.get_next_state(state, action, player)

            # check if this action leads to a win
            if game.check_win(state, action):
                results.append("Random actor")
                break

        player = game.get_opponent(player)  # switch player

    # game has ended, print the final state and record the winner
    print("Final state of game", len(results) + 1, ":")
    print(state)
    # _, is_terminal = game.get_value_and_terminated(state, None)
    # if is_terminal == 1:
    #     results.append("MCTS")
    # elif is_terminal == -1:
    #     results.append("Random actor")
    # else:
    #     results.append("Draw")

# print the results
print("Results of 100 games:")
print("MCTS wins:", results.count("MCTS"))
print("Random actor wins:", results.count("Random actor"))
print("Draws:", results.count("Draw"))

Final state of game 2 :
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  1.  0.  1.]
 [ 0.  0.  1. -1.  1.  1.  1.]
 [ 0.  0.  1.  1.  1. -1. -1.]
 [-1.  0. -1.  1. -1. -1. -1.]]
Final state of game 3 :
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  1. -1.  0.  0.  0.]
 [ 0.  1.  1.  1.  1.  0.  1.]
 [ 0. -1.  1.  1. -1.  0. -1.]
 [ 1. -1. -1.  1. -1. -1. -1.]]
Final state of game 4 :
[[ 0.  0.  0.  0. -1.  0.  0.]
 [ 0.  0.  0.  0. -1.  0.  0.]
 [ 0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  1.  1.  1.  0.]
 [-1. -1.  0.  1.  1.  1.  1.]
 [-1. -1.  1.  1. -1. -1. -1.]]
Final state of game 5 :
[[ 0.  1.  0. -1.  0.  0.  0.]
 [ 0. -1. -1.  1.  0.  0.  0.]
 [ 0.  1.  1. -1.  1. -1.  0.]
 [-1. -1.  1. -1.  1. -1.  0.]
 [ 1. -1.  1. -1.  1.  1.  0.]
 [ 1. -1. -1.  1.  1.  1. -1.]]
Final state of game 6 :
[[-1. -1.  0.  0.  1.  1.  0.]
 [ 1.  1.  0.  0. -1. -1.  0.]
 [ 1. -1.  0. -1. -1.  1.  1.]
 [-1.  1.  1.  1.  1. -1.  1.]
 [-1.  1