# Base

> Base class containing the core methods of CRLD agents

In [1]:
#| default_exp Agents/MultipleObservationsAgentHandler

In [2]:
#| hide
# Imports for the nbdev development environment
from nbdev.showdoc import *
from fastcore.test import *

In [3]:
#| hide
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import numpy as np
import itertools as it
from functools import partial

import jax
from jax import jit
import jax.numpy as jnp

from typing import Iterable
from fastcore.utils import *

from pyCRLD.Utils.Helpers import *

## The agent base class
contains core methods to compute the strategy-average reward-prediction error. 

In [5]:
#| export
class HeteroObsAgentBase(object):
    """
    Base class for deterministic strategy-average independent (multi-agent)
    temporal-difference reinforcement learning. This class provides foundational
    functionality for multi-agent reinforcement learning (MARL) systems, handling
    the computation of average transition models, rewards, and value functions
    based on the collective strategies of all agents in the environment.
    """
    
    def __init__(self, 
                 TransitionTensor: np.ndarray,  # Transition model of the environment.
                 RewardTensor: np.ndarray,      # Reward model of the environment.
                 DiscountFactors: Iterable[float],  # Discount factors for each agent.
                 ObservationTensors: List[np.ndarray],
                 DiscountFactors: Iterable[float],
                 use_prefactor=False,  # Whether to scale values by (1 - discount factor).
                 opteinsum=True):      # Whether to optimize einsum computations.
        """
        Initializes the base class for MARL systems with essential models and parameters.
        
        Args:
            TransitionTensor (np.ndarray): A tensor representing the transition probabilities
                between states in the environment, given the actions of all agents.
            RewardTensor (np.ndarray): A tensor representing the rewards received by agents,
                given the current state, the actions of all agents, and the next state.
            DiscountFactors (Iterable[float]): A collection of discount factors for each agent,
                indicating how future rewards are valued relative to immediate rewards.
            ObservationTensors: List[np.ndarray],
            DiscountFactors: Iterable[float],
            use_prefactor (bool): If True, scales value function computations by (1 - discount factor)
                to keep values on the same scale as immediate rewards.
            opteinsum (bool): If True, enables optimization of einsum operations for efficiency.
        """

        # Initialize observation tensors per agent.
        self.observations = [jnp.array(O_i) for O_i in ObservationTensors]
        self.n_observations_per_agent = [O_i.shape[-1] for O_i in self.O]
        
        # Convert input tensors to JAX numpy arrays for efficiency in computation.
        self.rewards = jnp.array(RewardTensor)
        self.transitions = jnp.array(TransitionTensor)
    
        # Extract the number of agents from the first dimension of the reward tensor.
        self.n_agents = self.rewards.shape[0]  
        assert len(T.shape[1:-1]) == self.n_agents, "Inconsistent number of agents with the transition tensor."
        assert len(R.shape[2:-1]) == self.n_agents, "Inconsistent number of agents with the reward tensor."
        
        # Determine the number of actions for each agent from the transition tensor.
        self.n_agent_actions = self.transitions.shape[1] 
        assert np.allclose(self.transitions.shape[1:-1], n_agent_actions), 'Inconsistent number of actions across dimensions.'
        assert np.allclose(self.rewards.shape[2:-1], n_agent_actions), 'Inconsistent number of actions with the reward tensor.'
        
        # Determine the number of states from the dimensions of the transition tensor.
        self.n_states = self.transitions.shape[0] 
        assert self.transitions.shape[-1] == self.n_states, 'Inconsistent number of states in the transition tensor.'
        assert self.rewards.shape[-1] == self.n_states, 'Inconsistent number of states in the reward tensor.'
        assert self.rewards.shape[1] == self.n_states, 'Inconsistent number of states across dimensions of the reward tensor.'
        
        # Store class variables for use in later computations.
        self.Q = self.n_states
        
        # Convert discount factors to a JAX numpy array and store.
        self.gamma = make_variable_vector(DiscountFactors, N)

        # Determine whether to scale value function computations by (1 - discount factor).
        self.pre = 1 - self.gamma if use_prefactor else jnp.ones(N)
        self.use_prefactor = use_prefactor

        # Pre-compute a tensor for summing actions of other agents, for use in various computations.
        self.Omega = self._OtherAgentsActionsSummationTensor()
        self.has_last_statdist = False
        self._last_statedist = jnp.ones(Z) / Z
        
        # Store the flag indicating whether to optimize einsum operations.
        self.opti = opteinsum    

    @partial(jit, static_argnums=0)
    def Tss(self, policies_per_agent):
        """
        In a game like MultipleObsSocialDilemma with simple transitions but multiple observations,
        this method would theoretically map observations to agent policies and then to actions.
        Since the game's state transitions are straightforward, the focus is on policy derivation.
        
        Args:
            policies_per_agent (List[jnp.ndarray]): Policies derived from each agent's observations.
        """
        # Assuming a single state, focus on how observations influence these policies.
        # For a more complex game with actual state transitions based on observations:
        
        # Initialize an array for combined transitions (if more complex transitions were involved).
        combined_transitions = jnp.zeros((self.n_states, self.n_states))
        
        for agent_idx, policy in enumerate(policies_per_agent):
            # For each agent, derive how their observation influences their policy.
            # Here, we directly use the policy as observation influences are embedded in policy derivation.
            # In a more complex scenario, you'd map observations to states/actions here.
            
            for action_idx in range(self.n_agent_actions):
                # Simplified as direct influence; in more complex scenarios, you'd use observation tensors.
                action_probabilities = policy[:, action_idx]
                transition_probs = self.transitions[:, action_idx, :]
                
                # Weight transition probabilities by action probabilities (assuming more complex transitions).
                weighted_transitions = jnp.einsum('i,ij->ij', action_probabilities, transition_probs)
                
                # Aggregate transitions (if more than one state and complex transitions were involved).
                combined_transitions += weighted_transitions / self.n_agent_actions
        
        # For MultipleObsSocialDilemma, the transition is inherently simple.
        # The method above outlines how you might extend this for more complex scenarios.
        return combined_transitions
  
    @partial(jit, static_argnums=0)    
    def Ris(self,
            Xisa:jnp.ndarray,  # Joint strategy array.
            Risa:jnp.ndarray=None  # Optional pre-computed rewards for speed-up.
           ) -> jnp.ndarray:  # Average reward for state transitions.
        """
        Computes the average reward `Ris`, given a joint strategy `Xisa`.
        This method calculates the expected rewards for state transitions, averaged
        over the joint strategies of all agents, optionally using a pre-computed
        reward tensor for efficiency.
        
        Args:
            Xisa (jnp.ndarray): A joint strategy array.
            Risa (jnp.ndarray, optional): An optional pre-computed reward tensor
                to speed up calculations.
                
        Returns:
            jnp.ndarray: An array representing the average rewards for state transitions
                under the given joint strategy.
        """
        if Risa is None:  # Calculate Ris from scratch if Risa is not provided.
            # Variables for einsum computation.
            i = 0; s = 1; sprim = 2; b2d = list(range(3, 3+self.N))
        
            # Prepare arguments for einsum operation.
            X4einsum = list(it.chain(*zip(Xisa, [[s, b2d[a]] for a in range(self.N)])))

            # Compute average rewards.
            args = X4einsum + [self.T, [s]+b2d+[sprim], self.R, [i, s]+b2d+[sprim], [i, s]]
            return jnp.einsum(*args, optimize=self.opti)
        
        else:  # Use pre-computed Risa to calculate Ris.
            # Indices for agent, state, and action.
            i=0; s=1; a=2
            
            # Perform einsum operation using pre-computed Risa.
            args = [Xisa, [i, s, a], Risa, [i, s, a], [i, s]]
            return jnp.einsum(*args, optimize=self.opti)

       
    @partial(jit, static_argnums=0)    
    def Risa(self,
             Xisa:jnp.ndarray  # Joint strategy array.
            ) -> jnp.ndarray:  # Average reward for state-action pairs.
        """
        Computes the average reward `Risa`, given a joint strategy `Xisa`.
        This function determines the expected rewards for taking specific actions
        in specific states, averaged over the joint strategies of all agents.
        
        Args:
            Xisa (jnp.ndarray): A joint strategy array, indicating the probability
                of each agent choosing each action in each state.
                
        Returns:
            jnp.ndarray: An array representing the average rewards for state-action pairs
                under the given joint strategy.
        """
        # Variables for einsum computation.
        i = 0; a = 1; s = 2; s_ = 3  
        # Indices for other agents.
        j2k = list(range(4, 4+self.N-1))  
        # Indices for actions of all agents.
        b2d = list(range(4+self.N-1, 4+self.N-1 + self.N))  
        # Indices for actions of other agents.
        e2f = list(range(3+2*self.N, 3+2*self.N + self.N-1))  

        # Prepare indices for summation over other agents' actions.
        sumsis = [[j2k[l], s, e2f[l]] for l in range(self.N-1)]  
        otherX = list(it.chain(*zip((self.N-1)*[Xisa], sumsis)))

        # Perform einsum operation to compute average rewards for state-action pairs.
        args = [self.Omega, [i]+j2k+[a]+b2d+e2f] + otherX + [self.T, [s]+b2d+[s_], self.R, [i, s]+b2d+[s_], [i, s, a]]
        return jnp.einsum(*args, optimize=self.opti)
   
       
    @partial(jit, static_argnums=0)            
    def Vis(self,
            Xisa:jnp.ndarray,  # Joint strategy array.
            Ris:jnp.ndarray=None,  # Optional average rewards for speed-up.
            Tss:jnp.ndarray=None,  # Optional average transitions for speed-up.
            Risa:jnp.ndarray=None  # Optional rewards for state-action pairs for speed-up.
           ) -> jnp.ndarray:  # Average state values.
        """
        Computes the average state values `Vis`, given a joint strategy `Xisa`.
        This method calculates the value of being in each state, taking into account
        the expected future rewards based on the joint strategy of all agents.
        
        Args:
            Xisa (jnp.ndarray): A joint strategy array.
            Ris (jnp.ndarray, optional): Pre-computed average rewards for state transitions.
            Tss (jnp.ndarray, optional): Pre-computed average transition probabilities.
            Risa (jnp.ndarray, optional): Pre-computed rewards for state-action pairs.
                
        Returns:
            jnp.ndarray: An array representing the value of each state under the given
                joint strategy.
        """
        # Compute Ris and Tss if not provided.
        Ris = self.Ris(Xisa, Risa=Risa) if Ris is None else Ris
        Tss = self.Tss(Xisa) if Tss is None else Tss
        
        # Indices for agent and states.
        i = 0; s = 1; sp = 2

        # Compute the inverse of the matrix needed for solving the system of linear equations.
        n = np.newaxis
        Miss = np.eye(self.Z)[n,:,:] - self.gamma[:, n, n] * Tss[n,:,:]
        invMiss = jnp.linalg.inv(Miss)
        
        # Solve the system of linear equations to find the state values.
        return self.pre[:,n] * jnp.einsum(invMiss, [i, s, sp], Ris, [i, sp], [i, s], optimize=self.opti)


    @partial(jit, static_argnums=0)        
    def Qisa(self,
             Xisa:jnp.ndarray,  # Joint strategy array.
             Risa:jnp.ndarray=None,  # Optional rewards for speed-up.
             Vis:jnp.ndarray=None,  # Optional state values for speed-up.
             Tisas:jnp.ndarray=None  # Optional transitions for speed-up.
            ) -> jnp.ndarray:  # Average state-action values.
        """
        Computes the average state-action values `Qisa`, given a joint strategy `Xisa`.
        This function estimates the value of taking specific actions in specific states,
        considering the future rewards as influenced by the joint strategy of all agents.
        
        Args:
            Xisa (jnp.ndarray): A joint strategy array.
            Risa (jnp.ndarray, optional): Pre-computed rewards for state-action pairs.
            Vis (jnp.ndarray, optional): Pre-computed values of states.
            Tisas (jnp.ndarray, optional): Pre-computed average transitions from state-action pairs.
                
        Returns:
            jnp.ndarray: An array representing the value of taking specific actions in
                specific states under the given joint strategy.
        """
        # Compute necessary components if not provided.
        Risa = self.Risa(Xisa) if Risa is None else Risa
        Vis = self.Vis(Xisa, Risa=Risa) if Vis is None else Vis
        Tisas = self.Tisas(Xisa) if Tisas is None else Tisas

        # Compute the expected future state values based on the action taken.
        nextQisa = jnp.einsum(Tisas, [0,1,2,3], Vis, [0,3], [0,1,2], optimize=self.opti)

        # Apply discounting and add immediate rewards to find state-action values.
        n = np.newaxis
        return self.pre[:,n,n] * Risa + self.gamma[:,n,n]*nextQisa

    
    
    # === Helper ===
    @partial(jit, static_argnums=0)  
    def _jaxPs(self,
               Xisa,  # Joint strategy
               pS0):  # Last stationary state distribution 
        """
        Compute stationary distribution `Ps`, given joint strategy `Xisa`
        using JAX.
        """
        Tss = self.Tss(Xisa)
        _pS = compute_stationarydistribution(Tss)
        nrS = jnp.where(_pS.mean(0)!=-10, 1, 0).sum()

        @jit
        def single_dist(pS):
            return jnp.max(jnp.where(_pS.mean(0)!=-10,
                                     jnp.arange(_pS.shape[0]), -1))
        @jit
        def multi_dist(pS):
            ix = jnp.argmin(jnp.linalg.norm(_pS.T - pS0, axis=-1))
            return ix
            
        ix = jax.lax.cond(nrS == 1, single_dist, multi_dist, _pS)

        pS = _pS[:, ix]
        return pS
        

