## Tree approach

Each policy is constructed from a tree, given the action space (e.g. $\{\texttt{c}, \texttt{d}\}$)
and the policy length (i.e., the depth of the tree). The tree is constructed recursively.

In [3]:
import torch
import torch.nn.functional as F

import sys
sys.path.append('../')

from games import *

## SCAFFOLDING
num_actions = 2
num_agents = 3

A_params = torch.ones((num_agents, num_actions, num_actions))
A = A_params / A_params.sum(dim=1, keepdim=True)

B_params = torch.stack([
                torch.eye(num_actions) 
                for _ in range(num_actions) 
                for _ in range(num_agents)
            ]).reshape(num_agents, num_actions, num_actions, num_actions)
B = B_params / B_params.sum(dim=2, keepdim=True)

log_C = prisoners_dilemma_2player

In [7]:
def compute_efe(u, q_s_u, A, log_C):
    '''
    Compute the Expected Free Energy (EFE) of a given action
    
    Args:
        u (torch.Tensor): action
        q_s_u (torch.Tensor): variational posterior over states given action
        A (torch.Tensor): observation likelihood model
        log_C (torch.Tensor): log preference over observations

    Returns:
        EFE (torch.Tensor): Expected Free Energy for a given action
    '''
    EFE = 0 
    ambiguity = 0 
    risk = 0 
    salience = 0 
    pragmatic_value = 0 
    novelty = 0 
    
    # Predictive observation posterior -------------------------------------
    # (per factor 'f' and per possible action 'u')
    # E_{q(s'|u)}[p(o | s)]
    q_o_u = torch.einsum(
        'fos,fs->fo',
        A,         # (f, o, s)
        q_s_u      # (f, s)
    )              # (f, o)
    
    # If ego was to take action u, the observation o_i would be guaranteed
    # to be o_i = u, so replace q(o_i | u) = one_hot(u) for this action
    q_o_u[0] = F.one_hot(u, num_actions).to(torch.float)

    # EFE computation -------------------------------------------------------

    # Per-factor terms
    for factor_idx in range(num_agents):

        # Expected ambiguity term (per factor) -------------------------
        H = -torch.diag(A[factor_idx] @ torch.log(A[factor_idx] + EPSILON))  # Conditional (pseudo?) entropy (of the generated emissions matrix)
        assert H.ndimension() == 1, "H is not a 1-dimensional tensor"

        s_pred = q_s_u[factor_idx]  # shape (2, )
        assert s_pred.ndimension() == 1, "s_pred is not a 1-dimensional tensor"
        
        ambiguity += (H @ s_pred) # Ambiguity is conditional entropy of emissions
        # FIXME: not sure if these definitions are correct
        # risk[u_i] += (o_pred @ (torch.log(o_pred + EPSILON)))  - (o_pred @ log_C_modality) # Risk is negative posterior predictive entropy minus pragmatic value
        # salience[u_i] += -(o_pred @ (torch.log(o_pred + EPSILON)))  - (H @ s_pred) # Salience is negative posterior predictive entropy minus ambiguity (0)
        # pragmatic_value[u_i] += (o_pred @ log_C_modality) # Pragmatic value is negative cross-entropy

    # Joint predictive observation posterior ---------------------------
    # q(o_i, o_j, o_k | u)
    # Create the einsum subscripts string dynamically for n_agents
    # e.g., if n_agents = 3, this will be 'i,j,k->ijk'
    einsum_str = (
        ','.join([chr(105 + i) for i in range(num_agents)]) 
        + '->' 
        + ''.join([chr(105 + i) for i in range(num_agents)])
    )
    q_o_joint_u = torch.einsum(
        einsum_str, 
        *[q_o_u[i] for i in range(num_agents)]
    )
    # assert q_o_joint_u.shape == (num_actions, ) * (num_agents), (
    #     f"q_o_joint_u shape {q_o_joint_u.shape} != {(num_actions, ) * (num_agents)}"
    # )
    # assert torch.allclose(q_o_joint_u.sum(), torch.tensor(1.0)), (
    #     f"q_o_joint_u sum {q_o_joint_u.sum()} != 1.0"
    # )

    # Risk term (joint) ------------------------------------------------
    # i.e. KL[q(o|u) || p*(o)]
    risk = torch.tensordot(
        (torch.log(q_o_joint_u + EPSILON) - log_C),
        q_o_joint_u,
        dims=num_agents
    )

    # Novelty ----------------------------------------------------------
    # if self.compute_novelty:
    #     novelty[u] += self.compute_A_novelty(u)  # TODO: some sort of regularisation, novelty can be really large
    #     # TODO: B novelty?

    EFE = ambiguity + risk - novelty
    # assert not torch.any(torch.isnan(EFE)), f"EFE has NaN: {EFE}"
    # assert torch.allclose(risk[u_i] + ambiguity[u_i], EFE[u_i], atol=1e-4), f"[u_i = {u_i}] risk + ambiguity ({risk[u_i]} + {ambiguity[u_i]}={risk[u_i] + ambiguity[u_i]}) does not equal EFE (={EFE[u_i]})"
    # assert torch.allclose(-salience[u_i] - pragmatic_value[u_i], EFE[u_i], atol=1e-4), f"[u_i = {u_i}] -salience - pragmatic value (-{salience[u_i]} - {pragmatic_value[u_i]}={-salience[u_i] - pragmatic_value[u_i]}) does not equal EFE (={EFE[u_i]})"
    
    return EFE

