# Strategy Base (part. Obs.)

>Base class containing the core methods of CRLD agents learning under partial observability in strategy space

In [None]:
#| default_exp Agents/POStrategyBase

In [None]:
#| hide
# Imports for the nbdev development environment
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import jax
import numpy as np
import itertools as it

import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial

from fastcore.utils import *

from pyCRLD.Agents.Base import abase
from pyCRLD.Agents.POBase import aPObase
from pyCRLD.Agents.StrategyBase import strategybase
from pyCRLD.Utils.Helpers import *

In [None]:
#|export
class POstrategybase(aPObase, strategybase):
    """
    Base Class for
    deterministic policy-average independent (multi-agent) partially observable
    temporal-difference reinforcement learning in policy space.
    """
    
    def __init__(self, env, learning_rates, discount_factors,
                 choice_intensities=1, **kwargs):
        """
        Parameters
        ----------
        env : environment object
        learning_rates : the learning rate(s) for the agents
        discount_factors : the discount factor(s) for the agents
        choice_intensities : inverse temperature of softmax / exploitation level
        """
        self.env = env
        Tt = env.T; assert np.allclose(Tt.sum(-1), 1)
        Rt = env.R
        Ot = env.O    
        super().__init__(Tt, Rt, Ot, discount_factors, **kwargs)
        assert np.allclose(env.F, 0), 'PO learning w final state not def.'

        # learning rates
        self.alpha = make_variable_vector(learning_rates, self.N)
        
        # intensity of choice
        self.beta = make_variable_vector(choice_intensities, self.N)

        self.TDerror = self.RPEioa

In [None]:
#|export
@patch
def random_softmax_policy(self:POstrategybase):
            """Softmax policy with random probabilities."""
            expQ = jnp.exp(np.random.randn(self.N, self.Q, self.M))
            return expQ / expQ.sum(axis=-1, keepdims=True)

In [None]:
#|export
@patch
def zero_intelligence_policy(self:POstrategybase):
            """Policy with equal probabilities."""
            return jnp.ones((self.N, self.Q, self.M)) / float(self.M)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()