### let's try hand-coding the CFR algorithm for Kuhn Poker in openspiel

In [42]:
import pyspiel
import numpy as np
from collections import defaultdict

In [2]:
game = pyspiel.load_game("kuhn_poker")

In [124]:
from typing import Dict

cumulative_regrets = [defaultdict(lambda: defaultdict(float)), defaultdict(lambda: defaultdict(float))]
# cumulative_regrets[P][I][A] is the cumulative regret, for player P (0 or 1) at infostate I, of not taking action A
strategy_sums = [defaultdict(lambda: defaultdict(float)), defaultdict(lambda: defaultdict(float))]

InfoSet = str

def get_strategy_from_regrets(regrets: Dict[InfoSet, float]):
    positive_regrets = {action: max(value, 0) for action, value in regrets.items()}
    denom = sum(positive_regrets.values())
    if denom <= 0:
        return {action: 1/len(regrets) for action in regrets}
    return {action: value/denom for action, value in positive_regrets.items()}

"""
returns the expected value of the state for player 1 (not counterfactual, so the caller should do the scaling)
p1prob is the probability that we reach this state if *p2* plays to reach the state.
"""
def cfr(state, player, p1prob, p2prob):
    if state.is_terminal():
        return state.returns()[0] # take the 0th item here bc that's the payoff for player 1
    if state.is_chance_node():
        expected_value = 0
        for action, p in state.chance_outcomes():
            ev = cfr(state.child(action), player, p1prob*p, p2prob*p)
            expected_value += p*ev
        return expected_value
    else:
        infostate_str = state.information_state_string()
        legal_actions = state.legal_actions()
        # hacky way to lazily initialize cumulative regrets:
        for action in legal_actions:
            cumulative_regrets[state.current_player()][infostate_str][action] += 0
        
        strategy = get_strategy_from_regrets(cumulative_regrets[state.current_player()][infostate_str])
        # the following two variables store ev for the state.current_player().
        expected_value = 0
        expected_values_per_action = dict()
        for action in legal_actions:
            p = strategy[action]
            if state.current_player() == 0:
                ev = cfr(state.child(action), player, p1prob*p, p2prob)
            else:
                ev = cfr(state.child(action), player, p1prob, p2prob*p) * -1
            expected_values_per_action[action] = ev
            expected_value += ev * p
        if state.current_player() == player:
            if player == 0:
                pi_not_i, pi_i = p2prob, p1prob
            else:
                pi_not_i, pi_i = p1prob, p2prob
            for a, v in expected_values_per_action.items():
                counterfactual_regret = (v - expected_value)*pi_not_i
                strategy_sums[player][state.information_state_string()][a] += p1prob*p2prob * strategy[a]
                cumulative_regrets[player][state.information_state_string()][a] += counterfactual_regret
        # since expected_value stores the ev for the state.current_player(), we invert it before returning
        if state.current_player() == 0:
            return expected_value
        else:
            return expected_value * -1
        
    

In [125]:
ITS = 4000
cumulative_regrets = [defaultdict(lambda: defaultdict(float)), defaultdict(lambda: defaultdict(float))]
strategy_sums = [defaultdict(lambda: defaultdict(float)), defaultdict(lambda: defaultdict(float))]

for it in range(ITS):
    for player in [0,1]:
        state = game.new_initial_state()
        payoff = cfr(state, player, 1, 1)
        if it%500 == 0:
            print(it, payoff)

0 0.3125
0 0.25
500 -0.062502406696407
500 -0.0514975003093206
1000 -0.06118483326617913
1000 -0.05399961160211475
1500 -0.05941286367004783
1500 -0.053638321238782205
2000 -0.058844011113143246
2000 -0.05399614485905907
2500 -0.058567310653070304
2500 -0.05419790241062444
3000 -0.05850255541149996
3000 -0.05449023838493211
3500 -0.05818587349793236
3500 -0.05440124857762285


In [94]:
def get_strategy_from_strategy_sums(strategy_sums):
    denom = sum(strategy_sums.values())
    if denom <= 0:
        return {action: 1/len(strategy_sums) for action in strategy_sums}
    return {action: p/denom for action, p in strategy_sums.items()}
    

def get_payoff_with_final_strat(state):
    if state.is_terminal():
        return state.returns()[0]
    elif state.is_chance_node():
        ev = 0
        for action, p in state.chance_outcomes():
            ev += p*get_payoff_with_final_strat(state.child(action))
        return ev
    else:
        infostate_str = state.information_state_string()
        legal_actions = state.legal_actions()
        print(' '*state.move_number(), state.history_str())
        strategy = get_strategy_from_strategy_sums(strategy_sums[state.current_player()][infostate_str])
        ev = 0
        for action in legal_actions:
            print(' '*state.move_number(),infostate_str,'->',state.action_to_string(action), strategy[action])
            ev += strategy[action] * get_payoff_with_final_strat(state.child(action))
        return ev

