In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env, SubprocVecEnv


In [None]:
def preprocess(img):
    # img = cv.resize(img, dsize=(84, 84)) # or you can simply use rescaling
    # img = cv.cvtColor(img, cv.COLOR_RGB2GRAY) / 255.0
    
    # gray scale
    img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    
    # crop
    img = img[:84, 6:90]
    return img 

In [None]:
class CustomEnv(gym.Wrapper):
    def __init__(
        self,
        env,
        skip_frames=4,
        stack_frames=4,
        initial_no_op=50,
        max_episode_steps=1000, 
        **kwargs
    ):
        super(CustomEnv, self).__init__(env, **kwargs)
        
        # image modifications
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(4, 84,84), dtype=np.uint8)
        self.initial_no_op = initial_no_op
        self.skip_frames = skip_frames
        self.stack_frames = stack_frames

        # max episode length
        self.max_episode_steps = max_episode_steps 
        self.step_count = 0
    
    def reset(self, **kwargs):
        # Reset the original environment.
        s, info = self.env.reset(**kwargs)
        self.step_count = 0

        # Do nothing for the next self.initial_no_op steps
        for i in range(self.initial_no_op):
            s, r, terminated, truncated, info = self.env.step(0)
        
        # Convert a frame to 84 X 84 gray scale one
        s = preprocess(s)
        
        
        # The initial observation is simply a copy of the frame s
        self.stacked_state = np.tile(s, (self.stack_frames, 1, 1))  # [4, 84, 84]
        return self.stacked_state, info
    
    def step(self, action):        
        self.step_count += 1

        # We take an action for self.skip_frames steps
        reward = 0
        for _ in range(self.skip_frames):
            s, r, terminated, truncated, info = self.env.step(action)
            reward += r
            if terminated or truncated:
                break
        
        if self.step_count >= self.max_episode_steps:
            truncated = True
            info['truncated'] = True

        # clip reward
        reward = np.clip(reward, a_min=None, a_max=1.0)

        # Convert a frame to 84 X 84 gray scale one
        s = preprocess(s)

        # Push the current frame s at the end of self.stacked_state
        self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)

        return self.stacked_state, reward, terminated, truncated, info

In [None]:
env = gym.make('CarRacing-v2', continuous=False, render_mode="rgb_array")
env = CustomEnv(env)

# inspect observation type
s, _ = env.reset()
print("The shape of an observation: ", s.shape)

fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for i in range(4):
    axes[i].imshow(s[i], cmap='gray')
    axes[i].axis('off')
plt.show()

# check if env is okay for stable_baselines3
check_env(env);

In [None]:
make_env = lambda : CustomEnv(gym.make('CarRacing-v2', continuous=False, render_mode="rgb_array"))
num_envs = 6

envs = SubprocVecEnv([make_env]*num_envs)
model = PPO("CnnPolicy", envs, verbose=1)
model.learn(total_timesteps=1e7, progress_bar=True)

In [None]:
from IPython import display

test_env = make_env()
observation, info = test_env.reset()
img = plt.imshow(test_env.render()) 
while True:
    img.set_data(test_env.render()) 
    display.display(plt.gcf())
    display.clear_output(wait=True)
    action, _states = model.predict(observation)
    observation, reward, terminated, truncated, info = test_env.step(action)
    if terminated or truncated:
        break
