<a href="https://colab.research.google.com/github/Calcifer777/learn-rl/blob/main/actor_critic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install gymnasium

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gymnasium
  Downloading gymnasium-0.28.1-py3-none-any.whl (925 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m925.5/925.5 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
Collecting jax-jumpy>=1.0.0 (from gymnasium)
  Downloading jax_jumpy-1.0.0-py3-none-any.whl (20 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, jax-jumpy, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-0.28.1 jax-jumpy-1.0.0


In [23]:
from collections import namedtuple
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym

In [58]:
DISCOUNT = 0.99

In [28]:
SEED = 42
torch.random.manual_seed(SEED)

<torch._C.Generator at 0x7fae2417a570>

In [24]:
SavedAction = namedtuple("SavedAction", ["log_prob", "value"])

In [25]:
class Policy(nn.Module):
  def __init__(self, inputs_dim, hidden_dim, outputs_dim):
    super(Policy, self).__init__()
    self.l1 = nn.Linear(inputs_dim, out_features=hidden_dim)
    self.l2 = nn.Linear(hidden_dim, out_features=hidden_dim)
    
    self.action_head = nn.Linear(hidden_dim, outputs_dim)
    self.value_head = nn.Linear(hidden_dim, 1)
    
    self.saved_actions: List[SavedAction] = []
    self.rewards = []

  def forward(self, x):
    h = self.l1(x)
    h = F.relu(h)
    h = self.l2(h)
    action_probs = F.softmax(self.action_head(h), dim=-1)
    values = self.value_head(h)
    return action_probs, values

In [18]:
policy = Policy(inputs_dim=4, hidden_dim=64, outputs_dim=2)

In [19]:
sample = torch.rand((1, 4))

In [20]:
policy(sample)

(tensor([[0.5603, 0.4397]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.0392]], grad_fn=<AddmmBackward0>))

In [26]:
def select_action(policy, state):
  state = torch.from_numpy(state).float()
  probs, value = policy(state)

  action_distr = Categorical(probs)
  action = action_distr.sample()

  policy.saved_actions.append(
      SavedAction(action_distr.log_prob(action), value)
  )

  return action.item()

In [29]:
env = gym.make("CartPole-v1")
env.reset(seed=SEED)

(array([ 0.0273956 , -0.00611216,  0.03585979,  0.0197368 ], dtype=float32),
 {})

In [32]:
state, _ = env.reset()

In [45]:
action = select_action(policy, state)

In [44]:
NUM_EPISODES = 1000
T = 475

In [46]:
env.step(action)

(array([-0.03963102,  0.24230015,  0.0266861 , -0.25572422], dtype=float32),
 1.0,
 False,
 False,
 {})

In [47]:
def run_episode(env, policy: Policy, max_time_steps: int=T):
  state, _ = env.reset()
  ep_reward = 0

  for t in range(max_time_steps):
    action = select_action(policy, state)
    state, reward, done, _, _ = env.step(action)
    policy.rewards.append(reward)

    if done:
      break

In [55]:
l = [4,2,3,1]
list(reversed(l))

[1, 3, 2, 4]

In [None]:
def backprop(policy: Policy):
  actions = policy.saved_actions
  rewards = policy.rewards

  losses_policy, losses_value = [], []

  # For each step, compute its value
  values = []
  accumulated_return = 0
  for reward in reversed(policy.rewards):
    accumulated_return = reward + DISCOUNT*accumulated_return
    values.append(accumulated_return)
  values = reversed(values)
  values = torch.tensor(values)
  values = (values - values.mean()) / values.std()

  # For each step, compute the loss
  for (logit, value_hat), value in zip(actions, values):
    advantage = value - value_hat.item()
    losses_policy.append(-logit*advantage)
    losses_value.append(F.smooth_l1_loss(value_hat, torch.tensor([value])))
    


In [50]:
for ep_idx in range(2):

  run_episode(env, policy)