https://arxiv.org/pdf/1303.4441.pdf  

In [374]:
import pyspiel
import numpy as np
from collections import defaultdict
from open_spiel.python.policy import TabularPolicy
from open_spiel.python.algorithms import cfr
from open_spiel.python.algorithms import exploitability
import random
from open_spiel.python.algorithms import best_response
import copy
from open_spiel.python.algorithms import expected_game_score

In [472]:
game = pyspiel.load_game("kuhn_poker")
TRUNK_DEPTH = 3 # the policy for any states with depth < TRUNK_DEPTH are in the trunk.
# Anything with depth (move_number) >= TRUNK_DEPTH is considered in subgames
# for kuhn poker, trunk_depth = 3 means the first betting action is in the trunk, and the response is a subgame.

#### Re-Solving Strategies in Subgames
solve the game, discard the "subgame strategies" (i.e. discard the strategies that aren't in the trunk), and then re-solve the subgames.  Compare with unsafe resolving.

In [473]:
# generate an approximate strategy by running cfr
cfr_solver = cfr.CFRSolver(game)
ITS = 500
for i in range(ITS):
    cfr_solver.evaluate_and_update_policy()

conv = exploitability.exploitability(game, cfr_solver.average_policy())
print("After {} iterations, exploitability: {}".format(i+1, conv))

After 500 iterations, exploitability: 0.001168582440478766


In [475]:
class AugmentedSubgame():
    def __init__(self, actual_game, root, chance_outcomes):
        self.game = actual_game
        self.root = root
    def __getattr__(self, attr):
        # hacky hacky hacky hacky
        assert attr != 'new_initial_state'
        return self.game.__getattribute__(attr)
    def get_type(self):
        overrides = {
            'provides_information_state_tensor': False,
            'provides_observation_tensor': False,
        }
        all_attrs = ["short_name","long_name","dynamics","chance_mode","information","utility","reward_model","max_num_players","min_num_players","provides_information_state_string","provides_information_state_tensor","provides_observation_string","provides_observation_tensor","parameter_specification"]
        kwargs = {k: (self.game.get_type().__getattribute__(k) if k not in overrides else overrides[k]) for k in all_attrs}
        return pyspiel.GameType(**kwargs)
    def new_initial_state(self):
        return self.root
    def get(self):
        return False

glob_id = 0
class DummyState:
    def __init__(self, game, children, current_player, payoff=None, probabilities=None):
        self.game = game
        self.cur_player = current_player
        self.children = children
        self.payoff = payoff
        self.probabilities = probabilities
        global glob_id
        self._id = str(glob_id)
        glob_id += 1
    def current_player(self):
        return self.cur_player
    def legal_actions(self, player=None):
        if player is not None and player != self.current_player():
            return []
        return list(range(len(self.children)))
    def legal_actions_mask(self, player):
        if player is not None and player != self.current_player():
            return []
        elif self.is_terminal():
            return []
        else:
            length = self.game.max_chance_outcomes() if self.is_chance_node() else self.game.num_distinct_actions()
            action_mask = [0] * length
            for action in self.legal_actions():
                action_mask[action] = 1
            return action_mask
    def move_number(self):
        # HACK: only using this for my own debug_game() function
        return 0
    def chance_outcomes(self):
#         print(list(zip(self.children, self.probabilities)))
        assert len(self.probabilities) == len(self.children)
        return list(zip(range(len(self.probabilities)), self.probabilities))
    def apply_action(self, action):
        """Applies the specified action to the state."""
        next_state = self.children[action]
        # terrible hack
        self.__class__ = next_state.__class__
        self.__dict__ = next_state.__dict__
    def child(self, action):
        return self.children[action]
    def is_chance_node(self):
        return self.probabilities is not None
    def action_to_string(self, arg0, arg1=None):
        """Action -> string. Args either (player, action) or (action)."""
        player = self.current_player() if arg1 is None else arg0
        action = arg0 if arg1 is None else arg1
        row, col = self._coord(action)
        return "{}({},{})".format("x" if player == 0 else "o", row, col)
    def is_terminal(self):
        return len(self.children) == 0
    def returns(self):
        if self.is_terminal():
            return [self.payoff, -self.payoff]
        return [0.0, 0.0]
    def player_return(self, player):
        return self.returns()[player]
    def information_state_string(self, pl=None):
        if pl is None:
            pl = self.current_player()
        return 'dummy' + ','.join(x.information_state_string(pl) for x in self.children)
    def history_str(self):
        hs = 'dummy' + ','.join(x.history_str() for x in self.children) + self._id
        return hs
    def clone(self):
        return copy.deepcopy(self)
    def is_simultaneous_node(self):
        return False
        
def get_CBV(br, infostate, player):
    # for a state which is not the solving player's decision node
    # TODO: make this a method on BestResponsePolicy and submit pull request?
    value = 0
    print('getting CBV for {} (player {})'.format(infostate, player))
    reach_prob_sum = 0
    for state in infostate_to_states[(player, infostate)]:
        # weight values for how likely it is to reach the state if player plays to get there
        print('{}, CF_P: {}, BRV: {}'.format(state.history_str(), history_str_to_reach_probabilities[state.history_str()][1-player], br.value(state)))
        value += history_str_to_reach_probabilities[state.history_str()][1-player] * br.value(state)
        reach_prob_sum += history_str_to_reach_probabilities[state.history_str()][1-player]
    return value/reach_prob_sum
def make_augmented_subgame_root(solving_player, roots: [(float,pyspiel.State)]):
    """
    roots is a list of tuples: [(probability, root),] where root is a state at the root of the subgame,
    and probability is the probability of getting there.
    """
    other_player = 1-solving_player
    root_parents = []
    for probability, root in roots:
        # todo: can move the BestResponsePolicy construction out of this loop and even out of the function
        br = best_response.BestResponsePolicy(game,
                                              player_id=other_player,
                                              policy=cfr_solver.average_policy(),
                                              )
        br_value = get_CBV(br, root.information_state_string(other_player), other_player)
        print('best response value at {} ({}): {}'.format(root.information_state_string(other_player), root.history_str(), br_value))
        alternative_payoff = DummyState(game, children=[], current_player=pyspiel.PlayerId.TERMINAL, payoff=br_value)
        root_parent = DummyState(game, children=[alternative_payoff, root], current_player=other_player)
        root_parents.append(root_parent)
    augmented_subgame_root = DummyState(game, children=root_parents, current_player=-1, probabilities=[x[0] for x in roots])
    return augmented_subgame_root
    
def crawl_game(state, policy):
    global infostate_to_states
    global history_str_to_reach_probabilities
    infostate_to_states = defaultdict(list)
    history_str_to_reach_probabilities = dict()
    crawl_game_dfs(state, np.array([1, 1]), policy)
    
def crawl_game_dfs(state, reach_probabilities, policy):
    """reach_probabilities = [x,y] means a prob of x of getting here if player TWO plays to get here and y prob if p ONE tries to get here"""
    history_str_to_reach_probabilities[state.history_str()] = reach_probabilities
    if state.is_terminal():
        return
    if state.current_player() >= 0:
        infostate_to_states[(0,state.information_state_string(0))].append(state)
        infostate_to_states[(1,state.information_state_string(1))].append(state)
    legal_actions = state.legal_actions()
    for action in legal_actions:
        new_reach_probabilities = np.array(reach_probabilities)
        if state.is_player_node():
            new_reach_probabilities[state.current_player()] *= policy.action_probabilities(state)[action]
        elif state.is_chance_node():
            chance_outcomes = {a:p for a,p in state.chance_outcomes()}
            new_reach_probabilities = new_reach_probabilities * chance_outcomes[action]
        crawl_game_dfs(state.child(action), new_reach_probabilities, policy)
        
def get_all_equivalent_states(state):
    states = set()
    histories = set()
    to_explore = [state]
    while len(to_explore) > 0:
        ex = to_explore.pop()
        reach_probabilities = history_str_to_reach_probabilities[ex.history_str()]
        states.add((reach_probabilities[state.current_player()], ex))
        histories.add(ex.history_str())
        # union-find preprocess instead of doing this uberslow search here if you want to do this on bigger games
        for infostate in [(player, ex.information_state_string(player)) for player in (0,1)]:
            for s in infostate_to_states[infostate]:
                if s.history_str() not in histories:
                    to_explore.append(s)
    return states

In [359]:
#history_str_to_reach_probabilities

In [360]:
def erase_subgame_policy_recursive(state, policy):
    if state.move_number() >= TRUNK_DEPTH and state.is_player_node():
        x = 1/len(s.legal_actions())
        for action in s.legal_actions():
            policy.policy_for_key(state.information_state_string())[action] = x
    for action in state.legal_actions():
        erase_subgame_policy_recursive(state.child(action), policy)

In [361]:
# debug crawl_game, which should fill up infostate_to_states
r = game.new_initial_state()
crawl_game(r, cfr_solver.average_policy())
#print(infostate_to_states)

In [362]:
# debug erase_subgame_policy_recursive, which should erase non-trunk policies from a policy
# np.set_printoptions(suppress=True)
# trunk_policy = copy.copy(cfr_solver.average_policy())
# print(trunk_policy.action_probability_array)
# erase_subgame_policy_recursive(game.new_initial_state(), trunk_policy)
# print(trunk_policy.action_probability_array)
# print('after erasing, exploitable for:', exploitability.nash_conv(game, trunk_policy)/2)

In [363]:
subgame_root = game.new_initial_state()
subgame_root.apply_action(0)
subgame_root.apply_action(1)
subgame_root.apply_action(0)

In [476]:
def get_state_probability(state):
    """how likely is it to get to this state given that the other player tries to get here"""
    pass
def get_policy_value(g, policy):
    return expected_game_score.policy_value(g.new_initial_state(), [policy] * 2)[0]

subgame_solvers = []
def debug_game(state, policy):
    prefix = ' '*state.move_number()
    if not state.is_chance_node() and not state.is_terminal():
        info = ',info: ' + state.information_state_string()
    else:
        info = ''
    print(prefix,'his:',state.history_str(), info, 'pays out {}'.format(state.player_return(0)) if state.is_terminal() else '')
#     print('state is a chance node:', state.is_chance_node())
    for a in state.legal_actions():
        if state.is_chance_node():
            pass
            print(prefix, state.history_str(), '->', a, ':', {a:p for a,p in state.chance_outcomes()}[a])
        else:
            print(prefix, state.history_str(), '->', a, ':', policy.action_probabilities(state)[a])
        debug_game(state.child(a), policy)

def train_all_subgames_recursive(state, combined_trunk_subgame_policy, seen):
    """this method solves subgames, and writes the subgame policies to combined_trunk_subgame_policy.
    (combined_trunk_subgame_policy starts off as just the trunk policy)"""
    if state.move_number() >= TRUNK_DEPTH and state.is_player_node():
        if state.history_str() in seen:
            return
        print('----the subgame rooted at:', state.history_str(), 'and equivalent states')
        roots = get_all_equivalent_states(state)
        for r in roots:
            seen.add(r[1].history_str())
        augmented_subgame_root = make_augmented_subgame_root(state.current_player(), roots)
        augmented_subgame = AugmentedSubgame(game, augmented_subgame_root, len(roots))
        subgame_solver = cfr.CFRSolver(augmented_subgame)
        ITS = 200
        for i in range(ITS):
            subgame_solver.evaluate_and_update_policy()
        conv = exploitability.nash_conv(augmented_subgame, subgame_solver.average_policy())/2
        debug_game(augmented_subgame_root, subgame_solver.average_policy())
        subgame_solvers.append(subgame_solver)
        print("After {} iterations, exploitability on augmented subgame: {}".format(i+1, conv))
        print("value of subgame: {}".format(get_policy_value(augmented_subgame, subgame_solver.average_policy())))
        for s in subgame_solver.average_policy().states:
            if s.information_state_string() in combined_trunk_subgame_policy.state_lookup:
                for action, value in enumerate(subgame_solver.average_policy().policy_for_key(s.information_state_string())):
                    print('writing to combined strategy')
                    print(' at state {}, action {} with probability {}'.format(s.information_state_string(), action, value))
                    combined_trunk_subgame_policy.policy_for_key(s.information_state_string())[action] = value
    else:
        for action in state.legal_actions():
            train_all_subgames_recursive(state.child(action), combined_trunk_subgame_policy, seen)
trunk_policy = copy.copy(cfr_solver.average_policy())
erase_subgame_policy_recursive(game.new_initial_state(), trunk_policy)
seen = set()
train_all_subgames_recursive(game.new_initial_state(), trunk_policy, seen)

----the subgame rooted at: 0 1 0 and equivalent states
getting CBV for 2p (player 0)
2 0 0, CF_P: 0.16666666666666666, BRV: 1.332527195276242
2 1 0, CF_P: 0.16666666666666666, BRV: 1.007
best response value at 2p (2 0 0): 1.1697635976381209
getting CBV for 0p (player 0)
0 1 0, CF_P: 0.16666666666666666, BRV: -1.0
0 2 0, CF_P: 0.16666666666666666, BRV: -1.0
best response value at 0p (0 2 0): -1.0
getting CBV for 1p (player 0)
1 0 0, CF_P: 0.16666666666666666, BRV: 0.33494560944751584
1 2 0, CF_P: 0.16666666666666666, BRV: -1.0
best response value at 1p (1 0 0): -0.3325271952762421
getting CBV for 1p (player 0)
1 0 0, CF_P: 0.16666666666666666, BRV: 0.33494560944751584
1 2 0, CF_P: 0.16666666666666666, BRV: -1.0
best response value at 1p (1 2 0): -0.3325271952762421
getting CBV for 0p (player 0)
0 1 0, CF_P: 0.16666666666666666, BRV: -1.0
0 2 0, CF_P: 0.16666666666666666, BRV: -1.0
best response value at 0p (0 1 0): -1.0
getting CBV for 2p (player 0)
2 0 0, CF_P: 0.16666666666666666, BRV

In [465]:
weird = subgame_solvers[0]

In [466]:
for i, v in weird._info_state_nodes.items():
    print(i, '-', v)

dummydummy,2p - _InfoStateNode(legal_actions=[0, 1], index_in_tabular_policy=3, cumulative_regret=defaultdict(<class 'float'>, {0: 0.1980797474949333, 1: -0.7833900177494169}), cumulative_policy=defaultdict(<class 'float'>, {0: 1928.4079485933441, 1: 71.59205140665682}))
0p - _InfoStateNode(legal_actions=[0, 1], index_in_tabular_policy=8, cumulative_regret=defaultdict(<class 'float'>, {0: 0.4279504013466878, 1: 0.22399101091802534}), cumulative_policy=defaultdict(<class 'float'>, {0: 1333.7232466304233, 1: 666.276753369541}))
2pb - _InfoStateNode(legal_actions=[0, 1], index_in_tabular_policy=2, cumulative_regret=defaultdict(<class 'float'>, {0: -167.31918834238706, 1: 0.25}), cumulative_policy=defaultdict(<class 'float'>, {0: 0.5, 1: 71.09205140665682}))
dummydummy,0p - _InfoStateNode(legal_actions=[0, 1], index_in_tabular_policy=4, cumulative_regret=defaultdict(<class 'float'>, {0: 0.041666666666666664, 1: -0.041666666666666664}), cumulative_policy=defaultdict(<class 'float'>, {0: 199

In [467]:
trunk_policy.action_probability_array

array([[0.80812725, 0.19187275],
       [0.5       , 0.5       ],
       [0.98501584, 0.01498416],
       [0.85821482, 0.14178518],
       [0.42407323, 0.57592677],
       [0.00698402, 0.99301598],
       [0.998     , 0.002     ],
       [0.66127318, 0.33872682],
       [0.0005    , 0.9995    ],
       [0.0005    , 0.9995    ],
       [0.66686162, 0.33313838],
       [0.9995    , 0.0005    ]])

In [468]:
print(cfr_solver.average_policy().action_probability_array)
print([s.information_state_string() for s in cfr_solver.average_policy().states])

[[0.80812725 0.19187275]
 [0.99938129 0.00061871]
 [0.98501584 0.01498416]
 [0.47621272 0.52378728]
 [0.42407323 0.57592677]
 [0.00117904 0.99882096]
 [0.993      0.007     ]
 [0.66350416 0.33649584]
 [0.002      0.998     ]
 [0.001      0.999     ]
 [0.6674728  0.3325272 ]
 [0.999      0.001     ]]
['0', '0pb', '1', '1pb', '2', '2pb', '1p', '1b', '2p', '2b', '0p', '0b']


In [259]:
p = subgame_solver.average_policy()

In [477]:
print(exploitability.exploitability(game, trunk_policy))
trunk_policy.policy_for_key('0pb')[0] = 1
trunk_policy.policy_for_key('0pb')[1] = 0
print(exploitability.exploitability(game, trunk_policy))
trunk_policy.policy_for_key('1pb')[0] = 0.47621272
trunk_policy.policy_for_key('1pb')[1] = 0.52378728
print(exploitability.exploitability(game, trunk_policy))

0.09867399924458736
0.06500203041284203
0.0030334725301733867


0.09867399924458736

In [56]:
tp = p.to_tabular()

they want to know my history 1
they want to know my history 0


In [57]:
tp.action_probability_array

array([[0.5 , 0.5 ],
       [0.99, 0.01],
       [0.5 , 0.5 ]])

In [59]:
from open_spiel.python import policy

game = pyspiel.load_game("kuhn_poker")
test_policy = policy.UniformRandomPolicy(game)
br = best_response.BestResponsePolicy(game, policy=cfr_solver.average_policy(), player_id=0)
br.value(game.new_initial_state())

-0.0542545325460404

In [162]:
test = AugmentedSubgame(game, game.new_initial_state().child(0))
test.new_initial_state()

<pyspiel.State at 0x7f8b854906f0>

In [72]:
s = game.new_initial_state()
s.apply_action(0)
s.apply_action(1)

In [76]:
s.information_state_string(0), s.information_state_string(1)

('0', '1')

In [77]:
s.apply_action(0)

In [78]:
s.information_state_string(0), s.information_state_string(1)

('0p', '1p')

In [128]:
roots = get_all_equivalent_states(subgame_root)

In [129]:
roots

{<pyspiel.State at 0x7fd82fd50370>,
 <pyspiel.State at 0x7fd82fd7f230>,
 <pyspiel.State at 0x7fd8300e5bf0>,
 <pyspiel.State at 0x7fd82dba1ab0>,
 <pyspiel.State at 0x7fd8300755f0>,
 <pyspiel.State at 0x7fd830075130>}

In [132]:
subgame_root.full_history()

[<pyspiel.PlayerAction at 0x7fd82ff57ef0>,
 <pyspiel.PlayerAction at 0x7fd82fcc3df0>,
 <pyspiel.PlayerAction at 0x7fd82ff52570>]

In [133]:
p.action_probabilities()

TypeError: action_probabilities() missing 1 required positional argument: 'state'

[[0.80812725 0.19187275]
 [0.99938129 0.00061871]
 [0.98501584 0.01498416]
 [0.47621272 0.52378728]
 [0.42407323 0.57592677]
 [0.00117904 0.99882096]
 [0.993      0.007     ]
 [0.66350416 0.33649584]
 [0.002      0.998     ]
 [0.001      0.999     ]
 [0.6674728  0.3325272 ]
 [0.999      0.001     ]]
[[0.80812725 0.19187275]
 [0.5        0.5       ]
 [0.98501584 0.01498416]
 [0.5        0.5       ]
 [0.42407323 0.57592677]
 [0.5        0.5       ]
 [0.5        0.5       ]
 [0.5        0.5       ]
 [0.5        0.5       ]
 [0.5        0.5       ]
 [0.5        0.5       ]
 [0.5        0.5       ]]
after erasing, exploitable for: 0.4200273108087714


In [163]:
[state.information_state_string() for state in fullgamepolicy.states]

['0', '0pb', '1', '1pb', '2', '2pb', '1p', '1b', '2p', '2b', '0p', '0b']

0.4200273108087714

In [171]:
s.child(0).information_state_string()

SpielError: /private/var/folders/b4/fpq296zs5yvglb1sb9f660t80000gn/T/pip-install-lxcuryp8/open-spiel/open_spiel/games/kuhn_poker.cc:111 player >= 0
player = -4, 0 = 0

In [189]:
TRUNK_DEPTH

3

In [242]:
asdf = cfr_solver.average_policy()

In [343]:
aaa = game.new_initial_state()
aaa.apply_action(1)
aaa.apply_action(0)
aaa.apply_action(0)

In [353]:
bbb = game.new_initial_state()
bbb.apply_action(1)
bbb.apply_action(2)
bbb.apply_action(0)

In [244]:
asdf.action_probabilities(aaa)

{0: 0.8081272519618873, 1: 0.19187274803811277}

In [347]:
br = best_response.BestResponsePolicy(game,
                                              player_id=0,
                                              policy=cfr_solver.average_policy(),
                                              )

In [350]:
onep = br.infosets['1p']

KeyError: '1p'

In [299]:
br.value(onep[1][0])

-1.0

In [None]:
sum(cf_p * self.q_value(s, a) for s, cf_p in infoset

In [326]:
list(br.decision_nodes(onep[0][0]))

[(<pyspiel.State at 0x7fd8305c6a70>, 1.0)]

In [348]:
get_CBV(br, '1p', 0)

getting CBV for 1p (player 0)
1 0 0, CF_P: 0.16666666666666666, BRV: 0.33494560944751584
1 2 0, CF_P: 0.16666666666666666, BRV: -1.0


-0.11084239842541402

In [352]:
cfr_solver.average_policy().action_probabilities(aaa)

{0: 0.6674728047237579, 1: 0.3325271952762421}

In [354]:
cfr_solver.average_policy().action_probabilities(bbb)

{0: 0.002, 1: 0.998}

In [355]:
br.value(bbb.child(1))

-1.0

In [356]:
br.value(aaa.child(1))

-1.0

In [393]:
p.action_probabilities(aaa)

KeyError: '0p'