In [None]:
#!pip install swig 
#!pip install gym[box2d]

In [None]:
#pip install ema_pytorch

In [None]:
import os
import imageio
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import *
from ema_pytorch import EMA
from torch.distributions.categorical import Categorical
from gymnasium.wrappers import *
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from stable_baselines3.common.vec_env import DummyVecEnv

In [None]:
def make_env(env_id):
    def wrapped_env():
        env = gym.make(env_id, render_mode='rgb_array', continuous=False)
        env = MaxAndSkipEnv(env, skip=4)
        env = ResizeObservation(env, 84)
        env = GrayScaleObservation(env)
        env = FrameStack(env, 4)
        return env
    return wrapped_env

env_id = "CarRacing-v2"
num_envs = 4
envs = DummyVecEnv([make_env(env_id) for _ in range(num_envs)])

obs = envs.reset()
img = envs.render()
plt.imshow(img)
plt.axis("off")
plt.savefig('./car_init.png')
plt.show()

In [None]:
obs_shape = envs.observation_space.shape
action_shape = envs.action_space.shape
n_action = envs.action_space.n
init_obs = envs.reset()

print(init_obs.shape)
print('obs_shape:', obs_shape)
print('action_shape:', action_shape)
print('n_action:', n_action)

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(ActorCritic, self).__init__()
        # Common layers
        self.conv1 = nn.Conv2d(num_inputs, 32, kernel_size=7, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        
        self.fc1 = nn.Linear(128 * 5 * 5, 512)
        self.actor_fc = nn.Linear(512, num_actions)
        self.critic_fc = nn.Linear(512, 1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        
        x = x.view(-1, 128 * 5 * 5)  
        x = F.relu(self.fc1(x))
        logits = self.actor_fc(x) # Actor
        state_values = self.critic_fc(x) # Critic
        return Categorical(logits=logits), state_values

In [None]:
gamma = 0.99
gae_lambda = 0.95
clip_coef = 0.2
ent_coef = 0.01
vf_coef = 1 

num_envs = 4
num_steps = 512
minibatch_size = 32
total_steps = 500000
batch_size = num_envs * num_steps
device = "cuda" if torch.cuda.is_available() else "cpu"
agent = ActorCritic(4, n_action).to(device)
ema = EMA(agent, beta = 0.995, update_every = 1)
optimizer = optim.Adam(agent.parameters(), lr=1e-4, eps=1e-5)
v_loss_fn = nn.SmoothL1Loss(reduction='mean')

obs = torch.zeros((num_steps, num_envs) + obs_shape).to(device)
actions = torch.zeros((num_steps, num_envs) + action_shape).to(device)
logprobs = torch.zeros((num_steps, num_envs)).to(device)
rewards = torch.zeros((num_steps, num_envs)).to(device)
dones = torch.zeros((num_steps, num_envs)).to(device)
values = torch.zeros((num_steps, num_envs)).to(device)

global_step = 0
next_obs = torch.Tensor(envs.reset()).to(device)
num_updates = total_steps // batch_size

In [None]:
ckpt_dir = './model_weight/'
os.makedirs(ckpt_dir, exist_ok=True)
reward_record = []

for i in tqdm(range(1, num_updates + 1), desc="PPO training progress: ~~"):
    for step in range(num_steps):
        global_step += 1 * num_envs
        obs[step] = next_obs
    
        with torch.no_grad():
            dist, value = agent(next_obs)
            values[step] = value.view(-1)
            action = dist.sample()
            logprob = dist.log_prob(action)

        actions[step] = action
        logprobs[step] = logprob
        next_obs, reward, done, info = envs.step(action.cpu().numpy())
        dones[step] = torch.Tensor(done).to(device)
        rewards[step] = torch.tensor(reward).to(device).view(-1)
        next_obs = torch.Tensor(next_obs).to(device)

    with torch.no_grad():
        _, next_value = agent(next_obs)
        advantages = torch.zeros_like(rewards).to(device)
        lastgaelam = 0

        for t in reversed(range(num_steps)):
            if t == num_steps - 1:
                is_done = 1.0 - dones[t]
                nextvalues = next_value
            else:
                is_done = 1.0 - dones[t]
                nextvalues = values[t+1]
            delta = rewards[t] + gamma * nextvalues * is_done - values[t]
            advantages[t] = lastgaelam = delta + gamma * gae_lambda * is_done * lastgaelam
        returns = advantages + values

    b_obs = obs.view((-1,) + obs_shape)
    b_actions = actions.view((-1,) + action_shape)
    b_logprobs = logprobs.view(-1)
    b_advantages = advantages.view(-1)
    b_returns = returns.view(-1)

    b_inds = np.arange(batch_size)
    for epoch in range(4):
        np.random.shuffle(b_inds)
        for start in range(0, batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]

            dist, new_value = agent(b_obs[mb_inds])
            newlogprob = dist.log_prob(b_actions.long()[mb_inds])
            entropy = dist.entropy()
            logratio = newlogprob - b_logprobs[mb_inds]
            ratio = logratio.exp()

            # Policy loss
            mb_advantages = b_advantages[mb_inds]
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()
            
            # Value loss
            newvalue = newvalue.view(-1)
            v_loss = v_loss_fn(newvalue, b_returns[mb_inds])
            entropy_loss = entropy.mean()
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
            optimizer.step()
            ema.update()
    
    reward_record.append(b_returns.mean().item())
    torch.save(ema.state_dict(), f"{ckpt_dir}/ema_epoch.pt")

In [None]:
#print(reward_record)
plt.plot(reward_record)
plt.title('cumalative reward')
plt.show()

In [None]:
os.makedirs('./result/', exist_ok=True)
env_id = "CarRacing-v2"
num_envs = 4
envs = DummyVecEnv([make_env(env_id) for _ in range(num_envs)])

def record_video(env, policy, out_directory, fps=20):
    images = []  
    done = False
    state = env.reset()
    img = env.render()
    images.append(img)
    
    i = 0
    while not done:
        i += 1
        state = torch.Tensor(state).to(device)
        dist, _ = policy(state)
        action = dist.sample()
        
        state, reward, done, info = env.step(action.cpu().numpy()) 
        img = env.render()
        done = any(done)
        #print(img.shape)
        images.append(img)
        if i > 2000:
            break
    imageio.mimsave('./result/test.gif', [np.array(img) for _, img in enumerate(images)], fps=fps)
ema.eval()
record_video(envs, ema, out_directory = None)