In [1]:
import argparse
import gym
import numpy as np
from itertools import count
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
# parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
# parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
#                     help='discount factor (default: 0.99)')
# parser.add_argument('--seed', type=int, default=543, metavar='N',
#                     help='random seed (default: 543)')
# parser.add_argument('--render', action='store_true',
#                     help='render the environment')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='interval between training status logs (default: 10)')
# args = parser.parse_args()


env = gym.make('CartPole-v1')
# env = gym.make('GridWorld-v0')
env.reset(seed=0)
torch.manual_seed(0)


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

In [3]:
from refl.utils import PolicyNet

In [4]:
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()


def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()

In [5]:
def finish_episode():
    R = 0
    policy_loss = []
    returns = deque()
    for r in policy.rewards[::-1]:
        R = r + 1.0 * R
        returns.appendleft(R)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]

In [6]:
def main():
    running_reward = 10
    exp_avgs = []
    for i_episode in range(200):
        state, _ = env.reset()
        ep_reward = 0
        done = False
        # t = 0
        while not done:  # Don't infinite loop while learning
            action = select_action(state)
            state, reward, done, _, _ = env.step(action)
            policy.rewards.append(reward)
            ep_reward += reward
        exp_avgs.append({'Episode':i_episode, 'AvgReturn':running_reward})
        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        finish_episode()
        if i_episode % 5 == 0:
            print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
                  i_episode, ep_reward, running_reward))
        # if running_reward > env.spec.reward_threshold:
        #     print("Solved! Running reward is now {} and "
        #           "the last episodegot {}!".format(running_reward, ep_reward))
        #     break
    return exp_avgs

In [7]:
exp_avgs = main()

  if not isinstance(terminated, (bool, np.bool8)):


Episode 0	Last reward: 16.00	Average reward: 10.30
Episode 5	Last reward: 15.00	Average reward: 12.33
Episode 10	Last reward: 11.00	Average reward: 14.93
Episode 15	Last reward: 10.00	Average reward: 15.25
Episode 20	Last reward: 26.00	Average reward: 15.48
Episode 25	Last reward: 9.00	Average reward: 14.02
Episode 30	Last reward: 27.00	Average reward: 14.10
Episode 35	Last reward: 27.00	Average reward: 14.83
Episode 40	Last reward: 18.00	Average reward: 15.21
Episode 45	Last reward: 25.00	Average reward: 15.67
Episode 50	Last reward: 17.00	Average reward: 19.90
Episode 55	Last reward: 23.00	Average reward: 24.68
Episode 60	Last reward: 26.00	Average reward: 30.18
Episode 65	Last reward: 89.00	Average reward: 37.07
Episode 70	Last reward: 157.00	Average reward: 46.21
Episode 75	Last reward: 106.00	Average reward: 64.37
Episode 80	Last reward: 135.00	Average reward: 76.20
Episode 85	Last reward: 184.00	Average reward: 89.29
Episode 90	Last reward: 76.00	Average reward: 101.76
Episode 95

In [8]:
import pandas as pd
import plotly.express as px

In [9]:
df = pd.DataFrame.from_records([r for r in exp_avgs])

In [10]:
fig = px.line(df, x="Episode", y="AvgReturn")
fig.show()