In [2]:
import gymnasium as gym
import ale_py
import numpy as np
import cv2
import torch
from torch import nn, optim
import torch.nn.functional as F
import random
from collections import namedtuple, deque
from PIL import Image

MINIBATCH_SIZE = 32
REPLAY_MEMORY_SIZE = 100000
AGENT_HISTORY_LENGTH = 4
TARGET_NETWORK_UPDATE_FREQUENCY = 10000
DISCOUNT_FACTOR = 0.99
ACTION_REPEAT = 4
UPDATE_FREQUENCY = 4
LEARNING_RATE = 0.00025
GRADIENT_MOMENTUM = 0.95
SQUARED_GRADIENT_MOMENTUM = 0.95
MIN_SQUARED_GRADIENT = 0.01
INITIAL_EXPLORATION = 1
FINAL_EXPLORATION = 0.1
FINAL_EXPLORATION_FRAME = 1000000
REPLAY_START_SIZE = 50000
NO_OP_MAX = 30

def pre_process(pre_frame,current_frame):
    new_obs = np.maximum(pre_frame,current_frame)
    obs_yuv = cv2.cvtColor(new_obs,cv2.COLOR_RGB2YUV)
    y, u ,v  = cv2.split(obs_yuv)
    y = cv2.resize(y,(80,80))
    return y

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

class QNet(nn.Module):
    def __init__(self,n_actions):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        # 假设输入图像尺寸为84x84，计算卷积层输出尺寸
        # 卷积层输出尺寸 = (输入尺寸 - kernel_size + 2*padding) / stride + 1
        # 第一层: (84-8)/4+1 = 20
        # 第二层: (20-4)/2+1 = 9
        # 第三层: (9-3)/1+1 = 7
        # 最终特征图尺寸为7x7，通道数为64
        self.in_features_dim = 64 * 7 * 7

        self.full_connect = nn.Linear(self.in_features_dim, 512)
        self.output = nn.Linear(512, n_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # 展平卷积层输出
        x = x.view(x.size(0), -1)

        x = F.relu(self.full_connect(x))
        x = self.output(x)
        return x

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

if __name__ == "__main__":
    gym.register_envs(ale_py)

    env = gym.make("ALE/Seaquest-v5")

    n_actions = env.action_space.n

    state, info = env.reset()



    policy_net = QNet(n_actions).to(device)
    target_net = QNet(n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())

    optimizer = optim.AdamW(policy_net.parameters(), lr=LEARNING_RATE, amsgrad=True)
    memory = ReplayMemory(REPLAY_MEMORY_SIZE)

    new_frame, reward, terminated, truncated, info = env.step(env.action_space.sample())

    im = Image.fromarray(pre_process(state,new_frame))
    im.show()
    env.close()