In [1]:
import gymnasium as gym
import torch
from torch import nn
from torch import optim
import numpy as np
from torch.distributions import Normal
import matplotlib.pyplot as plt

In [2]:
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()

        self.actor = nn.Sequential(
            nn.Linear(2,64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,1)
        )

        self.log_std = nn.Parameter(torch.ones(1) * 0.5)

        self.critic = nn.Sequential(
            nn.Linear(2,64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,1)
        )

    def get_action_and_value(self, state):

        mean = self.actor(state)
        std = torch.exp(self.log_std)

        dist = Normal(mean, std)

        action = dist.sample()
        log_prob = dist.log_prob(action).sum(-1)
        value = self.critic(state).squeeze(-1)
        return action, log_prob, value

In [3]:
class RolloutBuffer:
    def __init__(self):
        self.buffer = []

    def clear(self):
        self.buffer = []

    def push(self, state, action, log_prob, reward, value, done):
        self.buffer.append((state, action, log_prob, reward, value, done))

    def get(self):
        states, actions, log_probs, rewards, values, dones = zip(*self.buffer)

        return (
            torch.stack(states),
            torch.stack(actions),
            torch.stack(log_probs),
            torch.tensor(rewards, dtype=torch.float32),
            torch.stack(values),
            torch.tensor(dones, dtype=torch.float32),
        )

In [4]:
def compute_returns(rewards, dones, gamma=0.99):
    returns = []
    G = 0

    for reward, done in zip(reversed(rewards), reversed(dones)):
        if done == True:
            G = 0
        G = reward + gamma * G
        returns.insert(0, G)

    return torch.tensor(returns, dtype=torch.float32)

In [5]:
def compute_advantages(returns, values):
    advantages = returns-values.detach()
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages

In [6]:
num_updates = 500
rollout_steps = 2048
mini_batch_size = 128
ppo_epochs = 10
clip_eps = 0.2
value_coef = 0.5
entropy_coef = 0.01
episode_no = 1
reward_tracker = []
episode_reward = 0

buffer = RolloutBuffer()
env = gym.make("MountainCarContinuous-v0")
model = ActorCritic()
model.train()
actor_optimizer = optim.Adam(list(model.actor.parameters())+[model.log_std], lr = 3e-4)
critic_optimizer = optim.Adam(model.critic.parameters(), lr=1e-3)

for iteration in range(num_updates):

    buffer.clear()
    state, _ = env.reset()

    for rollout in range(rollout_steps):

        state = torch.tensor(state).float()
        with torch.no_grad():
            action, log_prob, value = model.get_action_and_value(state)
        clipped_action = torch.clip(action, -1, 1).numpy()

        next_state, reward, terminated, truncated, info = env.step(clipped_action)
        done = terminated or truncated
        episode_reward += reward

        buffer.push(state, action, log_prob, reward, value, done)

        state = next_state
        if done:
            print(f"Episode: {episode_no} | Reward: {episode_reward:.5f}")
            episode_no += 1
            episode_reward = 0
            reward_tracker.append(reward)
            state, _ = env.reset()

    states, actions, old_log_probs, rewards, values, dones = buffer.get()
    returns = compute_returns(rewards, dones)
    advantages = compute_advantages(returns, values)

    for epoch in range(ppo_epochs):

        indices = torch.randperm(rollout_steps)

        for start in range(0, rollout_steps, mini_batch_size):
            end = start + mini_batch_size
            mb_idx = indices[start:end]

            mb_states = states[mb_idx]
            mb_actions = actions[mb_idx]
            mb_old_log_probs = old_log_probs[mb_idx]
            mb_returns = returns[mb_idx]
            mb_advantages = advantages[mb_idx]

            mean = model.actor(mb_states)
            std = torch.exp(model.log_std)
            dist = Normal(mean, std)

            mb_new_log_probs = dist.log_prob(mb_actions).sum(-1)
            mb_values = model.critic(mb_states).squeeze(-1)
            mb_entropy = dist.entropy().sum(-1)

            ratio = torch.exp(mb_new_log_probs - mb_old_log_probs)
            unclipped = ratio * mb_advantages
            clipped = torch.clip(ratio, 1-clip_eps, 1+clip_eps) * mb_advantages

            actor_loss = -torch.min(unclipped, clipped).mean()
            entropy_loss = -mb_entropy.mean()
            critic_loss = (mb_values - mb_returns).pow(2).mean()

            loss = actor_loss + value_coef * critic_loss + entropy_coef * entropy_loss

            actor_optimizer.zero_grad()
            critic_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            actor_optimizer.step()
            critic_optimizer.step()


