In [None]:
"""Base class containing the core methods of CRLD agents"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/Agents/99_ABase.ipynb.

# %% auto 0
__all__ = ['abase']

# %% ../../nbs/Agents/99_ABase.ipynb 4
import numpy as np
import itertools as it
from functools import partial

import jax
from jax import jit
import jax.numpy as jnp

from typing import Iterable
from fastcore.utils import *

from ..Utils.Helpers import *

# %% ../../nbs/Agents/99_ABase.ipynb 6
class abase(object):
    """
    Base class for deterministic strategy-average independent (multi-agent)
    temporal-difference reinforcement learning.
    """
    
    def __init__(self, 
                 TransitionTensor: np.ndarray, # transition model of the environment
                 RewardTensor: np.ndarray,  # reward model of the environment
                 DiscountFactors: Iterable[float],  # the agents' discount factors
                 use_prefactor=False,  # use the 1-DiscountFactor prefactor
                 opteinsum=True):  # optimize einsum functions

        '''idk why otiesm is eree'''
                
        R = jnp.array(RewardTensor)
        T = jnp.array(TransitionTensor)

        '''R and T is defined as the reward tensor and the transituon tensor '''
    
        # number of agents
        N = R.shape[0]  
        
        '''what is shape pooR  of agents x joint behaviour profile x current state x transition state'''
        assert len(T.shape[1:-1]) == N, "Inconsistent number of agents"
        assert len(R.shape[2:-1]) == N, "Inconsistent number of agents" 
        
        # number of actions for each agent 

        '''what is transition tensor shape?'''

        M = T.shape[1] 
        assert np.allclose(T.shape[1:-1], M), 'Inconsisten number of actions'
        assert np.allclose(R.shape[2:-1], M), 'Inconsisten number of actions'
        
        # number of states
        Z = T.shape[0] 
        assert T.shape[-1] == Z, 'Inconsisten number of states'
        assert R.shape[-1] == Z, 'Inconsisten number of states'
        assert R.shape[1] == Z, 'Inconsisten number of states'
        
        self.R, self.T, self.N, self.M, self.Z, self.Q = R, T, N, M, Z, Z
        
        '''make variable vector'''
        # discount factors
        self.gamma = make_variable_vector(DiscountFactors, N)

        # use (1-DiscountFactor) prefactor to have values on scale of rewards
        self.pre = 1 - self.gamma if use_prefactor else jnp.ones(N)        
        self.use_prefactor = use_prefactor

        # 'load' the other agents actions summation tensor for speed
        self.Omega = self._OtherAgentsActionsSummationTensor()
        self.has_last_statdist = False
        self._last_statedist = jnp.ones(Z) / Z
        
        # use optimized einsum method
        self.opti = opteinsum  


        '''Tss acerega trasnition model '''

    @partial(jit, static_argnums=0)    
    def Tss(self, 
            Xisa:jnp.ndarray  # Joint strategy
           ) -> jnp.ndarray: # Average transition matrix
        """Compute average transition model `Tss`, given joint strategy `Xisa`"""
        # i = 0  # agent i (not needed)
        s = 1  # state s
        sprim = 2  # next state s'
        b2d = list(range(3, 3+self.N))  # all actions

        X4einsum = list(it.chain(*zip(Xisa, [[s, b2d[a]] for a in range(self.N)])))
        args = X4einsum + [self.T, [s]+b2d+[sprim], [s, sprim]]
        return jnp.einsum(*args, optimize=self.opti)
    
    @partial(jit, static_argnums=0)    
    def Tisas(self,
              Xisa:jnp.ndarray  # Joint strategy
             ) -> jnp.ndarray:  #  Average transition Tisas
        """Compute average transition model `Tisas`, given joint strategy `Xisa`"""      
        i = 0  # agent i
        a = 1  # its action a
        s = 2  # the current state
        s_ = 3  # the next state
        j2k = list(range(4, 4+self.N-1))  # other agents
        b2d = list(range(4+self.N-1, 4+self.N-1 + self.N))  # all actions
        e2f = list(range(3+2*self.N, 3+2*self.N + self.N-1))  # all other acts

        sumsis = [[j2k[l], s, e2f[l]] for l in range(self.N-1)]  # sum inds
        otherX = list(it.chain(*zip((self.N-1)*[Xisa], sumsis)))

        args = [self.Omega, [i]+j2k+[a]+b2d+e2f] + otherX\
            + [self.T, [s]+b2d+[s_], [i, s, a, s_]]
        return jnp.einsum(*args, optimize=self.opti)

    @partial(jit, static_argnums=0)    
    def Ris(self,
            Xisa:jnp.ndarray, # Joint strategy
            Risa:jnp.ndarray=None # Optional reward for speed-up
           ) -> jnp.ndarray: # Average reward
        """Compute average reward `Ris`, given joint strategy `Xisa`""" 
        if Risa is None:  # for speed up
            # Variables      
            i = 0; s = 1; sprim = 2; b2d = list(range(3, 3+self.N))
        
            X4einsum = list(it.chain(*zip(Xisa,
                                    [[s, b2d[a]] for a in range(self.N)])))

            args = X4einsum + [self.T, [s]+b2d+[sprim],
                               self.R, [i, s]+b2d+[sprim], [i, s]]
            return jnp.einsum(*args, optimize=self.opti)
        
        else:  # Compute Ris from Risa 
            i=0; s=1; a=2
            args = [Xisa, [i, s, a], Risa, [i, s, a], [i, s]]
            return jnp.einsum(*args, optimize=self.opti)
        
        '''Ris reward across all ACTIONS/'''
       
    @partial(jit, static_argnums=0)    
    def Risa(self,
             Xisa:jnp.ndarray # Joint strategy
            ) -> jnp.ndarray:  # Average reward
        """Compute average reward `Risa`, given joint strategy `Xisa`"""
        i = 0; a = 1; s = 2; s_ = 3  # Variables
        j2k = list(range(4, 4+self.N-1))  # other agents
        b2d = list(range(4+self.N-1, 4+self.N-1 + self.N))  # all actions
        e2f = list(range(3+2*self.N, 3+2*self.N + self.N-1))  # all other acts
 
        sumsis = [[j2k[l], s, e2f[l]] for l in range(self.N-1)]  # sum inds
        otherX = list(it.chain(*zip((self.N-1)*[Xisa], sumsis)))

        args = [self.Omega, [i]+j2k+[a]+b2d+e2f] + otherX\
            + [self.T, [s]+b2d+[s_], self.R, [i, s]+b2d+[s_],
               [i, s, a]]
        return jnp.einsum(*args, optimize=self.opti)       
       
    @partial(jit, static_argnums=0)            
    def Vis(self,
            Xisa:jnp.ndarray, # Joint strategy
            Ris:jnp.ndarray=None, # Optional reward for speed-up
            Tss:jnp.ndarray=None, # Optional transition for speed-up
            Risa:jnp.ndarray=None  # Optional reward for speed-up
           ) -> jnp.ndarray:  # Average state values
        """Compute average state values `Vis`, given joint strategy `Xisa`"""
        # For speed up
        Ris = self.Ris(Xisa, Risa=Risa) if Ris is None else Ris
        Tss = self.Tss(Xisa) if Tss is None else Tss
        
        i = 0  # agent i
        s = 1  # state s
        sp = 2  # next state s'

        n = np.newaxis
        Miss = np.eye(self.Z)[n,:,:] - self.gamma[:, n, n] * Tss[n,:,:]
        
        invMiss = jnp.linalg.inv(Miss)
               
        return self.pre[:,n] * jnp.einsum(invMiss, [i, s, sp], Ris, [i, sp],
                                          [i, s], optimize=self.opti)
    

    @partial(jit, static_argnums=0)        
    def Qisa(self,
             Xisa:jnp.ndarray, # Joint strategy
             Risa:jnp.ndarray=None, #  Optional reward for speed-up
             Vis:jnp.ndarray=None, # Optional values for speed-up
             Tisas:jnp.ndarray=None, # Optional transition for speed-up
            ) -> jnp.ndarray:  # Average state-action values
        """Compute average state-action values Qisa, given joint strategy `Xisa`"""
        # For speed up
        Risa = self.Risa(Xisa) if Risa is None else Risa
        Vis = self.Vis(Xisa, Risa=Risa) if Vis is None else Vis
        Tisas = self.Tisas(Xisa) if Tisas is None else Tisas

        nextQisa = jnp.einsum(Tisas, [0,1,2,3], Vis, [0,3], [0,1,2],
                              optimize=self.opti)

        n = np.newaxis
        return self.pre[:,n,n] * Risa + self.gamma[:,n,n]*nextQisa
    
    
    # === Helper ===
    @partial(jit, static_argnums=0)  
    def _jaxPs(self,
               Xisa,  # Joint strategy
               pS0):  # Last stationary state distribution 
        """
        Compute stationary distribution `Ps`, given joint strategy `Xisa`
        using JAX.
        """
        Tss = self.Tss(Xisa)
        _pS = compute_stationarydistribution(Tss)
        nrS = jnp.where(_pS.mean(0)!=-10, 1, 0).sum()

        @jit
        def single_dist(pS):
            return jnp.max(jnp.where(_pS.mean(0)!=-10,
                                     jnp.arange(_pS.shape[0]), -1))
        @jit
        def multi_dist(pS):
            ix = jnp.argmin(jnp.linalg.norm(_pS.T - pS0, axis=-1))
            return ix
            
        ix = jax.lax.cond(nrS == 1, single_dist, multi_dist, _pS)

        pS = _pS[:, ix]
        return pS
        

# %% ../../nbs/Agents/99_ABase.ipynb 15
@patch
def Ps(self:abase,
       Xisa:jnp.ndarray # Joint strategy
       ) -> jnp.ndarray: # Stationary state distribution
    """Compute stationary state distribution `Ps`, given joint strategy `Xisa`."""
    
    # To make it work with JAX just-in-time compilation
    if self.has_last_statdist: # Check whether we found a previous Ps
        # If so, use jited computation
        Ps =  self._jaxPs(Xisa, self._last_statedist)
    else:
        # If not, use the slower numpy implementation once
        Ps = jnp.array(self._numpyPs(Xisa))
        self.has_last_statdist = True

    self._last_statedist = Ps
    return Ps


@patch
def _numpyPs(self:abase, Xisa):
    """
    Compute stationary distribution `Ps`, given joint strategy `Xisa`
    just using numpy and without using JAX.
    """
    Tss = self.Tss(Xisa)
    _pS = np.array(compute_stationarydistribution(Tss))

    # clean _pS from unwanted entries 
    _pS = _pS[:, _pS.mean(0)!=-10]
    if len(_pS[0]) == 0:  # this happens when the tollerance can distinquish 
        assert False, 'No _statdist return - must not happen'
    elif len(_pS[0]) > 1:  # Should not happen, in an ideal world
        # sidenote: This means an ideal world is ergodic ;)
        print("More than 1 state-eigenvector found")

        if hasattr(self, '_last_statedist'):  # if last exists
            # take one that is closesd to last
            # Sidenote: should also not happen, because for this case
            # we are using the jitted implementation `_jaxPS`.
            pS0 = self._last_statedist
            choice = np.argmin(np.linalg.norm(_pS.T - pS0, axis=-1))
            print('taking closest to last')
        else: # if no last_Ps exists
            # take a random one.
            print(_pS.round(2))
            nr = len(_pS[0])
            choice = np.random.randint(nr)
            print("taking random one: ", choice)
        _pS = _pS[:, choice] 
        
    return _pS.flatten() # clean

# %% ../../nbs/Agents/99_ABase.ipynb 20
@patch
def Ri(self:abase,
       Xisa:jnp.ndarray # Joint strategy `Xisa`
      ) -> jnp.ndarray: # Average reward `Ri`
    """Compute average reward `Ri`, given joint strategy `Xisa`.""" 
    i, s = 0, 1
    return jnp.einsum(self.Ps(Xisa), [s], self.Ris(Xisa), [i, s], [i])

# %% ../../nbs/Agents/99_ABase.ipynb 22
@patch
def trajectory(self:abase,
               Xinit:jnp.ndarray,  # Initial condition
               Tmax:int=100, # the maximum number of iteration steps
               tolerance:float=None, # to determine if a fix point is reached 
               verbose=False,  # Say something during computation?
               **kwargs) -> tuple: # (`trajectory`, `fixpointreached`)
    """
    Compute a joint learning trajectory.
    """
    traj = []
    t = 0
    X = Xinit.copy()
    fixpreached = False

    while not fixpreached and t < Tmax:
        print(f"\r [computing trajectory] step {t}", end='') if verbose else None 
        traj.append(X)

        X_, TDe = self.step(X)
        if np.any(np.isnan(X_)):
            fixpreached = True
            break

        if tolerance is not None:
            fixpreached = np.linalg.norm(X_ - X) < tolerance

        X = X_
        t += 1

    print(f" [trajectory computed]") if verbose else None

    return np.array(traj), fixpreached

# %% ../../nbs/Agents/99_ABase.ipynb 24
@patch
def _OtherAgentsActionsSummationTensor(self:abase):
    """
    To sum over the other agents and their respective actions using `einsum`.
    """
    dim = np.concatenate(([self.N],  # agent i
                          [self.N for _ in range(self.N-1)],  # other agnt
                          [self.M],  # agent a of agent i
                          [self.M for _ in range(self.N)],  # all acts
                          [self.M for _ in range(self.N-1)]))  # other a's
    Omega = np.zeros(dim.astype(int), int)

    for index, _ in np.ndenumerate(Omega):
        I = index[0]
        notI = index[1:self.N]
        A = index[self.N]
        allA = index[self.N+1:2*self.N+1]
        notA = index[2*self.N+1:]

        if len(np.unique(np.concatenate(([I], notI)))) is self.N:
            # all agents indices are different

            if A == allA[I]:
                # action of agent i equals some other action
                cd = allA[:I] + allA[I+1:]  # other actionss
                areequal = [cd[k] == notA[k] for k in range(self.N-1)]
                if np.all(areequal):
                    Omega[index] = 1

    return jnp.array(Omega)


In [70]:
import numpy as np
import itertools as it
import jax.numpy as jnp
import jax
from functools import partial

class SimpleAgent:
    def __init__(self, TransitionTensor: np.ndarray, num_agents: int, optimize=True):
        self.T = jnp.array(TransitionTensor)  # Transition tensor
        self.N = num_agents  # Number of agents
        self.opti = optimize  # Whether to optimize einsum

    def Tss(self, Xisa: jnp.ndarray) -> jnp.ndarray:
        """Compute average transition matrix `Tss`, given joint strategy `Xisa`"""
        s = 1  # Current state index
        sprim = 2  # Next state index
        b2d = list(range(3, 3+self.N))  # Indices for agent actions

        # Convert indices to Python integers
        s = int(s)
        sprim = int(sprim)
        b2d = [int(x) for x in b2d]

        # Debugging: Print intermediate indices
        print(f"State index (s): {s}")           #
        print(f"Next state index (sprim): {sprim}")
        print(f"Action indices (b2d): {b2d}")


        print('tests for X4einsum')

        print([[s, b2d[a]] for a in range(self.N)])  
        print(list(zip(Xisa, [[s, b2d[a]] for a in range(self.N)])))

        # Pair each agent's strategy with their respective indices
        X4einsum = list(it.chain(*zip(Xisa, [[s, b2d[a]] for a in range(self.N)])))

        # Debugging: Print how X4einsum is constructed
        print("\nX4einsum (flattened strategy + indices):")
        for item in X4einsum:
            if isinstance(item, list):
                print(np.array(item).tolist())  # Convert JAX array to list
            else:
                print(item)

        # Prepare arguments for jnp.einsum
        args = X4einsum + [self.T, [s] + b2d + [sprim], [s, sprim]]

        # Debugging: Print final einsum arguments
        print("\nArguments passed to einsum:")
        for arg in args:
            if isinstance(arg, list):
                print(np.array(arg).tolist())  # Convert JAX array to list
            else:
                print(arg)

        # Compute Tss using einsum
        result = jnp.einsum(*args, optimize=self.opti)
        
        # Convert result to NumPy array for clean printing
        result_np = np.array(result)
        print("\nComputed Tss:\n", result_np)

        return result

# ======= Example Setup =======
# Define a simple transition tensor for 2 states and 2 agents
T = np.array([
    # If in state 0
    [[[0.8, 0.2],  # If agents take (action 0, action 0)
      [0.7, 0.3]], # If agents take (action 0, action 1)
     [[0.6, 0.4],  # If agents take (action 1, action 0)
      [0.5, 0.5]]],# If agents take (action 1, action 1)
    
    # If in state 1
    [[[0.4, 0.6],
      [0.3, 0.7]],
     [[0.2, 0.8],
      [0.1, 0.9]]]
])

#agent0 and agent1 ; 0.8, 0.2

print(T.shape)

# Define a simple joint strategy: probability of choosing each action in each state
Xisa = np.array([
    [[0.9, 0.1],  # Agent 1: prefers action 0 in state 0, action 1 in state 1
     [0.2, 0.8]],
    
    [[0.3, 0.7],  # Agent 2: prefers action 1 in state 0, action 0 in state 1
     [0.6, 0.4]]
])

# Convert Xisa to JAX array
Xisa_jax = jnp.array(Xisa)

# Create an instance of the class and compute Tss
agent = SimpleAgent(T, num_agents=2)
Tss_result = agent.Tss(Xisa_jax)

# Print final result
print("\nFinal Tss (state transition matrix):")
print(Tss_result)


(2, 2, 2, 2)
State index (s): 1
Next state index (sprim): 2
Action indices (b2d): [3, 4]
tests for X4einsum
[[1, 3], [1, 4]]
[(Array([[0.9, 0.1],
       [0.2, 0.8]], dtype=float32), [1, 3]), (Array([[0.3, 0.7],
       [0.6, 0.4]], dtype=float32), [1, 4])]

X4einsum (flattened strategy + indices):
[[0.9 0.1]
 [0.2 0.8]]
[1, 3]
[[0.3 0.7]
 [0.6 0.4]]
[1, 4]

Arguments passed to einsum:
[[0.9 0.1]
 [0.2 0.8]]
[1, 3]
[[0.3 0.7]
 [0.6 0.4]]
[1, 4]
[[[[0.8 0.2]
   [0.7 0.3]]

  [[0.6 0.4]
   [0.5 0.5]]]


 [[[0.4 0.6]
   [0.3 0.7]]

  [[0.2 0.8]
   [0.1 0.9]]]]
[1, 3, 4, 2]
[1, 2]

Computed Tss:
 [[0.71000004 0.29      ]
 [0.20000002 0.8       ]]

Final Tss (state transition matrix):
[[0.71000004 0.29      ]
 [0.20000002 0.8       ]]


In [2]:
import numpy as np
import itertools as it

class SimpleAgent:
    def __init__(self, TransitionTensor: np.ndarray, num_agents: int):
        self.T = np.array(TransitionTensor)  # Transition tensor (NumPy)
        self.N = num_agents  # Number of agents

    def Tss(self, Xisa: np.ndarray) -> np.ndarray:
        """Compute average transition matrix `Tss`, given joint strategy `Xisa`"""
        s = 1  # Current state index
        sprim = 2  # Next state index
        b2d = list(range(3, 3 + self.N))  # Indices for agent actions

        # Convert to plain Python integers
        s, sprim = int(s), int(sprim)
        b2d = [int(x) for x in b2d]

        # Print intermediate indices
        print(f"State index (s): {s}")
        print(f"Next state index (sprim): {sprim}")
        print(f"Action indices (b2d): {b2d}")


        print(list(zip(Xisa, [[s, b2d[a]] for a in range(self.N)])))

        # Pair each agent's strategy with their respective indices
        X4einsum = list(it.chain(*zip(Xisa, [[s, b2d[a]] for a in range(self.N)])))

        # Print how X4einsum is constructed
        print("\nX4einsum ( strategy + indices):")
        for item in X4einsum:
            if isinstance(item, (list, tuple)):  # If it's an index pair
                print(item)
            else:  # If it's a NumPy array
                print(item.tolist())

        # Prepare arguments for np.einsum
        args = X4einsum + [self.T, [s] + b2d + [sprim], [s, sprim]]

        # Print final einsum arguments
        print("\nArguments passed to einsum:")
        for arg in args:
            if isinstance(arg, (list, tuple)):
                print(arg)
            else:
                print(arg.tolist())


        # Compute Tss using NumPy einsum
        result = np.einsum(*args)

        # Print final computed Tss
        print("\nComputed Tss:")
        print(result.tolist())

        return result

# ======= Example Setup =======
# Define a simple transition tensor for 2 states and 2 agents
T = np.array([
    # If in state 0
    [[[0.8, 0.2],  # If agents take (action 0, action 0)
      [0.7, 0.3]], # If agents take (action 0, action 1)
     [[0.6, 0.4],  # If agents take (action 1, action 0)
      [0.5, 0.5]]],# If agents take (action 1, action 1)
    
    # If in state 1
    [[[0.4, 0.6],
      [0.3, 0.7]],
     [[0.2, 0.8],
      [0.1, 0.9]]]
])

# Define a simple joint strategy: probability of choosing each action in each state
Xisa = np.array([
    [[0.9, 0.1],  # Agent 1: prefers action 0 in state 0, action 1 in state 1
     [0.2, 0.8]],
    
    [[0.3, 0.7],  # Agent 2: prefers action 1 in state 0, action 0 in state 1
     [0.6, 0.4]]
])

# Create an instance of the class and compute Tss
agent = SimpleAgent(T, num_agents=2)
Tss_result = agent.Tss(Xisa)

# Print final result
print("\nFinal Tss (state transition matrix):")
print(Tss_result.tolist())  # Ensures clean printing





State index (s): 1
Next state index (sprim): 2
Action indices (b2d): [3, 4]
[(array([[0.9, 0.1],
       [0.2, 0.8]]), [1, 3]), (array([[0.3, 0.7],
       [0.6, 0.4]]), [1, 4])]

X4einsum ( strategy + indices):
[[0.9, 0.1], [0.2, 0.8]]
[1, 3]
[[0.3, 0.7], [0.6, 0.4]]
[1, 4]

Arguments passed to einsum:
[[0.9, 0.1], [0.2, 0.8]]
[1, 3]
[[0.3, 0.7], [0.6, 0.4]]
[1, 4]
[[[[0.8, 0.2], [0.7, 0.3]], [[0.6, 0.4], [0.5, 0.5]]], [[[0.4, 0.6], [0.3, 0.7]], [[0.2, 0.8], [0.1, 0.9]]]]
[1, 3, 4, 2]
[1, 2]

Computed Tss:
[[0.7100000000000001, 0.29], [0.2, 0.8]]

Final Tss (state transition matrix):
[[0.7100000000000001, 0.29], [0.2, 0.8]]


In [4]:
print(T.shape[1:-1])

(2, 2)


In [14]:
array = np.zeros((3,3))
for i, _ in np.ndenumerate(array):
    print(i)



(0, 0)
(0, 1)
(0, 2)
(1, 0)
(1, 1)
(1, 2)
(2, 0)
(2, 1)
(2, 2)


In [None]:
import numpy as np
import jax.numpy as jnp

class abase:
    def __init__(self, N, M):
        self.N = N  # Number of agents
        self.M = M  # Number of actions per agent

    def _OtherAgentsActionsSummationTensor(self):
        """
        To sum over the other agents and their respective actions using `einsum`.
        """
        # Construct dimension sizes
        dim = np.concatenate(([self.N],  # agent i
                              [self.N for _ in range(self.N-1)],  # other agents
                              [self.M],  # action a of agent i
                              [self.M for _ in range(self.N)],  # all actions
                              [self.M for _ in range(self.N-1)]))  # other agents' actions

        print("Dimension sizes (dim):", dim)  # Debugging step

        # Initialize Omega tensor with zeros
        Omega = np.zeros(dim.astype(int), int)
        print("Initialized Omega tensor with shape:", Omega.shape)

        # Iterate over all indices of the tensor
        for index, _ in np.ndenumerate(Omega):

            I = index[0]  # Agent i, I 
            notI = index[1:self.N]  # Other agents
            A = index[self.N]  # Action of agent i
            allA = index[self.N+1:2*self.N+1]  # Actions of all agents
            notA = index[2*self.N+1:]  # Other agents' actions

            print("\nCurrent Index:", index)
            print("Agent I:", I)
            print("Other agents:", notI)
            print("Action of I:", A)
            print("All actions:", allA)
            print("Other agents' actions:", notA)

            #self.N+1: 2*self.N + 1 

            # Check if all agent indices are unique
            unique_agents = np.unique(np.concatenate(([I], notI)))
            # print("Unique Agents:", unique_agents)

            if len(unique_agents) == self.N:
                # print("✅ All agents are unique")

                # Check if the action of agent i matches the corresponding entry in allA
                if A == allA[I]:
                    # print("✅ Action of agent i matches allA[I]")

                    cd = allA[:I] + allA[I+1:]  # Other actions
                    # print("cd (other actions):", cd)

                    areequal = [cd[k] == notA[k] for k in range(self.N-1)]
                    # print("Comparison result (areequal):", areequal)

                    if np.all(areequal):
                        # print("✅  Setting Omega[", index, "] = 1")
                        Omega[index] = 1

        # print("\nFinal Omega tensor:\n", Omega)
        return jnp.array(Omega)

# Example Usage
N = 3  # Number of agents
M = 2  # Number of actions per agent

obj = abase(N, M)
Omega_result = obj._OtherAgentsActionsSummationTensor()


Dimension sizes (dim): [3 3 3 2 2 2 2 2 2]
Initialized Omega tensor with shape: (3, 3, 3, 2, 2, 2, 2, 2, 2)

Current Index: (0, 0, 0, 0, 0, 0, 0, 0, 0)
Agent I: 0
Other agents: (0, 0)
Action of I: 0
All actions: (0, 0, 0)
Other agents' actions: (0, 0)

Current Index: (0, 0, 0, 0, 0, 0, 0, 0, 1)
Agent I: 0
Other agents: (0, 0)
Action of I: 0
All actions: (0, 0, 0)
Other agents' actions: (0, 1)

Current Index: (0, 0, 0, 0, 0, 0, 0, 1, 0)
Agent I: 0
Other agents: (0, 0)
Action of I: 0
All actions: (0, 0, 0)
Other agents' actions: (1, 0)

Current Index: (0, 0, 0, 0, 0, 0, 0, 1, 1)
Agent I: 0
Other agents: (0, 0)
Action of I: 0
All actions: (0, 0, 0)
Other agents' actions: (1, 1)

Current Index: (0, 0, 0, 0, 0, 0, 1, 0, 0)
Agent I: 0
Other agents: (0, 0)
Action of I: 0
All actions: (0, 0, 1)
Other agents' actions: (0, 0)

Current Index: (0, 0, 0, 0, 0, 0, 1, 0, 1)
Agent I: 0
Other agents: (0, 0)
Action of I: 0
All actions: (0, 0, 1)
Other agents' actions: (0, 1)

Current Index: (0, 0, 0, 0,

In [None]:
import numpy as np
import jax.numpy as jnp

def compute_Vis(Xisa: jnp.ndarray, gamma: jnp.ndarray, pre: jnp.ndarray,
                Ris: jnp.ndarray = None, Tss: jnp.ndarray = None, 
                Risa: jnp.ndarray = None) -> jnp.ndarray:
    """
    Compute average state values `Vis`, given a joint strategy `Xisa`,
    reward matrix `Ris`, and transition probabilities `Tss`.

    Parameters:
        Xisa (jnp.ndarray): Joint strategy matrix.
        gamma (jnp.ndarray): Discount factor for future rewards.
        pre (jnp.ndarray): Pre-multiplier for state values.
        Ris (jnp.ndarray, optional): Reward matrix. If None, defaults to a sample.
        Tss (jnp.ndarray, optional): State transition probabilities. If None, defaults to a sample.
        Risa (jnp.ndarray, optional): Additional reward parameter (not used here).

    Returns:
        jnp.ndarray: Computed average state values `Vis`.
    """
    
    Z = Xisa.shape[1]  # Number of states
    n = np.newaxis  # Expands dimensions for broadcasting

    print("\n==== Step 1: Initialize Parameters ====")
    print(f"Number of Agents: {Xisa.shape[0]}, Number of States: {Z}")
    print(f"Gamma (Discount Factor): \n{gamma}")
    print(f"Pre-multipliers: \n{pre}")

    # Default random reward and transition matrices if not provided
    if Ris is None:
        Ris = jnp.array(np.random.rand(Xisa.shape[0], Z))  # Random rewards (num_agents x num_states)
        print("\n==== Step 2: Generating Random Reward Matrix ====")
    print(f"Reward Matrix (Ris): \n{Ris}")

    if Tss is None:
        Tss = jnp.array(np.random.rand(Z, Z))  # Random transition matrix (num_states x num_states)
        print("\n==== Step 3: Generating Random Transition Matrix ====")
    print(f"Transition Matrix (Tss): \n{Tss}")

    # Compute matrix for Bellman-like equation
    Miss = np.eye(Z)[n, :, :] - gamma[:, n, n] * Tss[n, :, :]

    print(np.eye(Z)[n, :, :], "np.eye(Z)[n, :, :]")
    print(gamma[:, n, n].shape, "gamma[:, n, n]")
    print(Tss[n, :, :].shape, "Tss[n, :, :],")
    print((gamma[:, n, n] * Tss[n, :, :]).shape, "gamma[:, n, n] * Tss[n, :, :]")
    

  
    print("\n==== Step 4: Compute Miss Matrix ====")
    print(f"Miss Matrix: \n{Miss}")

    # Compute the inverse of Miss
    invMiss = jnp.linalg.inv(Miss)
    print("\n==== Step 5: Compute Inverse of Miss ====")
    print(f"Inverse Miss Matrix: \n{invMiss}")

    # Compute Vis using tensor contraction (Einstein summation)
    Vis = pre[:, n] * jnp.einsum(invMiss, [0, 1, 2], Ris, [0, 2], [0, 1], optimize=True)
    print("\n==== Step 6: Compute Final Vis Values ====")
    print(f"Computed Vis (Average State Values): \n{Vis}")

    return Vis

# Example Usage:
num_agents = 2
num_states = 3

# Generate random joint strategy, discount factors, and pre-multipliers
Xisa = jnp.array(np.random.rand(num_agents, num_states, num_states))
gamma = jnp.array([0.5]*num_agents)  # Discount factor per agent
pre = jnp.array([0.5]*num_agents)    # Pre-multiplier per agent

# Call the function
Vis_values = compute_Vis(Xisa, gamma, pre)

# Print final output
print("\n==== Final Output ====")
print("Computed Vis (Average State Values):\n", Vis_values)



==== Step 1: Initialize Parameters ====
Number of Agents: 2, Number of States: 3
Gamma (Discount Factor): 
[0.5 0.5]
Pre-multipliers: 
[0.5 0.5]

==== Step 2: Generating Random Reward Matrix ====
Reward Matrix (Ris): 
[[0.78257215 0.36220554 0.6678415 ]
 [0.85826486 0.508726   0.25482106]]

==== Step 3: Generating Random Transition Matrix ====
Transition Matrix (Tss): 
[[0.23162678 0.9234872  0.7622641 ]
 [0.27162108 0.1299726  0.43560246]
 [0.00553137 0.5795066  0.13084036]]
[[[1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]] np.eye(Z)[n, :, :]
(2, 1, 1) gamma[:, n, n]
(1, 3, 3) Tss[n, :, :],
(2, 3, 3) gamma[:, n, n] * Tss[n, :, :]

==== Step 4: Compute Miss Matrix ====
Miss Matrix: 
[[[ 0.8841866  -0.4617436  -0.38113204]
  [-0.13581054  0.9350137  -0.21780123]
  [-0.00276568 -0.2897533   0.93457985]]

 [[ 0.8841866  -0.4617436  -0.38113204]
  [-0.13581054  0.9350137  -0.21780123]
  [-0.00276568 -0.2897533   0.93457985]]]

==== Step 5: Compute Inverse of Miss ====
Inverse Miss Matrix: 
[[[1.26

In [21]:
test_array = jnp.array([5,0,5,4])
print(jnp.where(test_array > 1, jnp.arange(4), 0))

[0 0 2 3]


In [1]:
import jax
import jax.numpy as jnp

def compute_stationarydistribution(Tkk: jnp.ndarray):
    """Compute stationary distribution for transition matrix Tkk."""
    
    # Step 1: Eigenvalues and Eigenvectors
    oeival, oeivec = jnp.linalg.eig(Tkk.T)
    oeival = oeival.real
    oeivec = oeivec.real
    
    print("Eigenvalues (oeival):", oeival)
    print("Eigenvectors (oeivec):\n", oeivec)
    
    # Step 2: Define mask for eigenvalues close to 1
    get_mask = lambda tol: jnp.abs(oeival - 1) < tol
  
    # Step 3: Generate tolerances from 0.1 to 1e-15
    tolerances = jax.lax.map(lambda x: 0.1**x, jnp.arange(1, 16, 1))
    print("\nTolerances:", tolerances)
    
    # Step 4: Apply mask for each tolerance
    masks = jax.lax.map(get_mask, tolerances)
    print("\nMasks (for each tolerance):\n", masks)
    

    # Step 5: Select the strictest tolerance with at least one match
    ix = jnp.max(jnp.where(masks.sum(-1) >= 1, jnp.arange(len(masks)), -1))

    print("masks.sum(-1) >= 1", masks.sum(-1))
    print(jnp.where(masks.sum(-1) >= 1, jnp.arange(len(masks)), -1), "ix")
    
    mask = masks[ix]
    tol = tolerances[ix]
    
    print("\nSelected tolerance index:", ix)
    print("Selected tolerance value:", tol)
    print("Selected mask:", mask)
    
    # Step 6: Extract the eigenvector corresponding to eigenvalue 1
    meivec = jnp.where(mask, oeivec, -42)
    print("\nMasked eigenvectors (meivec):\n", meivec)
    
    # Step 7: Normalize to get stationary distribution
    dist = meivec / meivec.sum(axis=0, keepdims=True)
    dist = jnp.where(dist < tol, 0, dist)
    dist = dist / dist.sum(axis=0, keepdims=True)
    
    print("\nNormalized stationary distribution before final cleanup:\n", dist)
    
    # Step 8: Handle invalid entries
    final_dist = jnp.where(meivec == -42, -10, dist)
    print("\nFinal stationary distribution:\n", final_dist)
    
    return final_dist


# Example: Simple 3x3 transition matrix
Tkk = jnp.array([[0.9, 0.075, 0.025],
                 [0.15, 0.8, 0.05],
                 [0.25, 0.25, 0.5]])

# Run the function
compute_stationarydistribution(Tkk)


Eigenvalues (oeival): [1.0000001  0.7414212  0.45857865]
Eigenvectors (oeivec):
 [[ 0.8908709   0.7365801  -0.27569187]
 [ 0.44543526 -0.67339206 -0.52773327]
 [ 0.08908706 -0.06318857  0.80342495]]

Tolerances: [1.00000001e-01 1.00000007e-02 1.00000005e-03 1.00000005e-04
 1.00000007e-05 1.00000011e-06 1.00000008e-07 1.00000008e-08
 1.00000008e-09 1.00000015e-10 1.00000017e-11 1.00000021e-12
 1.00000019e-13 1.00000024e-14 1.00000022e-15]

Masks (for each tolerance):
 [[ True False False]
 [ True False False]
 [ True False False]
 [ True False False]
 [ True False False]
 [ True False False]
 [False False False]
 [False False False]
 [False False False]
 [False False False]
 [False False False]
 [False False False]
 [False False False]
 [False False False]
 [False False False]]
masks.sum(-1) >= 1 [1 1 1 1 1 1 0 0 0 0 0 0 0 0 0]
[ 0  1  2  3  4  5 -1 -1 -1 -1 -1 -1 -1 -1 -1] ix

Selected tolerance index: 5
Selected tolerance value: 1.0000001e-06
Selected mask: [ True False False]

Masked

Array([[  0.6250001 , -10.        , -10.        ],
       [  0.3124999 , -10.        , -10.        ],
       [  0.06249999, -10.        , -10.        ]], dtype=float32)