In [1]:
from collections import namedtuple, defaultdict
import numpy as np
import pandas as pd

from attrdict import AttrDict

In [2]:
class Distribution:
    def __init__(self):
        self._parameters = None
        
    @property
    def parameters(self):
        raise NotImplementedError
        
    def sample(self):
        raise NotImplementedError

In [3]:
class NormalDistribution(Distribution):
    def __init__(self, mean=0, variance=1):
        self._parameters = AttrDict(mean=mean, variance=variance)
        
    @property
    def parameters(self):
        return self._parameters.copy()
    
    def sample(self):
        return np.random.normal(loc=selr.parameters.mean, scale=np.sqrt(self.parameters.variance))

In [4]:
class Bandit:
    def __init__(self, n_arms):
        self.n_arms = n_arms
        self.distribution = NormalDistribution(mean=0, variance=1)
        
    def step(self):
        return AttrDict(obs=None, rew=self.distribution.sample(), done=False, info={})

In [5]:
class Policy:
    def __init__(self, state_space, action_space):
        self.state_space=state_space
        self.action_space=action_space
        
        self._q_estimate = None
        self._value_estimate = None
        self._advantage_estimate = None
    
    @property
    def q_estimate(self, state, action) -> pd.DataFrame:
        # Do NOT subclass unless you know what you're doing!
        if self.state_space is None:
            assert isinstance(self._q_estimate, pd.Series)
        else:
            assert isinstance(self._q_estimate, pd.DataFrame)
        return self._q_estimate
    
    @property
    def value_estimate(self, state) -> pd.Series:
        # Do NOT subclass unless you know what you're doing!
        assert isinstance(self._value_estimate, pd.Series)
        return self._value_estimate
    
    @property
    def advantage_estimate(self, state, action) -> pd.DataFrame:
        # Do NOT subclass unless you know what you're doing!
        assert isinstance(self._advantage_estimate, pd.DataFrame)
        return self._advantage_estimate
    
    def update_q_estimate(self, state, action, value):
        assert state in self.state_space
        assert action in self.action_space
        self._q_estimate.set_value(state, action, value)
        
    def act(self):
        raise NotImplementedError

In [6]:
class SampleAveragingPolicy(Policy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reward_history = defaultdict(list)
        if self.state_space is None:
            self._q_estimate = pd.Series(data=np.zeros(shape=len(self.action_space)), index=list(self.action_space))
        else:
            self._q_estimate = pd.DataFrame(data=np.zeros(shape=(len(self.state_space), len(self.action_space))), 
                                            index=list(self.state_space), 
                                            columns=list(self.action_space))

    def update_q_estimate(self, state, action, reward):
        self.reward_history[action].append(reward)
        super().update_q_estimate(state, action, np.mean(self.reward_history[action]))

In [7]:
class GreedyPolicy(SampleAveragingPolicy, Policy):
    def act(self, state):
        # Choose the action that has the highest estimated reward
        if self.state_space is None:
            return self.q_estimate.idxmax()
        return self.q_estimate[state].idxmax()

In [8]:
class Agent:
    def __init__(self, env, policy, horizon):
        self.env = env
        self.policy = policy
        self.horizon = horizon
        
    def train(self):
        init_state = self.env.reset()
        for step in self.horizon:
            action = self.policy.act(init_state)
            new_state, rew, _, _ = env.step(action)
            self.policy.update(init_state, action, new_state, reward)

# Test

In [9]:
bandit = Bandit(10)

In [11]:
greedy_policy = GreedyPolicy(state_space=None, action_space=np.arange(9))

In [13]:
greedy_policy.q_estimate

TypeError: q_estimate() missing 2 required positional arguments: 'state' and 'action'