In [None]:
import numpy as np
import torch 
import torch.nn as nn 
from torch.distributions import Categorical
import matplotlib.pyplot as plt 
from matplotlib import rcParams
rcParams['font.size'] = 24
rcParams['figure.figsize'] = (16, 8)
from tqdm import tqdm 

import importlib 
import ipywidgets
from ipywidgets import interact
from IPython.display import Image
import IPython

from rllib.dataset.datatypes import Observation
from rllib.util.utilities import get_entropy_and_log_p

from rllib.util.training.agent_training import train_agent
from rllib.environment import GymEnvironment
from rllib.environment.mdps import EasyGridWorld
from rllib.policy import TabularPolicy
from rllib.value_function import TabularQFunction, TabularValueFunction

In [None]:
def extract_policy(q_function):
    """Extract a policy from the q_function."""
    policy = TabularPolicy(num_states=q_function.num_states,
                           num_actions=q_function.num_actions)
    for state in range(policy.num_states):
        q_val = q_function(torch.tensor(state).long())
        action = torch.argmax(q_val)

        policy.set_value(state, action)

    return policy

def integrate_q(q_function, policy):
    value_function = TabularValueFunction(num_states=q_function.num_states)
    for state in range(policy.num_states):
        state = torch.tensor(state).long()
        pi = Categorical(logits=policy(state))
        value = 0
        for action in range(policy.num_actions):
            value += pi.probs[action] * \
                q_function(state, torch.tensor(action).long())

        value_function.set_value(state, value)

    return value_function


environment = EasyGridWorld()
Image("images/grid_world.png")

# Plotters
def plot_value_function(value_function, ax):
    ax.imshow(value_function)
    rows, cols = value_function.shape
    for i in range(rows):
        for j in range(cols):
            ax.text(j, i, f"{value_function[i, j]:.1f}",
                    ha="center", va="center", color="w")

def policy2str(policy):
    left = u'\u2190'
    right = u'\u2192'
    up = u'\u2191'
    down = u'\u2193'
    policy_str = ""
    if 0 == policy:
        policy_str += down 
    if 1 == policy:
        policy_str += up 
    if 2 == policy:
        policy_str += right
    if 3 == policy:
        policy_str += left
    return policy_str

def plot_value_function(value_function, ax):
    ax.imshow(value_function)
    rows, cols = value_function.shape
    for row in range(rows):
        for col in range(cols):
            ax.text(row, col, f"{value_function[col, row]:.1f}", ha="center", va="center", color="w", fontsize=24)

def plot_policy(policy, ax):
    rows, cols = policy.shape
    ax.imshow(np.zeros((rows, cols)))
    for row in range(environment.height):
        for col in range(environment.width):
            ax.text(col, row, policy2str(policy[row, col]), ha="center", va="center", color="r", fontsize=24)


def plot_value_and_policy(value_function, policy):
    fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(20, 8))

    plot_value_function(value_function, axes[0])
    plot_policy(policy, axes[1])

# Tabular Q Learning 

In [None]:
def q_learning(gamma=0.9, alpha=0.5, eps=0., optimistic_init=False):
    global state 
    q_function = TabularQFunction(
        num_states=environment.num_states, num_actions=environment.num_actions)
    nn.init.ones_(q_function.nn.head.weight)

    if optimistic_init:
        q_function.nn.head.weight.data = 10 / \
            (1 - gamma) * q_function.nn.head.weight.data  # Initialization

    state = environment.reset()

    def step(num_iter):
        global state

        for i in range(num_iter):
            if np.random.rand() < eps:
                action = np.random.choice(environment.num_actions)
            else:
                action = torch.argmax(q_function(
                    torch.tensor(state).long())).item()

            q_val = q_function(torch.tensor(state).long(),
                               torch.tensor(action).long())

            next_state, reward, done, info = environment.step(action)

            next_q = torch.max(q_function(torch.tensor(next_state).long()))
            reward = torch.tensor(reward).double()
            td = reward + gamma * next_q - q_val

            q_function.set_value(state, action, q_val + alpha * td)

            state = next_state

        plot()

    def plot():
        policy = extract_policy(q_function)
        value_function = integrate_q(q_function, policy)

        plot_value_and_policy(value_function.table.reshape(5, 5).detach().numpy(),
                              policy.table.argmax(0).reshape(5, 5).detach().numpy())
        
        IPython.display.clear_output(wait=True)
        IPython.display.display(plt.gcf())
        plt.close()

        button = ipywidgets.Button(description="Step 100")
        button.on_click(lambda b: step(num_iter=100))
        display(button)

        button2 = ipywidgets.Button(description="Step 1000")
        button2.on_click(lambda b: step(num_iter=1000))
        display(button2)
    plot()


