In [None]:
%matplotlib notebook

In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

In [2]:
%matplotlib qt5

In [3]:
class Actor(nn.Module):
    def __init__(self, n_state, n_action, hidden_size = 64):
        super(Actor, self).__init__()
        
        self.fc1 = torch.nn.Linear(n_state, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, n_action)

    def forward(self, state):
        x = self.fc1(state)
        x = self.fc2(F.relu(x))
        x = self.fc3(F.relu(x))
        return F.softmax(x, dim=1)

        
class Critic(nn.Module):
    def __init__(self, n_state, hidden_size=64):
        super(Critic, self).__init__()

        self.fc1 = torch.nn.Linear(n_state, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, 1)

    def forward(self, state):
        x = self.fc1(state)
        x = self.fc2(F.relu(x))
        x = self.fc3(F.relu(x))
        return x


def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

class PPO(nn.Module):

    def __init__(self, n_state, n_action, n_hidden = 64, actor_lr=1e-4, critic_lr=1e-4, lmbda=0.1, epochs=10, eps=0.01, gamma=0.99, device="cpu"):
        super(PPO, self).__init__()

        self.actor = Actor(n_state, n_action, hidden_size=n_hidden)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr = actor_lr)

        self.critic = Critic(n_state, hidden_size=n_hidden)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr = critic_lr)

        self.lmbda = lmbda
        self.gamma = gamma
        self.eps = eps
        self.epochs = epochs
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device) 

        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta)

        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)

            l1 = ratio * advantage
            l2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
            l3 = - torch.min(l1, l2)

            actor_loss = torch.mean(l3)
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

            self.actor_opt.zero_grad()
            self.critic_opt.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_opt.step()
            self.critic_opt.step()
        

In [4]:
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 500
hidden = 128

gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2

device = "cpu"
env_name = "CartPole-v1"

env = gym.make(env_name)
# env.seed(0)
torch.manual_seed(0)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

agent = PPO(n_state=state_dim, n_action=action_dim, n_hidden=hidden,actor_lr=actor_lr, critic_lr=critic_lr)

In [5]:
def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state, _ = env.reset()
                done , truncated = False, False
                while not done and not truncated:
                    action = agent.take_action(state)
                    next_state, reward, done, truncated, info = env.step(action)

                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

In [6]:
train_on_policy_agent(env, agent, num_episodes)

Iteration 0:   0%|          | 0/50 [00:00<?, ?it/s]

  state = torch.tensor([state], dtype=torch.float).to(self.device)
Iteration 0: 100%|██████████| 50/50 [00:00<00:00, 90.59it/s, episode=50, return=148.800]
Iteration 1: 100%|██████████| 50/50 [00:01<00:00, 40.76it/s, episode=100, return=292.100]
Iteration 2: 100%|██████████| 50/50 [00:01<00:00, 28.16it/s, episode=150, return=341.700]
Iteration 3: 100%|██████████| 50/50 [00:02<00:00, 23.69it/s, episode=200, return=490.100]
Iteration 4: 100%|██████████| 50/50 [00:02<00:00, 23.07it/s, episode=250, return=488.100]
Iteration 5: 100%|██████████| 50/50 [00:02<00:00, 22.29it/s, episode=300, return=500.000]
Iteration 6: 100%|██████████| 50/50 [00:02<00:00, 22.55it/s, episode=350, return=500.000]
Iteration 7: 100%|██████████| 50/50 [00:02<00:00, 22.54it/s, episode=400, return=500.000]
Iteration 8: 100%|██████████| 50/50 [00:02<00:00, 22.42it/s, episode=450, return=500.000]
Iteration 9: 100%|██████████| 50/50 [00:02<00:00, 22.44it/s, episode=500, return=500.000]


[17.0,
 9.0,
 20.0,
 35.0,
 26.0,
 25.0,
 55.0,
 17.0,
 20.0,
 23.0,
 37.0,
 24.0,
 33.0,
 10.0,
 28.0,
 20.0,
 24.0,
 75.0,
 44.0,
 17.0,
 44.0,
 33.0,
 68.0,
 64.0,
 29.0,
 106.0,
 45.0,
 26.0,
 44.0,
 31.0,
 35.0,
 157.0,
 56.0,
 98.0,
 89.0,
 70.0,
 137.0,
 160.0,
 118.0,
 17.0,
 133.0,
 49.0,
 100.0,
 122.0,
 136.0,
 178.0,
 264.0,
 149.0,
 143.0,
 214.0,
 129.0,
 202.0,
 143.0,
 183.0,
 219.0,
 208.0,
 207.0,
 188.0,
 178.0,
 115.0,
 182.0,
 170.0,
 187.0,
 229.0,
 196.0,
 174.0,
 190.0,
 186.0,
 234.0,
 204.0,
 231.0,
 221.0,
 200.0,
 209.0,
 232.0,
 216.0,
 197.0,
 216.0,
 211.0,
 258.0,
 249.0,
 226.0,
 258.0,
 225.0,
 257.0,
 259.0,
 238.0,
 309.0,
 271.0,
 390.0,
 284.0,
 306.0,
 268.0,
 278.0,
 275.0,
 291.0,
 270.0,
 288.0,
 363.0,
 298.0,
 307.0,
 269.0,
 387.0,
 406.0,
 398.0,
 293.0,
 420.0,
 475.0,
 306.0,
 376.0,
 307.0,
 275.0,
 281.0,
 484.0,
 279.0,
 382.0,
 398.0,
 406.0,
 425.0,
 451.0,
 312.0,
 450.0,
 500.0,
 413.0,
 468.0,
 361.0,
 452.0,
 243.0,
 322.0,
 388.

In [7]:
def test_agent(agent):
    env = gym.make("CartPole-v1", render_mode="human")

    state, info = env.reset()

    print(f"Starting observation: {state}")

    episode_over = False
    total_reward = 0

    while not episode_over:
        action = agent.take_action(state)
        state, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        episode_over = terminated or truncated

    print(f"Episode finished! Total reward: {total_reward}")
    env.close()

test_agent(agent)

Starting observation: [ 0.00934569  0.04162183 -0.00678425 -0.04791244]
Episode finished! Total reward: 500.0
