In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import gymnasium as gym

from typing import List, Tuple
from torch import nn
from torch.optim import Adam
from torch.distributions.categorical import Categorical

In [213]:
class PolicyAndValueNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, policy_output_dim: int):
        super().__init__()

        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        self.policy_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, policy_output_dim)
        )

        self.value_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def value(self, state: torch.Tensor) -> torch.Tensor:
        """Takes in observation/state and returns value (based on past rewards) of being in the given observation/state"""
        z = self.shared_layers(state)
        value = self.value_layers(z)
        return value # single number

    def policy(self, state: torch.Tensor) -> torch.Tensor:
        """Takes in observation/state and returns logits for different actions

        Note: Take `e^logits`, sum for all actions, P(obs, action) = e^logits / sum
        """
        z = self.shared_layers(state)
        policy_logits = self.policy_layers(z)
        return policy_logits # (1, action_space_size)

    def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z = self.shared_layers(state)
        policy_logits = self.policy_layers(z)
        value = self.value_layers(z)
        return policy_logits, value

    def get_action_and_value(self, state: List[float]) -> Tuple[Tuple[int, float], float]:
        """Return ((action, action_log_prob, action_prob), value)"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            logits, value = self.forward(state_tensor)
            probs = F.softmax(logits, dim=1)

        # Sample action
        dist = Categorical(logits=logits)
        action = dist.sample()
        action_prob = probs[0, action].item()

        return (action.item(), action_prob), value.item()


In [214]:
class Trajectory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = [] # Extra in PPO
        self.probs = []
        self.dones = []

    def add(self, state: np.array, action: int, reward: float, value: float, prob: float, done: bool):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.probs.append(prob)
        self.dones.append(done)

    def _compute_returns(self, gamma: float = 0.99) -> List[float]:
        returns = []
        G = 0

        for r, d in zip(reversed(self.rewards), reversed(self.dones)):
            G = r + gamma * G * (1 - int(d))
            returns.append(G)

        return list(reversed(returns))

    def _compute_advantages(self, gamma: float = 0.99, decay: float = 0.97):
        """Computes GAE for the trajectory"""
        next_values = np.concatenate([self.values[1:], [0]])
        deltas = [reward + gamma * next_value - value for (reward, value, next_value) in zip(self.rewards, self.values, next_values)]

        gaes = [deltas[-1]]
        for i in reversed(range(len(deltas) - 1)):
            gaes.append(deltas[i] + decay * gamma * gaes[-1])

        return np.array(gaes[::-1])

    def get_return(self, gamma: float) -> float:
        """Returns return (not reward) for the entire trajectory"""
        returns = self._compute_returns(gamma)
        # If you do sum, numbers would go wild based on length of the game when _compute_returns is applied
        # Mean is larger if game was larger, but the differences aren't wild like when summing
        return np.mean(returns)

    def to_tensor(self, gamma: float = 0.99):
        """Returns states, actions, values, log_probs, probs, returns"""
        states = torch.FloatTensor(np.array(self.states))
        actions = torch.LongTensor(np.array(self.actions))
        probs = torch.stack([torch.tensor(p) for p in self.probs])
        returns = torch.FloatTensor(self._compute_returns(gamma)) # instead of reward

        return states, actions, probs, returns


In [215]:
t = Trajectory()
t.add(state=np.array([0, 1]), action=0, reward=1, value=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=1, reward=1, value=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=0, reward=1, value=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=0, reward=1, value=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=1, reward=1, value=1, prob=0.2, done=True)

gamma = 0.99
print(t._compute_returns(gamma), t._compute_advantages(), t.get_return(gamma))

states, actions, probs, returns = t.to_tensor()
print(probs.shape)

[4.90099501, 3.9403989999999998, 2.9701, 1.99, 1.0] [3.73036137 2.85365133 1.940697   0.99       0.        ] 2.960298802
torch.Size([5])


In [None]:
class PPO:
    def __init__(self, env: gym.Env,
                 hidden_dim: int = 64, gamma: float = 0.99,
                 clip_param: float = 0.2, target_kl_div: float = 0.01,
                 max_policy_train_iters: int = 80, max_value_train_iters: int = 80,
                 policy_lr=3e-4, value_lr=1e-2):
        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.gamma = gamma
        self.clip_param = clip_param
        self.target_kl_div = target_kl_div

        self.max_policy_train_iters = max_policy_train_iters
        self.max_value_train_iters = max_value_train_iters

        self.model = PolicyAndValueNetwork(input_dim=self.state_dim, hidden_dim=hidden_dim, policy_output_dim=self.action_dim)

        policy_params = list(self.model.shared_layers.parameters()) + list(self.model.policy_layers.parameters())
        self.policy_optim = Adam(policy_params, lr=policy_lr)

        value_params = list(self.model.shared_layers.parameters()) + list(self.model.value_layers.parameters())
        self.value_optim = Adam(value_params, lr=value_lr)

    def collect_trajectories(self, n_trajectories: int) -> List[Trajectory]:
        trajectories = []

        for _ in range(n_trajectories):
            traj = Trajectory()
            state, _ = self.env.reset()
            done = False

            while not done:
                (action, action_prob), value = self.model.get_action_and_value(state)
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated

                traj.add(state=state, action=action, reward=reward, value=value, prob=action_prob, done=done)
                state = next_state

            trajectories.append(traj)

        return trajectories

    def update_policy(self, states: torch.Tensor, actions: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor):
        for _ in range(self.max_policy_train_iters):
            self.policy_optim.zero_grad()

            new_logits = self.model.policy(states)
            new_dist = Categorical(logits=new_logits)
            new_log_probs = new_dist.log_prob(actions) # log probs for the same action

            # Derivation: e^(new_log_prob-old_log_prob) (efficient to calculate) => e^new_log_prob / e^old_log_prob => new_prob / old_prob (original formula)
            policy_ratio = torch.exp(new_log_probs - old_log_probs)
            clipped_ratio = policy_ratio.clamp(1 - self.clip_param, 1 + self.clip_param)

            full_loss = policy_ratio * advantages # without clipping
            clipped_loss = clipped_ratio * advantages # with clipping

            # mean flattens out the matrix and - sign is required because pytorch is supposed to minimize, not maximize
            policy_loss = -torch.min(full_loss, clipped_loss).mean()

            policy_loss.backward()
            self.policy_optim.step()

            kl_div = (old_log_probs - new_log_probs).mean().abs() # FIXME: Shouldn't this have .abs() in the end?
            if kl_div >= self.target_kl_div:
                # print("large KL div", kl_div)
                # Heavy deviations from the original thing. Early stop here
                break

    def update_value(self, states: torch.Tensor, returns: torch.Tensor):
        # TODO: Allow training with multiple trajectories at once?
        for _ in range(self.max_value_train_iters):
            self.value_optim.zero_grad()

            values = self.model.value(states).reshape(-1, )
            value_loss = ((returns - values) ** 2) # Simple L2 loss
            value_loss = value_loss.mean()

            value_loss.backward()
            self.value_optim.step()

    def train(self, n_episodes: int):
        rewards_history = []

        for episode in range(n_episodes):
            traj = self.collect_trajectories(n_trajectories=1)[0]

            # Just for logging:
            episode_reward = sum(traj.rewards)
            rewards_history.append(episode_reward)

            # TODO: Allow training with multiple trajectories at once?
            states, actions, probs, returns = traj.to_tensor()

            advantages = torch.FloatTensor(traj._compute_advantages())
            calculated_log_probs = -torch.log(1 / probs)

            self.update_policy(states=states, actions=actions, old_log_probs=calculated_log_probs, advantages=advantages)
            self.update_value(states=states, returns=returns)

            if (episode + 1) % 10 == 0:
                print("Episode {}, Episode reward: {:.2f}".format(episode+1, episode_reward))

        return rewards_history

    def evaluate(self, env: gym.Env, n_episodes: int = 10, render: bool = False):
        rewards = []

        for _ in range(n_episodes):
            state, _ = env.reset()
            done = False
            total_reward = 0

            while not done:
                if render:
                    env.render()

                (action, _), _ = self.model.get_action_and_value(state) # TODO: Add .get_action for perf
                state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                total_reward += reward

            rewards.append(total_reward)

        avg_reward = np.mean(rewards)
        print("Evaluation: Average reward over {} episodes: {:.2f}".format(n_episodes, avg_reward))

        return avg_reward


In [217]:
# Init ppo algorithm and env

ENV_NAME = 'CartPole-v1' # Possible values: CartPole-v1, Acrobot-v1

env = gym.make(ENV_NAME, max_episode_steps=500)
ppo = PPO(env=env, clip_param=0.2, target_kl_div=0.02, max_policy_train_iters=40, max_value_train_iters=40, policy_lr=3e-4, value_lr=1e-3)

In [218]:
t = ppo.collect_trajectories(1)[0]

states, actions, probs, returns = t.to_tensor()

print(states.shape, actions.shape, probs.shape, returns.shape)

print(probs)
print(returns)
print(t.values)

torch.Size([37, 4]) torch.Size([37]) torch.Size([37]) torch.Size([37])
tensor([0.5348, 0.4688, 0.5347, 0.5311, 0.4718, 0.4691, 0.4659, 0.5347, 0.4658,
        0.4652, 0.5313, 0.5353, 0.4653, 0.5354, 0.5349, 0.4685, 0.5348, 0.5315,
        0.4716, 0.5312, 0.4718, 0.5308, 0.4721, 0.5303, 0.4723, 0.5298, 0.4726,
        0.5292, 0.5272, 0.4713, 0.5270, 0.4713, 0.5270, 0.4710, 0.5275, 0.5295,
        0.4677])
tensor([31.0551, 30.3587, 29.6552, 28.9447, 28.2269, 27.5020, 26.7697, 26.0300,
        25.2828, 24.5281, 23.7657, 22.9957, 22.2179, 21.4322, 20.6386, 19.8369,
        19.0272, 18.2093, 17.3831, 16.5486, 15.7057, 14.8542, 13.9942, 13.1254,
        12.2479, 11.3615, 10.4662,  9.5618,  8.6483,  7.7255,  6.7935,  5.8520,
         4.9010,  3.9404,  2.9701,  1.9900,  1.0000])
[-0.12627637386322021, -0.13448408246040344, -0.12569132447242737, -0.1344195306301117, -0.14513495564460754, -0.13459813594818115, -0.12454517185688019, -0.12022320926189423, -0.12495197355747223, -0.12017303705215454

In [219]:
rewards = ppo.train(n_episodes=300)
avg_reward = ppo.evaluate(env, n_episodes=10)

Episode 10, Avg reward: 87.00
Episode 20, Avg reward: 18.00
Episode 30, Avg reward: 132.00
Episode 40, Avg reward: 144.00
Episode 50, Avg reward: 258.00
Episode 60, Avg reward: 292.00
Episode 70, Avg reward: 173.00
Episode 80, Avg reward: 253.00
Episode 90, Avg reward: 315.00
Episode 100, Avg reward: 293.00
Episode 110, Avg reward: 500.00
Episode 120, Avg reward: 229.00
Episode 130, Avg reward: 378.00
Episode 140, Avg reward: 500.00
Episode 150, Avg reward: 500.00
Episode 160, Avg reward: 304.00
Episode 170, Avg reward: 500.00
Episode 180, Avg reward: 335.00
Episode 190, Avg reward: 500.00
Episode 200, Avg reward: 500.00
Episode 210, Avg reward: 500.00
Episode 220, Avg reward: 374.00
Episode 230, Avg reward: 231.00
Episode 240, Avg reward: 500.00
Episode 250, Avg reward: 332.00
Episode 260, Avg reward: 388.00
Episode 270, Avg reward: 275.00
Episode 280, Avg reward: 267.00
Episode 290, Avg reward: 340.00
Episode 300, Avg reward: 486.00
Evaluation: Average reward over 10 episodes: 488.80

In [220]:
## Evaluate:
env = gym.make(ENV_NAME, max_episode_steps=500, render_mode='human')
_ = ppo.evaluate(env, 1, True)
env.close()

Evaluation: Average reward over 1 episodes: 380.00