interact(
    q_learning,
    gamma=ipywidgets.FloatSlider(
        value=0.9, min=0., max=0.99, step=1e-2, continuous_update=False),
    alpha=ipywidgets.FloatSlider(
        value=0.5, min=0., max=2.0, step=1e-2, continuous_update=False),
    eps=ipywidgets.FloatSlider(
        value=0., min=0., max=1.0, step=1e-2, continuous_update=False),
    optimistic_init=ipywidgets.Checkbox(value=False)
);

# Q Learning with function approximation

- Q Learning: approximate Q with a parametric function. 
- DQN: Approximate Q with a parametric function and use a target network to compute the delays. 
https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
- DDQN: Approximate Q with a parametric function and use the target network to compute the maximum. See https://arxiv.org/pdf/1509.06461.pdf 

In [None]:
from rllib.policy import EpsGreedy, SoftMax
from rllib.util.parameter_decay import ExponentialDecay

def run(env_name, agent_name, exploration):
    if "CartPole" in env_name:
        max_steps = 200
    else:
        max_steps = 1000 
        
    environment = GymEnvironment(env_name)
    agent = getattr(
        importlib.import_module("rllib.agent"), 
        f"{agent_name}Agent"
    ).default(environment)
    
    if exploration == "eps-greedy":
        policy = EpsGreedy(agent.algorithm.critic, ExponentialDecay(start=1.0, end=0.01, decay=500))
    elif exploration == "softmax":
        policy = SoftMax(agent.algorithm.critic, ExponentialDecay(start=1.0, end=0.01, decay=500))
    agent.set_policy(policy)
    try:
        train_agent(environment=environment, agent=agent, num_episodes=50, max_steps=max_steps, render=True, plot_flag=False)
    except KeyboardInterrupt:
        pass 
    environment.close()
    
    IPython.display.clear_output()
    
    plt.plot(agent.logger.get("train_return-0"), linewidth=16)
    plt.xlabel("Episode")
    plt.ylabel("Return");
    
interact(
    run,
    env_name = ["CartPole-v0", "Acrobot-v1", "MountainCar-v0"],
    agent_name = ["QLearning", "DQN", "DDQN"],
    exploration = ["softmax", "eps-greedy"]
);

# Tabular SARSA 

In [None]:
environment = EasyGridWorld()
def sarsa(gamma=0.9, alpha=0.5, eps=0., optimistic_init=False):
    global state, action 
    print(environment.num_states, environment.num_actions)
    q_function = TabularQFunction(num_states=environment.num_states, num_actions=environment.num_actions)
    nn.init.ones_(q_function.nn.head.weight)
    if optimistic_init:
        q_function.nn.head.weight.data = 10 / (1 - gamma) * q_function.nn.head.weight.data
        
    state = environment.reset()
    if np.random.rand() < eps:
        action = np.random.choice(environment.num_actions)
    else:
        action = torch.argmax(q_function(torch.tensor(state).long())).item()

    def step(num_iter):
        global state, action
        for i in range(num_iter):
            q_val = q_function(torch.tensor(state).long(), torch.tensor(action).long())

            next_state, reward, done, info = environment.step(action)
    
            if np.random.rand() < eps:
                next_action = np.random.choice(environment.num_actions)
            else:
                next_action = torch.argmax(q_function(torch.tensor(next_state).long())).item()


            next_q = q_function(torch.tensor(next_state).long(), torch.tensor(next_action).long())
            reward = torch.tensor(reward).double()
            td = reward + gamma * next_q - q_val 

            q_function.set_value(state, action, q_val + alpha * td)
            state, action = next_state, next_action

        plot()
              
    
    def plot():
        IPython.display.clear_output(wait=True)
        plt.close()
        
        policy = extract_policy(q_function)
        value_function = integrate_q(q_function, policy)


        plot_value_and_policy(value_function.table.reshape(5, 5).detach().numpy(), 
                              policy.table.argmax(0).reshape(5, 5).detach().numpy())
        
        
        button = ipywidgets.Button(description="Step 100")
        button.on_click(lambda b: step(num_iter=100))
        display(button)
        
        button2 = ipywidgets.Button(description="Step 1000")
        button2.on_click(lambda b: step(num_iter=1000))
        display(button2)
    plot()

interact(
    sarsa, 
    gamma=ipywidgets.FloatSlider(value=0.9, min=0., max=0.99, step=1e-2, continuous_update=False),
    alpha=ipywidgets.FloatSlider(value=0.5, min=0., max=2.0, step=1e-2, continuous_update=False),
    eps=ipywidgets.FloatSlider(value=0., min=0., max=1.0, step=1e-2, continuous_update=False),
    optimistic_init=ipywidgets.Checkbox(value=False)
);