In [17]:
from collections import namedtuple
from enum import IntEnum
from gymnasium.spaces import Box
import numpy as np
import gymnasium as gym
import ale_py

gym.register_envs(ale_py)


def create_abstract_pong(render_mode=None):
    env = gym.make("PongNoFrameskip-v4", render_mode=render_mode)
    env = AbstractPong(env)
    return env


class PongAction(IntEnum):
    NOOP = 0
    FIRE = 1
    RIGHT = 2
    LEFT = 3
    RIGHTFIRE = 4
    LEFTFIRE = 5


class AbstractPong(gym.Wrapper):
    ArenaView = namedtuple("ArenaView", ["low", "high"])
    DataView = namedtuple("DataView", ["column", "color", "color_value"])
    _ARENA = ArenaView(34, -16)
    _OPP_VIEW = DataView(18, 0, 213)
    _AGENT_VIEW = DataView(141, 1, 186)
    _BALL_VIEW = DataView(None, 1, 236)

    def __init__(self, env):
        super().__init__(env)
        assert env.spec.id == "PongNoFrameskip-v4"
        self.env = env
        self.frame_skip = 4
        self.observation_space = Box(
            low=np.zeros((4, 4), dtype=np.int64),
            high=np.tile((210, 210, 210, 160), (4, 1)),
            dtype=np.int64,
        )
        self.framebuf = None
        self.posbuf = np.zeros((4, 4), dtype=np.float32)

    def reset(self):
        self.posbuf[:] = 0
        self.framebuf, info = self.env.reset()
        # Random starts
        noops = np.random.random_integers(1, 30)
        for _ in range(noops):
            self.framebuf, _, _, _, info = self.env.step(PongAction.NOOP)
        self.posbuf[-1] = self._get_positions()
        return self._get_obs(), info

    def step(self, action: PongAction):
        rew = 0
        for _ in range(self.frame_skip):
            self.framebuf, r, terminated, truncated, info = self.env.step(action)
            rew += r
            if terminated or truncated:
                break
        self.posbuf[:3] = self.posbuf[1:]
        self.posbuf[-1] = self._get_positions()
        return self._get_obs(), rew, terminated, truncated, {}

    def _arena_slice(self):
        return self.framebuf[self._ARENA.low : self._ARENA.high]

    def _paddle_slice(self, view: DataView):
        return self._arena_slice()[:, view.column, view.color]

    def _paddle_pos(self, view: DataView) -> int:
        slice = self._paddle_slice(view)
        if slice[0] == view.color_value:
            idx = np.argmax(slice != view.color_value) - 8 + self._ARENA.low
        else:
            idx = np.argmax(slice == view.color_value)
            if idx == 0:
                return idx
            else:
                idx += 8 + self._ARENA.low
        return idx

    def _ball_pos(self):
        arena = self._arena_slice()[:, :, self._BALL_VIEW.color]
        idx = np.argmax(arena == self._BALL_VIEW.color_value)
        if idx == 0:
            return 0, 0
        row = idx // 160
        column = idx - row * 160
        return row + self._ARENA.low + 1, column

    def _get_positions(self):
        agent_pos = self._paddle_pos(self._AGENT_VIEW)
        opp_pos = self._paddle_pos(self._OPP_VIEW)
        ball_pos = self._ball_pos()
        positions = np.array([agent_pos, opp_pos, *ball_pos])
        return positions

    def _get_obs(self):
        return np.array(self.posbuf)


In [18]:
import torch
import torch.nn as nn

