In [1]:
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
N_DICE = (1, 1) # tuple (number of dice P1, number of dice P2)

class Die():
    faces = ["L", "2", "3", "4", "5", "6"] # Llamas first

    def __init__(self):
        self.roll()
    
    # roll the die n times and return the result as a list
    def roll(self, n=1):
        result = np.random.choice(Die.faces, n)
        self.result = result[-1]
        return result
        

class Player():
    action_space = np.concat(
        [[face * i for face in Die.faces] for i in range(1,np.sum(N_DICE) + 1)] +
        [["L" * i for i in range(np.sum(N_DICE) + 1, 2 * np.sum(N_DICE) + 1)]] +
        # [["D", "C"]] # doubt or call
        [["D"]] # just doubt
    )
    marked_for_removal = []
    for i, el in enumerate(action_space):
        if el[0] == "L":
            if len(el) % 2 == 1:
                marked_for_removal.append(i)
            else:
                action_space[i] = el[:len(el)//2]
    action_space = np.delete(action_space, marked_for_removal)
    print(action_space)

    def __init__(self, player_id):
        self.player_id = player_id
        self.die = Die()
        self.private = self.die.roll(N_DICE[player_id])


class Node():
    MAX_DEPTH = 3
    n_nodes = 0
    leaves = []
    deep_nodes = [] # these are not leaves

    def roll2index(roll):
        index = 0
        for i, face in enumerate(roll):
            d = Die.faces.index(face)
            index += (len(Die.faces) ** i) * d
        return index

    def index2roll(index, n_dice):
        roll = []
        for i in range(n_dice):
            roll.append(Die.faces[index % len(Die.faces)])
            index //= len(Die.faces)
        return roll

    def __init__(self, parent, last_action=None):
        self.parent = parent
        if parent is None:
            self.depth = 0
            self.player = 0 # start with player zero
            self.probability = torch.ones([len(Die.faces) ** n for n in N_DICE]) # certain to hit this node
            Node.player_logits = [[], []]
            Node.n_nodes = 0
        else:
            self.depth = self.parent.depth + 1
            self.player = (self.parent.player + 1) % 2
            self.probability = torch.zeros([len(Die.faces) ** n for n in N_DICE])
            if self.player:
                self.probability = self.probability.t()
        Node.n_nodes += 1
        self.last_action = last_action
        self.is_leaf = last_action == "D" or last_action == "C"
        self.children = {} # empty dictionary. Keys are actions and values are Nodes.
        if self.is_leaf:
            self.logits = None
            Node.leaves.append(self)
            self.compute_winner_matrix()
        else:
            self.winner = None
            if last_action is None:
                starting_index = 0
            else:
                starting_index = np.where(Player.action_space == last_action)[0][0] + 1
            possible_actions = Player.action_space[starting_index:]
            if last_action is None:
                possible_actions = np.delete(possible_actions,
                                             np.where([x[0] in ("L", "D", "C") for x in possible_actions])[0])
            n_actions = len(possible_actions)
            n_private = len(Die.faces) ** N_DICE[self.player]

            self.logits = torch.ones(n_private, n_actions, requires_grad=True)
            Node.player_logits[self.player].append(self.logits)
            
            if self.depth < Node.MAX_DEPTH:
                for action in possible_actions:
                    self.children[action] = Node(parent=self, last_action=action)
            else:
                previous_node = self
                for i in range(Node.MAX_DEPTH - 3):
                    previous_node = previous_node.parent
                would_start_with = previous_node.last_action
                if (would_start_with[0] == "L") or ((Node.MAX_DEPTH - self.player) % 2):
                    for action in possible_actions:
                        self.children[action] = Node(parent=self, last_action=action)
                else:
                    Node.deep_nodes.append(self)

        if parent is None: # start connecting the deep nodes
            print("Connecting deepest nodes")
            for node in Node.deep_nodes:
                previous_actions = []
                previous_node = node
                for i in range(Node.MAX_DEPTH - 2): # minus 2 to match the players
                    previous_actions.append(previous_node.last_action)
                    previous_node = previous_node.parent
                # Now find the Node that is similar to node but one depth less.
                next_node = self
                for i in range(Node.MAX_DEPTH - 2):
                    next_node = next_node.children[previous_actions.pop()]
                node.children = next_node.children
    
    def propagate_probability(self, probability):
        # the probability is a matrix with rows possible private info of self.player and
        # columns possible private info of the opponent. Each entry corresponds to the
        # conditional probability of arriving at that node given the private information.
        softmaxed = F.softmax(self.logits, dim=-1)
        for i, child in enumerate(self.children.values()): # correct order verified.
            give_prob = softmaxed[:, i].unsqueeze(1)
            new_probability = (give_prob * probability).t()
            child.probability += new_probability
            if not child.is_leaf:
                child.propagate_probability(new_probability)
    
    def reset_probability(self):
        if self.parent is not None:
            self.probability = torch.zeros([len(Die.faces) ** n for n in N_DICE])
            if self.player:
                self.probability = self.probability.t()
        if not self.is_leaf:
            for child in self.children.values():
                if child.depth > self.depth:
                    child.reset_probability()

    def who_wins(self, my_roll, opponent_roll):
        assert self.is_leaf
        all_dice = my_roll + opponent_roll
        claim = self.parent.last_action # This is my claim
        response = self.last_action # call C or doubt D. My opponent said this.
        face = claim[0]
        quantity = len(claim)
        true_count = all_dice.count(face)
        if face != "L":
            true_count += all_dice.count("L")
        if response == "C":
            if quantity == true_count:
                return 1 - self.player
            return self.player
        if response == "D":
            if true_count >= quantity:
                return self.player
            return 1 - self.player

    def compute_winner_matrix(self):
        assert self.is_leaf
        self.winner = torch.zeros([len(Die.faces) ** n for n in N_DICE]) # zeros or ones does not matter
        if self.player == 1:
            self.winner = self.winner.t()
        shape = self.winner.shape
        for i in range(shape[0]): # my roll index
            for j in range(shape[1]): # opponent roll index
                my_roll = Node.index2roll(i, N_DICE[self.player])
                opponent_roll = Node.index2roll(j, N_DICE[1 - self.player])
                winner = self.who_wins(my_roll, opponent_roll)
                self.winner[i, j] = winner


PLAYERS = [Player(i) for i in range(len(N_DICE))]
grandfather_node = Node(None)

['2' '3' '4' '5' '6' 'L' '22' '33' '44' '55' '66' 'LL' 'D']
Connecting deepest nodes


In [None]:
def get_prob_1_wins():
    grandfather_node.reset_probability()
    grandfather_node.propagate_probability(grandfather_node.probability)
    stacked = torch.stack([(leaf.probability * leaf.winner).t() if leaf.player == 1
                    else (leaf.probability * leaf.winner) for leaf in Node.leaves],
                    dim=0)
    prob_1_wins = torch.sum(stacked) / (len(Die.faces) ** np.sum(N_DICE))
    return prob_1_wins


def calculate_equilibrium(lr = 0.01):
    opt_p0 = torch.optim.Adam(Node.player_logits[0], lr=lr)
    opt_p1 = torch.optim.Adam(Node.player_logits[1], lr=lr)

    for it in range(2000):
        opt_p0.zero_grad()
        prob_1_wins = get_prob_1_wins()
        prob_1_wins.backward()
        opt_p0.step()

        opt_p1.zero_grad()
        prob_1_wins = get_prob_1_wins()
        (-prob_1_wins).backward()
        opt_p1.step()

        if it % 10 == 0:
            print(f"iter {it:4d}   Prob(P1 wins) ≈ {prob_1_wins.item(): .4f}")

calculate_equilibrium()

iter    0   Prob(P1 wins) ≈  0.4877
iter   10   Prob(P1 wins) ≈  0.4899
iter   20   Prob(P1 wins) ≈  0.4915
iter   30   Prob(P1 wins) ≈  0.4925
iter   40   Prob(P1 wins) ≈  0.4928
iter   50   Prob(P1 wins) ≈  0.4921
iter   60   Prob(P1 wins) ≈  0.4901
iter   70   Prob(P1 wins) ≈  0.4872
iter   80   Prob(P1 wins) ≈  0.4846
iter   90   Prob(P1 wins) ≈  0.4839
iter  100   Prob(P1 wins) ≈  0.4859
iter  110   Prob(P1 wins) ≈  0.4896
iter  120   Prob(P1 wins) ≈  0.4933


KeyboardInterrupt: 

In [3]:
print(F.softmax(grandfather_node.logits, dim=-1))
print(Node.n_nodes)
print(len(Node.leaves))
print(Node.n_nodes - len(Node.leaves))
print([len(x) for x in Node.player_logits])
print(grandfather_node.children["2"].children["33"].children["D"].winner)
grandfather_node.reset_probability()
grandfather_node.propagate_probability(grandfather_node.probability)
stacked = torch.stack([leaf.probability.t() if leaf.player == 1
                    else leaf.probability for leaf in Node.leaves],
                    dim=0)
torch.sum(stacked, dim=0)

tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000]], grad_fn=<SoftmaxBackward0>)
726
240
486
[121, 365]
tensor([[1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])


tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<SumBackward1>)

