In [21]:
from gym_runner import GymRunner
from q_func_approx import QualityFuncApprox
from agents.sarsa_agent import SarsaAgent
import pandas as pd
import altair as alt 

In [10]:
runner = GymRunner('CartPole-v1', display = True)

In [11]:
num_actions = runner.env.action_space.n


In [12]:
num_states = runner.env.observation_space.shape[0]
num_states

4

In [23]:
q_func = QualityFuncApprox(
    state_space=num_states,
    num_actions=num_actions,
    optimizer = 'sgd', 
    loss_func='l1', 
    alpha = 0.01, 
)

In [24]:
agent = SarsaAgent(
    q_func_approx=q_func,
    state_space=num_states,
    num_actions=num_actions,
    gamma = .9225,
    epsilon=1.0,
    epsilon_decay=.9975,
)

In [25]:
rewards = runner.train(agent = agent, num_episodes=1000)

Epsilon:  0.1
Current Reward:  500.0
Episode:  990


In [26]:
rewards

[30.0,
 12.0,
 21.0,
 19.0,
 14.0,
 37.0,
 17.0,
 16.0,
 29.0,
 10.0,
 13.0,
 28.0,
 20.0,
 34.0,
 21.0,
 56.0,
 14.0,
 29.0,
 21.0,
 17.0,
 29.0,
 31.0,
 16.0,
 21.0,
 16.0,
 18.0,
 16.0,
 23.0,
 16.0,
 11.0,
 10.0,
 13.0,
 17.0,
 11.0,
 24.0,
 25.0,
 12.0,
 15.0,
 13.0,
 27.0,
 37.0,
 15.0,
 12.0,
 16.0,
 48.0,
 74.0,
 57.0,
 14.0,
 24.0,
 12.0,
 14.0,
 38.0,
 26.0,
 14.0,
 15.0,
 26.0,
 15.0,
 36.0,
 24.0,
 9.0,
 13.0,
 23.0,
 36.0,
 33.0,
 51.0,
 45.0,
 25.0,
 26.0,
 14.0,
 28.0,
 12.0,
 33.0,
 66.0,
 33.0,
 13.0,
 28.0,
 30.0,
 73.0,
 78.0,
 47.0,
 15.0,
 36.0,
 51.0,
 10.0,
 36.0,
 22.0,
 80.0,
 31.0,
 11.0,
 15.0,
 19.0,
 45.0,
 54.0,
 21.0,
 42.0,
 27.0,
 11.0,
 38.0,
 31.0,
 23.0,
 26.0,
 22.0,
 38.0,
 48.0,
 19.0,
 18.0,
 17.0,
 58.0,
 14.0,
 45.0,
 16.0,
 55.0,
 54.0,
 57.0,
 34.0,
 34.0,
 49.0,
 20.0,
 35.0,
 27.0,
 34.0,
 27.0,
 14.0,
 54.0,
 54.0,
 76.0,
 45.0,
 13.0,
 28.0,
 13.0,
 147.0,
 37.0,
 17.0,
 39.0,
 28.0,
 98.0,
 26.0,
 14.0,
 24.0,
 24.0,
 34.0,
 27.0,
 26.0,

In [27]:
test_rewards = runner.attempt(agent, num_episodes=100)

In [28]:
rewards = pd.DataFrame(rewards).reset_index()
rewards.columns = ["episode", "reward"]

test_rewards = pd.DataFrame(test_rewards).reset_index()
test_rewards.columns = ["episode", "reward"]

In [29]:
alt.Chart(rewards).mark_point().encode(x="episode", y="reward") | alt.Chart(
    test_rewards
).mark_point().encode(x="episode", y="reward")

In [30]:
test_rewards

Unnamed: 0,episode,reward
0,0,500.0