class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) 
        self.conv2 = nn.Conv2d(32, 48, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(48, 96, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(96 * 4 * 4, 192)
        self.fc2 = nn.Linear(192, 6)

        self.dropout1 = nn.Dropout(0.04)  
        self.dropout2 = nn.Dropout(0.04)

    def forward(self, x):
        x = x.unsqueeze(1)

        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)

        return x


In [19]:
#check if its the right output
batch_size = 64
random_input = torch.randn(batch_size, 4, 4)

random_input = random_input.unsqueeze(1)

qnet = QNet()

output = qnet(random_input) 
print(f"Output shape: {output.shape}") 


Output shape: torch.Size([64, 6])


In [20]:
#check number of weights
def count_weights(model):
    total_weights = 0
    for param in model.parameters():
        total_weights += param.numel()
    return total_weights
qnet = QNet() 
total_weights = count_weights(qnet)
print(f"Number of weights: {total_weights}")

Number of weights: 352022


In [21]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from collections import deque, namedtuple
from torch.nn.functional import mse_loss


Transition = namedtuple(
    "Transition", ("state", "action", "reward", "next_state", "terminated")
)

Batch = namedtuple(
    "Batch", ("states", "actions", "rewards", "next_states", "terminateds")
)

class Agent:
    def __init__(self, device, buffer_size=90000, batch_size=64, lr=0.0003, qnet=None, target_qnet=None,
                 eps_min=0.04, tau=0.988, lam=0.01, gamma=0.99):
        
        self.device = device

        self.qnet = qnet or QNet()
        self.target_qnet = target_qnet or QNet()

        self.device = device
        
        self.qnet = qnet or QNet().to(self.device)
        self.target_qnet = target_qnet or QNet().to(self.device)
        
        self.optimizer = torch.optim.Adam(self.qnet.parameters(), lr=lr)
        self.replay_buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.epsilon = 1.0
        self.eps_min = eps_min
        self.tau = tau
        self.lam = lam
        self.gamma = gamma

    def store_transition(self, state, action, reward, next_state, terminated):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        next_state = torch.tensor(next_state, dtype=torch.float32).to(self.device) 
        
        state = state.unsqueeze(0)
        next_state = next_state.unsqueeze(0)
        
        self.replay_buffer.append(Transition(state, action, reward, next_state, terminated))

    def sample(self):
        sample = random.sample(self.replay_buffer, k=self.batch_size)
        states, actions, rewards, next_states, terminateds = list(zip(*sample))

        states = torch.stack([state.to(self.device) for state in states])
        actions = torch.tensor(actions, dtype=torch.int64).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states = torch.stack([next_state.to(self.device) for next_state in next_states])
        terminateds = torch.tensor(terminateds, dtype=torch.float32).to(self.device)

        return Batch(states, actions, rewards, next_states, terminateds)
    
    @torch.no_grad()
    def greedy_action(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.randint(6)
        else:
            state = torch.tensor(np.array(state), dtype=torch.float32).unsqueeze(0)
            state = state.unsqueeze(1)
            state = state.to(self.device)
            q_values = self.qnet(state)
            return q_values.argmax().item()
    
    def DQN_update(self, batch: Batch):
        if len(self.replay_buffer) < self.batch_size:
            return 0.0
        batch = Batch(
                batch.states.to(self.device),
                batch.actions.to(self.device),
                batch.rewards.to(self.device),
                batch.next_states.to(self.device),
                batch.terminateds.to(self.device)
            )
        
        with torch.no_grad():
            max_next_q = self.target_qnet(batch.next_states).max(1)[0]
            targets = batch.rewards + (1 - batch.terminateds) * self.gamma * max_next_q

        current_qs = self.qnet(batch.states).gather(1, batch.actions.unsqueeze(1)).squeeze()

        loss = nn.SmoothL1Loss()(current_qs, targets)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.qnet.parameters(), 1.5)
        self.optimizer.step()

        for target_param, qnet_param in zip(self.target_qnet.parameters(), self.qnet.parameters()):
            target_param.data.copy_((1.0 - self.lam) * target_param.data + self.lam * qnet_param.data)

        return loss.item()



    def save(self, filename: str):
        torch.save(self.qnet.state_dict(), filename)

    def load(self, filename: str):
        self.qnet.load_state_dict(torch.load(filename))

In [22]:
def experiment_DQN(env, agent, num_experiments=3, num_episodes=500) -> pd.DataFrame:
    dfs = []

    for exp in range(num_experiments):
        print(f"Running experiment {exp + 1}/{num_experiments}...", end="\r")
        episode_returns = []
        agent.epsilon = 1
        tot_rew = 0

        agent.qnet = QNet().to(agent.device)
        agent.target_qnet = QNet().to(agent.device)
        agent.optimizer = torch.optim.Adam(agent.qnet.parameters(), lr=0.00015)

        for episode in range(num_episodes):
            observation, info = env.reset()
            terminated = False
            tot_rew = 0
            steps = 0

            while not terminated or truncated:
                steps += 1
                state = observation
                action = agent.greedy_action(state)
                observation, reward, terminated, truncated, info = env.step(action)
                next_state = observation

                tot_rew += reward
                agent.store_transition(state, action, reward, next_state, terminated)
                state = next_state

                if terminated:
                    break

                if len(agent.replay_buffer) >= agent.batch_size:
                    batch = agent.sample()
                    agent.DQN_update(batch)
                

            print(f"Episode {episode + 1}/{num_episodes} - Reward: {tot_rew} - Steps: {steps}")

            agent.epsilon = max(agent.eps_min, agent.tau * agent.epsilon)
            episode_returns.append(tot_rew)

        df_exp = pd.DataFrame(
            {"exp": exp, "episode": np.arange(num_episodes), "return": episode_returns}
        )

        df_exp["moving_avg_return"] = df_exp["return"].rolling(window=10, min_periods=1).mean()
        dfs.append(df_exp)

        agent.save(f"dqn_weights_{exp}.pth")

    df = pd.concat(dfs, ignore_index=True)
    df["algo"] = "DQN"
    return df

In [16]:
#run experiment

import matplotlib.pyplot as plt
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = create_abstract_pong(render_mode=None)


agent = Agent(device,qnet=QNet(), target_qnet=QNet())


results = experiment_DQN(env, agent, num_experiments=3, num_episodes=500)


A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


KeyboardInterrupt: 

In [23]:
#test against opponent
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
agent = Agent(device)
agent.load('dqn5_weights_500.pth')

 
agent.qnet.eval()

env = create_abstract_pong(render_mode="human")

state, info = env.reset()

done = False
total_reward = 0
agent.epsilon = 0

while not done:
    action = agent.greedy_action(state)
    state, reward, done, truncated, info = env.step(action)
    total_reward += reward
    env.render()

print(f"Total reward: {total_reward}")
env.close()

  self.qnet.load_state_dict(torch.load(filename))
MESA: error: ZINK: failed to choose pdev
glx: failed to create drisw screen
  noops = np.random.random_integers(1, 30)


KeyboardInterrupt: 

In [24]:
#plots graphs
sns.lineplot(results, x = "episode", y = "moving_avg_return", hue = "algo", legend = "auto")
plt.show()

NameError: name 'results' is not defined