In [None]:
print(Node.index2roll(Node.roll2index(["L", "5", "4"]), 3))
print(Node.MAX_DEPTH)
print(grandfather_node.children["66"])
print(grandfather_node.children["55"].children["66"])
print(grandfather_node.children["44"].children["55"].children["66"])
print(grandfather_node.children["33"].children["44"].children["55"].children["66"])

In [None]:
a = torch.randn(4, 1)
b = torch.randn(4, 5)
a = a.expand((4,5))
print(a)
Die.faces.index("3")

a = (1,2)
a[1] = 11
a



### Insights from Von Neuman's work and ChatGPT

- At each step, for each player, alternating between them, I should optimize for the probability distribution that maximizes the active player's overall expected probability of winning.
- This can be done by computing the expected winning probability of the *entire tree*, and maximizing the probability of the winning leaves by only changing the player's own nodes.

Something like this:
```python
# Optimizers for max (P1) and min (P2)
opt_p1 = torch.optim.Adam([p1_logits], lr=1e-1)
opt_p2 = torch.optim.Adam([p2_logits], lr=1e-1)

for it in range(5000):
    # — P1 update (ascent) —
    opt_p1.zero_grad()
    ev = expected_payoff()
    (-ev).backward(retain_graph=True)   # gradient of –E wrt p1_logits
    opt_p1.step()

    # — P2 update (descent) —
    opt_p2.zero_grad()
    ev = expected_payoff()
    (ev).backward()                     # gradient of  E wrt p2_logits
    opt_p2.step()

    if it % 500 == 0:
        print(f"iter {it:4d}   EV ≈ {ev.item(): .4f}")
```