In [2]:
import os
from collections import deque, namedtuple
import numpy as np
import cv2
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import ale_py

class AtariPreprocess:
    def __init__(self, width=84, height=84):
        self.width = width
        self.height = height

    def __call__(self, obs):
        img = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return np.ascontiguousarray(img, dtype=np.uint8)

class FrameStack:
    def __init__(self, k):
        self.k = k
        self.frames = deque(maxlen=k)

    def reset(self, frame):
        for _ in range(self.k):
            self.frames.append(frame)
        return self._get_obs()

    def append(self, frame):
        self.frames.append(frame)
        return self._get_obs()

    def _get_obs(self):
        return np.stack(self.frames, axis=0)

class SkipEnvWrapper(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            done = terminated or truncated
            if done:
                break
        return obs, total_reward, terminated, truncated, info

def make_env(env_id="ALE/AirRaid-v5", seed=0, skip=4, stack=4):
    raw = gym.make(env_id, render_mode=None)
    raw.reset(seed=seed)
    pre = AtariPreprocess()
    skipw = SkipEnvWrapper(raw, skip=skip)
    stacker = FrameStack(stack)

    class EnvObj:
        def __init__(self, e, pre, stacker):
            self.e = e
            self.pre = pre
            self.stacker = stacker
            self.action_space = e.action_space
            self.observation_space = gym.spaces.Box(
                low=0, high=255, shape=(stacker.k, pre.height, pre.width), dtype=np.uint8
            )

        def reset(self):
            obs, info = self.e.reset()
            obs_proc = self.pre(obs)
            stacked = self.stacker.reset(obs_proc)
            return stacked, info

        def step(self, action):
            obs, reward, terminated, truncated, info = self.e.step(action)
            obs_proc = self.pre(obs)
            stacked = self.stacker.append(obs_proc)
            reward = np.clip(reward, -1.0, 1.0)
            return stacked, reward, terminated, truncated, info

        def render(self, *args, **kwargs):
            return self.e.render(*args, **kwargs)

        def close(self):
            return self.e.close()

    return EnvObj(skipw, pre, stacker)

Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done', 'log_prob', 'value'))

class ActorCritic(nn.Module):
    def __init__(self, in_channels, n_actions):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self._conv_out = self._get_conv_out(in_channels)
        self.fc_shared = nn.Linear(self._conv_out, 512)
        self.actor = nn.Linear(512, n_actions)
        self.critic = nn.Linear(512, 1)

    def _get_conv_out(self, in_channels):
        x = torch.zeros(1, in_channels, 84, 84)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return int(np.prod(x.size()))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_shared(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

class PPOBuffer:
    def __init__(self, capacity, device):
        self.capacity = capacity
        self.device = device
        self.buffer = []

    def push(self, *args):
        self.buffer.append(Experience(*args))

    def get_all(self):
        if len(self.buffer) == 0:
            return None
        states = torch.from_numpy(np.stack([e.state for e in self.buffer])).float().to(self.device) / 255.0
        actions = torch.tensor([e.action for e in self.buffer], dtype=torch.long, device=self.device)
        rewards = torch.tensor([e.reward for e in self.buffer], dtype=torch.float32, device=self.device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in self.buffer])).float().to(self.device) / 255.0
        dones = torch.tensor([float(e.done) for e in self.buffer], dtype=torch.float32, device=self.device)
        log_probs = torch.tensor([e.log_prob for e in self.buffer], dtype=torch.float32, device=self.device)
        values = torch.tensor([e.value for e in self.buffer], dtype=torch.float32, device=self.device)
        return states, actions, rewards, next_states, dones, log_probs, values

    def clear(self):
        self.buffer = []

    def __len__(self):
        return len(self.buffer)

