In [71]:
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing
import ale_py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from tqdm import tqdm

device = "mps" 

def irelu_e_decay(start=1.0, end=0.1, steps=1000):
    rate = (start - end) / steps
    return lambda x: start - rate * x if x < steps else end


class SimpleQFunction(nn.Module):
    def __init__(self, action_space_size):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=8, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(in_features=12800, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=action_space_size)
    def forward(self, x):
        b, h, w, c = x.shape
        x = x.permute(0, 3, 1, 2) 
        x = F.relu(self.conv1(x)) 
        x = F.relu(self.conv2(x))
        x = x.view(b, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

num_parallel_envs = 48
envs = [gym.make('ALE/Breakout-v5', frameskip=1) for _ in range(num_parallel_envs)]
envs = [AtariPreprocessing(env, 
    screen_size=84,
    grayscale_obs=True,
    frame_skip=4,
    noop_max=30
) for env in envs]
model = SimpleQFunction(envs[0].action_space.n).to(device)
target_model = SimpleQFunction(envs[0].action_space.n).to(device)
target_model.load_state_dict(model.state_dict())
gamma = 0.99
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)
timesteps = 100_000_000
epsilon_schedule = irelu_e_decay(start=1.0, end=0.1, steps=1_000_000)
# one for each env
replay_buffers = [[] for _ in range(num_parallel_envs)]
minibatch_size = 32
buffer_size = 1_000_000
losses = []
min_experiences = 10000


def get_next_steps(envs, action_indices):
    # returns a list of tuples (observation, reward, terminated, truncated, info)
    return [env.step(action_index) for env, action_index in zip(envs, action_indices)]


class Experience:
    def __init__(self, current_observation, action_index, reward, next_observation, done):
        self.current_observation = current_observation
        self.action_index = action_index
        self.reward = np.clip(reward, -1, 1, dtype=np.float32)
        self.next_observation = next_observation
        self.done = done


def store_experience(replay_buffers, envs, action_indices, init_obs):
    next_step = get_next_steps(envs, action_indices)
    count = 0
    for i in range(num_parallel_envs):
        buffer = replay_buffers[i]
        # this must be a torch tensor
        current_observation = buffer[-1].current_observation if len(buffer) > 0 else init_obs[i]
        # assert current_observation.dtype == torch.float32
        next_observation = torch.from_numpy(next_step[i][0]).float().to(device)
        reward = next_step[i][1]
        done = next_step[i][2] or next_step[i][3]
        action = action_indices[i]
        replay_buffers[i].append(Experience(current_observation, action, reward, next_observation, done))
        if done:
            count += 1
    return count


total_steps = 0
for t in tqdm(range(timesteps)):
    # reset all environments
    init_obs = [torch.from_numpy(env.reset()[0]).float().to(device) for env in envs]

    env_done_count = 0
    while env_done_count < num_parallel_envs:
        p_value = np.random.rand()
        e_value = epsilon_schedule(t) 
        
        # exploration vs exploitation, and also that we need to pass 4 experiences to model
        if p_value < e_value or len(replay_buffers[0]) < 4:
            action_indices = [np.random.randint(env.action_space.n) for env in envs]
        else:
            # we pass the past 4 observations to the model to give it a sense of time passing
            batched_past_observations = torch.stack([torch.stack([obs[0] for obs in replay_buffer[-4:]], dim=-1) for replay_buffer in replay_buffers]) # B,H,W,4
            # assert batched_past_observations.shape == (num_parallel_envs, 84, 84, 4)
            batched_q_values = model(batched_past_observations) # B,A
            # assert batched_q_values.shape == (num_parallel_envs, env.action_space.n)
            action_indices = torch.argmax(batched_q_values, dim=-1).cpu().numpy() # B
            # assert action_indices.shape == (num_parallel_envs,)

        env_done_count += store_experience(replay_buffers, envs, action_indices, init_obs)
        
        if total_steps % 12000 == 0:
            target_model.load_state_dict(model.state_dict())
            
        # train on minibatch after sufficient data
        if total_steps > min_experiences and len(replay_buffers[0]) >= minibatch_size:
            sample_indices = np.array(random.sample(range(3, len(replay_buffers[0])), minibatch_size * num_parallel_envs)).reshape(num_parallel_envs, minibatch_size)
            # assert sample_indices.shape == (num_parallel_envs, minibatch_size)
            # this should contain random batches of 4 observations for each env
            experience_minibatch = [[replay_buffers[replay_buffer_idx][i-3 : i+1] for i in sample_indices[replay_buffer_idx]] for replay_buffer_idx in range(num_parallel_envs)]
            # assert len(experience_minibatch[0][0]) == 4
            # now we want to take the last (action_index, reward, new_observation, done)
            # but stack the 4 observations into a single tensor
            batched_total_observations = []
            for i in range(num_parallel_envs):
                batched_observations = torch.stack([torch.stack([experience.current_observation for experience in timestep], dim=-1) for timestep in experience_minibatch[i]])
                batched_total_observations.append(batched_observations)
            batched_total_observations = torch.cat(batched_total_observations, dim=0) # B,H,W,4
            # assert batched_total_observations.shape == (minibatch_size*num_parallel_envs, 84, 84, 4)


            # assert batched_observations.shape == (minibatch_size, 84, 84, 4)
            # we get the last state of each minibatch element and set the observation to be the last four observations

            batched_actions = torch.tensor([experiences[-1].action_index for i in range(len(experience_minibatch)) for experiences in experience_minibatch[i]]).to(device)
            # assert batched_actions.shape == (minibatch_size*num_parallel_envs,)
            batched_rewards = torch.tensor([experiences[-1].reward for i in range(len(experience_minibatch)) for experiences in experience_minibatch[i]]).to(device)
            # assert batched_rewards.shape == (minibatch_size*num_parallel_envs,)
            batched_next_observations = torch.stack([experiences[-1].next_observation for i in range(len(experience_minibatch)) for experiences in experience_minibatch[i]]).to(device)
            batched_next_observations = torch.cat([batched_total_observations[:, :, :, -3:], batched_next_observations.unsqueeze(-1)], dim=-1)
            # assert batched_next_observations.shape == (minibatch_size*num_parallel_envs, 84, 84, 4)
            done_tensor = torch.tensor([experiences[-1].done for i in range(len(experience_minibatch)) for experiences in experience_minibatch[i]]).to(device).float()
            # assert done_tensor.shape == (minibatch_size*num_parallel_envs,)

            batched_q_values = model(batched_total_observations)
            batched_max_q_values = batched_q_values.gather(1, batched_actions.unsqueeze(1)).squeeze(-1)
            batched_next_q_values = target_model(batched_next_observations)
            batched_max_next_q_values = batched_next_q_values.max(dim=1).values
            total_reward = batched_rewards + gamma * batched_max_next_q_values * (1 - done_tensor)
            loss = F.mse_loss(batched_max_q_values, total_reward.detach())

            if t % 100 == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
            optimizer.step()
            replay_buffers = [replay_buffer[-buffer_size:] for replay_buffer in replay_buffers]

        # next observation becomes current observation
        total_steps += 1

[env.close() for env in envs]

In [24]:
# save model weights
torch.save(model.state_dict(), 'breakout_model.pth')


In [None]:
from matplotlib import pyplot as plt
plt.plot(losses)
plt.show()