In [6]:
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStack
from gymnasium.vector import SyncVectorEnv
import pygame
import numpy as np
import torch

from utils import DQN, ReplayMemory

In [7]:
def make_env():
    def _init():
        env = gym.make("ALE/Breakout-v5")
        env = AtariPreprocessing(
            env,
            frame_skip=1,  # ALE/Breakout-v5 already uses frame_skip=4
            screen_size=84,
            grayscale_obs=True,
            noop_max=30,
        )
        env = FrameStack(env, 4)
        return env

    return _init


envs = SyncVectorEnv([make_env() for _ in range(1)])


policy_network = DQN((4, 84, 84), 4)
policy_network.load_state_dict(torch.load("./models/atari_model.pth", map_location=torch.device("cpu")))

<All keys matched successfully>

In [8]:
pygame.init()
screen = pygame.display.set_mode((84, 84))
pygame.display.set_caption('Breakout-v5')

In [9]:
import time

states, _ = envs.reset()

total_reward = 0.0
running = True

for i in range(10000000000):
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

    if running == False:
        break
    
    if i % 10000 == 0:
        actions = policy_network.act(states, 1)
        results = envs.step(actions)
        new_states, rewards, terminated, truncated, _ = results

        total_reward += rewards[0]

    # Capture and render the frame
    frame = new_states[0, -1]  # Extract the last frame from the state
    frame = np.array(frame)  # Ensure frame is a numpy array
    frame_surface = pygame.surfarray.make_surface(frame)
    screen.blit(pygame.transform.scale(frame_surface, (84, 84)), (0, 0))
    pygame.display.flip()

    if i % 10000 == 0:
        states = new_states

        if terminated or truncated:
            print("Done")
            break
    

envs.close()
pygame.quit()

total_reward

Done


1.0