## Strategy averaging
Core methods to compute the strategy-average reward-prediction error

In [6]:
show_doc(abase.Tss)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L90){target="_blank" style="float:right; font-size:smaller"}

### abase.Tss

>      abase.Tss (Xisa:jax.Array)

Computes the average transition model `Tss`, given a joint strategy `Xisa`.
This method calculates how the environment's state is expected to change
on average, given the current policies of all agents.

Args:
    Xisa (jnp.ndarray): A joint strategy array, indicating the probability
        of each agent choosing each action in each state.

Returns:
    jnp.ndarray: An array representing the average transition probabilities
        between states under the given joint strategy.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| Xisa | Array | Joint strategy as a numpy array. |
| **Returns** | **Array** | **Returns the average transition matrix.** |

In [7]:
show_doc(abase.Tisas)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L120){target="_blank" style="float:right; font-size:smaller"}

### abase.Tisas

>      abase.Tisas (Xisa:jax.Array)

Computes the average transition model `Tisas`, given a joint strategy `Xisa`.
This function calculates the transition probabilities from state-action pairs
to subsequent states, averaged over the joint strategies of all agents.

Args:
    Xisa (jnp.ndarray): A joint strategy array, indicating the probability
        of each agent choosing each action in each state.

Returns:
    jnp.ndarray: An array representing the average transition probabilities
        from state-action pairs to next states under the given joint strategy.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| Xisa | Array | Joint strategy array. |
