In [None]:
import sys, os
sys.path.insert(0, os.path.abspath(".."))
os.environ["SDL_VIDEODRIVER"] = "dummy" # for pygame rendering
import torch
import gym
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
from pathlib import Path
from common import helper as h

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
# Load policy from default path to plot
policy_dir = Path().cwd() / "results" / "CartPole-v0" / "model"
policy_dir

In [None]:
from rbf_agent import RBFAgent

rbf = RBFAgent(2)
rbf.load(policy_dir)

In [None]:
# Create a grid and initialize arrays to store rewards and actions
resolution = 500
npoints = resolution
x_limit = 4.8
theta_limit = 0.5
x_range = np.linspace(-x_limit, x_limit, npoints)
theta_range = np.linspace(-theta_limit, theta_limit, npoints)
# rewards = np.zeros((npoints, npoints))
actions = np.zeros((npoints, npoints), dtype=np.int32)

for i, th1 in enumerate(x_range):
    for j, th2 in enumerate(theta_range):
        # Create the state vector from th1, th2
        state = np.array([th1, 0, th2, 0])
        state = rbf.featurize(state)

        action_probs = np.zeros(len(rbf.q_functions)) - np.Inf
        for idx, regressor in enumerate(rbf.q_functions):
            action_probs[idx] = regressor.predict(state)

        actions[i, j] = action_probs.argmax()


In [None]:
print(x_range.shape, theta_range.shape)
actions.shape

In [None]:
# Create the reward plot
num_ticks = 10
tick_skip = max(1, npoints // num_ticks)
# x
x_tick_shift = 2 * x_limit / npoints / 2
x_tick_points = np.arange(npoints)[::tick_skip] + x_tick_shift
x_tick_labels = x_range.round(2)[::tick_skip]
# theta
theta_tick_shift = 2 * theta_limit / npoints / 2
theta_tick_points = np.arange(npoints)[::tick_skip] + theta_tick_shift
theta_tick_labels = theta_range.round(2)[::tick_skip]


sns.heatmap(actions.T)
plt.xticks(theta_tick_points, theta_tick_labels, rotation=45)
plt.yticks(x_tick_points, x_tick_labels, rotation=45)
plt.xlabel(r"$\theta$")
plt.ylabel("x")
plt.title("Plot for policy - best action in terms of state")
# plt.suptitle("Rewards in %s" % env_name)
plt.show()
