# A Knucklebones AI

Knucklebones is a highly random dice game developed by the studio [Massive Monster](https://massivemonster.co) available to play online [here](https://knucklebones.io). There you can also find a summary of the rules of the game; a more detailed explanation also given on the [Fandom Wiki](https://cult-of-the-lamb.fandom.com/wiki/Knucklebones). The goal of this project is to train an AI for Knucklebones using various methods and compare the resulting models.

In [1]:
using POMDPs
using Random
using StaticArrays
using MCTS
using LinearAlgebra
using DataStructures
using Flux
using Random

### Environment Definition

We start by modelling the game as a Markov Decision Process using the [POMDPs.jl interface](https://github.com/JuliaPOMDP/POMDPs.jl). Our models will provide a ranking of their preferences for the three options (encoded as one of the 6 permutations) and the game will chooses the valid actions with the highest preference. This prevents the AI from making invalid actions, so placing a dice in an already full column.

In [2]:
struct KnucklebonesState
    board::SMatrix{2,3,SVector{3,Int},6}
    current_player::Int
    current_dice::Int
end

struct KnucklebonesMDP <: MDP{KnucklebonesState, Int}
    discount_factor::Float64
end

POMDPs.discount(mdp::KnucklebonesMDP) = mdp.discount_factor

# Define permutations
const PERMUTATIONS = [
    [1, 2, 3],
    [1, 3, 2],
    [2, 1, 3],
    [2, 3, 1],
    [3, 1, 2],
    [3, 2, 1]
]

function POMDPs.gen(mdp::KnucklebonesMDP, s::KnucklebonesState, a::Int, rng::AbstractRNG)
    new_board = copy(s.board)
    player = s.current_player
    opponent = 3 - player

    # Find the first valid column based on the action's permutation
    column = first(filter(c -> any(new_board[player, c] .== 0), PERMUTATIONS[a]))

    # Place dice
    for row in 1:3
        if new_board[player, column][row] == 0
            new_board = setindex(new_board, setindex(new_board[player, column], s.current_dice, row), player, column)
            break
        end
    end

    # Remove opponent's dice
    opponent_column = new_board[opponent, column]
    new_opponent_column = sort(filter(x -> x != s.current_dice, opponent_column), rev=true)
    new_opponent_column = vcat(new_opponent_column, zeros(Int, 3 - length(new_opponent_column)))
    new_board = setindex(new_board, SVector{3,Int}(new_opponent_column), opponent, column)

    # Calculate reward
    player_score = calculate_score(new_board[player, :])
    opponent_score = calculate_score(new_board[opponent, :])
    reward = player_score - opponent_score

    # New state
    new_state = KnucklebonesState(new_board, opponent, rand(rng, 1:6))

    return (sp=new_state, r=reward)
end

function calculate_score(player_board)
    score = 0
    for column in player_board
        unique_vals = unique(filter(!iszero, column))
        for val in unique_vals
            c = count(==(val), column)
            score += val * c * c
        end
    end
    return score
end

POMDPs.actions(mdp::KnucklebonesMDP, s::KnucklebonesState) = 1:6

function POMDPs.isterminal(mdp::KnucklebonesMDP, s::KnucklebonesState)
    return playerFinished(mdp, s, 1) || playerFinished(mdp, s, 2)
end

function playerFinished(mdp::KnucklebonesMDP, s::KnucklebonesState, p::Int)
    for row in 1:3
        for col in 1 : 3
            if iszero(s.board[p, col][row])
                return false  # player p can still play
            end
        end
    end

    return true
end

function POMDPs.initialstate(mdp::KnucklebonesMDP)
    return KnucklebonesState(
        @SMatrix[SVector(0,0,0) SVector(0,0,0) SVector(0,0,0);
                 SVector(0,0,0) SVector(0,0,0) SVector(0,0,0)],
        1,
        rand(1:6)
    )
end

As a reward function we will use the difference between the score of the model player and the score of the opponent in the next round. The aim is to encourage the model to maximize its own score but also prevent the opponent from achieving a high score. Now we need a function to evaluate two models playing against each other:

In [3]:
function play_game(mdp::KnucklebonesMDP, policy1, policy2)
    s = initialstate(mdp)

    while !isterminal(mdp, s)
        a = s.current_player == 1 ? action(policy1, s) : action(policy2, s)
        s, _ = gen(mdp, s, a, Random.GLOBAL_RNG)
    end

    player1_score = calculate_score(s.board[1, :])
    player2_score = calculate_score(s.board[2, :])
    return player1_score - player2_score
end

function evaluate_policies(mdp::KnucklebonesMDP, policy1, policy2, n_games::Int)
    scores = [play_game(mdp, policy1, policy2) for _ in 1:n_games]
    policy1_wins = count(s -> s > 0, scores)
    policy2_wins = count(s -> s < 0, scores)
    draws = count(iszero, scores)

    println("Policy 1 wins: $policy1_wins")
    println("Policy 2 wins: $policy2_wins")
    println("Draws: $draws")
end


evaluate_policies (generic function with 1 method)

To check that our game implementation is working correctly, we can have a look at an example game between models making uniformly random actions:

In [4]:
function print_board(s::KnucklebonesState)
    for player in 1:2
        println("Player $player:")
        for row in 1:3
            println(join([s.board[player, col][row] for col in 1:3], " "))
        end
        println("Score: $(calculate_score(s.board[player, :]))")
        println()
    end
    println("Current player: $(s.current_player), Current dice: $(s.current_dice)")
end

function play_example_game(mdp::KnucklebonesMDP, policy1, policy2)
    s = initialstate(mdp)
    
    while !isterminal(mdp, s)
        print_board(s)
        a = s.current_player == 1 ? action(policy1, s) : action(policy2, s)
        println("Player $(s.current_player) chooses action $a")
        s, r = gen(mdp, s, a, Random.GLOBAL_RNG)
        println("Reward: $r")
        println("------------------------")
    end
    
    print_board(s)
    println("Game Over")
    
    player1_score = calculate_score(s.board[1, :])
    player2_score = calculate_score(s.board[2, :])
    println("Final Scores - Player 1: $player1_score, Player 2: $player2_score")
    if player1_score > player2_score
        println("Player 1 wins!")
    elseif player2_score > player1_score
        println("Player 2 wins!")
    else
        println("It's a draw!")
    end
end

struct RandomPolicy <: Policy end
POMDPs.action(::RandomPolicy, s::KnucklebonesState) = rand(1:6)

mdp = KnucklebonesMDP(0.95)
random_policy = RandomPolicy()

play_example_game(mdp, random_policy, random_policy)

Player 1:
0 0 0
0 0 0
0 0 0
Score: 0

Player 2:
0 0 0
0 0 0
0 0 0
Score: 0

Current player: 1, Current dice: 5
Player 1 chooses action 2
Reward: 5
------------------------
Player 1:
5 0 0
0 0 0
0 0 0
Score: 5

Player 2:
0 0 0
0 0 0
0 0 0
Score: 0

Current player: 2, Current dice: 2
Player 2 chooses action 4
Reward: -3
------------------------
Player 1:
5 0 0
0 0 0
0 0 0
Score: 5

Player 2:
0 2 0
0 0 0
0 0 0
Score: 2

Current player: 1, Current dice: 1
Player 1 chooses action 5
Reward: 4
------------------------
Player 1:
5 0 1
0 0 0
0 0 0
Score: 6

Player 2:
0 2 0
0 0 0
0 0 0
Score: 2

Current player: 2, Current dice: 4
Player 2 chooses action 6
Reward: 0
------------------------
Player 1:
5 0 1
0 0 0
0 0 0
Score: 6

Player 2:
0 2 4
0 0 0
0 0 0
Score: 6

Current player: 1, Current dice: 5
Player 1 chooses action 5
Reward: 5
------------------------
Player 1:
5 0 1
0 0 5
0 0 0
Score: 11

Player 2:
0 2 4
0 0 0
0 0 0
Score: 6

Current player: 2, Current dice: 1
Player 2 chooses action 1
R

### Monte Carlo tree search

For our first model, we will employ the heuristic search algorithm [MCTS](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search), which is very well suited for highly random games like Knucklebones.

In [5]:
function train_mcts(mdp::KnucklebonesMDP, n_iterations::Int, exploration_constant::Float64)
    solver = MCTSSolver(n_iterations=n_iterations, depth=20, exploration_constant=exploration_constant)
    return solve(solver, mdp)
end

mcts_policy = train_mcts(mdp, 1000, 5.0)

MCTSPlanner{KnucklebonesMDP, KnucklebonesState, Int64, MCTS.SolvedRolloutEstimator{POMDPTools.Policies.RandomPolicy{Random._GLOBAL_RNG, KnucklebonesMDP, POMDPTools.BeliefUpdaters.NothingUpdater}, Random._GLOBAL_RNG}, Random._GLOBAL_RNG}(MCTSSolver(1000, Inf, 20, 5.0, Random._GLOBAL_RNG(), RolloutEstimator(POMDPTools.Policies.RandomSolver(Random._GLOBAL_RNG()), 50, 0.0), 0.0, 0, false, false, MCTS.var"#5#7"()), KnucklebonesMDP(0.95), MCTS.MCTSTree{KnucklebonesState, Int64}(Dict{KnucklebonesState, Int64}(), Vector{Int64}[], Int64[], KnucklebonesState[], Int64[], Float64[], Int64[], Dict{Pair{Int64, Int64}, Int64}()), MCTS.SolvedRolloutEstimator{POMDPTools.Policies.RandomPolicy{Random._GLOBAL_RNG, KnucklebonesMDP, POMDPTools.BeliefUpdaters.NothingUpdater}, Random._GLOBAL_RNG}(POMDPTools.Policies.RandomPolicy{Random._GLOBAL_RNG, KnucklebonesMDP, POMDPTools.BeliefUpdaters.NothingUpdater}(Random._GLOBAL_RNG(), KnucklebonesMDP(0.95), POMDPTools.BeliefUpdaters.NothingUpdater()), Random._GLOBAL

In [6]:
# Evaluate MCTS against Random
println("Evaluating MCTS against Random:")
evaluate_policies(mdp, mcts_policy, random_policy, 200)

println("\nPlaying an example game (MCTS vs Random):")
play_example_game(mdp, mcts_policy, random_policy)

Evaluating MCTS against Random:
Policy 1 wins: 163
Policy 2 wins: 35
Draws: 2

Playing an example game (MCTS vs Random):
Player 1:
0 0 0
0 0 0
0 0 0
Score: 0

Player 2:
0 0 0
0 0 0
0 0 0
Score: 0

Current player: 1, Current dice: 1
Player 1 chooses action 5
Reward: 1
------------------------
Player 1:
0 0 1
0 0 0
0 0 0
Score: 1

Player 2:
0 0 0
0 0 0
0 0 0
Score: 0

Current player: 2, Current dice: 4
Player 2 chooses action 5
Reward: 3
------------------------
Player 1:
0 0 1
0 0 0
0 0 0
Score: 1

Player 2:
0 0 4
0 0 0
0 0 0
Score: 4

Current player: 1, Current dice: 4
Player 1 chooses action 6
Reward: 5
------------------------
Player 1:
0 0 1
0 0 4
0 0 0
Score: 5

Player 2:
0 0 0
0 0 0
0 0 0
Score: 0

Current player: 2, Current dice: 1
Player 2 chooses action 3
Reward: -4
------------------------
Player 1:
0 0 1
0 0 4
0 0 0
Score: 5

Player 2:
0 1 0
0 0 0
0 0 0
Score: 1

Current player: 1, Current dice: 5
Player 1 chooses action 1
Reward: 9
------------------------
Player 1:
5 0 1
0 

We already observe a very strong performance against a model making random moves, due to the big influence of the dice rolls even a perfect player has always a chance of loosing. Now let's see if we can improve the performance using Reinforcement Learning.

### Q-Learning

[Q-Learning](https://en.wikipedia.org/wiki/Q-learning) is one of the standard algorithms in Reinforcement Learning and estimates the expected value of each action state pair. The necessary functions for the training are quickly implemented:

In [7]:
# Define Q-learning policy
mutable struct QLearningPolicy <: Policy
    Q::Dict{Tuple{KnucklebonesState, Int}, Float64}
    ϵ::Float64  # Epsilon for ϵ-greedy exploration
end

function QLearningPolicy(ϵ::Float64 = 0.1)
    return QLearningPolicy(Dict{Tuple{KnucklebonesState, Int}, Float64}(), ϵ)
end

function POMDPs.action(policy::QLearningPolicy, s::KnucklebonesState)
    if rand() < policy.ϵ
        return rand(1:6)
    else
        return argmax(a -> get(policy.Q, (s, a), 0.0), 1:6)
    end
end

# Q-learning update function
function update!(policy::QLearningPolicy, s::KnucklebonesState, a::Int, r, s_next::KnucklebonesState, α::Float64, γ::Float64)
    current_q = get(policy.Q, (s, a), 0.0)
    next_max_q = maximum(get(policy.Q, (s_next, a_next), 0.0) for a_next in 1:6)
    policy.Q[(s, a)] = current_q + α * (r + γ * next_max_q - current_q)
end

# Training function for Q-learning
function train_qlearning(mdp::KnucklebonesMDP, n_episodes::Int, α::Float64, γ::Float64, ϵ::Float64)
    policy = QLearningPolicy(ϵ)
    
    for episode in 1:n_episodes
        s = initialstate(mdp)
        while !isterminal(mdp, s)
            a = action(policy, s)
            s_next, r = gen(mdp, s, a, Random.GLOBAL_RNG)
            update!(policy, s, a, r, s_next, α, γ)
            s = s_next
        end
        
        if episode % 10000 == 0
            println("Completed episode $episode")
        end
    end
    
    return policy
end

train_qlearning (generic function with 1 method)

In [8]:
# Train Q-learning policy
println("Training Q-learning policy...")
qlearning_policy = train_qlearning(mdp, 100000, 0.1, 0.95, 0.1)

# Evaluate Q-learning against Random
println("\nEvaluating Q-learning against Random:")
evaluate_policies(mdp, qlearning_policy, random_policy, 200)

Training Q-learning policy...
Completed episode 10000
Completed episode 20000
Completed episode 30000
Completed episode 40000
Completed episode 50000
Completed episode 60000
Completed episode 70000
Completed episode 80000
Completed episode 90000
Completed episode 100000

Evaluating Q-learning against Random:
Policy 1 wins: 109
Policy 2 wins: 86
Draws: 5


It quickly shows that the Q-Learning approach leads to a weaker model than MCTS. The probable cause is the vastness of the possible game states, making identical action state pairs across repeated games very unlikely and Q-Learning ineffective at modelling the true expected value for the reward.

### Deep Q-Network

To counteract the problem of a vast state space, DeepMind developed the (DQN algorithm)[https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf], which combines traditional Reinforcement Learning with Deep Neural Networks. We will use the Machine Learning library [Flux.jl](https://fluxml.ai/Flux.jl/stable/) to implement the Deep Neural Network.

In [12]:
mutable struct DQNPolicy
    q_network::Chain
    ϵ::Float64  # Exploration rate
end

# Epsilon-greedy action selection
function POMDPs.action(policy::DQNPolicy, s::KnucklebonesState)
    if rand() < policy.ϵ
        return rand(1:6)
    else
        q_values = policy.q_network(state_to_vector(s))
        return argmax(q_values)
    end
end

# Update epsilon to reduce exploration over time
function update_epsilon!(policy::DQNPolicy, min_ϵ::Float64, decay::Float64)
    policy.ϵ = max(min_ϵ, policy.ϵ * decay)
end

# State to vector function
function state_to_vector(s::KnucklebonesState)
    # Flatten the board (2 players * 3 columns * 3 rows = 18 elements)
    board_floats = Float32[s.board[player, col][row] for player in 1:2, col in 1:3, row in 1:3]
    
    # Add current player and current dice (2 more elements)
    return vcat(vec(board_floats), Float32(s.current_player), Float32(s.current_dice))
end

# DQN training function
function train_dqn(mdp, n_episodes::Int, batch_size::Int, ϵ::Float64)
    state_dim = 20  # Assuming state is a 20-dimensional vector
    action_dim = 6   # Assuming 6 possible actions
    
    # Define the Q-network (input: state_dim, output: action_dim)
    q_network = Chain(
        Dense(state_dim, 64, relu),
        Dense(64, 64, relu),
        Dense(64, action_dim)
    )
    
    target_network = deepcopy(q_network)
    optimizer = ADAM(0.001)
    
    replay_buffer = []
    policy = DQNPolicy(q_network, ϵ)

    # Training loop
    for episode in 1:n_episodes
        s = initialstate(mdp)  # Get the initial state of the episode
        
        while !isterminal(mdp, s)
            # Select action based on epsilon-greedy policy
            state_vec = state_to_vector(s)
            a = select_action(policy, state_vec)
            
            # Take the action, observe the next state and reward
            s_next, r = gen(mdp, s, a, Random.GLOBAL_RNG)
            
            # Add experience to replay buffer
            push!(replay_buffer, (state_to_vector(s), a, r, state_to_vector(s_next), isterminal(mdp, s_next)))
            if length(replay_buffer) > 10000
                popfirst!(replay_buffer)  # Maintain buffer size
            end
            
            s = s_next  # Move to the next state
            
            # Perform training if buffer has enough experiences
            if length(replay_buffer) >= batch_size
                batch_indices = rand(1:length(replay_buffer), batch_size)
                batch = replay_buffer[batch_indices]

                # Prepare batch data (states, actions, rewards, next states, and terminal flags)
                states = hcat([b[1] for b in batch]...)
                actions = [b[2] for b in batch]
                rewards = [b[3] for b in batch]
                next_states = hcat([b[4] for b in batch]...)
                dones = [b[5] for b in batch]

                # Calculate current Q-values
                current_q_values = q_network(states)

                # Calculate target Q-values using the target network
                next_q_values = target_network(next_states)
                max_next_q_values = [maximum(next_q) for next_q in eachcol(next_q_values)]
                targets = [r + (1 - d) * 0.99 * max_q for (r, d, max_q) in zip(rewards, dones, max_next_q_values)]
                
                # Update Q-values for the taken actions
                q_updates = copy(current_q_values)
                for i in 1:batch_size
                    q_updates[actions[i], i] = targets[i]
                end

                # Define the loss function
                function compute_loss()
                    return Flux.mse(current_q_values, q_updates)
                end

                # Manually calculate gradients and update parameters
                grads = Flux.gradient(() -> compute_loss(), Flux.params(q_network))

                # Perform the optimizer step
                Flux.Optimise.update!(optimizer, Flux.params(q_network), grads)

            end
        end

        # Update the target network every 100 episodes
        if episode % 100 == 0
            target_network = deepcopy(q_network)
        end
        
        # Decay epsilon to reduce exploration over time
        update_epsilon!(policy, 0.1, 0.995)
        
        # Print progress every 1000 episodes
        if episode % 10000 == 0
            println("Completed episode $episode, ϵ: $(policy.ϵ)")
        end
    end
    
    return policy
end


train_dqn (generic function with 1 method)

In [14]:
# Train DQN policy
println("Training DQN policy...")
dqn_policy = train_dqn(mdp, 100000, 32, 0.1)

# Evaluate DQN against Random
println("\nEvaluating DQN against Random:")
evaluate_policies(mdp, dqn_policy, random_policy, 200)

Training DQN policy...
Completed episode 10000, ϵ: 0.1
Completed episode 20000, ϵ: 0.1
Completed episode 30000, ϵ: 0.1
Completed episode 40000, ϵ: 0.1
Completed episode 50000, ϵ: 0.1
Completed episode 60000, ϵ: 0.1
Completed episode 70000, ϵ: 0.1
Completed episode 80000, ϵ: 0.1
Completed episode 90000, ϵ: 0.1
Completed episode 100000, ϵ: 0.1

Evaluating DQN against Random:
Policy 1 wins: 95
Policy 2 wins: 98
Draws: 7


The deep neural network did not seem to improve the model performance of Q-Learning at all, so both models should be no match for the original MCTS model:

In [15]:
println("\nEvaluating policies (MCTS vs Q-learning)")
evaluate_policies(mdp, mcts_policy, qlearning_policy, 200)

println("\nEvaluating policies (Q-learning vs MCTS)")
evaluate_policies(mdp, qlearning_policy, mcts_policy, 200)

println("\nEvaluating policies (MCTS vs DQN)")
evaluate_policies(mdp, mcts_policy, dqn_policy, 200)

println("\nEvaluating policies (DQN vs MCTS)")
evaluate_policies(mdp, dqn_policy, mcts_policy, 200)

println("\nEvaluating policies (Q-learning vs DQN)")
evaluate_policies(mdp, qlearning_policy, dqn_policy, 200)

println("\nEvaluating policies (DQN vs Q-learning)")
evaluate_policies(mdp, dqn_policy, qlearning_policy, 200)



Evaluating policies (MCTS vs Q-learning)
Policy 1 wins: 150
Policy 2 wins: 47
Draws: 3

Evaluating policies (Q-learning vs MCTS)
Policy 1 wins: 48
Policy 2 wins: 149
Draws: 3

Evaluating policies (MCTS vs DQN)
Policy 1 wins: 157
Policy 2 wins: 41
Draws: 2

Evaluating policies (DQN vs MCTS)
Policy 1 wins: 46
Policy 2 wins: 153
Draws: 1

Evaluating policies (Q-learning vs DQN)
Policy 1 wins: 93
Policy 2 wins: 99
Draws: 8

Evaluating policies (DQN vs Q-learning)
Policy 1 wins: 115
Policy 2 wins: 83
Draws: 2


We can observe that DQN outperforms Q-learning, but both are no match for MCTS.

(c) Mia Müßig