| **Returns** | **Array** | **Average transition Tisas.** |

In [8]:
show_doc(abase.Ris)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L160){target="_blank" style="float:right; font-size:smaller"}

### abase.Ris

>      abase.Ris (Xisa:jax.Array, Risa:jax.Array=None)

Computes the average reward `Ris`, given a joint strategy `Xisa`.
This method calculates the expected rewards for state transitions, averaged
over the joint strategies of all agents, optionally using a pre-computed
reward tensor for efficiency.

Args:
    Xisa (jnp.ndarray): A joint strategy array.
    Risa (jnp.ndarray, optional): An optional pre-computed reward tensor
        to speed up calculations.

Returns:
    jnp.ndarray: An array representing the average rewards for state transitions
        under the given joint strategy.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| Xisa | Array |  | Joint strategy array. |
| Risa | Array | None | Optional pre-computed rewards for speed-up. |
| **Returns** | **Array** |  | **Average reward for state transitions.** |

In [9]:
show_doc(abase.Risa)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L200){target="_blank" style="float:right; font-size:smaller"}

### abase.Risa

>      abase.Risa (Xisa:jax.Array)

Computes the average reward `Risa`, given a joint strategy `Xisa`.
This function determines the expected rewards for taking specific actions
in specific states, averaged over the joint strategies of all agents.

