In [37]:
import random
random.seed(1)

In [38]:
PASS = 0; BET = 1; NUM_ACTIONS = 2


In [81]:
class Node():
    def __init__(self):
        self.infoSet = ''
        self.regretSum = [0] * NUM_ACTIONS
        self.strategy = [0] * NUM_ACTIONS
        self.strategySum = [0] * NUM_ACTIONS
        
    def getStrategy(self, rweight):
        normalizingSum = 0
        for a in range(NUM_ACTIONS):
            if self.regretSum[a] > 0:
                self.strategy[a] = self.regretSum[a]
            else:
                self.strategy[a] = 0
            normalizingSum += self.strategy[a]
        for a in range(NUM_ACTIONS):
            if normalizingSum > 0:
                self.strategy[a] /= normalizingSum
            else:
                self.strategy[a] = 1/NUM_ACTIONS
            self.strategySum[a] += rweight * self.strategy[a]
        return self.strategy
    
    def getAverageStrategy(self):
        averageStrat = [0] * NUM_ACTIONS
        normalizingSum = 0
        for a in range(NUM_ACTIONS):
            normalizingSum += self.strategySum[a]
        for a in range(NUM_ACTIONS):
            if normalizingSum > 0:
                averageStrat[a] = self.strategySum[a] / normalizingSum
            else:
                averageStrat[a] = 1/NUM_ACTIONS
        return averageStrat
    
    
    def __str__(self):
        return self.infoSet + str(self.getAverageStrategy())
    
    
class KuhnTrainer():
    
    
    def __init__(self):
        self.nodeMap = dict()
    
    def train(self, iters):
        cards = [1,2,3]
        utility = 0
        for i in range(iters):
            for j in reversed(range(1, len(cards))):
                rand = random.randrange(0, j+1)
                cards[j], cards[rand] = cards[rand], cards[j]
            utility += self.cfr(cards, "", 1, 1)
        print("Average game value: " + str(utility / iters))
        for nodeval in self.nodeMap.values():
            print(nodeval)
            
    def cfr(self, cards, history, p0, p1):
        plays = len(history)
        player = plays % 2
        opponent = 1 - player
        if plays > 1:
            terminalP = history[plays-1] == 'p'
            doubleBet = history[plays-2:plays] == 'bb'
            isHigher = cards[player] > cards[opponent]
            if terminalP:
                if history == 'pp':
                    if isHigher:
                        return 1
                    else:
                        return -1
                else:
                    return 1
            elif doubleBet:
                if isHigher:
                    return 2
                else:
                    return -2
        infoSet = str(cards[player]) + history
        if infoSet not in self.nodeMap:
            node = Node()
            node.infoSet = infoSet
            self.nodeMap[infoSet] = node
        else:
            node = self.nodeMap[infoSet]
        if player == 0:
            strategy = node.getStrategy(p0)
        else:
            strategy = node.getStrategy(p1)
        utility = [0] * NUM_ACTIONS
        nodeUtil = 0
        for a in range(NUM_ACTIONS):
            if a == 0:
                nextHist = history + 'p'
            else:
                nextHist = history + 'b'
            if player == 0:
                utility[a] = -self.cfr(cards, nextHist, p0 * strategy[a], p1)
            else:
                utility[a] = -self.cfr(cards, nextHist, p0, p1 * strategy[a])
            nodeUtil += strategy[a] * utility[a]
        for a in range(NUM_ACTIONS):
            regret = utility[a] - nodeUtil
            if player == 0:
                node.regretSum[a] += p1 * regret
            else:
                node.regretSum[a] += p0 * regret
        return nodeUtil
            
    
    

In [82]:
a = KuhnTrainer().train(100000)

Average game value: -0.04251207286041456
1[0.709276379186929, 0.29072362081307107]
2p[0.9997873388252149, 0.0002126611747851003]
1pb[0.9999893374501657, 1.0662549834323998e-05]
2b[0.6589651735325275, 0.34103482646747246]
2[0.999838581931507, 0.00016141806849291485]
1p[0.6614967609189046, 0.33850323908109525]
2pb[0.38085189467941355, 0.6191481053205864]
1b[0.9999851539535022, 1.4846046497817632e-05]
3[0.12475552623898861, 0.8752444737610114]
3pb[5.9366589998446434e-05, 0.9999406334100015]
3p[0.00012188804582990523, 0.9998781119541701]
3b[1.5236005728738154e-05, 0.9999847639942713]


In [83]:
print(a)

None
