In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

In [3]:
import algos
import features
import parametric
import policy
import chicken
from agents import OffPolicyAgent, OnPolicyAgent
from rlbench import *

In [42]:
class HordeAgent:
    def __init__(self, algo, pol, phi, update_params=dict()):
        self.algo = algo
        self.pol = pol
        if phi is None: 
            self.phi = lambda x: x 
        else: 
            self.phi = phi
        # default parameter functions to use for updates
        self.param_funcs = {k: parametric.to_parameter(v) 
                            for k, v in update_params.items()}
        
    def update(self, s, a, r, sp, **params):
        # determine the state dependent update params
        update_params = {k: v(s) for k, v in self.param_funcs.items()}
        # compute the action selection probability ratio
        update_params['rho'] = self.pol.prob(s, a)
        update_params.update(**params)
        # get the arguments to pass to the function 
        args = [update_params[k] for k in self.algo.update_params]
        
        # function approximation
        x = self.phi(s)
        xp = self.phi(sp)
        
        return self.algo.update(x, r, xp, *args)
    
    @property
    def theta(self):
        return self.algo.theta
    
    def get_values(self, states):
        """Compute the values for each of the given states."""
        theta = self.theta
        return {s: np.dot(theta, self.phi(s)) for s in states}

In [43]:
def run_many(agent_lst, behavior, env, max_steps):
    steps = []
    error_dct = {}
    t = 0
    
    # reset the environment and get initial state
    env.reset()
    s = env.state
    while not env.is_terminal() and t < max_steps:
        a = behavior.choose(s, env.actions)
        r, sp = env.do(a)
        
        # update the agents
        for agent in agent_lst:
            delta = agent.update(s, a, r, sp)
        
        # record the transition
        ret.append((s, a, r, sp))
        
        # prepare for next iteration
        t += 1
        s = sp
    
    # return information about the run
    ret = {}
    return ret  

In [48]:
def make_agents(algo_lst, target, phi, update_params):
    """A quick function for making HordeAgent objects from a list 
    of algorithm classes. Somewhat brittle.
    """
    ret = []
    for cls in algo_lst:
        algo = cls(phi.length) 
        params = {k: v for k, v in update_params.items() if k in algo.update_params}
        container = HordeAgent(algo, target, phi, params)
        ret.append(container)
    return ret 

In [49]:
# define the experiment
num_states = 8

# set up environment
env = chicken.Chicken(num_states)

# set up algorithm parameters
update_params = {
    'alpha': 0.02,
    'beta': 0.002,
    'gm': 0.9,
    'gm_p': 0.9,
    'lm': 0.0,
    'lm_p': 0.0,
    'interest': 1.0,
}

# Define the target policy
pol_pi = policy.FixedPolicy({s: {0: 1} for s in env.states})

# set feature mapping
# phi = features.RandomBinary(num_features, num_features // 2, random_seed=101011)
num_features = 8
phi = features.Int2Unary(num_states)

agent_lst = make_agents([algos.TD, algos.ETD], pol_pi, phi, update_params)

In [50]:
# Define the experiment
max_steps = 500

# Define the behavior policy
pol_mu = policy.FixedPolicy({s: {0: 1} if s < 4 else {0: 0.5, 1: 0.5} for s in env.states})

data = run_horde(agent_lst, pol_mu, env, max_steps)

In [None]:
# set up algorithm parameters
update_params = {
    'alpha': 0.02,
    'beta': 0.002,
    'gm': 0.9,
    'gm_p': 0.9,
    'lm': 0.0,
    'lm_p': 0.0,
    'interest': 1.0,
}

# Define the target policy
pol_pi = policy.FixedPolicy({s: {0: 1} for s in env.states})
# Define the behavior policy
pol_mu = policy.FixedPolicy({s: {0: 1} if s < 4 else {0: 0.5, 1: 0.5} for s in env.states})


# Run all available algorithms 
max_steps = 50000
for name, alg in algos.algo_registry.items():    
    # Set up the agent, run the experiment, get state-values
    agent = OffPolicyAgent(alg(phi.length), pol_pi, pol_mu, phi, update_params)
    mse_lst = run_errors(agent, env, max_steps, mse_values)
    mspbe_lst = run_errors(agent, env, max_steps, mspbe_values)

    # Plot the errors
    xdata = np.arange(max_steps)
    plt.plot(xdata, mse_lst)
    plt.plot(xdata, mspbe_lst)
#     plt.plot(xdata, np.log(mse_lst))
#     plt.plot(xdata, np.log(mspbe_lst))
    
    # Format and label the graph
    plt.ylim(0, 2)
    plt.title(name)
    plt.xlabel('Timestep')
    plt.ylabel('Error')
    plt.show()