Args:
    Xisa (jnp.ndarray): A joint strategy array, indicating the probability
        of each agent choosing each action in each state.

Returns:
    jnp.ndarray: An array representing the average rewards for state-action pairs
        under the given joint strategy.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| Xisa | Array | Joint strategy array. |
| **Returns** | **Array** | **Average reward for state-action pairs.** |

In [10]:
show_doc(abase.Vis)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L235){target="_blank" style="float:right; font-size:smaller"}

### abase.Vis

>      abase.Vis (Xisa:jax.Array, Ris:jax.Array=None, Tss:jax.Array=None,
>                 Risa:jax.Array=None)

Computes the average state values `Vis`, given a joint strategy `Xisa`.
This method calculates the value of being in each state, taking into account
the expected future rewards based on the joint strategy of all agents.

Args:
    Xisa (jnp.ndarray): A joint strategy array.
    Ris (jnp.ndarray, optional): Pre-computed average rewards for state transitions.
    Tss (jnp.ndarray, optional): Pre-computed average transition probabilities.
    Risa (jnp.ndarray, optional): Pre-computed rewards for state-action pairs.

Returns:
    jnp.ndarray: An array representing the value of each state under the given
        joint strategy.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| Xisa | Array |  | Joint strategy array. |
| Ris | Array | None | Optional average rewards for speed-up. |
| Tss | Array | None | Optional average transitions for speed-up. |
| Risa | Array | None | Optional rewards for state-action pairs for speed-up. |
| **Returns** | **Array** |  | **Average state values.** |

