https://fivethirtyeight.com/features/how-good-are-you-at-guess-who/

Rule Summary:
- The board has `N` characters.
- Players choose a character as their own.
- Players alternate turns.
- At each turn, a player may:
  - Make a specific guess about the opponent's character (correct = win, incorrect = lose).
  - Ask a yes/no question to reduce the possible characters.

In [1]:
import Pkg; Pkg.activate("..")

[32m[1mActivating[22m[39m environment at `~/vcs/notebooks/Project.toml`


In [2]:
"""
Defines the environment, policy, and state for reinforcement learning.
"""
module GuessWho

import Reinforce: 
    AbstractEnvironment, AbstractPolicy, action, actions, finished, maxsteps, state, step!, reset!, reward  # action

export GuessWhoEnvironment, GuessWhoPolicy, GuessWhoState

struct GuessWhoState
    ownchars::Int
    oppchars::Int
end

GuessWhoState(n) = GuessWhoState(n, n)

flipstate(s::GuessWhoState) = GuessWhoState(s.oppchars, s.ownchars)

abstract type GuessWhoPolicy <: AbstractPolicy end

mutable struct GuessWhoEnvironment <: AbstractEnvironment
    n::Int  # number of characters
    s::GuessWhoState  # state
    r::Int  # reward
    o::GuessWhoPolicy  # opponent policy
    GuessWhoEnvironment(n, o) = new(n, GuessWhoState(n), 0, o)
end

reset!(env::GuessWhoEnvironment) = (env.s = GuessWhoState(env.n); env.r = 0; env)

function makestep(s::GuessWhoState, a::Int)
    if a == 0
        # Guess and end game
        r = rand() < (1 / s.ownchars) ? 1 : -1
        s = GuessWhoState(0, 0)
    else
        r = 0
        remaining = rand() < (a / s.ownchars) ? a : (s.ownchars - a)
        s = GuessWhoState(remaining, s.oppchars)
    end
    return r, s
end

function step!(env::GuessWhoEnvironment, s::GuessWhoState, a::Int)
    r, s = makestep(s, a)
    if !finished(env, s)
        # take random step as opponent
        os = flipstate(s)
        A = actions(env, os)
        oa = action(env.o, 0, os, A)
        or, os = makestep(os, oa)
        r = -1 * or
        s = flipstate(os)
    end
    env.r = r
    env.s = s
    return (env.r, env.s)
end
state(env::GuessWhoEnvironment) = env.s
reward(env::GuessWhoEnvironment) = env.r
maxsteps(env::GuessWhoEnvironment) = env.n + 1
actions(env::GuessWhoEnvironment, s::GuessWhoState) = collect(0:(s.ownchars ÷ 2))
finished(env::GuessWhoEnvironment, s′::GuessWhoState) = (s′.ownchars == 0 || s′.oppchars == 0)

end  # module

Main.GuessWho

In [3]:
"""
Defines a GuessWho policy of playing randomly
"""
module RandomGuessWho

import Reinforce: action
import Main.GuessWho: GuessWhoPolicy, GuessWhoState

export RandomGuessWhoPolicy

# define random policy
struct RandomGuessWhoPolicy <: GuessWhoPolicy end
action(p::RandomGuessWhoPolicy, r::Int, s::GuessWhoState, A) = rand(A)

end # module

Main.RandomGuessWho

In [4]:
"""
Defines a policy that can learn from prior game play
"""
module LearningGuessWho

import OnlineStats: fit!, Mean, value
import Reinforce: action
import Main.GuessWho: GuessWhoPolicy, GuessWhoState

export LearningGuessWhoPolicy, learn!

# define learning policy
struct LearningGuessWhoPolicy <: GuessWhoPolicy
    ε::Float64
    states::Dict
end

LearningGuessWhoPolicy(ε) = LearningGuessWhoPolicy(ε, Dict())
LearningGuessWhoPolicy(p::LearningGuessWhoPolicy, ε) = LearningGuessWhoPolicy(ε, p.states)

function action(p::LearningGuessWhoPolicy, r::Int, s::GuessWhoState, A)
    if rand() < p.ε || !haskey(p.states, s)
        return rand(A)
    else
        best_a = 0
        exp_r = 0.0
        for (a, r) in pairs(p.states[s])
            if value(r) >= exp_r
                best_a = a
                exp_r = value(r)
            end
        end
        return best_a
    end
end

function learn!(p::LearningGuessWhoPolicy, r, states, actions)
    for (s, a) in Iterators.zip(states, actions)
        if !haskey(p.states, s)
            p.states[s] = Dict()
        end
        if !haskey(p.states[s], a)
            p.states[s][a] = Mean()
        end
        fit!(p.states[s][a], max(r, 0.0))
    end
end

end  # module

Main.LearningGuessWho

### Run Episode

In [5]:
import Reinforce: action, Episode, run_episode
using  Main.GuessWho
using  Main.RandomGuessWho
using  Main.LearningGuessWho

In [6]:
env = GuessWhoEnvironment(4, RandomGuessWhoPolicy())
p = LearningGuessWhoPolicy(0.1)
ep = Episode(env, p)

Episode{GuessWhoEnvironment,LearningGuessWhoPolicy,Float64}(GuessWhoEnvironment(4, GuessWhoState(4, 4), 0, RandomGuessWhoPolicy()), LearningGuessWhoPolicy(0.1, Dict{Any,Any}()), 0.0, 0.0, 1, 1, 5)

In [7]:
state_array = GuessWhoState[]
action_array = Int[]
for (s, a, r, s′) in ep
    # do some custom processing of the sars-tuple
    push!(state_array, s)
    println("Initial state: $s")
    push!(action_array, a)
    println("Action taken : $a")
    println("Result       : $r")
    println("Final state  : $s′")
    println()
end
println("Performed $(ep.niter) iterations with a result of $(ep.total_reward)");

Initial state: GuessWhoState(4, 4)
Action taken : 2
Result       : -1
Final state  : GuessWhoState(0, 0)

Performed 1 iterations with a result of -1.0


In [8]:
# learn from the episode
learn!(p, ep.total_reward, state_array, action_array)

### Train Policy

In [9]:
function train!(p::LearningGuessWhoPolicy, env::GuessWhoEnvironment, num_episodes::Int)
    for i in 1:num_episodes
        state_array = GuessWhoState[]
        action_array = Int[]
        R = run_episode(env, p) do (s, a, r, s′)
            # called after each step
            push!(state_array, s)
            push!(action_array, a)
        end
        learn!(p, R, state_array, action_array)
    end
end

train! (generic function with 1 method)

In [17]:
function training(num_policies::Int, policy_episodes::Int, ε::Float64)
    opponent = RandomGuessWhoPolicy()
    policies = [LearningGuessWhoPolicy(ε) for i in 1:num_policies]
    for i in 1:num_policies
        println("Training policy $i")
        opponent = (i == 1 ? RandomGuessWhoPolicy() : policies[i-1])
        env = GuessWhoEnvironment(4, opponent)
        train!(policies[i], env, policy_episodes)
    end
    return policies
end

training (generic function with 1 method)

In [23]:
policies = training(10, 100_000, 0.1);

Training policy 1
Training policy 2
Training policy 3
Training policy 4
Training policy 5
Training policy 6
Training policy 7
Training policy 8
Training policy 9
Training policy 10


In [24]:
p = policies[end];

In [26]:
p.states[GuessWhoState(4, 4)]

Dict{Any,Any} with 3 entries:
  0 => Mean: n=3354 | value=0.238819
  2 => Mean: n=3368 | value=0.637173
  1 => Mean: n=93278 | value=0.660027

In [29]:
p.states[GuessWhoState(2, 1)]  # always guess if your opponent will win on the next turn

Dict{Any,Any} with 2 entries:
  0 => Mean: n=17231 | value=0.491382
  1 => Mean: n=859 | value=0.0

In [30]:
p.states[GuessWhoState(2, 2)]

Dict{Any,Any} with 2 entries:
  0 => Mean: n=2422 | value=0.50578
  1 => Mean: n=5667 | value=0.526204