In [None]:
import gymnasium as gym
import torch
import time
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
import ale_py
gym.register_envs(ale_py)

ENV_NAME = "PongNoFrameskip-v4"
TOTAL_TIMESTEPS = 1_000_000
MODEL_PATH = "ppo_pong"

def train():
    env = make_atari_env(ENV_NAME, n_envs=1, seed=0)
    env = VecFrameStack(env, n_stack=4)

    model = PPO("CnnPolicy", env, verbose=1)
    model = PPO.load(MODEL_PATH, env=env)
    model.learn(total_timesteps=TOTAL_TIMESTEPS)
    model.save(MODEL_PATH)
    env.close()

def play():
    model = PPO.load(MODEL_PATH)
    env = make_atari_env(ENV_NAME, n_envs=1, seed=0)
    env = VecFrameStack(env, n_stack=4)
    # vis_env = make_atari_env(ENV_NAME, n_envs=1, seed=0)
    obs = env.reset()
    env.render("human")
    done = False

    while True:
        action, _states = model.predict(obs, deterministic=False)
        obs, rewards, dones, info = env.step(action)
        env.render("human")
        time.sleep(0.01)
            
    env.close()

In [None]:
train()

In [None]:
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C

import ale_py

# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
# Frame-stacking with 4 frames
vec_env = VecFrameStack(vec_env, n_stack=4)

model = A2C("CnnPolicy", vec_env, verbose=1)
model.learn(total_timesteps=1)

In [None]:
obs = vec_env.reset()
while True:
    action, _states = model.predict(obs, deterministic=False)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")
    time.sleep(0.05)