In [1]:
import gymnasium as gym
import numpy as np
from itertools import count
from collections import deque
import torch
from torch.distributions import Categorical

In [2]:
env = gym.make('CartPole-v1')

In [6]:
state, info = env.reset()

In [16]:
env.observation_space.sample()

array([ 1.4768261e+00, -1.5968510e+37,  8.8566922e-02,  7.2072305e+37],
      dtype=float32)

In [14]:
env.action_space.n

2

In [21]:
class Policy(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Policy, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_dim, 128),
            torch.nn.Dropout(p=0.2),
            torch.nn.ReLU(),
            torch.nn.Linear(128, out_dim),
            torch.nn.Softmax(dim=1)
        )

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        return self.net(x)

In [22]:
policy = Policy(4, env.action_space.n)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()

In [25]:
def select_action(state, policy):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()

In [141]:
def finish_episode():
    R = 0
    policy_loss = []
    returns = deque()
    for r in policy.rewards[::-1]:
        R = r + gamma * R
        returns.appendleft(R)

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob*R)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]
        