In [None]:
import torch
torch.manual_seed(0)

In [None]:
from refl.agents import ReinforceAgent, BaselineReinforceAgent
from refl.envs import GridWorldEnv
import pandas as pd
import plotly.express as px

In [None]:
env = GridWorldEnv()

In [None]:
GAMMAS = [1.0, 0.99, 0.95]

## Reinforce

In [None]:
exp_avgs = []
final_returns = []
N_EPISODES = 500

In [None]:
for gamma in GAMMAS:
    agent = ReinforceAgent(n_state_dims=env.n_state_dims, n_latent_dims=128, n_actions=env.n_actions, gamma=gamma)
    avgs = agent.learn(env, N_EPISODES, 9.5)
    exp_avgs.extend(avgs)
    final_returns.extend([{'Gamma':gamma, 'Return':ret, 'Episode':ep} for ep, ret in enumerate(agent.evaluate(env, 10))])

In [None]:
df = pd.DataFrame.from_records([r for r in exp_avgs])

In [None]:
fig = px.line(df, x="Episode", y="AvgReturn", color="Gamma", title=f"N_EPISODES={N_EPISODES}")
fig.show()

In [None]:
df = pd.DataFrame.from_records([r for r in final_returns])
fig = px.line(df, x="Episode", y="Return", color="Gamma", title=f"Evaluation")
fig.show()

## Baseline Reinforce

In [None]:
exp_avgs = []
N_EPISODES = 800
final_returns = []

In [None]:
for gamma in GAMMAS:
    agent = BaselineReinforceAgent(n_state_dims=env.n_state_dims, n_latent_dims=128, n_actions=env.n_actions, gamma=gamma)
    avgs = agent.learn(env, N_EPISODES, 9.5)
    exp_avgs.extend(avgs)
    final_returns.extend([{'Gamma':gamma, 'Return':ret, 'Episode':ep} for ep, ret in enumerate(agent.evaluate(env, 10))])

In [None]:
df = pd.DataFrame.from_records([r for r in exp_avgs])

In [None]:
fig = px.line(df, x="Episode", y="AvgReturn", color="Gamma", title=f"N_EPISODES={N_EPISODES}")
fig.show()

In [None]:
df = pd.DataFrame.from_records([r for r in final_returns])
fig = px.line(df, x="Episode", y="Return", color="Gamma", title=f"Evaluation")
fig.show()