## Breakout-V4

#### All the graphs generated get stored to an experiment tracking platform i.e. Weights & Biases (wandb)

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import gymnasium as gym
import numpy as np
from stable_baselines3.common.atari_wrappers import (
    EpisodicLifeEnv,
    FireResetEnv,
)
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Dict, Any
import ale_py

gym.register_envs(ale_py)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.login()

class BreakoutPPONet(nn.Module):
    """
    Neural network for the PPO algorithm in the Breakout environment.
    """
    def __init__(self, input_shape: Tuple[int, int, int], num_actions: int):
        """
        Initializes the Breakout PPO network.

        Args:
            input_shape (Tuple[int, int, int]): Shape of the input (channels, height, width).
            num_actions (int): Number of actions in the action space.
        """
        super(BreakoutPPONet, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        def conv2d_size_out(size: int, kernel_size: int, stride: int) -> int:
            return (size - (kernel_size - 1) - 1) // stride + 1

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(input_shape[1], 8, 4), 4, 2), 3, 1)
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(input_shape[2], 8, 4), 4, 2), 3, 1)
        linear_input_size = convw * convh * 64
        
        self.fc = nn.Linear(linear_input_size, 512)
        self.actor = nn.Linear(512, num_actions)
        self.critic = nn.Linear(512, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor representing the state.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Actor logits and critic value.
        """
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return self.actor(x), self.critic(x)

    def get_action_and_value(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get action, log probability, entropy, and value for a given state.

        Args:
            state (torch.Tensor): Current state of the environment.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: Action, log probability, entropy, and value.
        """
        logits, value = self.forward(state)
        probs = F.softmax(logits, dim=-1)
        dist = Categorical(probs)
        action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value

class PPOTrainer:
    """
    Trainer for the PPO algorithm in the Breakout environment.
    """
    def __init__(self, env: gym.Env, params: Dict[str, Any]):
        """
        Initializes the PPO trainer.

        Args:
            env (gym.Env): The environment for training.
            params (Dict[str, Any]): Dictionary of training parameters.
        """
        wandb.init(project="breakout-ppo", config=params)
        self.env = env
        self.params = params
        self.num_actions = env.action_space.n
        self.agent = BreakoutPPONet((4, 84, 84), self.num_actions).to(device)
        self.optimizer = torch.optim.Adam(self.agent.parameters(), lr=params["learning_rate"])
        self.gamma = params["gamma"]
        self.lam = params["gae_lambda"]
        self.clip_epsilon = params["clip_epsilon"]
        self.entropy_coef = params["entropy_coef"]
        self.value_coef = params["value_coef"]
        self.max_grad_norm = params["max_grad_norm"]
        self.frames = []
        self.episode_returns = []
        self.episode_losses = []
        self.episode_lengths = []
        self.moving_avg_window = 15

    def preprocess_obs(self, obs: np.ndarray) -> torch.Tensor:
        """
        Preprocesses the observation for the neural network.

        Args:
            obs (np.ndarray): Raw observation from the environment.

        Returns:
            torch.Tensor: Preprocessed observation.
        """
        return torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device) / 255.0
    
    def calculate_moving_average(self, data: List[float]) -> float:
        """
        Calculates the moving average of the provided data.

        Args:
            data (List[float]): Data series to calculate the moving average.

        Returns:
            float: Moving average of the data.
        """
        return pd.Series(data).rolling(window=self.moving_avg_window).mean().iloc[-1]

    def collect_rollout(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[float], List[torch.Tensor], List[torch.Tensor], List[bool], List[torch.Tensor]]:
        """
        Collects a rollout from the environment.

        Returns:
            Tuple[List[torch.Tensor], List[torch.Tensor], List[float], List[torch.Tensor], List[torch.Tensor], List[bool], List[torch.Tensor]]:
                States, actions, rewards, values, log probabilities, done flags, and entropies from the rollout.
        """
        states, actions, rewards, values, log_probs, dones, entropies = [], [], [], [], [], [], []
        max_steps = 2048
        step = 0
        obs, _ = self.env.reset()
        obs = self.preprocess_obs(obs)
        episode_reward = 0
        
        while step < max_steps:
            with torch.no_grad():
                action, log_prob, entropy, value = self.agent.get_action_and_value(obs)
            next_obs, reward, done, truncated, _ = self.env.step(action.cpu().numpy()[0])
            next_obs = self.preprocess_obs(next_obs)
            states.append(obs.detach())
            actions.append(action.detach())
            rewards.append(reward)
            values.append(value.detach())
            log_probs.append(log_prob.detach())
            entropies.append(entropy.detach())
            dones.append(done or truncated)
            episode_reward += reward
            obs = next_obs
            step += 1
            if done or truncated:
                obs, _ = self.env.reset()
                obs = self.preprocess_obs(obs)
                self.episode_returns.append(episode_reward)
                self.episode_lengths.append(len(rewards))
                episode_reward = 0
        self.episode_returns.append(episode_reward)
        self.episode_lengths.append(len(rewards))
        return states, actions, rewards, values, log_probs, dones, entropies

    def compute_gae(self, rewards: List[float], values: List[torch.Tensor], dones: List[bool]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes Generalized Advantage Estimation (GAE).

        Args:
            rewards (List[float]): List of rewards from the rollout.
            values (List[torch.Tensor]): List of value estimates.
            dones (List[bool]): List of done flags.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Computed advantages and returns.
        """
        rewards = np.array(rewards)
        advantages, returns = [], []
        last_gae = 0
        next_value = 0 if dones[-1] else values[-1].item()
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t].item()
            last_gae = delta + self.gamma * self.lam * (1 - dones[t]) * last_gae
            advantages.insert(0, last_gae)
            returns.insert(0, last_gae + values[t].item())
            next_value = values[t].item()
        return (torch.tensor(advantages, dtype=torch.float32).detach().to(device),
                torch.tensor(returns, dtype=torch.float32).detach().to(device))

    def update_policy(self, states: List[torch.Tensor], actions: List[torch.Tensor], returns: torch.Tensor, advantages: torch.Tensor, old_log_probs: List[torch.Tensor]):
        """
        Updates the policy network using PPO.

        Args:
            states (List[torch.Tensor]): Collected states from the rollout.
            actions (List[torch.Tensor]): Actions taken.
            returns (torch.Tensor): Returns computed from GAE.
            advantages (torch.Tensor): Advantages computed from GAE.
            old_log_probs (List[torch.Tensor]): Log probabilities of the old actions.
        """
        states = torch.cat(states).to(device)
        actions = torch.cat(actions).to(device)
        old_log_probs = torch.cat(old_log_probs).to(device)
        batch_size = 64
        indices = np.arange(states.shape[0])
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        for _ in range(self.params["update_epochs"]):
            np.random.shuffle(indices)
            for start in range(0, states.shape[0], batch_size):
                end = start + batch_size
                batch_indices = indices[start:end]
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_returns = returns[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                
                logits, values = self.agent(batch_states)
                probs = F.softmax(logits, dim=-1)
                dist = Categorical(probs)
                new_log_probs = dist.log_prob(batch_actions)
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = F.mse_loss(values.squeeze(-1), batch_returns)
                entropy = dist.entropy().mean()
                loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()
                self.episode_losses.append(loss.item())

    def save_model(self):
        """
        Saves the trained model to disk and logs it to W&B.
        """
        artifact = wandb.Artifact('ppo_model', type='model')
        path = 'ppo_breakout_model.pth'
        torch.save({
            'model_state_dict': self.agent.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'episode_returns': self.episode_returns,
            'episode_losses': self.episode_losses,
            'episode_lengths': self.episode_lengths,
            'params': self.params
        }, path)
        artifact.add_file(path)
        wandb.log_artifact(artifact)

    def train(self):
        """
        Main training loop for the PPO algorithm.
        """
        for episode in range(self.params["num_episodes"]):
            states, actions, rewards, values, log_probs, dones, entropies = self.collect_rollout()
            advantages, returns = self.compute_gae(rewards, values, dones)
            self.update_policy(states, actions, returns, advantages, log_probs)
            episode_reward = self.episode_returns[-1]
            reward_moving_avg = self.calculate_moving_average(self.episode_returns)
            loss_moving_avg = self.calculate_moving_average(self.episode_losses)
            entropy_avg = np.mean([entropy.item() for entropy in entropies])
            wandb.log({
                "Reward": {"Raw": episode_reward, "Moving Avg": reward_moving_avg},
                "Loss": {"Raw": self.episode_losses[-1], "Moving Avg": loss_moving_avg},
                "Entropy": entropy_avg,
            })
            print(f"Episode {episode + 1}: Total Reward = {episode_reward}, Moving Avg Reward = {reward_moving_avg:.2f}, Entropy = {entropy_avg:.4f}")
        self.save_model()
        wandb.finish()

def make_env() -> gym.Env:
    """
    Creates and wraps the Breakout environment.

    Returns:
        gym.Env: The wrapped Breakout environment.
    """
    env = gym.make("Breakout-v4", render_mode="rgb_array")
    env = EpisodicLifeEnv(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayscaleObservation(env)
    env = gym.wrappers.FrameStackObservation(env, 4)
    return env

if __name__ == "__main__":
    # Define default configuration
    config_defaults = {
        "learning_rate": 3e-4,
        "gamma": 0.99,
        "gae_lambda": 0.95,
        "clip_epsilon": 0.2,
        "entropy_coef": 0.01,
        "value_coef": 0.500,
        "max_grad_norm": 0.5,
        "num_episodes": 4000,
        "update_epochs": 10,
    }

    # Initialize a W&B run with the default configuration
    wandb.init(project="breakout-ppo", config=config_defaults)

    # Get the configuration from W&B
    config = wandb.config

    # Initialize the environment and the trainer
    env = make_env()
    trainer = PPOTrainer(env, config)

    # Start training
    trainer.train()

    # Clean up
    env.close()


## Pong

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import gymnasium as gym
import numpy as np
from stable_baselines3.common.atari_wrappers import (
    EpisodicLifeEnv,
    FireResetEnv,
)
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Dict, Any
import ale_py

gym.register_envs(ale_py)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.login()

class BreakoutPPONet(nn.Module):
    """
    Neural network model for PPO to handle Breakout game.

    Attributes:
        conv1, conv2, conv3: Convolutional layers.
        fc: Fully connected layer for feature extraction.
        actor: Fully connected layer for policy output.
        critic: Fully connected layer for value estimation.
    """
    def __init__(self, input_shape: Tuple[int, int, int], num_actions: int):
        """
        Initialize the Breakout PPO network.

        Args:
            input_shape (Tuple[int, int, int]): The shape of the input (channels, height, width).
            num_actions (int): The number of possible actions.
        """
        super(BreakoutPPONet, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        def conv2d_size_out(size: int, kernel_size: int, stride: int) -> int:
            return (size - (kernel_size - 1) - 1) // stride + 1

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(input_shape[1], 8, 4), 4, 2), 3, 1)
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(input_shape[2], 8, 4), 4, 2), 3, 1)
        linear_input_size = convw * convh * 64

        self.fc = nn.Linear(linear_input_size, 512)
        self.actor = nn.Linear(512, num_actions)
        self.critic = nn.Linear(512, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The logits for actions and value estimates.
        """
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return self.actor(x), self.critic(x)

    def get_action_and_value(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get the action, log probability, entropy, and value for a given state.

        Args:
            state (torch.Tensor): The input state.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                The action, log probability, entropy, and value.
        """
        logits, value = self.forward(state)
        probs = F.softmax(logits, dim=-1)
        dist = Categorical(probs)
        action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value


class PPOTrainer:
    """
    Trainer class for PPO on the Breakout environment.

    Attributes:
        env: The game environment.
        params: Training parameters.
        agent: The PPO network.
        optimizer: Optimizer for training the network.
        gamma, lam: Discount factors for rewards and GAE.
        clip_epsilon: Clipping epsilon for PPO.
        entropy_coef, value_coef: Coefficients for entropy and value loss.
        max_grad_norm: Maximum gradient norm for clipping.
        frames, episode_returns, episode_losses, episode_lengths: Tracking data.
        moving_avg_window: Window size for moving average calculation.
    """
    def __init__(self, env: gym.Env, params: Dict[str, Any]):
        """
        Initialize the PPO trainer.

        Args:
            env (gym.Env): The game environment.
            params (Dict[str, Any]): Dictionary of training parameters.
        """
        wandb.init(project="breakout-ppo", config=params)
        self.env = env
        self.params = params
        self.num_actions = env.action_space.n
        self.agent = BreakoutPPONet((4, 84, 84), self.num_actions).to(device)
        self.optimizer = torch.optim.Adam(self.agent.parameters(), lr=params["learning_rate"])
        self.gamma = params["gamma"]
        self.lam = params["gae_lambda"]
        self.clip_epsilon = params["clip_epsilon"]
        self.entropy_coef = params["entropy_coef"]
        self.value_coef = params["value_coef"]
        self.max_grad_norm = params["max_grad_norm"]
        self.frames = []
        self.episode_returns = []
        self.episode_losses = []
        self.episode_lengths = []
        self.moving_avg_window = 15

    def preprocess_obs(self, obs: np.ndarray) -> torch.Tensor:
        """
        Preprocess the observation for the model.

        Args:
            obs (np.ndarray): The raw observation.

        Returns:
            torch.Tensor: The preprocessed observation.
        """
        return torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device) / 255.0

    def calculate_moving_average(self, data: List[float]) -> float:
        """
        Calculate the moving average of a data series.

        Args:
            data (List[float]): The data series.

        Returns:
            float: The moving average.
        """
        return pd.Series(data).rolling(window=self.moving_avg_window).mean().iloc[-1]

    def collect_rollout(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[float], List[torch.Tensor], List[torch.Tensor], List[bool], List[torch.Tensor]]:
        """
        Collect a rollout from the environment.

        Returns:
            Tuple containing states, actions, rewards, values, log_probs, dones, and entropies.
        """
        states, actions, rewards, values, log_probs, dones, entropies = [], [], [], [], [], [], []
        max_steps = 2048
        step = 0
        obs, _ = self.env.reset()
        obs = self.preprocess_obs(obs)
        episode_reward = 0

        while step < max_steps:
            with torch.no_grad():
                action, log_prob, entropy, value = self.agent.get_action_and_value(obs)
            next_obs, reward, done, truncated, _ = self.env.step(action.cpu().numpy()[0])
            next_obs = self.preprocess_obs(next_obs)
            states.append(obs.detach())
            actions.append(action.detach())
            rewards.append(reward)
            values.append(value.detach())
            log_probs.append(log_prob.detach())
            entropies.append(entropy.detach())
            dones.append(done or truncated)
            episode_reward += reward
            obs = next_obs
            step += 1
            if done or truncated:
                obs, _ = self.env.reset()
                obs = self.preprocess_obs(obs)
                self.episode_returns.append(episode_reward)
                self.episode_lengths.append(len(rewards))
                episode_reward = 0
        self.episode_returns.append(episode_reward)
        self.episode_lengths.append(len(rewards))
        return states, actions, rewards, values, log_probs, dones, entropies

    def compute_gae(self, rewards: List[float], values: List[torch.Tensor], dones: List[bool]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the Generalized Advantage Estimate (GAE).

        Args:
            rewards (List[float]): List of rewards.
            values (List[torch.Tensor]): List of value estimates.
            dones (List[bool]): List of done flags.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Advantages and returns.
        """
        rewards = np.array(rewards)
        advantages, returns = [], []
        last_gae = 0
        next_value = 0 if dones[-1] else values[-1].item()
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t].item()
            last_gae = delta + self.gamma * self.lam * (1 - dones[t]) * last_gae
            advantages.insert(0, last_gae)
            returns.insert(0, last_gae + values[t].item())
            next_value = values[t].item()
        return (torch.tensor(advantages, dtype=torch.float32).detach().to(device),
                torch.tensor(returns, dtype=torch.float32).detach().to(device))

    def update_policy(self, states: List[torch.Tensor], actions: List[torch.Tensor], returns: torch.Tensor, advantages: torch.Tensor, old_log_probs: List[torch.Tensor]):
        """
        Update the policy using PPO.

        Args:
            states (List[torch.Tensor]): Collected states.
            actions (List[torch.Tensor]): Actions taken.
            returns (torch.Tensor): Discounted returns.
            advantages (torch.Tensor): Computed advantages.
            old_log_probs (List[torch.Tensor]): Old log probabilities of actions.
        """
        states = torch.cat(states).to(device)
        actions = torch.cat(actions).to(device)
        old_log_probs = torch.cat(old_log_probs).to(device)
        batch_size = 64
        indices = np.arange(states.shape[0])
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(self.params["update_epochs"]):
            np.random.shuffle(indices)
            for start in range(0, states.shape[0], batch_size):
                end = start + batch_size
                batch_indices = indices[start:end]
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_returns = returns[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]

                logits, values = self.agent(batch_states)
                probs = F.softmax(logits, dim=-1)
                dist = Categorical(probs)
                new_log_probs = dist.log_prob(batch_actions)
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = F.mse_loss(values.squeeze(-1), batch_returns)
                entropy = dist.entropy().mean()
                loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()
                self.episode_losses.append(loss.item())

    def save_model(self):
        """
        Save the trained model to disk and log it to W&B.
        """
        artifact = wandb.Artifact('ppo_model', type='model')
        path = 'ppo_breakout_model.pth'
        torch.save({
            'model_state_dict': self.agent.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'episode_returns': self.episode_returns,
            'episode_losses': self.episode_losses,
            'episode_lengths': self.episode_lengths,
            'params': self.params
        }, path)
        artifact.add_file(path)
        wandb.log_artifact(artifact)

    def train(self):
        """
        Main training loop for PPO.
        """
        for episode in range(self.params["num_episodes"]):
            states, actions, rewards, values, log_probs, dones, entropies = self.collect_rollout()
            advantages, returns = self.compute_gae(rewards, values, dones)
            self.update_policy(states, actions, returns, advantages, log_probs)
            episode_reward = self.episode_returns[-1]
            reward_moving_avg = self.calculate_moving_average(self.episode_returns)
            loss_moving_avg = self.calculate_moving_average(self.episode_losses)
            entropy_avg = np.mean([entropy.item() for entropy in entropies])
            wandb.log({
                "Reward": {"Raw": episode_reward, "Moving Avg": reward_moving_avg},
                "Loss": {"Raw": self.episode_losses[-1], "Moving Avg": loss_moving_avg},
                "Entropy": entropy_avg,
            })
            print(f"Episode {episode + 1}: Total Reward = {episode_reward}, Moving Avg Reward = {reward_moving_avg:.2f}, Entropy = {entropy_avg:.4f}")
        self.save_model()
        wandb.finish()


def make_env() -> gym.Env:
    """
    Create and wrap the Breakout environment.

    Returns:
        gym.Env: The wrapped Breakout environment.
    """
    env = gym.make("Breakout-v4", render_mode="rgb_array")
    env = EpisodicLifeEnv(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayscaleObservation(env)
    env = gym.wrappers.FrameStackObservation(env, 4)
    return env


if __name__ == "__main__":
    # Define default configuration
    config_defaults = {
        "learning_rate": 3e-4,
        "gamma": 0.99,
        "gae_lambda": 0.95,
        "clip_epsilon": 0.2,
        "entropy_coef": 0.01,
        "value_coef": 0.500,
        "max_grad_norm": 0.5,
        "num_episodes": 4000,
        "update_epochs": 10,
    }

    # Initialize a W&B run with the default configuration
    wandb.init(project="pong-ppo", config=config_defaults)

    # Get the configuration from W&B
    config = wandb.config

    # Initialize the environment and the trainer
    env = make_env()
    trainer = PPOTrainer(env, config)

    # Start training
    trainer.train()

    # Clean up
    env.close()
