In [15]:
import torch
import gymnasium as gym
import random

import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.optim import Adam
from ale_py import ALEInterface
from copy import deepcopy

In [16]:
NUM_EPISODES = 1000
GAMMA = 0.999
EPSILON = 0.1

In [17]:
def cat(a, b):
    return torch.cat((a, b), 0)

def copy(m):
    return deepcopy(m)

In [18]:
env = gym.make("ALE/Pong-v5", full_action_space=False, obs_type='grayscale')

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [20]:
class ReplayBuffer(object):
    def __init__(self, max_size, input_shape, n_actions):
        self.mem_size = max_size
        self.mem_cntr = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape),
                                     dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_shape),
                                         dtype=np.float32)

        self.action_memory = np.zeros(self.mem_size, dtype=np.int64)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.terminal_memory[index] = done
        self.mem_cntr += 1

    def sample_buffer(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_mem, batch_size, replace=False)

        states = self.state_memory[batch]
        actions = self.action_memory[batch]
        rewards = self.reward_memory[batch]
        states_ = self.new_state_memory[batch]
        terminal = self.terminal_memory[batch]

        return states, actions, rewards, states_, terminal

In [12]:
class PongDQNNModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv2d1 = nn.Conv2d(2, 32, 5)
        self.conv2d2 = nn.Conv2d(32, 64, 5)
        # self.conv2d3 = nn.Conv2d(64, 64, 5)

        self.linear1 = nn.Linear(116032, 512)
        self.linear2 = nn.Linear(512, 6)

        self.max_pool = nn.MaxPool2d(2, 2)

        # Optimizer
        self.optimizer = Adam(self.parameters(), lr=0.001)

        self.to(device)

        self.loss = nn.MSELoss()

    def forward(self, x):
        x = F.relu(self.conv2d1(x))
        x = self.max_pool(x)
        
        x = F.relu(self.conv2d2(x))
        x = self.max_pool(x)
        
        # x = F.relu(self.conv2d3(x))
        # x = self.max_pool(x)
        
        x = torch.flatten(x, 1)
        
        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x

In [21]:
class PongDQNNAgent():
    def __init__(self):
        # Models
        self.model_main = PongDQNNModel()
        self.model_target = PongDQNNModel()
        
        # Loss
        # self.loss_fn = nn.MSELoss()

        self.memory = ReplayBuffer(20000, (2, 210, 160), 6)

        self.queries = 0

        self.epsilon = 1
        self.eps_min = 0.1
        self.eps_dec = 1e-5
        self.batch_size = 32

    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state.cpu(), action, reward, state_.cpu(), done)

    def decrement_epsilon(self):
        self.epsilon = self.epsilon - self.eps_dec \
                           if self.epsilon > self.eps_min else self.eps_min

    def sample_memory(self):
        state, action, reward, new_state, done = self.memory.sample_buffer(self.batch_size)

        states = torch.tensor(state).to(device)
        rewards = torch.tensor(reward).to(device)
        dones = torch.tensor(done).to(device)
        actions = torch.tensor(action).to(device)
        states_ = torch.tensor(new_state).to(device)

        return states, actions, rewards, states_, dones


    def getAction(self, observations, action_space):
        if random.random() >= self.epsilon:
            a = torch.argmax(self.q)
        else:
            a = action_space.sample()

        print('self.queries: ', self.queries)
        
        return a

    def train(self, observations, r):
        if self.memory.mem_cntr < self.batch_size:
            return
                
        if self.queries % 100 == 0:
            self.model_target.load_state_dict(self.model_main.state_dict())
        
        states, actions, rewards, states_, dones = self.sample_memory()
        
        indices = np.arange(self.batch_size)

        q_pred = self.model_main(states)[indices, actions]
        
        q_next = self.model_target(states_).max(dim=1)[0]
        
        q_next[dones] = 0.0
        q_target = rewards + GAMMA * q_next

        loss = self.model_main.loss(q_target, q_pred).to(device)
        loss.backward()
        self.model_main.optimizer.step()

        self.queries += 1

In [None]:
# Agent
agent = PongDQNNAgent()

for episode in range(NUM_EPISODES):
    observation_prev, _ = env.reset()
    observation, r, done, truncated, info = env.step(env.action_space.sample())

    observation_prev = torch.FloatTensor([observation_prev])
    observation = torch.FloatTensor([observation])

    observations = cat(observation, observation_prev)
    
    episode_reward = 0

    print('episode:', episode)

    while not done:
        a = agent.getAction(observations, env.action_space)
            
        observation_, r, done, truncated, info = env.step(a)

        episode_reward += r
        
        observation_ = torch.FloatTensor([observation_])

        observations_ = cat(observation_, observation)

        agent.store_transition(observations, a, r, observations_, done)

        observations = observations_
        
        agent.train(observations, r)

        observation_prev = observation
        observation = observation_

    print('episode_reward:', episode_reward)

torch.save([model_main.state_dict()], "pong.pth")

episode: 0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  0
self.queries:  1
self.queries:  2
self.queries:  3
self.queries:  4
self.queries:  5
self.queries:  6
self.queries:  7
self.queries:  8
self.queries:  9
self.queries:  10
self.queries:  11
self.queries:  12
self.queries:  13
self.queries:  14
self.queries:  15
self.queries:  16
self.queries:  17
self.queries:  18
self.queries:  19
self.queries:  20
self.queries:  21
self.queries:  22
self.queries:  23
self.queries:  24
self.queries:  25
self

In [23]:
'a'

'a'