In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

In [2]:
# HyperParameters

LEARNING_RATE = 0.0003
GAMMA = 0.99
EPSILON_CLIP = 0.2
ENTROPY_COEFF = 0.01
EPOCHS = 10
BATCH_SIZE = 64
TIMESTEPS = 2048

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Action Critic Network

class PPOActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPOActorCritic, self).__init__()
    
        # actor
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim = -1)
        )

        # critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
    
    def forward(self):
        raise NotImplementedError

    def get_action_and_value(self, state):

        action_probs = self.actor(state)  # pi(a|s) => left|right => ex : [0.7, 0.3]
        state_values = self.critic(state) # v(s), R + rV(s') - "V(s)"

        dist = Categorical(action_probs)
        action = dist.sample() # left|right = [0.75, 0.25] => 0번 선택
        action_logprobs = dist.log_prob(action) # log(0.75)
        entropy = dist.entropy() # entropy([0.75, 0.25])

        return action, action_logprobs, state_values, entropy

In [4]:
class RolloutBuffer():
    def __init__(self):
        self.actions = []
        self.states = []
        self.log_probs = []
        self.rewards = []
        self.state_values = []
        self.dones = []

    def clear(self):
        self.actions = []
        self.states = []
        self.log_probs = []
        self.rewards = []
        self.state_values = []
        self.dones = []

In [5]:
def train_ppo(buffer, old_model, new_model, optimizer):
    # calculate last state_value for : r + gamma * V(t)
    state = buffer.states[-1]
    done = buffer.dones[-1]
    with torch.no_grad():
        discounted_rewards = 0 if done else old_model.get_action_and_value(torch.FloatTensor(state))

    returns = []
    for reward in reversed(buffer.rewards):
        discounted_rewards = reward + GAMMA * discounted_rewards
        returns.insert(0, discounted_rewards)

    advantages = torch.FloatTensor(returns) - torch.FloatTensor(buffer.state_values)

    for _ in range(EPOCHS):
        for idx in range(0, len(buffer.states), BATCH_SIZE):
            batch_states = torch.FloatTensor(buffer.states[idx : idx + BATCH_SIZE])
            batch_actions = torch.LongTensor(buffer.actions[idx : idx + BATCH_SIZE])

            batch_returns = torch.FloatTensor(returns[idx : idx + BATCH_SIZE])
            batch_advantages = torch.FloatTensor(advantages[idx : idx + BATCH_SIZE])

            # new_model에서 새로운 정책 계산
            new_policy_logits = new_model.actor(batch_states)
            values = new_model.critic(batch_states)
            new_policy_dist = Categorical(logits = new_policy_logits)
            new_log_probs = new_policy_dist.log_prob(batch_actions)
            entropy = new_policy_dist.entropy()

            # old_model에서 이전 정책 계산
            with torch.no_grad():
                old_policy_logits = old_model.actor(batch_states)
                old_policy_dist = Categorical(logits = old_policy_logits)
                old_log_probs = old_policy_dist.log_prob(batch_actions)

            # Compute ratio
            ratios = torch.exp(new_log_probs - old_log_probs)

            # PPO Loss
            surrogate1 = ratios * batch_advantages
            surrogate2 = torch.clamp(ratios, 1 - EPSILON_CLIP, 1 + EPSILON_CLIP) * batch_advantages
            policy_loss = -torch.min(surrogate1, surrogate2).mean()

            value_loss = nn.MSELoss()(values.squeeze(), batch_returns)

            entropy_loss = - ENTROPY_COEFF * entropy.mean()

            loss = policy_loss + value_loss + entropy_loss

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [8]:
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

old_model = PPOActorCritic(state_dim, action_dim)
new_model = PPOActorCritic(state_dim, action_dim)

new_model.load_state_dict(old_model.state_dict()) # Synchronize models initially
optimizer = optim.Adam([
    {'params' : new_model.actor.parameters(), 'lr' : LEARNING_RATE},
    {'params' : new_model.critic.parameters(), 'lr' : LEARNING_RATE}
])

buffer = RolloutBuffer()
buffer.clear()

for episode in range(1000):
    state, _ = env.reset()
    state = torch.FloatTensor(state)
    episode_reward = 0

    buffer.clear()

    for t in range(TIMESTEPS):
        with torch.no_grad():
            action, log_prob, value, _ = old_model.get_action_and_value(state)
            next_state, reward, done, _, _ = env.step(action.item())

        # Store data
        buffer.states.append(state.numpy())
        buffer.actions.append(action.item())
        buffer.log_probs.append(log_prob.item())
        buffer.rewards.append(reward)
        buffer.state_values.append(value.item())
        buffer.dones.append(done)

        state = torch.FloatTensor(next_state)
        episode_reward += reward

        if done:
            state, _ = env.reset()
            state = torch.FloatTensor(state)
            break

    # Train PPO
    train_ppo(buffer, old_model, new_model, optimizer)
    old_model.load_state_dict(new_model.state_dict()) # Update old_model to match new_model

    if (episode % 100 == 0) and (episode != 0):
        print(f"Episode {episode} - Reward: {episode_reward}")

    # early stop
    if (episode_reward > 1000):
        break


  return F.mse_loss(input, target, reduction=self.reduction)


Episode 100 - Reward: 56.0
Episode 200 - Reward: 85.0
Episode 300 - Reward: 216.0


In [9]:
import time
max_ep_len = 300

total_test_episodes = 5
test_running_reward = 0

env = gym.make("CartPole-v1", render_mode = "human")

for episode in range(1, total_test_episodes+1):
    state, _ = env.reset()
    ep_reward = 0

    for t in range(max_ep_len):
        action_probs = new_model.actor(torch.FloatTensor(state))
        dist = Categorical(action_probs)
        action = dist.sample()

        state, reward, done, truncated, _ = env.step(action.item())
        ep_reward += reward

        env.render()
        time.sleep(0.01)

        if done:
            state, info = env.reset()
    
    test_running_reward += ep_reward
    print(f"Episode : {episode} \t\t Reward : {round(ep_reward, 2)}")
    ep_reward = 0

env.close()

Episode : 1 		 Reward : 300.0
Episode : 2 		 Reward : 300.0
Episode : 3 		 Reward : 300.0
Episode : 4 		 Reward : 300.0
Episode : 5 		 Reward : 300.0