Episode: 1 | Reward: -69.20537
Episode: 2 | Reward: 53.36619
Episode: 3 | Reward: -92.63981
Episode: 4 | Reward: -69.38791
Episode: 5 | Reward: 34.12478
Episode: 6 | Reward: -68.78464
Episode: 7 | Reward: -82.15235
Episode: 8 | Reward: -68.77759
Episode: 9 | Reward: 65.59584
Episode: 10 | Reward: -70.62952
Episode: 11 | Reward: 26.27275
Episode: 12 | Reward: -67.98559
Episode: 13 | Reward: -8.70660
Episode: 14 | Reward: -67.53296
Episode: 15 | Reward: -76.00800
Episode: 16 | Reward: -69.58783
Episode: 17 | Reward: -71.90506
Episode: 18 | Reward: -69.88527
Episode: 19 | Reward: -76.71995
Episode: 20 | Reward: 37.16980
Episode: 21 | Reward: 42.01614
Episode: 22 | Reward: 61.13952
Episode: 23 | Reward: 54.37184
Episode: 24 | Reward: 50.70394
Episode: 25 | Reward: 55.88952
Episode: 26 | Reward: -123.39340
Episode: 27 | Reward: -69.05785
Episode: 28 | Reward: 33.76372
Episode: 29 | Reward: -68.17706
Episode: 30 | Reward: -78.70662
Episode: 31 | Reward: 36.62293
Episode: 32 | Reward: -77.476

KeyboardInterrupt: 

In [7]:
model.eval()
for episode in range(50):
    state, _ = env.reset()
    done = False
    total_reward = 0
    while not done:
        state = torch.tensor(state).float()
        with torch.no_grad():
            action = model.actor(state)
        action = torch.clamp(action, -1, 1).numpy()
        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        total_reward += reward
        state = next_state
    print(f"Episode: {episode+1} | Reward: {total_reward}")

Episode: 1 | Reward: 93.39745999466662
Episode: 2 | Reward: 93.39596827851504
Episode: 3 | Reward: 93.38360521080449
Episode: 4 | Reward: 93.40240396917235
Episode: 5 | Reward: 93.48236786548297
Episode: 6 | Reward: 93.43309569355344
Episode: 7 | Reward: 93.34751291617319
Episode: 8 | Reward: 93.25329001387455
Episode: 9 | Reward: 93.430956153439
Episode: 10 | Reward: 93.39121156055192
Episode: 11 | Reward: 93.4063443590383
Episode: 12 | Reward: 93.33577466505785
Episode: 13 | Reward: 93.4243808335815
Episode: 14 | Reward: 93.26385178806322
Episode: 15 | Reward: 93.40855039828745
Episode: 16 | Reward: 93.43778818154122
Episode: 17 | Reward: 93.44934490702292
Episode: 18 | Reward: 93.44496974035738
Episode: 19 | Reward: 93.4362928648792
Episode: 20 | Reward: 93.34478848373975
Episode: 21 | Reward: 93.39406024768866
Episode: 22 | Reward: 93.34424470685624
Episode: 23 | Reward: 93.4274757767623
Episode: 24 | Reward: 93.43532107028622
Episode: 25 | Reward: 93.45765994288215
Episode: 26 | R

In [8]:
torch.save(model.state_dict(), "ActorCritic.pth")

In [None]:
model.load_state_dict()