In [25]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make("ALE/Asteroids-v5")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")

print(env.observation_space.shape)
print(env.action_space.n)

(210, 160, 3)
14


In [26]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [27]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        # Defining the convolutional layers
        self.conv_layers = nn.Sequential(
            # Input: 3x160x210
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),  # Output: 16x160x210
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # Output: 16x80x105
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), # Output: 32x80x105
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # Output: 32x40x52
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2), # Output: 64x40x52
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)   # Output: 64x20x26
        )
        
        # Calculating the total number of features after the final pooling layer
        # Output dimensions are 64 channels, each with a size of 20x26
        self.num_features = 64 * 20 * 26
        
        # Defining the fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(self.num_features, 256),  # from calculated features to 256
            nn.ReLU(),
            nn.Linear(256, 14)  # 256 inputs to 14 outputs (actions)
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, self.num_features)  # Manually flatten to ensure correct shape
        x = self.fc_layers(x)
        return x


In [28]:
def select_action(state):
    global steps_done
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-steps_done / EPS_DECAY)
    if random.random() > eps_threshold:
        with torch.no_grad():
            # Ensure action is a 1D tensor with batch dimension
            return policy_model(state).argmax(dim=1, keepdim=True).to(device)
    else:
        # Randomly sample an action and ensure it's also a 1D tensor
        return torch.tensor([[random.randrange(env.action_space.n)]], device=device, dtype=torch.long)


In [29]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s.unsqueeze(0) for s in batch.next_state if s is not None], dim=0)

    state_batch = torch.cat([s.unsqueeze(0) for s in batch.state], dim=0)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)


    state_action_values = policy_model(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        if len(non_final_next_states) > 0:
            next_state_values[non_final_mask] = target_model(non_final_next_states).max(1).values

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 100)
    optimizer.step()

In [30]:
# hyperparameters
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 0.05
TAU = 0.005
LR = 1e-4
UPDATE_TARGET_EVERY = 2800

# initialize things
memory = ReplayMemory(10000)
policy_model = DQN().to(device)
target_model = DQN().to(device)
target_model.load_state_dict(policy_model.state_dict())

optimizer = optim.AdamW(policy_model.parameters(), lr=LR, amsgrad=True)
 
steps_done = 0

In [31]:
def save_model(path):
    torch.save(policy_model.state_dict(), path)

def load_model(path):
    policy_model.load_state_dict(torch.load(path))
    policy_model.to(device)

In [32]:
if torch.cuda.is_available():
    num_episodes = 600
else:
    num_episodes = 50

for i_episode in range(num_episodes):
    state, info = env.reset()
    state = torch.tensor(state,dtype=torch.float32,device=device).permute(2,1,0)
    current_lives = info["lives"] # intialize # of lives from environment
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, info = env.step(action.item())

        if info['lives'] < current_lives:
            reward -= 100
            current_lives = info['lives']  

        reward += 1  
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).permute(2,1,0)

        memory.push(state, action, next_state, reward)

        state = next_state

        optimize_model()

        if steps_done % UPDATE_TARGET_EVERY == 0:
            # update target network using tau
            target_net_state_dict = target_model.state_dict()
            policy_net_state_dict = policy_model.state_dict()
            for key in policy_net_state_dict:
                target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
            target_model.load_state_dict(target_net_state_dict)

        if done:
            print(f"on episode {i_episode}, which lasted {t} frames")
            break

print('Complete')

KeyboardInterrupt: 

In [None]:
torch.save(target_model.state_dict(), "./rl-asteroids-v1-target.pth")