In [101]:
state = game.new_initial_state()
ev = get_payoff_with_final_strat(state)
print(ev)

   0 1
   0 -> Pass 0.764724189393865
    0 1 0
    1p -> Pass 0.9996989973206932
    1p -> Bet 0.00030100267930670705
     0 1 0 1
     0pb -> Pass 0.9999725586952606
     0pb -> Bet 2.7441304739366313e-05
   0 -> Bet 0.23527581060613498
    0 1 1
    1b -> Pass 0.6561151815826567
    1b -> Bet 0.3438848184173433
   0 2
   0 -> Pass 0.764724189393865
    0 2 0
    2p -> Pass 0.0001196960864903472
    2p -> Bet 0.9998803039135097
     0 2 0 1
     0pb -> Pass 0.9999725586952606
     0pb -> Bet 2.7441304739366313e-05
   0 -> Bet 0.23527581060613498
    0 2 1
    2b -> Pass 0.0003615765470801664
    2b -> Bet 0.9996384234529199
   1 0
   1 -> Pass 0.9988066610907277
    1 0 0
    0p -> Pass 0.66381571464462
    0p -> Bet 0.33618428535538003
     1 0 0 1
     1pb -> Pass 0.4286631829955627
     1pb -> Bet 0.5713368170044373
   1 -> Bet 0.0011933389092724158
    1 0 1
    0b -> Pass 0.9998798988746334
    0b -> Bet 0.00012010112536670338
   1 2
   1 -> Pass 0.9988066610907277
    1 2 0
   

The resulting EV looks correct (https://en.wikipedia.org/wiki/Kuhn_poker): "the first player should expect to lose at a rate of −1/18 per hand"

In [96]:
1/18

0.05555555555555555

Ok that looks correct.

Player one should bet 3x more likely with a K (2) than with a J (0): https://upload.wikimedia.org/wikipedia/commons/a/a9/Kuhn_poker_tree.svg

```
0 -> Bet 0.23675878818473314
...
2 -> Bet 0.6698347958294653
```

In [98]:
0.237 * 3

0.711

Hm.... that doesn't look exactly correct.  [edit: Re-ran with 5x more iterations (6000 iterations) and got closer values (0.235 and 0.688), so I'm satisfied]

More checking:  
Player one should check 100% of the time with a Q (1), which it does: `1 -> Pass 0.9928399665443669`  
Player two should fold 100% of the time with a J when bet to, which it does: `0b -> Pass 0.9992716750496424`  
And player two should bet 33% of the time with a J when checked to, which it does: `0p -> Pass 0.6608318109822778`

Alright now let's try using openspiel's CFR solver: [cfr_example.py](https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/examples/cfr_example.py)


In [123]:
from open_spiel.python.algorithms import cfr
from open_spiel.python.algorithms import exploitability

cfr_solver = cfr.CFRSolver(game)
ITS = 4000
for i in range(ITS):
    cfr_solver.evaluate_and_update_policy()
    if i % 500 == 0:
        conv = exploitability.exploitability(game, cfr_solver.average_policy())
        print("Iteration {} exploitability {}".format(i, conv))

Iteration 0 exploitability 0.45833333333333326
Iteration 500 exploitability 0.0012031926490508604
Iteration 1000 exploitability 0.0009701106073763677
Iteration 1500 exploitability 0.0004689275667045245
Iteration 2000 exploitability 0.0005293432339987247
Iteration 2500 exploitability 0.0004632766227801177
Iteration 3000 exploitability 0.0004150519969013944
Iteration 3500 exploitability 0.000248111630169745


## Comparing performance with openspiel's CFR implementation
Aight, now let's try to compare the exploitability of the resulting policy from my hand-coded CFR.

In [128]:
from open_spiel.python.policy import TabularPolicy
# make a TabularPolicy object and fill it with the strategy from strategy_sums
# so we can pass it in to exploitability()
def get_openspiel_tabular_policy(strategy_sums):
    my_policy = TabularPolicy(game)
    for player in [0,1]:
        for info_state in strategy_sums[player]:
            strategy = get_strategy_from_strategy_sums(strategy_sums[player][info_state])
            for action, p in strategy.items():
                my_policy.policy_for_key(info_state)[action] = p
    return my_policy

In [129]:
my_policy = get_openspiel_tabular_policy(strategy_sums)
conv = exploitability.exploitability(game, my_policy)
print("my exploitability: {}".format(conv))

my exploitability: 0.0035180604381704894


hmmmmmm mine looks like it's as good after 6000 iterations as openspiel's is after 500 iterations. Let's look at the learning trajectory

In [132]:
ITS = 6000
cumulative_regrets = [defaultdict(lambda: defaultdict(float)), defaultdict(lambda: defaultdict(float))]
strategy_sums = [defaultdict(lambda: defaultdict(float)), defaultdict(lambda: defaultdict(float))]

for it in range(ITS):
    for player in [0,1]:
        state = game.new_initial_state()
        payoff = cfr(state, player, 1, 1)
        if it%500 == 0:
            policy = get_openspiel_tabular_policy(strategy_sums)
            conv = exploitability.exploitability(game, policy)
            print(it, conv)

0 0.4375
0 0.27083333333333337
500 0.00923797545498345
500 0.009225375133536107
1000 0.006859414338126396
1000 0.006853907955201438
1500 0.005748481871713701
1500 0.005750853185816807
2000 0.004997746844913026
2000 0.0049969069600468985
2500 0.00446675227661858
2500 0.004467282753330332
3000 0.004052222222096835
3000 0.004050880850328481
3500 0.0037376126388465047
3500 0.003737465657473571
4000 0.0035178118779868317
4000 0.003518280835981019
4500 0.0033227965633366163
4500 0.0033231556879633473
5000 0.003152277645402768
5000 0.003152170973910906
5500 0.003011400561846067
5500 0.0030109213849761063
