In [1]:
import numpy as np

from collections import deque

import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

from PIL import Image, ImageOps

from tqdm import tqdm

# Gym
import gymnasium as gym

import cv2

In [2]:
env_id = "ALE/Assault-v5"

env = gym.make(env_id, render_mode="rgb_array")

In [3]:
state, info = env.reset()

total_reward = 0

for _ in range(500):
    
    action = env.action_space.sample()
    n_state, reward, terminated, truncated, info = env.step(action)

    frame = env.render()
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.resize(frame, (320, 420))
    frame = cv2.putText(frame, f'Action taken: {action}  Reward: {reward}', (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 
    cv2.imshow("gameplay", frame)
    pressedKey = cv2.waitKey(60) & 0xFF
    if pressedKey == ord('q'):
        break
        
    total_reward += reward
    if terminated or truncated:
            break
        
    state = n_state
    
cv2.destroyAllWindows()

print(total_reward)

  logger.warn(


189.0


In [4]:
def process_image(img, size=(84, 84)):
    img = Image.fromarray(img)
    img = ImageOps.grayscale(img).resize((size[0], size[1]))
    img = np.array(img)
    return torch.tensor(img, dtype=torch.float) / 255.0

In [5]:
class Policy(torch.nn.Module):
    def __init__(self, in_dim=1, out_dim=env.action_space.n):
        super(Policy, self).__init__()
        self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_dim, 4, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(4, 8, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(8, 16, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU()
        )

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(1024, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, out_dim)
        )

    def forward(self, x):
        conv_out = self.conv_net(x)
        flattened = torch.flatten(conv_out, start_dim=1)
        fc_out = self.fc(flattened)
        return torch.nn.functional.softmax(fc_out, dim=1)
        #return fc_out

    def act(self, state):
        #state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [6]:
def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):
    scores_deque = deque(maxlen=100)
    scores = []
    
    for i_episode in tqdm(range(1, n_training_episodes+1)):
        saved_log_probs = []
        rewards = []
        state, info = env.reset()
        step = 0
        while True:
            step+=1
            processed_state = process_image(state).unsqueeze(0).unsqueeze(0).to(device)
            action, log_prob = policy.act(processed_state)
            saved_log_probs.append(log_prob)
            state, reward, terminated, truncated, info = env.step(action)
            rewards.append(reward)
            if terminated or truncated:
                break
        scores_deque.append(sum(rewards))
        scores.append(sum(rewards))
        
        returns = deque(maxlen=step)
        n_steps = len(rewards)
        
        for t in range(n_steps)[::-1]:
            disc_return_t = returns[0] if len(returns)>0 else 0
            returns.appendleft(gamma*disc_return_t+rewards[t])
            
        eps = np.finfo(np.float32).eps.item()
        
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + eps)
        
        policy_loss = []
        for log_prob, disc_return in zip(saved_log_probs, returns):
            policy_loss.append(-log_prob * disc_return)
        policy_loss = torch.cat(policy_loss).sum()
        
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        
        if i_episode % print_every == 0:
            print('Episode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))
    
    return scores

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [13]:
hyperparameters = {
    "n_training_episodes": 300,
    "n_evaluation_episodes": 10,
    "max_t": 1000,
    "gamma": 1.0,
    "lr": 1e-2,
    "env_id": env_id,
}

In [14]:
policy = Policy().to(device)

optimizer = optim.Adam(policy.parameters(), lr=hyperparameters["lr"])

In [15]:
scores = reinforce(
    policy,
    optimizer,
    hyperparameters["n_training_episodes"],
    hyperparameters["max_t"],
    hyperparameters["gamma"],
    100
)

 33%|██████████████████████████▋                                                     | 100/300 [06:26<14:25,  4.33s/it]

Episode 100	Average Score: 225.95


 67%|█████████████████████████████████████████████████████▎                          | 200/300 [13:43<07:09,  4.30s/it]

Episode 200	Average Score: 285.00


100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [20:50<00:00,  4.17s/it]

Episode 300	Average Score: 285.00





In [16]:
torch.save(policy.state_dict(), "policy assault v5 VERSION_2.pt")

In [17]:
def evaluate_agent(env, max_steps, n_eval_episodes, policy):
    episode_rewards = []
    for episode in tqdm(range(n_eval_episodes)):
        state, info = env.reset()
        step = 0
        terminated = False
        truncated = False
        total_rewards_ep = 0
        
        while True:
            processed_state = process_image(state).unsqueeze(0).unsqueeze(0).to(device)
            action, _ = policy.act(processed_state)
            new_state, reward, terminated, truncated, info = env.step(action)
            total_rewards_ep+= reward
            
            if terminated or truncated:
                break
            state = new_state
            
        episode_rewards.append(total_rewards_ep)
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    
    return mean_reward, std_reward

In [18]:
evaluate_agent(
    env,
    hyperparameters["max_t"],
    hyperparameters["n_evaluation_episodes"],
    policy
)

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:35<00:00,  3.51s/it]


(285.0, 0.0)

In [13]:
policy = Policy()

In [14]:
policy.load_state_dict(torch.load("policy assault v5.pt"))

<All keys matched successfully>

In [15]:
policy.to('cpu')

Policy(
  (conv_net): Sequential(
    (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU()
    (6): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=7, bias=True)
  )
)

In [16]:
state, info = env.reset()
terminated = False
truncated = False
step = 0

episode_reward = 0

while True:
    
    processed_state = process_image(state).unsqueeze(0).unsqueeze(0)#.to(DEVICE)
    with torch.no_grad():
        action, _ = policy.act(processed_state)
    n_state, reward, terminated, truncated, info = env.step(action)

    frame = env.render()
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.resize(frame, (320, 420))
    frame = cv2.putText(frame, f'Action taken: {action}  Reward: {reward}', (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.imshow("gameplay", frame)
    pressedKey = cv2.waitKey(60) & 0xFF
    if pressedKey == ord('q'):
        break

    episode_reward += reward

    if terminated or truncated:
        break

    state = n_state

cv2.destroyAllWindows()
print(episode_reward)

357.0
