In [1]:
import notebook_importer
from cfr import CFRTree, SolveWithSampleCFR
import capnp
from functools import reduce

importing Jupyter notebook from cfr.ipynb
importing Jupyter notebook from trees.ipynb


In [None]:
capnp.remove_import_hook()
kuhn_schema = capnp.load("kuhn_instances/game.capnp")

In [None]:
f = open("kuhn_instances/kuhn_rank3.bin", "rb")
game = kuhn_schema.Game.read_packed(f)
f.close()

In [None]:
class KuhnNode:
    def __init__(self, id, player):
        self.id = id
        self.parent = None
        self.children = []
        self.information_set = -1
        self.incoming_action = -1
        self.player = player
        self.utility = None
        
    def addChild(self, node):
        node.incoming_action = len(self.children)
        self.children.append(node)
        node.parent = self
        
    def isLeaf(self):
        return len(self.children) == 0
    
    def getSequence(self, player):
        if(self.parent == None):
            return {}
        if(self.parent.player != player):
            return self.parent.getSequence(player)
        
        sequence = self.parent.getSequence(player) 
        sequence[self.parent.information_set] = self.incoming_action
        return sequence

In [None]:
class KuhnTree:
    def __init__(self, root):
        self.root = root
        self.numOfPlayers = 1

In [None]:
def build_tree_from_treeplex(treeplex, player):
    infosets = treeplex.infosets
    max_sequence_id = max(infosets, key = lambda el: el.parentSequenceId).parentSequenceId
    nodes = [KuhnNode(id, player) for id in range(max_sequence_id+1)]
    
    iset_id = 0
    for iset in infosets:
        parent = iset.parentSequenceId
        for child in range(iset.startSequenceId, iset.endSequenceId+1):
            nodes[child].information_set = iset_id
            nodes[parent].addChild(nodes[child])
        iset_id += 1
    
    root = list(filter(lambda el: el.parent == None, nodes))[0]
    root.information_set = iset_id
    
    return KuhnTree(root)

In [None]:
def build_payoff_dict(payoffMatrix):
    payoffDict = {}
    
    for e in payoffMatrix.entries:
        payoffDict[(e.sequencePl1, e.sequencePl2, e.sequencePl3)] = \
                ((e.payoffPl1, e.payoffPl2, e.payoffPl3), e.chanceFactor)
        
    return payoffDict

In [None]:
def sampleTreeplexCFR(node, player, pi, action_plan):
    n_players = len(pi)
    node.visits += reduce(lambda x, y: x * y, pi, 1)
    
    if(node.isLeaf()):
        return 42 # TODO: fetch real utility
    
    iset = node.information_set
    v = 0
    v_alt = [0 for a in node.children]
    
    sampled_action = action_plan[iset.id]
    
    if(max(pi) == 0):
        return sampleTreeplexCFR(node.children[sampled_action], player, pi, action_plan)
    
    for a in range(len(node.children)):
        if(a == sampled_action):
            v_alt[a] = sampleTreeplexCFR(node.children[a], player, pi, action_plan)
        else:
            old_pi = pi[iset.player]
            pi[iset.player] = 0
            v_alt[a] = sampleTreeplexCFR(node.children[a], player, pi, action_plan)
            pi[iset.player] = old_pi
        
    v = v_alt[sampled_action]
    
    if(iset.player == player):
        pi_other = 1
        for i in range(len(pi)):
            if(i != player):
                pi_other *= pi[i]

        for a in range(len(node.children)):                        
            ##### CFR+ #####
            iset.cumulative_regret[a] = max(iset.cumulative_regret[a] + pi_other * (v_alt[a] - v), 0)
            
            ##### This is useless for NFCCE #####
            iset.cumulative_strategy[a] += pi[player] * iset.current_strategy[a]
    
    return v

In [None]:
cfr_tree = CFRTree(build_tree_from_treeplex(game.treeplexPl1, 1))

In [None]:
sampleTreeplexCFR(cfr_tree.root, 1, [1,1,1], cfr_tree.sampleActionPlan())

In [None]:
cfr_tree.sampleActionPlan()

In [None]:
print(cfr_tree.information_sets[0].sampleAction())

In [None]:
cfr_tree.information_sets[3].nodes[1].children

In [None]:
list(game.treeplexPl1.infosets)

In [None]:
list(game.treeplexPl2.infosets)

In [None]:
list(game.treeplexPl3.infosets)

In [None]:
len(game.payoffMatrix.entries)

In [None]:
list(game.payoffMatrix.entries)

In [None]:
list(filter(lambda el: el.sequencePl1 == 14, game.payoffMatrix.entries))