In [None]:
import gymnasium as gym
from stable_baselines3 import DQN
import ale_py

#### Check all registerd environemnts in Gymnasium

In [None]:
# List all registered envs
envs = gym.envs.registry.keys()

# Filter for Atari ones
atari_envs = [env_id for env_id in envs if "NoFrameskip" in env_id]
print(sorted(envs))

#### If Atari Environment not registedred run code below

In [None]:
gym.register_envs(ale_py)

#### Create environment and train model

In [None]:
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
env_name = "Breakout-v4"
env = make_atari_env(env_name, n_envs=4, seed=0)
env = VecFrameStack(env, n_stack=4)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./breakout_tensorboard/

In [None]:
#Define model here
model = DQN(
    'CnnPolicy',
    env,
    verbose=1,
    learning_rate=0.00025,
    gamma=0.99,
    buffer_size=100000,
    batch_size=32,
    train_freq=4,
    target_update_interval=10000,
    learning_starts=50000,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    tensorboard_log="./Breakout_tensorboard"
)


#train model here
model.learn(total_timesteps=1000000)

model.save("model_Breakout")

#### Valdiate model by having it play live atari game

In [None]:
#Load model here
model.load("model_Breakout")

In [None]:
from gymnasium.wrappers import RecordVideo
import matplotlib.pyplot as plt

env = gym.make(env_name, render_mode="rgb_array")
env = gym.wrappers.AtariPreprocessing(env, grayscale_obs=True, screen_size=84, frame_skip=4, scale_obs=False)
env = gym.wrappers.FrameStackObservation(env, stack_size=4)
env = RecordVideo(env, "./")

env_data = env.reset()
obs = env_data[0] #observation

episode_reward = 0
max_ep_timesteps = 100000

for t in range(max_ep_timesteps):
    print(f"timestep: {t}")

    action, _states = model.predict(obs, deterministic = True) # inference the model given the current game data. 
    env_data = env.step(action) # update enviorment data with action from inferenced model 
    obs = env_data[0] #update obs with new enviorment data from agent action 
    reward = env_data[1] #update reward with new enviorment data from agent action 
    done = env_data[2] # update if episode is done with new viorment data from agent action

    episode_reward += reward

    plt.imshow(obs[-1], cmap = "gray")
    plt.show()

    print(f"action: {action}")
    print(f"episode reward: {episode_reward}")
    print(env_data[1:])


    if done: break
env.close()