In [None]:
# You are allowed to use the following modules
import numpy as np
import matplotlib.pyplot as plt
from mountain_car import MountainCar
import pygame as pg
from itertools import product, count
from collections import deque
from tqdm import tqdm
from IPython.display import clear_output

"""
Comment if you get an error, or install the following library to use latex with matplotlib on Linux:
sudo apt-get install texlive-latex-extra texlive-fonts-recommended dvipng cm-super
"""
plt.rcParams['text.usetex'] = True


In [None]:
def _x_s(s: np.array):
    """return x(s) as fourier basis of state"""
    x = np.zeros(NUM_FEATURES)
    for i, c in enumerate(product(range(D + 1), repeat=K)):
        c = np.array(c)
        x[i] = np.cos(np.pi * s.T @ c)
    return x


def x_s(s: np.array):
    """
    return x(s) as fourier basis of state. 
    Look first if previously computated, else compute and store it
    """
    global BASIS
    try:
        return BASIS[tuple(s)]
    except KeyError:
        BASIS[tuple(s)] = _x_s(s)
        return BASIS[tuple(s)]

def x_sa(s: np.array, a: int):
    """return x(s, a) as fourier basis of state, shifted according to the action index"""
    x = np.zeros(NUM_FEATURES * NUM_ACTIONS)
    start = NUM_FEATURES * a
    end = start + NUM_FEATURES
    x[start: end] = x_s(s)
    return x
    
    
def h_s(s: np.array, theta: np.array):
    """return actions' preferences in state s"""
    h = np.zeros(NUM_ACTIONS)
    for a in range(NUM_ACTIONS):
        h[a] = theta @ x_sa(s, a)
    return h

def pi_s(s: np.array, theta: np.array):
    """return policy at state s"""
    h = h_s(s, theta)
    exp = np.exp(h - np.max(h))
    return exp / np.sum(exp)

def v_s(s: np.array, w: np.array):
    """return the value of a state given the weights vector"""
    return w @ x_s(s)

def get_action(s, theta):
    """return index of action at state s according to weights theta"""
    policy = pi_s(s, theta)
    return np.random.choice(range(NUM_ACTIONS), p=policy), policy

def get_pi_gradient(s, a, policy):
    """compute gradient ln pi(a|s, theta), which equals x(s,a) = \sum_b \pi(b|s, theta) x(s,b)"""
    x = x_sa(s, a)
    summation = 0
    for i in range(NUM_ACTIONS):
        summation += policy[i] * x_sa(s, i)
    return x - summation



In [None]:
def actor_critic_et(num_episodes):
    player1 = PongPlayer()
    gamma = 1
    theta = np.zeros(NUM_ACTIONS * NUM_FEATURES)  # theta for each action
    W = np.zeros(NUM_FEATURES)  # weights for estimating v_s
    
    lambda_w = 0.8
    lambda_theta = 0.8
    
    alpha_w = 1e-3
    alpha_theta = 1e-3
    
    steps_per_e = np.zeros(num_episodes)
    
    for episode in range(num_episodes):
        # initialize s
        xv = np.array((np.random.uniform(-0.6, -0.4), 0))
        s = xv_to_s(xv)

        # reset z vectors
        z_theta = np.zeros_like(theta)
        z_w = np.zeros_like(W)

        # reset gamma multiplier
        I = 1
        
        # reset trajectory
        traj = [xv[0]]
        
        # loop through episode
        for t in count():
            # select action
            a, policy = get_action(s, theta)
            
            # take action, observe reward and next state
            x, v = s_to_xv(s)
            xp, vp, r, goal_reached = car.move(x, v, A[a])
            sp = xv_to_s(np.array((xp, vp)))
            traj.append(xp)
            
            # calculate the error (delta) - account for terminal state
            if goal_reached:
                v_sp = 0
            else:
                v_sp = v_s(sp, W)
            delta = r + gamma * v_sp  - v_s(s, W)
            
            # update z_w
            z_w = gamma * lambda_w * z_w + x_s(s)
            
            # update z_theta
            gradient = get_pi_gradient(s, a, policy)
            z_theta = gamma * lambda_theta * z_theta + I * gradient
            
            # update w
            W += alpha_w * delta * z_w
            
            # update theta
            theta += alpha_theta * delta * z_theta
            
            if goal_reached:
                car.goal_reached = False
                break
            
            I *= gamma
            s = sp
        
        steps_per_e[episode] = t
            
    return theta, traj, steps_per_e
            