In [8]:
import torch

class TreeNode:
    def __init__(self, action=None, depth=0):
        self.u = action  # Tensor or action at this node
        self.EFE_u = torch.tensor(0)
        self.children = []  # List to hold child nodes
        self.depth = depth  # Depth of the node

    def add_child(self, action):
        # Add a child node with the given action (as a tensor)
        child = TreeNode(action, self.depth + 1)
        self.children.append(child)
        return child
    
    def __repr__(self):
        return f"TreeNode(action={self.u}, depth={self.depth}, children={len(self.children)})"
    

def build_policy_tree(action_space, max_depth, node=None):
    '''Recursive function to build the tree'''

    if node is None:
        node = TreeNode()

    # Base case
    if node.depth == max_depth:
        return node

    # Add a child for each action in the action space 
    # (recursively until the max depth is reached)
    for action in action_space:
        child = node.add_child(torch.tensor([action]))
        build_policy_tree(action_space, max_depth, node=child)

    return node

def collect_policies(
        node, 
        q_s,
        policy_EFEs=None,
        current_policy=None,
        ):
    '''Function to traverse the tree and collect policies (as tensors)'''
    
    # Root node case
    if current_policy is None:
        current_policy = []
    if policy_EFEs is None:
        policy_EFEs = []
    if node.u is None:
        node.q_s_u = q_s
        new_policy_EFEs = policy_EFEs
    # Other nodes
    else:
        # Compute q(s|u) and EFE(u) for the current node
        node.q_s_u = torch.einsum(
                'funk,fk->fun',
                B,            # (f, u, n, k): factor, u (action), next (state), kurrent (state)
                q_s           # (f, k)
            )[:, node.u].squeeze()  # (f, u, n) -> (f, n)
    
        node.EFE_u = compute_efe(node.u, node.q_s_u, A, log_C).unsqueeze(0)
        new_policy_EFEs = policy_EFEs + [node.EFE_u]  # EFEs collected top-down
    
    # Base case (leaf node)
    if not node.children:
        return [torch.cat(new_policy_EFEs)], [torch.cat(current_policy)]

    # Recursive case
    EFEs = []
    policies = []
    for child in node.children:
        new_policy = current_policy + [child.u]  # Policies collected bottom-up
        node_EFE, sub_policy = collect_policies(child, node.q_s_u, new_policy_EFEs, new_policy)
        EFEs.extend(node_EFE)
        policies.extend(sub_policy)

    return torch.vstack(EFEs), torch.vstack(policies)

## Usage example

Below we show how the algorithm works with different policy lengths. Since the `A` and `B` models are uniform/identity, the output is not very exciting.

In [20]:
q_s = torch.tensor([  # Dummy initial beliefs
    [0.3, 0.7],
    [0.8, 0.2],
    [0.3, 0.7],
])

# Showcase the algorithm for different policy lengths
for policy_length in [1, 2, 3]:
    print('\nPolicy length:', policy_length)

    # Build the tree with tensors
    root = build_policy_tree(torch.arange(num_actions), policy_length)

    # Collect policies
    EFEs, policies = collect_policies(
        root,
        q_s=q_s,
    )

    for p, EFE in zip(policies, EFEs.sum(dim=1)):
        print(f'EFE{p.tolist()} = {EFE.item():.3f}')


Policy length: 1
EFE[0] = -1.807
EFE[1] = -1.807

Policy length: 2
EFE[0, 0] = -3.614
EFE[0, 1] = -3.614
EFE[1, 0] = -3.614
EFE[1, 1] = -3.614

Policy length: 3
EFE[0, 0, 0] = -5.421
EFE[0, 0, 1] = -5.421
EFE[0, 1, 0] = -5.421
EFE[0, 1, 1] = -5.421
EFE[1, 0, 0] = -5.421
EFE[1, 0, 1] = -5.421
EFE[1, 1, 0] = -5.421
EFE[1, 1, 1] = -5.421
