# How Good Are You At Guess Who?

## Overview

Source: https://fivethirtyeight.com/features/how-good-are-you-at-guess-who/

### Rules

- The board has `N` characters.
- Players choose a character as their own.
- Players alternate turns.
- At each turn, a player may:
  - Make a specific guess about the opponent's character.
    - A correct guess counts as a win.
    - An incorrect guess counts as a loss.
    - This is encoded as a `0`.
  - Ask a yes/no question to reduce the possible characters.
    - The players are good enough to create any division of remaining characters.
    - This is encoded by the number of characters in the smaller remaining group.
    - An example: an attempt to divide 4 players into 1 and 3 would be encoded as `1`.

### Questions

What's the probability of the first player winning if:
1. `N == 4`?
1. `N == 24` (original game)?
1. `N == 14`?

## Solution

While there are other ways to solve this problem, this problem is a good fit for Reinforcement Learning. The following code uses the [Reinforce package](https://github.com/JuliaML/Reinforce.jl).

In [1]:
import Pkg; Pkg.activate("..");

[32m[1mActivating[22m[39m environment at `~/vcs/notebooks/Project.toml`


In [2]:
"""
Defines the environment, policy, and state for reinforcement learning.
"""
module GuessWho

import Reinforce: 
    AbstractEnvironment, AbstractPolicy, action, actions, finished, maxsteps, state, step!, reset!, reward

export GuessWhoEnvironment, GuessWhoPolicy, GuessWhoState

"""
GuessWho game state as defined by the number of characters remaining on
the player's board as well as the number of players on the opponent's
board.
"""
struct GuessWhoState
    ownchars::Int
    oppchars::Int
end

GuessWhoState(n) = GuessWhoState(n, n)

flipstate(s::GuessWhoState) = GuessWhoState(s.oppchars, s.ownchars)

"""
Parent class for policies that play GuessWho.
"""
abstract type GuessWhoPolicy <: AbstractPolicy end

"""
Reinforcement Learning environment that stores all the information needed 
to play GuessWho.
"""
mutable struct GuessWhoEnvironment <: AbstractEnvironment
    n::Int  # number of characters
    s::GuessWhoState  # state
    r::Int  # reward
    o::GuessWhoPolicy  # opponent policy
    GuessWhoEnvironment(n, o) = new(n, GuessWhoState(n), 0, o)
end

reset!(env::GuessWhoEnvironment) = (env.s = GuessWhoState(env.n); env.r = 0; env)

function makestep(s::GuessWhoState, a::Int)
    if a == 0
        # Guess and end game
        r = rand() < (1 / s.ownchars) ? 1 : -1
        s = GuessWhoState(0, 0)
    else
        r = 0
        remaining = rand() < (a / s.ownchars) ? a : (s.ownchars - a)
        s = GuessWhoState(remaining, s.oppchars)
    end
    return r, s
end

function step!(env::GuessWhoEnvironment, s::GuessWhoState, a::Int)
    r, s = makestep(s, a)
    if !finished(env, s)
        # take random step as opponent
        os = flipstate(s)
        A = actions(env, os)
        oa = action(env.o, 0, os, A)
        or, os = makestep(os, oa)
        r = -1 * or
        s = flipstate(os)
    end
    env.r = r
    env.s = s
    return (env.r, env.s)
end
state(env::GuessWhoEnvironment) = env.s
reward(env::GuessWhoEnvironment) = env.r
maxsteps(env::GuessWhoEnvironment) = env.n + 1
actions(env::GuessWhoEnvironment, s::GuessWhoState) = collect(0:(s.ownchars ÷ 2))
finished(env::GuessWhoEnvironment, s′::GuessWhoState) = (s′.ownchars == 0 || s′.oppchars == 0)

end  # module

In [3]:
"""
A minimal implementation of a GuessWho policy (random play).
"""
module RandomGuessWho

import Reinforce: action
import Main.GuessWho: GuessWhoPolicy, GuessWhoState

export RandomGuessWhoPolicy

struct RandomGuessWhoPolicy <: GuessWhoPolicy end
action(p::RandomGuessWhoPolicy, r::Int, s::GuessWhoState, A) = rand(A)

end # module

In [4]:
"""
A policy for GuessWho that improves from prior game play
"""
module LearningGuessWho

import OnlineStats: fit!, Mean, value
import Reinforce: action
import Main.GuessWho: GuessWhoPolicy, GuessWhoState

export LearningGuessWhoPolicy, learn!

# define learning policy
"""
An ε-greedy GuessWhoPolicy
"""
struct LearningGuessWhoPolicy <: GuessWhoPolicy
    ε::Float64
    states::Dict
end

LearningGuessWhoPolicy(ε) = LearningGuessWhoPolicy(ε, Dict())
LearningGuessWhoPolicy(p::LearningGuessWhoPolicy, ε) = LearningGuessWhoPolicy(ε, p.states)

function action(p::LearningGuessWhoPolicy, r::Int, s::GuessWhoState, A)
    if rand() < p.ε || !haskey(p.states, s)
        return rand(A)
    else
        possible_rewards = [value(get(p.states[s], a, Mean())) for a in A]
        exp_r, best_a = findmax(possible_rewards)
        return A[best_a]
    end
end

function learn!(p::LearningGuessWhoPolicy, r, states, actions)
    for (s, a) in Iterators.zip(states, actions)
        if !haskey(p.states, s)
            p.states[s] = Dict()
        end
        if !haskey(p.states[s], a)
            p.states[s][a] = Mean()
        end
        fit!(p.states[s][a], max(r, 0.0))
    end
end

end  # module

### Run an Episode

At this point, game definition and play has been implemented. It's time to test the code by running a single episode.

In [5]:
import Reinforce: action, Episode, run_episode
using  Main.GuessWho
using  Main.RandomGuessWho
using  Main.LearningGuessWho

In [6]:
env = GuessWhoEnvironment(4, RandomGuessWhoPolicy())  # play against a random opponent
p = LearningGuessWhoPolicy(0.1)  # create a learning policy (with ε = 0.1)
ep = Episode(env, p)  # construct the episode

Episode{GuessWhoEnvironment,LearningGuessWhoPolicy,Float64}(GuessWhoEnvironment(4, GuessWhoState(4, 4), 0, RandomGuessWhoPolicy()), LearningGuessWhoPolicy(0.1, Dict{Any,Any}()), 0.0, 0.0, 1, 1, 5)

In [7]:
# Store states and actions for learning purposes
state_array = GuessWhoState[]
action_array = Int[]

# Iterate through the episode
# See: https://github.com/JuliaML/Reinforce.jl#episode-iterator
for (s, a, r, s′) in ep
    push!(state_array, s)
    println("Initial state: $s")
    push!(action_array, a)
    println("Action taken : $a")
    println("Result       : $r")
    println("Final state  : $s′")
    println()
end
println("Performed $(ep.niter) iterations with a result of $(ep.total_reward)");

Initial state: GuessWhoState(4, 4)
Action taken : 2
Result       : 0
Final state  : GuessWhoState(2, 3)

Initial state: GuessWhoState(2, 3)
Action taken : 1
Result       : 0
Final state  : GuessWhoState(1, 2)

Initial state: GuessWhoState(1, 2)
Action taken : 0
Result       : 1
Final state  : GuessWhoState(0, 0)

Performed 3 iterations with a result of 1.0


A result (`ep.total_reward`) of -1.0 indicates that the player with policy `p` lost the game. This value&mdash;in combination with `state_array` and `action_array`&mdash;can be applied to `p` to improve future actions.

In [8]:
# learn from the episode
learn!(p, ep.total_reward, state_array, action_array)
p.states[GuessWhoState(4, 4)]

Dict{Any,Any} with 1 entry:
  2 => Mean: n=1 | value=1.0

### Train Policy with Many Episodes

In [9]:
"""
Trains `p` by running `num_episodes`, learning after each episode
"""
function train!(p::LearningGuessWhoPolicy, env::GuessWhoEnvironment, num_episodes::Int)
    for i in 1:num_episodes
        state_array = GuessWhoState[]
        action_array = Int[]
        R = run_episode(env, p) do (s, a, r, s′)
            # called after each step
            push!(state_array, s)
            push!(action_array, a)
        end
        learn!(p, R, state_array, action_array)
    end
end

train!

In [10]:
"""
Conducts training of `num_policies` ε-greedy policies with each policy playing
against the previous policy (except the first, which plays against a random 
policy)
"""
function training(numchars::Int, num_policies::Int, policy_episodes::Int, ε::Float64)
    opponent = RandomGuessWhoPolicy()
    policies = [LearningGuessWhoPolicy(ε) for i in 1:num_policies]
    for i in 1:num_policies
        println("Training policy $i")
        if i == 1
            opponent = RandomGuessWhoPolicy()
        else
            opponent = LearningGuessWhoPolicy(0.0, policies[i-1].states)
        end
        env = GuessWhoEnvironment(numchars, opponent)
        train!(policies[i], env, policy_episodes)
    end
    return policies
end

training

In [11]:
# Train 10 policies, each with a million iterations to learn
policies = training(4, 10, 1_000_000, 0.15);

Training policy 1
Training policy 2
Training policy 3
Training policy 4
Training policy 5
Training policy 6
Training policy 7
Training policy 8
Training policy 9
Training policy 10


In [12]:
# create an optimal policy to infer results of optimal play
policies[end].states[GuessWhoState(4, 4)]

Dict{Any,Any} with 3 entries:
  0 => Mean: n=49660 | value=0.25439
  2 => Mean: n=56306 | value=0.618993
  1 => Mean: n=894034 | value=0.657091

The first player wins approximately 66% of the time.

In [13]:
let n = 24
    policies = training(n, 10, 1_000_000, 0.15);
    policies[end].states[GuessWhoState(n, n)]
end

Training policy 1
Training policy 2
Training policy 3
Training policy 4
Training policy 5
Training policy 6
Training policy 7
Training policy 8
Training policy 9
Training policy 10


Dict{Any,Any} with 13 entries:
  2  => Mean: n=11687 | value=0.530675
  11 => Mean: n=11716 | value=0.589109
  0  => Mean: n=11642 | value=0.0365058
  7  => Mean: n=11616 | value=0.58032
  9  => Mean: n=11648 | value=0.586453
  10 => Mean: n=860800 | value=0.617142
  8  => Mean: n=11494 | value=0.591787
  6  => Mean: n=11542 | value=0.57737
  4  => Mean: n=11698 | value=0.556334
  3  => Mean: n=11444 | value=0.539759
  5  => Mean: n=11465 | value=0.564501
  12 => Mean: n=11486 | value=0.612833
  1  => Mean: n=11762 | value=0.508757

In [14]:
let n = 14
    policies = training(n, 10, 1_000_000, 0.15);
    policies[end].states[GuessWhoState(n, n)]
end

Training policy 1
Training policy 2
Training policy 3
Training policy 4
Training policy 5
Training policy 6
Training policy 7
Training policy 8
Training policy 9
Training policy 10


Dict{Any,Any} with 8 entries:
  0 => Mean: n=19079 | value=0.0719639
  7 => Mean: n=788658 | value=0.603865
  4 => Mean: n=18710 | value=0.580866
  2 => Mean: n=18569 | value=0.546664
  3 => Mean: n=98684 | value=0.561773
  5 => Mean: n=18812 | value=0.592707
  6 => Mean: n=18612 | value=0.59381
  1 => Mean: n=18876 | value=0.513774

With 24 characters, the first player wins about 62% of the time; with 14 players, 60% of the time.

In [15]:
println("Exact answers: $(9/16), $(5/9), $(55/98)")

Exact answers: 0.5625, 0.5555555555555556, 0.5612244897959183


The RL results could probably get closer to the optimal results with parameter tuning and additional iterations.