In [1]:
import random
import cv2
import torch, os
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import gymnasium as gym
# import gym
import imageio
import imageio_ffmpeg
import warnings
warnings.filterwarnings("ignore")

In [2]:
SEED = 42
ENV_NAME = "CarRacing-v3"
GAMMA = 0.99
MAX_ENVS = 32
LEARNING_RATE = 3e-4
MAX_STEPS = 32
TOTAL_STEPS = 5_000_000
BATCH_SIZE = MAX_ENVS * MAX_STEPS
NUM_UPDATES = TOTAL_STEPS // BATCH_SIZE
NUM_MINIBATCHES = 4 
MINIBATCHSIZE = BATCH_SIZE // NUM_MINIBATCHES
PPO_EPOCHS = 5
CLIP_VALUE = 0.2
VALUE_COEFF = 0.5
ENTROPY_COEFF = 0.01
LOG_EVERY_N_STEPS = 10

print("Num updates: ", NUM_UPDATES)
print("Batch size: ", BATCH_SIZE)
print("Mini batch size: ", MINIBATCHSIZE)

TARGET_HEIGHT = 64
TARGET_WIDTH = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

Num updates:  4882
Batch size:  1024
Mini batch size:  256


device(type='cuda')

In [3]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x2417f5ccc50>

In [5]:
class PreprocessAndFrameStack(gym.ObservationWrapper):
    def __init__(self, env, height, width, num_stack):
        env = gym.wrappers.FrameStackObservation(env, num_stack)
        super().__init__(env)
        self.height = height
        self.width = width
        self.num_stack = num_stack

        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.num_stack, self.height, self.width), dtype=np.uint8)
    
    def observation(self, obs):
        stack = np.array(obs, dtype=np.uint8)
        
        stack = np.array([frame for frame in stack])
        
        # 3. Grayscale and Resize each frame in the stack
        processed_stack = []
        for frame in stack:
            if frame.ndim == 3 and frame.shape[2] == 3: # H, W, C
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
            frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
            processed_stack.append(frame)
        
        # 4. Stack frames along a new channel dimension
        return np.stack(processed_stack, axis=0)


In [None]:
class ActorNet(nn.Module):
    def __init__(self, action_space):
        super(ActorNet, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(),
        )
        self.actor = nn.Linear(512, action_space)

    def forward(self, x):
        return self.network(x / 255.0)
        
    def get_action(self, x, action=None):
        hidden = self.forward(x)
        logits = self.actor(hidden)
        dist = torch.distributions.Categorical(logits=logits)
        if action is None:
            action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return action, log_prob, entropy

In [7]:
class CriticNet(nn.Module):
    def __init__(self):
        super(CriticNet, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(),
        )
        self.critic = nn.Linear(512, 1)

    def forward(self, x):
        return self.critic(self.network(x / 255.0))

In [8]:
def make_env(env_id, seed, idx, eval_mode=False):
    def thunk():
        render_mode = "rgb_array" if eval_mode else None
        env = gym.make(env_id, render_mode=render_mode, continuous=False)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        
        env = PreprocessAndFrameStack(env, height=TARGET_HEIGHT, width=TARGET_WIDTH, num_stack=4)
        
        env.action_space.seed(seed + idx)
        env.observation_space.seed(seed + idx)
        return env
    return thunk

In [9]:
test_env = make_env(ENV_NAME, SEED, 99, True)()

In [10]:
envs = gym.vector.SyncVectorEnv(
        [make_env(ENV_NAME, SEED, i) for i in range(MAX_ENVS)]
    )

In [11]:
obs_space = envs.single_observation_space.shape
action_space = envs.single_action_space.n
print(f"Observation Space: {obs_space}")
print(f"Action Space: {action_space}")

Observation Space: (4, 64, 64)
Action Space: 5


In [12]:
actor_network = ActorNet(action_space).to(DEVICE)
critic_network = CriticNet().to(DEVICE)
optimizer = optim.Adam(list(actor_network.parameters()) + list(critic_network.parameters()), lr=LEARNING_RATE, eps=1e-5)

In [13]:
obs_storage = torch.zeros((MAX_STEPS, MAX_ENVS, *obs_space), dtype=torch.uint8).to(DEVICE)
action_storage = torch.zeros((MAX_STEPS, MAX_ENVS)).to(DEVICE)
old_log_probs = torch.zeros((MAX_STEPS, MAX_ENVS)).to(DEVICE)
rewards_storage = torch.zeros((MAX_STEPS, MAX_ENVS)).to(DEVICE)
dones_storage = torch.zeros((MAX_STEPS, MAX_ENVS)).to(DEVICE)
values_storage = torch.zeros((MAX_STEPS, MAX_ENVS)).to(DEVICE)

In [14]:
next_obs, _ = envs.reset(seed=SEED)
next_obs = torch.Tensor(next_obs).to(DEVICE)
next_done = torch.zeros(MAX_ENVS).to(DEVICE)
next_obs.shape

torch.Size([32, 4, 64, 64])

