# Strategy Base

> Base class containing the core methods of CRLD agents in strategy space

In [None]:
#| default_exp Agents/StrategyBase

In [None]:
#| hide
# Imports for the nbdev development environment
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import numpy as np
import itertools as it
from functools import partial

# import jax
from jax import jit
import jax.numpy as jnp
from typing import Iterable
from fastcore.utils import *

from pyCRLD.Agents.Base import abase
from pyCRLD.Utils.Helpers import *

In [None]:
#| export
class strategybase(abase):
    """
    Base class for deterministic strategy-average independent (multi-agent)
    temporal-difference reinforcement learning in strategy space.        
    """
    
    def __init__(self,
                 env, # An environment object
                 learning_rates:Union[float, Iterable], # agents' learning rates
                 discount_factors:Union[float, Iterable], # agents' discount factors
                 choice_intensities:Union[float, Iterable]=1.0, # agents' choice intensities
                 use_prefactor=False,  # use the 1-DiscountFactor prefactor
                 opteinsum=True,  # optimize einsum functions
                 **kwargs):

        self.env = env
        Tt = env.T; assert np.allclose(Tt.sum(-1), 1)
        Rt = env.R    
        super().__init__(Tt, Rt, discount_factors, use_prefactor, opteinsum)
        self.F = jnp.array(env.F)

        # learning rates
        self.alpha = make_variable_vector(learning_rates, self.N)

        # intensity of choice
        self.beta = make_variable_vector(choice_intensities, self.N)

        
        self.TDerror = self.RPEisa
        
    @partial(jit, static_argnums=0)
    def step(self,
             Xisa  # Joint strategy
            ) -> tuple:  # (Updated joint strategy, Prediction error)
        """
        Performs a learning step along the reward-prediction/temporal-difference error
        in strategy space, given joint strategy `Xisa`.
        """
        TDe = self.TDerror(Xisa)
        n = jnp.newaxis
        XexpaTDe = Xisa * jnp.exp(self.alpha[:,n,n] * TDe)
        return XexpaTDe / XexpaTDe.sum(-1, keepdims=True), TDe
    
    @partial(jit, static_argnums=0)
    def reverse_step(self,
                    Xisa  # Joint strategy
                    ) -> tuple:  # (Updated joint strategy, Prediction error)
        """
        Performs a reverse learning step in strategy space,
        given joint strategy `Xisa`.
        
        This is useful to compute the separatrix of a multistable regime. 
        """
        TDe = self.TDerror(Xisa)
        n = jnp.newaxis
        XexpaTDe = Xisa * jnp.exp(self.alpha[:,n,n] * -TDe)
        return XexpaTDe / XexpaTDe.sum(-1, keepdims=True), TDe  

Further optional paramerater inherting from `abase`:

|  | Type | Default |  Details |
| -- | -- | -- | -- |
| use_prefactor | bool | False |  use the 1-DiscountFactor prefactor |
| opteinsum | bool | True |  optimize einsum functions |

In [None]:
show_doc(strategybase.step)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/StrategyBase.py#L52){target="_blank" style="float:right; font-size:smaller"}

### strategybase.step

>      strategybase.step (Xisa)

Performs a learning step along the reward-prediction/temporal-difference error
in strategy space, given joint strategy `Xisa`.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| Xisa |  | Joint strategy |
| **Returns** | **tuple** | **(Updated joint strategy, Prediction error)** |

In [None]:
show_doc(strategybase.reverse_step)

---

[source](https://github.com/wbarfuss/pyCRLD/blob/main/pyCRLD/Agents/StrategyBase.py#L65){target="_blank" style="float:right; font-size:smaller"}

### strategybase.reverse_step

>      strategybase.reverse_step (Xisa)

Performs a reverse learning step in strategy space,
given joint strategy `Xisa`.

This is useful to compute the separatrix of a multistable regime.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| Xisa |  | Joint strategy |
| **Returns** | **tuple** | **(Updated joint strategy, Prediction error)** |

In [None]:
#| export
@patch
def zero_intelligence_strategy(self:strategybase):
    """Returns strategy `Xisa` with equal action probabilities."""
    return jnp.ones((self.N, self.Z, self.M)) / float(self.M)

In [None]:
#| export
@patch
def random_softmax_strategy(self:strategybase):
    """Returns softmax strategy `Xisa` with random action probabilities."""
    expQ = np.exp(np.random.randn(self.N, self.Z, self.M))
    X = expQ / expQ.sum(axis=-1, keepdims=True)
    return jnp.array(X)

In [None]:
#| export
@patch
def id(self:strategybase
      ) -> str:  # id
    """Returns an identifier to handle simulation runs."""
    envid = self.env.id() + "__"
    agentsid = f"j{self.__class__.__name__}_"

    if hasattr(self, 'O') and hasattr(self, 'Q'):
        agentsid += 'PartObs_'        

    agentsid += f"{str(self.alpha)}_{str(self.gamma)}_{str(self.beta)}"\
        + f"pre{self.use_prefactor}"

    return envid + agentsid

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()