In [1]:
from gym_runner import GymRunner
from q_func_approx import QualityFuncApprox
from agents.sarsa_agent import SarsaAgent
import pandas as pd
import altair as alt 

In [2]:
runner = GymRunner('CartPole-v1', display_metrics = True)

In [3]:
num_actions = runner.env.action_space.n


In [4]:
num_states = runner.env.observation_space.shape[0]
num_states

4

In [5]:
q_func = QualityFuncApprox(
    num_states=num_states,
    num_actions=num_actions,
    optimizer = 'sgd', 
    loss_func='l1', 
    alpha = 0.01, 
)

In [6]:
agent = SarsaAgent(
    q_func_approx=q_func,
    num_states=num_states,
    num_actions=num_actions,
    gamma = .9225,
    epsilon=1.0,
    epsilon_decay=.9975,
)

In [7]:
rewards = runner.train(agent = agent, num_episodes=1000, plot_state=False)

Epsilon:  0.1
Current Reward:  500.0
Episode:  990


In [8]:
rewards

[18.0,
 22.0,
 11.0,
 13.0,
 17.0,
 33.0,
 15.0,
 17.0,
 13.0,
 14.0,
 20.0,
 31.0,
 23.0,
 26.0,
 21.0,
 23.0,
 12.0,
 17.0,
 20.0,
 21.0,
 18.0,
 32.0,
 22.0,
 42.0,
 35.0,
 31.0,
 11.0,
 15.0,
 23.0,
 36.0,
 13.0,
 56.0,
 14.0,
 25.0,
 36.0,
 25.0,
 12.0,
 22.0,
 73.0,
 10.0,
 13.0,
 37.0,
 14.0,
 12.0,
 46.0,
 17.0,
 22.0,
 10.0,
 12.0,
 15.0,
 12.0,
 15.0,
 58.0,
 34.0,
 41.0,
 19.0,
 18.0,
 39.0,
 9.0,
 36.0,
 51.0,
 18.0,
 36.0,
 14.0,
 26.0,
 27.0,
 17.0,
 16.0,
 31.0,
 26.0,
 21.0,
 55.0,
 13.0,
 18.0,
 29.0,
 20.0,
 24.0,
 36.0,
 14.0,
 33.0,
 27.0,
 17.0,
 13.0,
 67.0,
 16.0,
 91.0,
 44.0,
 90.0,
 66.0,
 44.0,
 56.0,
 27.0,
 12.0,
 36.0,
 14.0,
 26.0,
 33.0,
 41.0,
 47.0,
 19.0,
 15.0,
 28.0,
 38.0,
 41.0,
 16.0,
 40.0,
 86.0,
 36.0,
 21.0,
 12.0,
 32.0,
 66.0,
 45.0,
 110.0,
 22.0,
 21.0,
 48.0,
 52.0,
 12.0,
 26.0,
 84.0,
 31.0,
 20.0,
 28.0,
 35.0,
 16.0,
 47.0,
 23.0,
 25.0,
 30.0,
 16.0,
 19.0,
 32.0,
 26.0,
 28.0,
 31.0,
 18.0,
 18.0,
 35.0,
 23.0,
 17.0,
 44.0,
 29.0,

In [9]:
test_rewards = runner.attempt(agent, num_episodes=100)

In [10]:
test_rewards

array([500., 500., 500., 500., 500., 500., 500., 500., 402., 391., 500.,
       500., 500., 500., 500., 500., 500., 500., 482., 500., 500., 384.,
       500., 500., 382., 500., 294., 500., 500., 500., 500., 500., 332.,
       500., 379., 500., 500., 500., 500., 500., 404., 500., 500., 500.,
       500., 452., 500., 500., 500., 500., 500., 500., 383., 342., 500.,
       363., 296., 500., 443., 500., 430., 339., 500., 376., 352., 500.,
       388., 500., 500., 500., 500., 324., 500., 500., 244., 500., 313.,
       500., 380., 500., 500., 500., 500., 256., 500., 500., 500., 400.,
       500., 500., 400., 500., 440., 384., 410., 500., 500., 500., 500.,
       500.])

In [11]:
rewards = pd.DataFrame(rewards).reset_index()
rewards.columns = ["episode", "reward"]

test_rewards = pd.DataFrame(test_rewards).reset_index()
test_rewards.columns = ["episode", "reward"]

In [12]:
alt.Chart(rewards).mark_point().encode(x="episode", y="reward") | alt.Chart(
    test_rewards
).mark_point().encode(x="episode", y="reward")

In [13]:
test_rewards

Unnamed: 0,episode,reward
0,0,500.0
1,1,500.0
2,2,500.0
3,3,500.0
4,4,500.0
...,...,...
95,95,500.0
96,96,500.0
97,97,500.0
98,98,500.0
