In [22]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image, ImageOps
from tqdm.auto import tqdm
import gymnasium as gym
import cv2
from torch.distributions import Categorical

In [3]:
env_id = "ALE/SpaceInvaders-v5"

env = gym.make(env_id, render_mode="rgb_array")

In [4]:
def process_image(img, size=(128, 128)):
    img = Image.fromarray(img)
    img = ImageOps.grayscale(img).resize((size[0], size[1]))
    img = np.array(img)
    return torch.tensor(img, dtype=torch.float) / 255.0

In [5]:
class Policy(torch.nn.Module):
    def __init__(self, in_dim=1, out_dim=env.action_space.n):
        super(Policy, self).__init__()
        self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_dim, 4, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(4, 8, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(8, 16, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 32, 3),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU()
        )

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(1152, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, out_dim)
        )

    def forward(self, x):
        conv_out = self.conv_net(x)
        flattened = torch.flatten(conv_out, start_dim=1)
        fc_out = self.fc(flattened)
        return torch.nn.functional.softmax(fc_out, dim=1)
        return fc_out

In [91]:
model = Policy().to(device)

In [92]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

729350


In [93]:
def update_model(rewards, log_probs, gamma, model, optimizer):
    rewards = torch.tensor(rewards)
    log_probs = torch.cat(log_probs).sum()
    returns = rewards * torch.pow(gamma, torch.arange(len(rewards)))
    returns = (returns.sum() - returns.mean()) / (returns.std()+1e-12)
    loss = -1*returns*log_probs

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    rewards = []
    log_probs = []

    step = 0

    return rewards, log_probs, step

In [94]:
gamma = 0.95
lr = 1e-2
batch = 16
device = "cuda" if torch.cuda.is_available() else "cpu"
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [100]:
def train(model, optimizer, env, n_episodes, gamma, batch):
    for episode in tqdm(range(n_episodes)):
        state, info = env.reset()
        terminated = False
        truncated = False

        episode_rewards = []
        episode_log_probs = []

        step = 0

        while True:
            step+=1
            processed_state = process_image(state).unsqueeze(0).unsqueeze(0).to(device)
            probs = model(processed_state)
            m = Categorical(probs)
            action = m.sample()
            log_prob = m.log_prob(action)
            n_state, reward, terminated, truncated, info = env.step(action.item())

            if terminated or truncated:
                episode_rewards, episode_log_probs, step = update_model(episode_rewards, episode_log_probs, gamma, model, optimizer)
                break
    
            episode_rewards.append(reward)
            episode_log_probs.append(log_prob)

            if step == batch:
                episode_rewards, episode_log_probs, step = update_model(episode_rewards, episode_log_probs, gamma, model, optimizer)
                
            state = n_state

    return model

In [101]:
trained_model = train(model, optimizer, env, 100, gamma, batch)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [10:19<00:00,  6.20s/it]


In [86]:
def evaluate_agent(env, n_eval_episodes, policy):
    episode_rewards = []
    for episode in tqdm(range(n_eval_episodes)):
        state, info = env.reset()
        step = 0
        terminated = False
        truncated = False
        total_rewards_ep = 0
        
        while True:
            processed_state = process_image(state).unsqueeze(0).unsqueeze(0).to(device)
            probs = policy(processed_state)
            m = Categorical(probs)
            action = m.sample()
            log_prob = m.log_prob(action)
            new_state, reward, terminated, truncated, info = env.step(action.item())
            total_rewards_ep+= reward
            
            if terminated or truncated:
                break
            state = new_state
            
        episode_rewards.append(total_rewards_ep)
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    
    return mean_reward, std_reward

In [102]:
mean_reward, std_reward = evaluate_agent(env, 10, trained_model)

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:33<00:00,  3.33s/it]


In [103]:
mean_reward, std_reward

(270.0, 0.0)

In [104]:
state, info = env.reset()
terminated = False
truncated = False
step = 0

episode_reward = 0

while True:
    
    processed_state = process_image(state).unsqueeze(0).unsqueeze(0).to(device)
    with torch.no_grad():
        probs = trained_model(processed_state)
        m = Categorical(probs)
        action = m.sample()
        log_prob = m.log_prob(action)
    n_state, reward, terminated, truncated, info = env.step(action.item())

    frame = env.render()
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.resize(frame, (320, 420))
    frame = cv2.putText(frame, f'Action taken: {action.item()}  Reward: {reward}', (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.imshow("gameplay", frame)
    pressedKey = cv2.waitKey(60) & 0xFF
    if pressedKey == ord('q'):
        break

    episode_reward += reward

    if terminated or truncated:
        break

    state = n_state

cv2.destroyAllWindows()
print(episode_reward)

270.0


Training... 	 Episode: 1	:   7%|███▍                                             | 7/100 [1:52:22<24:52:53, 963.16s/it]