In [11]:
show_doc(abase.Qisa)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L273){target="_blank" style="float:right; font-size:smaller"}

### abase.Qisa

>      abase.Qisa (Xisa:jax.Array, Risa:jax.Array=None, Vis:jax.Array=None,
>                  Tisas:jax.Array=None)

Computes the average state-action values `Qisa`, given a joint strategy `Xisa`.
This function estimates the value of taking specific actions in specific states,
considering the future rewards as influenced by the joint strategy of all agents.

Args:
    Xisa (jnp.ndarray): A joint strategy array.
    Risa (jnp.ndarray, optional): Pre-computed rewards for state-action pairs.
    Vis (jnp.ndarray, optional): Pre-computed values of states.
    Tisas (jnp.ndarray, optional): Pre-computed average transitions from state-action pairs.

Returns:
    jnp.ndarray: An array representing the value of taking specific actions in
        specific states under the given joint strategy.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| Xisa | Array |  | Joint strategy array. |
| Risa | Array | None | Optional rewards for speed-up. |
| Vis | Array | None | Optional state values for speed-up. |
| Tisas | Array | None | Optional transitions for speed-up. |
| **Returns** | **Array** |  | **Average state-action values.** |

## Helpers

In [12]:
#| export
@patch
def Ps(self:abase,
       Xisa:jnp.ndarray # Joint strategy
       ) -> jnp.ndarray: # Stationary state distribution
    """Compute stationary state distribution `Ps`, given joint strategy `Xisa`."""
    
    # To make it work with JAX just-in-time compilation
    if self.has_last_statdist: # Check whether we found a previous Ps
        # If so, use jited computation
        Ps =  self._jaxPs(Xisa, self._last_statedist)
    else:
        # If not, use the slower numpy implementation once
        Ps = jnp.array(self._numpyPs(Xisa))
        self.has_last_statdist = True

    self._last_statedist = Ps
    return Ps


@patch
def _numpyPs(self:abase, Xisa):
    """
    Compute stationary distribution `Ps`, given joint strategy `Xisa`
    just using numpy and without using JAX.
    """
    Tss = self.Tss(Xisa)
    _pS = np.array(compute_stationarydistribution(Tss))

    # clean _pS from unwanted entries 
    _pS = _pS[:, _pS.mean(0)!=-10]
    if len(_pS[0]) == 0:  # this happens when the tollerance can distinquish 
        assert False, 'No _statdist return - must not happen'
    elif len(_pS[0]) > 1:  # Should not happen, in an ideal world
        # sidenote: This means an ideal world is ergodic ;)
        print("More than 1 state-eigenvector found")

        if hasattr(self, '_last_statedist'):  # if last exists
            # take one that is closesd to last
            # Sidenote: should also not happen, because for this case
            # we are using the jitted implementation `_jaxPS`.
            pS0 = self._last_statedist
            choice = np.argmin(np.linalg.norm(_pS.T - pS0, axis=-1))
            print('taking closest to last')
        else: # if no last_Ps exists
            # take a random one.
            print(_pS.round(2))
            nr = len(_pS[0])
            choice = np.random.randint(nr)
            print("taking random one: ", choice)
        _pS = _pS[:, choice] 
        
    return _pS.flatten() # clean

`Ps` uses the `compute_stationarydistribution` function.

In [13]:
from pyCRLD.Environments.EcologicalPublicGood import EcologicalPublicGood as EPG
from pyCRLD.Agents.StrategyActorCritic import stratAC

In [14]:
env = EPG(N=2, f=1.2, c=5, m=-5, qc=0.2, qr=0.01, degraded_choice=False)
MAEi = stratAC(env=env, learning_rates=0.1, discount_factors=0.99, use_prefactor=True)

x = MAEi.random_softmax_strategy()
MAEi._numpyPs(x)

array([0.9028608 , 0.09713921], dtype=float32)

In [15]:
MAEi.Ps(x)

Array([0.9028608 , 0.09713921], dtype=float32)

In [16]:
#| export
@patch
def Ri(self:abase,
       Xisa:jnp.ndarray # Joint strategy `Xisa`
      ) -> jnp.ndarray: # Average reward `Ri`
    """Compute average reward `Ri`, given joint strategy `Xisa`.""" 
    i, s = 0, 1
    return jnp.einsum(self.Ps(Xisa), [s], self.Ris(Xisa), [i, s], [i])

