In [139]:
from collections import deque
from random import sample
import gymnasium as gym
import ale_py
import torch.nn as nn
import torch.nn.functional as functional
import itertools
import torch
import numpy as np
import random

In [140]:
class replayBuffer:
    def __init__(self, maxlength : int = 1000):
        self.buffer = deque(maxlen=maxlength)
    
    def sample(self, batch_size: int):
        batch_size = min(batch_size, len(self.buffer))
        start_idx = random.randint(0, len(self.buffer) - batch_size)
        print(start_idx, batch_size, len(self.buffer))
        print(list(itertools.islice(self.buffer, start_idx, start_idx + batch_size)))
        return list(itertools.islice(self.buffer, start_idx, start_idx + batch_size))
        #return self.buffer[start_idx: start_idx + batch_size]
    
    def push(self, state, action, reward, next_state):
        self.buffer.append((state, action, reward, next_state))
        

In [141]:
class DQN_Model(nn.Module):
    def __init__(self, input_shape, action_space):
        super(DQN_Model, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=8, stride=4)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=8)
        self.fc1 = nn.Linear(64 * 42 * 30, 512)
        self.fc2 = nn.Linear(512, action_space)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = nn.functional.tanh(self.fc1(x))
        x = self.fc2(x)
        return x


In [142]:
env = gym.make("ALE/MsPacman-v5", render_mode="rgb_array", obs_type="grayscale", frameskip=(2,5))
env.reset()

(array([[  0,   0,   0, ...,   0,   0,   0],
        [146, 146, 146, ..., 146, 146, 146],
        [146, 146, 146, ..., 146, 146, 146],
        ...,
        [  0,   0,   0, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8),
 {'lives': 3, 'episode_frame_number': 0, 'frame_number': 0})

In [143]:
model = DQN_Model((1, 210, 160), 4)

In [144]:
sum(param.numel() for param in model.parameters())

41685668

In [145]:
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
model(torch.from_numpy(observation.reshape(-1, 1, 210, 160)).float())

tensor([[-0.2496,  0.1276,  0.6892,  0.8275]], grad_fn=<AddmmBackward0>)

In [None]:
state, _ = env.reset()
state = torch.from_numpy(state.reshape(-1, 1, 210, 160)).float()

buffer = replayBuffer()


q_model = DQN_Model((1, 210, 160), 4)
y_model = DQN_Model((1, 210, 160), 4)
optimizer = torch.optim.Adam(q_model.parameters(), lr=0.0001)

ACTION_SPACE = [1, 2, 3, 4]
epsilon = 0.1
c_update = 10

def get_action(state):
    if random.random() < epsilon:
        return random.randint(0, 3)
    
    res = q_model(state).argmax()
    print(res)
    return res.numpy()

max_episode_steps = 1000
number_episodes = 1
for episode in range(number_episodes):
    # done = False
    steps_cnt = 0
    while steps_cnt < max_episode_steps:
        action = get_action(state)
        print(action)
        next_state, reward, terminated, truncated, info = env.step(ACTION_SPACE[action])
        if terminated or truncated:
            break
        buffer.push(state, action, reward, next_state)
        batch = buffer.sample(4)
        r = torch.FloatTensor([b[2] for b in batch])
        states = torch.from_numpy(np.asarray([b[0] for b in batch])).reshape(-1, 1, 210, 160).float()
        # states = torch.cat(states, dim=0)
        next_states = torch.from_numpy(np.asarray([b[3] for b in batch])).reshape(-1, 1, 210, 160).float()
        # next_states = torch.cat(next_states, dim=0)
        print(type(next_states))
        y = y_model(next_states).max() + r
        q = q_model(states).max()
        loss = functional.mse_loss(q, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        state =  torch.from_numpy(next_state.reshape(-1, 1, 210, 160)).float()
        if steps_cnt % c_update == 0:
            y_model.load_state_dict(q_model.state_dict())
        steps_cnt += 1
        

tensor(1)
1
0 1 1
[(tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]]]), array(1), 0.0, array([[  0,   0,   0, ...,   0,   0,   0],
       [146, 146, 146, ..., 146, 146, 146],
       [146, 146, 146, ..., 146, 146, 146],
       ...,
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8))]
<class 'torch.Tensor'>


  loss = functional.mse_loss(q, y)


tensor(1)
1
0 2 2
[(tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]]]), array(1), 0.0, array([[  0,   0,   0, ...,   0,   0,   0],
       [146, 146, 146, ..., 146, 146, 146],
       [146, 146, 146, ..., 146, 146, 146],
       ...,
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8)), (tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0

  loss = functional.mse_loss(q, y)


0
0 3 3
[(tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]]]), array(1), 0.0, array([[  0,   0,   0, ...,   0,   0,   0],
       [146, 146, 146, ..., 146, 146, 146],
       [146, 146, 146, ..., 146, 146, 146],
       ...,
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8)), (tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]]

  loss = functional.mse_loss(q, y)


tensor(1)
1
0 4 4
[(tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]]]), array(1), 0.0, array([[  0,   0,   0, ...,   0,   0,   0],
       [146, 146, 146, ..., 146, 146, 146],
       [146, 146, 146, ..., 146, 146, 146],
       ...,
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8)), (tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0

  loss = functional.mse_loss(q, y)


tensor(1)
1
1 4 5
[(tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]]]), array(1), 0.0, array([[  0,   0,   0, ...,   0,   0,   0],
       [146, 146, 146, ..., 146, 146, 146],
       [146, 146, 146, ..., 146, 146, 146],
       ...,
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=uint8)), (tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          [146., 146., 146.,  ..., 146., 146., 146.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0