In [15]:
test_x = torch.randn((MAX_ENVS, *obs_space), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
    test_action, test_log_porbs, test_entropy = actor_network.get_action(test_x)
    test_values = critic_network(test_x)

print(test_action.shape)
print(test_log_porbs.shape)
print(test_entropy.shape)
print(test_values.shape)

torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([32, 1])


In [16]:
def evaluate(actor_model, device, num_eval_eps=10, record=False):
    eval_env = make_env(env_id=ENV_NAME, seed=SEED, idx=100, eval_mode=True)()
    
    actor_model.to(device)
    actor_model.eval()
    returns = []
    frames = []

    for eps in range(num_eval_eps):
        obs, _ = eval_env.reset()
        done = False
        episode_reward = 0.0

        while not done:
            if record:
                frame = eval_env.unwrapped.render()
                frames.append(frame)

            with torch.no_grad():
                obs_tensor = torch.tensor(obs, device=device, dtype=torch.float32).unsqueeze(0)
                action, _, _ = actor_model.get_action(obs_tensor)
                action_scalar = action.cpu().numpy().item()
                obs, reward, terminated, truncated, info = eval_env.step(action_scalar)
                done = terminated or truncated
                episode_reward += float(reward)
          
        returns.append(episode_reward)
      
    eval_env.close()
    actor_model.train()
    return returns, frames

In [17]:
for update in range(1, NUM_UPDATES + 1):
    for step in range(MAX_STEPS):
        obs_storage[step] = next_obs
        dones_storage[step] = next_done

        with torch.no_grad():
            action, log_prob, _ = actor_network.get_action(next_obs)
            values = critic_network(next_obs)
        
        action_storage[step] = action
        old_log_probs[step] = log_prob
        values_storage[step] = values.flatten()

        next_obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())

        done = np.logical_or(terminated, truncated)
            
        rewards_storage[step] = torch.tensor(reward).to(DEVICE).view(-1)

        next_obs = torch.Tensor(next_obs).to(DEVICE)
        next_done = torch.Tensor(done).to(DEVICE)

    with torch.no_grad():
        returns = torch.zeros_like(rewards_storage).to(DEVICE)
        bootstrap_value = critic_network(next_obs).squeeze()
        gt_next_state = bootstrap_value * (1.0 - next_done)

        for i in reversed(range(MAX_STEPS)):
            rt = rewards_storage[i] + GAMMA * gt_next_state
            returns[i] = rt
            gt_next_state = returns[i] * (1.0 - dones_storage[i])
        
    advantages = returns - values_storage
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    b_obs = obs_storage.reshape((-1,) + envs.single_observation_space.shape)
    b_logprobs = old_log_probs.reshape(-1)
    b_actions = action_storage.reshape(-1)
    b_advantages = advantages.reshape(-1)
    b_returns = returns.reshape(-1)

    b_idxs = np.arange(BATCH_SIZE)

    for epoch in range(PPO_EPOCHS):
        np.random.shuffle(b_idxs)
        for start in range(0, BATCH_SIZE, MINIBATCHSIZE):
            end = start + MINIBATCHSIZE
            min_batch = b_idxs[start: end]

            _, new_log_probs, entropy = actor_network.get_action(b_obs[min_batch], b_actions[min_batch].long())
            ratio = torch.exp(new_log_probs - b_logprobs[min_batch])

            pg_loss1 = b_advantages[min_batch] * ratio
            pg_loss2 = b_advantages[min_batch] * torch.clamp(ratio, 1 - CLIP_VALUE, 1 + CLIP_VALUE) 
            policy_loss = -torch.min(pg_loss1, pg_loss2).mean()

            current_values = critic_network(b_obs[min_batch]).squeeze()
            critic_loss = VALUE_COEFF * F.mse_loss(current_values, b_returns[min_batch])

            entropy_loss = entropy.mean()

            loss = policy_loss - ENTROPY_COEFF * entropy_loss + critic_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(list(actor_network.parameters()) + list(critic_network.parameters()), 0.5)
            optimizer.step()

    if update % LOG_EVERY_N_STEPS == 0:
        print(f"[STEP]: {update}, [ACTOR_LOSS]: {policy_loss.item()}, [CRITIC_LOSS]: {critic_loss.item()}, [TOTAL_LOSS]: {loss.item()}, [REWARDS]: {rewards_storage.mean()}")
        train_video_path = f"B:\Pytorch\RL\eval_car_racing\ppo_car_racing_{update}.mp4"
        returns, frames = evaluate(actor_model=actor_network, device=DEVICE, num_eval_eps=1, record=True)

        if frames and len(frames) > 0:
            imageio.mimsave(
                train_video_path,
                frames,
                fps=30,
                codec='libx264',
                macro_block_size=1
            )

[STEP]: 10, [ACTOR_LOSS]: -0.047900475561618805, [CRITIC_LOSS]: 2.946540355682373, [TOTAL_LOSS]: 2.8862624168395996, [REWARDS]: 0.10729311406612396
[STEP]: 20, [ACTOR_LOSS]: 2.6768073439598083e-05, [CRITIC_LOSS]: 401.74951171875, [TOTAL_LOSS]: 401.73809814453125, [REWARDS]: -0.8218571543693542
[STEP]: 30, [ACTOR_LOSS]: 0.00026462599635124207, [CRITIC_LOSS]: 3.463057041168213, [TOTAL_LOSS]: 3.451810598373413, [REWARDS]: -0.0063512749038636684
[STEP]: 40, [ACTOR_LOSS]: -0.0004702135920524597, [CRITIC_LOSS]: 35.89255905151367, [TOTAL_LOSS]: 35.87985610961914, [REWARDS]: 0.08224737644195557
[STEP]: 50, [ACTOR_LOSS]: -0.01691450923681259, [CRITIC_LOSS]: 61.61197280883789, [TOTAL_LOSS]: 61.58324432373047, [REWARDS]: -0.24559199810028076
[STEP]: 60, [ACTOR_LOSS]: -0.03334979712963104, [CRITIC_LOSS]: 48.49032974243164, [TOTAL_LOSS]: 48.446205139160156, [REWARDS]: -0.08132374286651611
[STEP]: 70, [ACTOR_LOSS]: 0.028090698644518852, [CRITIC_LOSS]: 47.90833282470703, [TOTAL_LOSS]: 47.924812316894

In [20]:
torch.save(actor_network.state_dict(), r"B:\Pytorch\RL\models\ppo_carracing.pth")