In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import torch
import copy
import tqdm
import time
import gc
import cv2
from torch.distributions import Categorical

In [2]:
env = gym.make("ALE/SpaceInvaders-v5", render_mode="rgb_array")

In [3]:
state, info = env.reset()

total_reward = 0

for _ in range(500):
    
    action = env.action_space.sample()
    n_state, reward, terminated, truncated, info = env.step(action)

    frame = env.render()
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.resize(frame, (320, 420))
    frame = cv2.putText(frame, f'Action taken: {action}  Reward: {reward}', (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 
    cv2.imshow("gameplay", frame)
    pressedKey = cv2.waitKey(60) & 0xFF
    if pressedKey == ord('q'):
        break
        
    total_reward += reward
    if terminated or truncated:
            break
        
    state = n_state
    
cv2.destroyAllWindows()

print(total_reward)

  logger.warn(


15.0


In [4]:
def process_image(img, size=(84, 84)):
    img = Image.fromarray(img)
    img = ImageOps.grayscale(img).resize((size[0], size[1]))
    img = np.array(img)
    return torch.tensor(img, dtype=torch.float) / 255.0

In [5]:
class DQN(torch.nn.Module):
    def __init__(self, in_dim=1, out_dim=env.action_space.n):
        super(DQN, self).__init__()
        self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_dim, 4, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(4, 8, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(8, 16, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU()
        )

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(1024, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, out_dim)
        )

    def forward(self, x):
        conv_out = self.conv_net(x)
        flattened = torch.flatten(conv_out, start_dim=1)
        fc_out = self.fc(flattened)
        return torch.nn.functional.softmax(fc_out, dim=1)
        #return fc_out

    def act(self, x):
        probs = self.forward(x)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [6]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [7]:
dqn = DQN().to(DEVICE)
target_dqn = copy.deepcopy(dqn).to(DEVICE)

In [8]:
pytorch_total_params = sum(p.numel() for p in target_dqn.parameters() if p.requires_grad)
print(pytorch_total_params)

265446


In [9]:
max_epsilon = 1.0
min_epsilon = 0.05
decay_rate = 0.005
n_episodes = 1000
max_steps = 500
gamma = 0.95
alpha = 0.01
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(dqn.parameters(), lr=alpha)

In [10]:
torch.cuda.empty_cache()
gc.collect()

0

# TRAIN

In [50]:
for episode in tqdm.tqdm(range(n_episodes)):

    epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*episode)

    state, info = env.reset()
    terminated = False
    truncated = False
    step = 0

    for step in range(max_steps):

        processed_state = process_image(state).unsqueeze(0).unsqueeze(0).to(DEVICE)
        probs = dqn(processed_state)
        m = Categorical(probs)
        action = m.sample()
        n_state, reward, terminated, truncated, info = env.step(action)

        processed_new_state = process_image(n_state).unsqueeze(0).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            target_probs = target_dqn(processed_new_state)
            target = reward + gamma*target_probs.max()

        loss = loss_fn(probs, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if terminated or truncated:
            break
    
        state = n_state

    target_dqn = copy.deepcopy(dqn)

  return F.mse_loss(input, target, reduction=self.reduction)
  3%|██▏                                                                           | 28/1000 [02:27<1:25:37,  5.29s/it]


KeyboardInterrupt: 

# EVALUATION

In [11]:
dqn.load_state_dict(torch.load("dqn space invaders.pt"))

<All keys matched successfully>

In [12]:
dqn.to('cpu')

DQN(
  (conv_net): Sequential(
    (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU()
    (6): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=6, bias=True)
  )
)

In [14]:
state, info = env.reset()
terminated = False
truncated = False
step = 0

episode_reward = 0

while True:
    
    processed_state = process_image(state).unsqueeze(0).unsqueeze(0)#.to(DEVICE)
    with torch.no_grad():
        out_ids = dqn(processed_state)
    m = Categorical(out_ids)
    action = m.sample()
    n_state, reward, terminated, truncated, info = env.step(action)

    frame = env.render()
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.resize(frame, (320, 420))
    frame = cv2.putText(frame, f'Action taken: {action}  Reward: {reward}', (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.imshow("gameplay", frame)
    pressedKey = cv2.waitKey(60) & 0xFF
    if pressedKey == ord('q'):
        break

    episode_reward += reward

    if terminated or truncated:
        break

    state = n_state


cv2.destroyAllWindows()
print(episode_reward)

105.0


In [24]:
print(f"reward mean: {reward_mean:.3f} +/- {reward_std:.3f}")

reward mean: 100.000 +/- 0.000
