In [2]:
from collections import deque
from random import sample
import gymnasium as gym
import ale_py
import torch.nn as nn
import torch

In [3]:
class replayBuffer:
    def __init__(self, maxlength : int = 1000):
        self.buffer = deque(maxlen=maxlength)
    
    def sample(self, batch_size: int):
        return sample(self.buffer, batch_size)
    
    def push(self, state, action, reward, next_state):
        self.buffer.append((state, action, reward, next_state))
        

In [67]:
class DQN_Model(nn.Module):
    def __init__(self, input_shape, action_space):
        super(DQN_Model, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=8, stride=4)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=8)
        self.fc1 = nn.Linear(64 * 42 * 30, 512)
        self.fc2 = nn.Linear(512, action_space)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = nn.functional.tanh(self.fc1(x))
        x = self.fc2(x)
        return x


In [38]:
env = gym.make("ALE/MsPacman-v5", render_mode="rgb_array", obs_type="grayscale", frameskip=(2,5))
env.reset()

(array([[  0,   0,   0, ...,   0,   0,   0],
        [146, 146, 146, ..., 146, 146, 146],
        [146, 146, 146, ..., 146, 146, 146],
        ...,
        [  0,   0,   0, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8),
 {'lives': 3, 'episode_frame_number': 0, 'frame_number': 0})

In [62]:
model = DQN_Model((1, 210, 160), 4)

In [64]:
sum(param.numel() for param in model.parameters())

41685668

In [77]:
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
model(torch.from_numpy(observation.reshape(-1, 1, 210, 160)).float())

tensor([[ 0.8680, -0.7164,  0.1347, -0.0941]], grad_fn=<AddmmBackward0>)

In [16]:
env = gym.make("ALE/MsPacman-v5", render_mode="rgb_array", obs_type="grayscale", frameskip=(2,5))
env.reset()

for i in range(100):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)

    print(observation)
    
    if terminated or truncated:
        observation, info = env.reset()
env.close()



[[  0   0   0 ...   0   0   0]
 [146 146 146 ... 146 146 146]
 [146 146 146 ... 146 146 146]
 ...
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]]
[[  0   0   0 ...   0   0   0]
 [146 146 146 ... 146 146 146]
 [146 146 146 ... 146 146 146]
 ...
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]]
[[  0   0   0 ...   0   0   0]
 [146 146 146 ... 146 146 146]
 [146 146 146 ... 146 146 146]
 ...
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]]
[[  0   0   0 ...   0   0   0]
 [146 146 146 ... 146 146 146]
 [146 146 146 ... 146 146 146]
 ...
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]]
[[  0   0   0 ...   0   0   0]
 [146 146 146 ... 146 146 146]
 [146 146 146 ... 146 146 146]
 ...
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]
 [  0   0   0 ...   0   0   0]]
[[  0   0   0 ...   0   0   0]
 [146 146