In [None]:
import numpy as np
from random import shuffle
import time
import sys


class PokerB:

    def __init__(self):
        self.nodeMap = {}
        self.expected_game_value = 0
        self.n_cards = 10
        self.nash_equilibrium = dict()
        self.current_player = 0
        self.deck = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
        self.n_actions = 2
    
    def randomHand(self):
        for i in range(1,2):
            r = np.random(2598960)
            if r<1302540:
                  self.deck[i]=0
            if 1302540<r<1302540+1098240:
                self.deck[i]=1
            if 1302540+1098240<r<1302540+1098240+123552:
                self.deck[i]=2
            if 1302540+1098240+123552+54912<r<1302540+1098240+123552+54912+10200:
                self.deck[i]=3
            if 1302540+1098240+123552+54912+10200<r<1302540+1098240+123552+54912+10200+5108:
                self.deck[i]=4
            if 1302540+1098240+123552+54912+10200+5108<r<1302540+1098240+123552+54912+10200+5108+3744:
                self.deck[i]=5
            if 1302540+1098240+123552+54912+10200+5108+3744<r<1302540+1098240+123552+54912+10200+5108+3744+624:
                self.deck[i]=6
            if 1302540+1098240+123552+54912+10200+5108+3744+624<r<1302540+1098240+123552+54912+10200+5108+3744+624+36:
                self.deck[i]=7
            if 1302540+1098240+123552+54912+10200+5108+3744+624+36<r<1302540+1098240+123552+54912+10200+5108+3744+624+36+4:
                self.deck[i]=8
            if r>1302540+1098240+123552+54912+10200+5108+3744+624+36+4:
                self.deck[i]=9
        
        
    def train(self, n_iterations=50000):
        expected_game_value = 0
        for _ in range(n_iterations):
            shuffle(self.deck)
            expected_game_value += self.cfr('', 1, 1)
            for _, v in self.nodeMap.items():
                v.update_strategy()

        expected_game_value /= n_iterations
        display_results(expected_game_value, self.nodeMap)

    def cfr(self, history, pr_1, pr_2):
        n = len(history)
        is_player_1 = n % 2 == 0
        player_card = self.deck[0] if is_player_1 else self.deck[1]  

        if self.is_terminal(history):
            card_player = self.deck[0] if is_player_1 else self.deck[1]
            card_opponent = self.deck[1] if is_player_1 else self.deck[0]
            reward = self.get_reward(history, card_player, card_opponent)
            return reward

        node = self.get_node(player_card, history)
        strategy = node.strategy

        # Counterfactual utility per action.
        action_utils = np.zeros(self.n_actions)

        for act in range(self.n_actions):
            next_history = history + node.action_dict[act]
            if is_player_1:
                action_utils[act] = -1 * self.cfr(next_history, pr_1 * strategy[act], pr_2)
            else:
                action_utils[act] = -1 * self.cfr(next_history, pr_1, pr_2 * strategy[act])

        # Utility of information set.
        util = sum(action_utils * strategy)
        regrets = action_utils - util
        if is_player_1:
            node.reach_pr += pr_1
            node.regret_sum += pr_2 * regrets
        else:
            node.reach_pr += pr_2
            node.regret_sum += pr_1 * regrets

        return util

    @staticmethod
    def is_terminal(history):
        if history[-2:] == 'pp' or history[-2:] == "bb" or history[-2:] == 'bp':
            return True

    @staticmethod
    def get_reward(history, player_card, opponent_card):
        terminal_pass = history[-1] == 'p'
        double_bet = history[-2:] == "bb"
        if terminal_pass:
            if history[-2:] == 'pp':
                if player_card > opponent_card:
                    return 1
                if player_card==opponent_card:
                    return (0)
                else :
                    return (-1)
            else:
                return 1
        elif double_bet:
            if player_card > opponent_card:
                return 2
            if player_card==opponent_card:
                return (0)
            else :
                return (-2)


    def get_node(self, card, history):
        key = str(card) + " " + history
        if key not in self.nodeMap:
            action_dict = {0: 'p', 1: 'b'}
            info_set = Node(key, action_dict)
            self.nodeMap[key] = info_set
            return info_set
        return self.nodeMap[key]


class Node:
    def __init__(self, key, action_dict, n_actions=2):
        self.key = key
        self.n_actions = n_actions
        self.regret_sum = np.zeros(self.n_actions)
        self.strategy_sum = np.zeros(self.n_actions)
        self.action_dict = action_dict
        self.strategy = np.repeat(1/self.n_actions, self.n_actions)
        self.reach_pr = 0
        self.reach_pr_sum = 0

    def update_strategy(self):
        self.strategy_sum += self.reach_pr * self.strategy
        self.reach_pr_sum += self.reach_pr
        self.strategy = self.get_strategy()
        self.reach_pr = 0

    def get_strategy(self):
        regrets = self.regret_sum
        regrets[regrets < 0] = 0
        normalizing_sum = sum(regrets)
        if normalizing_sum > 0:
            return regrets / normalizing_sum
        else:
            return np.repeat(1/self.n_actions, self.n_actions)

    def get_average_strategy(self):
        strategy = self.strategy_sum / self.reach_pr_sum
        # Re-normalize
        total = sum(strategy)
        strategy /= total
        return strategy

    def __str__(self):
        strategies = ['{:03.2f}'.format(x)
                      for x in self.get_average_strategy()]
        return '{} {}'.format(self.key.ljust(6), strategies)


def display_results(ev, i_map):
    print('player 1 expected value: {}'.format(ev))
    print('player 2 expected value: {}'.format(-1 * ev))

    print()
    print('player 1 strategies:')
    sorted_items = sorted(i_map.items(), key=lambda x: x[0])
    for _, v in filter(lambda x: len(x[0]) % 2 == 0, sorted_items):
        print(v)
    print()
    print('player 2 strategies:')
    for _, v in filter(lambda x: len(x[0]) % 2 == 1, sorted_items):
        print(v)


if __name__ == "__main__":
    time1 = time.time()
    trainer = PokerB()
    trainer.train(n_iterations=25000)
    print(abs(time1 - time.time()))
    print(sys.getsizeof(trainer))