In [None]:
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing
import ale_py

env = gym.make('ALE/Breakout-v5', frameskip=1)
env = AtariPreprocessing(env, 
    screen_size=84,
    grayscale_obs=True,
    frame_skip=4,
    noop_max=30
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from tqdm import tqdm

device = "mps" 

def irelu_e_decay(start=1.0, end=0.1, steps=1000):
    rate = (start - end) / steps
    return lambda x: start - rate * x if x < steps else end


class SimpleQFunction(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=8, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(in_features=12800, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=env.action_space.n)
    def forward(self, x):
        b, h, w, c = x.shape
        x = x.permute(0, 3, 1, 2) 
        x = F.relu(self.conv1(x)) 
        x = F.relu(self.conv2(x))
        x = x.view(b, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = SimpleQFunction(env).to(device)
target_model = SimpleQFunction(env).to(device)
target_model.load_state_dict(model.state_dict())
gamma = 0.99
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)
timesteps = 100_000_000
epsilon_schedule = irelu_e_decay(start=1.0, end=0.1, steps=1_000_000)
total_steps = 0
replay_buffer = []
minibatch_size = 32
buffer_size = 1_000_000
losses = []
min_experiences = 10000

for t in tqdm(range(timesteps)):
    observation, info = env.reset()
    observation = torch.from_numpy(observation).float().to(device)
    done = False
    while not done:
        p_value = np.random.rand()
        e_value = epsilon_schedule(t) 
        
        # exploration vs exploitation, and also that we need to pass 4 experiences to model
        if p_value < e_value or len(replay_buffer) < 4:
            action_index = np.random.randint(env.action_space.n)
        else:
            # we pass the past 4 observations to the model to give it a sense of time passing
            past_observations = torch.stack([obs[0] for obs in replay_buffer[-4:]], dim=-1) # H,W,4
            q_values = model(past_observations.unsqueeze(0)).squeeze(0)
            action_index = torch.argmax(q_values)

        new_observation, reward, terminated, truncated, info = env.step(action_index)

        # clip reward to be between -1 and 1 as shown in paper
        reward = np.clip(reward, -1, 1, dtype=np.float32)
        
        done = terminated or truncated

        new_observation = torch.from_numpy(new_observation).float().to(device)
        replay_buffer.append((observation, action_index, reward, new_observation, done))
    
        if total_steps % 12000 == 0:
            target_model.load_state_dict(model.state_dict())
            
        # train on minibatch after sufficient data
        if total_steps > min_experiences:
            if len(replay_buffer) >= minibatch_size:
                sample_indices = random.sample(range(3, len(replay_buffer)), minibatch_size)
                expanded_minibatch = [replay_buffer[i-3 : i+1] for i in sample_indices]
                # now we want to take the last (action_index, reward, new_observation, done)
                # but stack the 4 observations into a single tensor
                batched_observations = torch.stack([torch.stack([experience[0] for experience in timestep], dim=-1) for timestep in expanded_minibatch])
                # assert batched_observations.shape == (minibatch_size, 84, 84, 4)
                # we get the last state of each minibatch element and set the observation to be the last four observations
                minibatch = [expanded_minibatch[i][-1][1:] for i in range(len(expanded_minibatch))]
                # assert len(minibatch[0]) == 4
                batched_actions, batched_rewards, batched_next_observations, batched_dones = zip(*minibatch)
                reward_tensor = torch.tensor(batched_rewards).to(device)
                done_tensor = torch.tensor(batched_dones, dtype=torch.float32).to(device)
                batched_next_observations = torch.cat([batched_observations[:, :, :, -3:], torch.stack(batched_next_observations).unsqueeze(-1)], dim=-1)
                # assert batched_next_observations.shape == (minibatch_size, 84, 84, 4)
                batched_actions = torch.tensor(batched_actions).to(device)

                batched_q_values = model(batched_observations)
                batched_max_q_values = batched_q_values.gather(1, batched_actions.unsqueeze(1)).squeeze(-1)
                batched_next_q_values = target_model(batched_next_observations)
                batched_max_next_q_values = batched_next_q_values.max(dim=1).values
                total_reward = reward_tensor + gamma * batched_max_next_q_values * (1 - done_tensor)
                loss = F.mse_loss(batched_max_q_values, total_reward.detach())
                if t % 100 == 0:
                    losses.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
                optimizer.step()
                replay_buffer = replay_buffer[-buffer_size:]

        # next observation becomes current observation
        observation = new_observation
        total_steps += 1 

env.close()

In [24]:
# save model weights
torch.save(model.state_dict(), 'breakout_model.pth')


In [None]:
from matplotlib import pyplot as plt
plt.plot(losses)
plt.show()