In [17]:
MAEi.Ri(x)

Array([-4.574549 , -4.4456086], dtype=float32)

In [18]:
#| export
@patch
def trajectory(self:abase,
               Xinit:jnp.ndarray,  # Initial condition
               Tmax:int=100, # the maximum number of iteration steps
               tolerance:float=None, # to determine if a fix point is reached 
               verbose=False,  # Say something during computation?
               **kwargs) -> tuple: # (`trajectory`, `fixpointreached`)
    """
    Compute a joint learning trajectory.
    """
    traj = []
    t = 0
    X = Xinit.copy()
    fixpreached = False

    while not fixpreached and t < Tmax:
        print(f"\r [computing trajectory] step {t}", end='') if verbose else None 
        traj.append(X)

        X_, TDe = self.step(X)
        if np.any(np.isnan(X_)):
            fixpreached = True
            break

        if tolerance is not None:
            fixpreached = np.linalg.norm(X_ - X) < tolerance

        X = X_
        t += 1

    print(f" [trajectory computed]") if verbose else None

    return np.array(traj), fixpreached

`trajectory` is an Array containing the time-evolution of the dynamic variable. 
`fixpointreached` is a bool saying whether or not a fixed point has been reached.

In [19]:
#| export
@patch
def _OtherAgentsActionsSummationTensor(self:abase):
    """
    To sum over the other agents and their respective actions using `einsum`.
    """
    dim = np.concatenate(([self.N],  # agent i
                          [self.N for _ in range(self.N-1)],  # other agnt
                          [self.M],  # agent a of agent i
                          [self.M for _ in range(self.N)],  # all acts
                          [self.M for _ in range(self.N-1)]))  # other a's
    Omega = np.zeros(dim.astype(int), int)

    for index, _ in np.ndenumerate(Omega):
        I = index[0]
        notI = index[1:self.N]
        A = index[self.N]
        allA = index[self.N+1:2*self.N+1]
        notA = index[2*self.N+1:]

        if len(np.unique(np.concatenate(([I], notI)))) is self.N:
            # all agents indices are different

            if A == allA[I]:
                # action of agent i equals some other action
                cd = allA[:I] + allA[I+1:]  # other actionss
                areequal = [cd[k] == notA[k] for k in range(self.N-1)]
                if np.all(areequal):
                    Omega[index] = 1

    return jnp.array(Omega)

In [20]:
show_doc(abase._OtherAgentsActionsSummationTensor)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/Base.py#L436){target="_blank" style="float:right; font-size:smaller"}

### abase._OtherAgentsActionsSummationTensor

>      abase._OtherAgentsActionsSummationTensor ()

To sum over the other agents and their respective actions using `einsum`.

To obtain the strategy-average reward-prediction error for agent $i$, we need to average out the probabilities contained in the strategies of all other agents $j \neq i$ and the  transition function $T$, 

$$
\sum_{a^j} \sum_{s'} \prod_{i\neq j} X^j(s, a^j) T(s, \mathbf a, s').
$$

The `_OtherAgentsActionsSummationTensor` enables this summation to be exectued in the efficient `einsum` function. It contains only $0$s and $1$s and is of dimension 

$$
N \times \underbrace{N \times ... \times N}_{(N-1) \text{ times}}
\times M \times \underbrace{M \times ... \times M}_{N \text{ times}}
\times \underbrace{M \times ... \times M}_{(N-1) \text{ times}}
$$

which represent

$$
\overbrace{N}^{\text{the focal agent}} 
\times 
\overbrace{\underbrace{N \times ... \times N}_{(N-1) \text{ times}}}^\text{all other agents}
\times 
\overbrace{M}^\text{focal agent's action} 
\times 
\overbrace{\underbrace{M \times ... \times M}_{N \text{ times}}}^\text{all actions}
\times 
\overbrace{\underbrace{M \times ... \times M}_{(N-1) \text{ times}}}^\text{all other agents' actions}
$$

It contains a $1$ only if

* all agent indices (comprised of the *focal agent* index and *all other agents* indices) are different from each other
* and the *focal agent's action* index matches the focal agents' action index in *all actions* 
* and if *all other agents' action* indices match their corresponding action indices in *all actions*.

Otherwise it contains a $0$.

In [21]:
#| hide
import nbdev; nbdev.nbdev_export()