class PPOAgent:
    def __init__(self, in_channels, n_actions, device, lr=2.5e-4, gamma=0.99,
                 gae_lambda=0.95, clip_epsilon=0.2, entropy_coeff=0.01, value_coeff=0.5):
        self.device = device
        self.n_actions = n_actions
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.entropy_coeff = entropy_coeff
        self.value_coeff = value_coeff
        self.actor_critic = ActorCritic(in_channels, n_actions).to(device)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr, eps=1e-5)
        self.buffer = PPOBuffer(50000, device)

    def select_action(self, state):
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device) / 255.0
        with torch.no_grad():
            action_probs, value = self.actor_critic(state_tensor)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action).item()
        return action.item(), log_prob, value.item()

    def compute_gae(self, rewards, values, next_values, dones):
        advantages = torch.zeros_like(rewards)
        gae = 0
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = next_values[-1] if not dones[t] else 0
            else:
                next_value = values[t + 1]
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
            advantages[t] = gae
        returns = advantages + values
        return advantages, returns

    def update(self, ppo_epochs=4, mini_batch_size=256):
        if len(self.buffer) == 0:
            return None, None, None
        states, actions, rewards, next_states, dones, old_log_probs, values = self.buffer.get_all()
        with torch.no_grad():
            _, next_values = self.actor_critic(next_states)
            next_values = next_values.squeeze()
        advantages, returns = self.compute_gae(rewards, values, next_values, dones)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        old_log_probs = torch.tensor(old_log_probs, device=self.device)
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy_loss = 0
        for _ in range(ppo_epochs):
            batch_size = len(states)
            indices = torch.randperm(batch_size)
            for start in range(0, batch_size, mini_batch_size):
                end = start + mini_batch_size
                batch_indices = indices[start:end]
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]
                action_probs, state_values = self.actor_critic(batch_states)
                dist = Categorical(action_probs)
                new_log_probs = dist.log_prob(batch_actions)
                entropy = dist.entropy().mean()
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                state_values = state_values.squeeze()
                value_loss = F.mse_loss(state_values, batch_returns)
                total_loss = policy_loss + self.value_coeff * value_loss - self.entropy_coeff * entropy
                self.optimizer.zero_grad()
                total_loss.backward()
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 0.5)
                self.optimizer.step()
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy_loss += entropy.item()
        num_updates = ppo_epochs * (batch_size // mini_batch_size + 1)
        return (total_policy_loss / num_updates,
                total_value_loss / num_updates,
                total_entropy_loss / num_updates)

    def save(self, path):
        torch.save({
            'actor_critic_state_dict': self.actor_critic.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def load(self, path):
        checkpoint = torch.load(path, map_location=self.device)
        self.actor_critic.load_state_dict(checkpoint['actor_critic_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

def train(env_id='ALE/AirRaid-v5',
          seed=42,
          total_steps=2_000_000,
          update_interval=2048,
          eval_interval=50_000,
          save_path='ppo_airraid.pth'):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = make_env(env_id, seed=seed, skip=4, stack=4)
    eval_env = make_env(env_id, seed=seed+1, skip=4, stack=4)
    obs, _ = env.reset()
    in_channels = obs.shape[0]
    n_actions = env.action_space.n
    agent = PPOAgent(in_channels, n_actions, device,
                    lr=2.5e-4, gamma=0.99, gae_lambda=0.95,
                    clip_epsilon=0.2, entropy_coeff=0.01, value_coeff=0.5)
    state, _ = env.reset()
    episode_reward = 0.0
    episode_cnt = 0
    total_step = 0
    policy_losses = []
    value_losses = []
    entropy_losses = []
    print(f"Starting PPO training on {device}")
    while total_step < total_steps:
        action, log_prob, value = agent.select_action(state)
        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        agent.buffer.push(state, action, reward, next_state, done, log_prob, value)
        state = next_state
        episode_reward += reward
        total_step += 1
        if done:
            state, _ = env.reset()
            episode_cnt += 1
            print(f"Step {total_step} | Episode {episode_cnt} ended | Reward {episode_reward:.2f}")
            episode_reward = 0.0
        if len(agent.buffer) >= update_interval:
            policy_loss, value_loss, entropy_loss = agent.update(ppo_epochs=4, mini_batch_size=256)
            if policy_loss is not None:
                policy_losses.append(policy_loss)
                value_losses.append(value_loss)
                entropy_losses.append(entropy_loss)
                print(f"Step {total_step} | Policy Loss: {policy_loss:.4f} | Value Loss: {value_loss:.4f} | Entropy: {entropy_loss:.4f}")
            agent.buffer.clear()
        if total_step % eval_interval == 0 and total_step > 0:
            avg_score = evaluate(agent, eval_env, episodes=5)
            print(f"== Eval at step {total_step}: avg score = {avg_score:.2f} ==")
    agent.save(save_path)
    env.close()
    eval_env.close()
    print(f"Training completed! Model saved to {save_path}")

def evaluate(agent: PPOAgent, env, episodes=5, render=False):
    scores = []
    for ep in range(episodes):
        state, _ = env.reset()
        done = False
        total = 0.0
        while not done:
            action, _, _ = agent.select_action(state)
            state, reward, terminated, truncated, info = env.step(action)
            total += reward
            done = terminated or truncated
            if render:
                env.render()
        scores.append(total)
    return float(np.mean(scores))

'''
if __name__ == "__main__":
    # Train the model
    train(env_id="ALE/AirRaid-v5",
          seed=42,
          total_steps=2_000_000,
          update_interval=2048,
          eval_interval=50_000,
          save_path="ppo_airraid.pth")
'''


'\nif __name__ == "__main__":\n    # Train the model\n    train(env_id="ALE/AirRaid-v5",\n          seed=42,\n          total_steps=2_000_000,\n          update_interval=2048,\n          eval_interval=50_000,\n          save_path="ppo_airraid.pth")\n'

In [3]:
def play_trained_model(model_path, env_id='ALE/AirRaid-v5', episodes=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    raw_env = gym.make(env_id, render_mode="human")
    env = make_env(env_id, seed=42, skip=4, stack=4)
    env.e = raw_env
    obs, _ = env.reset()
    in_channels = obs.shape[0]
    n_actions = env.action_space.n
    agent = PPOAgent(in_channels, n_actions, device)
    agent.load(model_path)
    for ep in range(episodes):
        state, _ = env.reset()
        done = False
        score = 0
        step_count = 0
        while not done:
            action, _, _ = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = next_state
            score += reward
            step_count += 1
            env.render()
        print(f"Episode {ep+1} finished | Score: {score:.2f} | Steps: {step_count}")
    env.close()
play_trained_model("models/ppo_airraid.pth")

  checkpoint = torch.load(path, map_location=self.device)


Episode 1 finished | Score: 51.00 | Steps: 2525
Episode 2 finished | Score: 65.00 | Steps: 2765
Episode 3 finished | Score: 41.00 | Steps: 1731
