In [None]:


import os
import random
import time
import gymnasium as gym
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import wandb
import cv2
import imageio
import ale_py

from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)

gym.register_envs(ale_py)
# from vizdoom import gymnasium_wrapper # Ensure ViZDoom is registered

# ===== CONFIGURATION =====
class Config:
    # Experiment settings
    exp_name = "PPO-Vectorized-Atari"
    seed = 42
    env_id = "PongNoFrameskip-v4"
    total_timesteps = 10_000_000  # Standard metric for vectorized training

    # PPO & Agent settings
    lr = 2.5e-4
    gamma = 0.99
    num_envs = 8  # Number of parallel environments
    max_steps = 128  # Steps per rollout per environment (aka num_steps)
    num_minibatches = 4
    PPO_EPOCHS = 4
    clip_value = 0.1 
    clip_coeff = 0.1  # Value clipping coefficient
    ENTROPY_COEFF = 0.01
    
    VALUE_COEFF = 0.5
    
    # Logging & Saving
    capture_video = True
    use_wandb = True
    wandb_project = "cleanRL"
    
    GAE = 0.95  # Generalized Advantage Estimation
    anneal_lr = True  # Whether to linearly decay the learning rate
    max_grad_norm = 0.5  # Gradient clipping value
    
    
    # Derived values
    @property
    def batch_size(self):
        return self.num_envs * self.max_steps

    @property
    def minibatch_size(self):
        return self.batch_size // self.num_minibatches

# --- Preprocessing ---
TARGET_HEIGHT = 64
TARGET_WIDTH = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# --- Networks ---
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, action_space):
        super(Agent, self).__init__()
        # Shared CNN feature extractor
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)), # Adjusted for 64x64 input
            nn.ReLU(),
        )
        # Actor head
        self.actor = layer_init(nn.Linear(512, action_space), std=0.01)
        # Critic head
        self.critic = layer_init(nn.Linear(512, 1), std=1.0)

    def get_features(self, x):
        return self.network(x)

    def get_value(self, x):
        return self.critic(self.get_features(x))

    def get_action(self, x, action=None, deterministic=False):
        features = self.get_features(x)
        logits = self.actor(features)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs=probs)
        if deterministic:
            probs = torch.softmax(logits, dim=-1)
            action = torch.argmax(probs, dim=-1)
        if action is None:
            action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return action, log_prob, entropy
    
    def evaluate_get_action(self, x, action):
        features = self.get_features(x)
        logits = self.actor(features)
        dist = torch.distributions.Categorical(logits=logits)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return log_prob, entropy

# --- Environment Creation ---
def make_env(env_id, seed, idx, run_name, eval_mode=False):
    def thunk():
        render_mode = "rgb_array" if eval_mode else None
        # Force RGB24 format for ViZDoom to avoid CRCGCB warning
        env = gym.make(env_id, render_mode=render_mode)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        # env = gym.wrappers.AtariPreprocessing(env,
        #     frame_skip=4,  # Standard frame skip for Atari
        #     grayscale_obs=True,  # Add channel dimension for grayscale
        #     scale_obs=True,  # Scale observations to [0, 1]
        #     screen_size=(TARGET_HEIGHT, TARGET_WIDTH),  # Resize to target dimensions
        # )
        # # Use our custom wrapper for all preprocessing
        # # env = PreprocessAndFrameStack(env, height=TARGET_HEIGHT, width=TARGET_WIDTH, num_stack=4)
        # env = gym.wrappers.FrameStackObservation(env, 4)
        
        # env = gym.wrappers.RecordEpisodeStatistics(env)
        # env = NoopResetEnv(env, noop_max=30)
        # env = MaxAndSkipEnv(env, skip=4)
        # env = EpisodicLifeEnv(env)  
        # # print(env.unwrapped.get_action_meanings())
        # if "FIRE" in env.unwrapped.get_action_meanings():
        #     env = FireResetEnv(env)
        # env = ClipRewardEnv(env)
        # env = gym.wrappers.ResizeObservation(env, (84, 84))
        # env = gym.wrappers.GrayscaleObservation(env)
        # env = gym.wrappers.FrameStackObservation(env, 4)
         # Use the all-in-one, official Atari wrapper
        env = gym.wrappers.AtariPreprocessing(
            env,
            noop_max=30,
            frame_skip=4,
            screen_size=84, # It assumes square images
            terminal_on_life_loss=True, # Standard for training
            grayscale_obs=True,
            scale_obs=True # We want uint8 [0, 255] for storage
        )
        
        # Now, stack the preprocessed frames
        # env = ClipRewardEnv(env)  # Clip rewards to [-1, 1]
        env = gym.wrappers.FrameStackObservation(env, 4)
        env.action_space.seed(seed + idx)
        env.observation_space.seed(seed + idx)
        return env
    return thunk

# --- Evaluation ---
def evaluate(agent_model, device, run_name, num_eval_eps=10, record=False):
    eval_env = make_env(env_id=Config.env_id, seed=Config.seed, idx=0, run_name=run_name, eval_mode=True)()
    
    agent_model.to(device)
    agent_model.eval()
    returns = []
    frames = []

    for eps in tqdm(range(num_eval_eps), desc="Evaluating"):
        obs, _ = eval_env.reset()
        done = False
        episode_reward = 0.0

        while not done:
            if record:
                # Get the raw frame from the original env for nice videos
                frame = eval_env.unwrapped.render()
                frames.append(frame)

            with torch.no_grad():
                # Add batch dimension and convert to tensor
                obs_tensor = torch.tensor(obs, device=device, dtype=torch.float32).unsqueeze(0)
                action, _, _ = agent_model.get_action(obs_tensor, deterministic=True)
                # Convert action to scalar integer for ViZDoom
                action_scalar = action.cpu().numpy().item()
                obs, reward, terminated, truncated, info = eval_env.step(action_scalar)
                done = terminated or truncated
                episode_reward += float(reward)
                # Use raw reward from info if available
                # if "episode" in info:
                #     episode_reward = info["episode"]["r"]
          
        returns.append(episode_reward)
      
    eval_env.close()
    agent_model.train()
    return returns, frames



# --- Main Execution ---
if __name__ == "__main__":
    args = Config()
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

    if args.use_wandb:
        wandb.init(
            project=args.wandb_project,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    # writer = SummaryWriter(f"runs/{run_name}")
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    device = "cuda"

    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.seed, i, run_name) for i in range(args.num_envs)]
    )
    # assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

    actor_network = Agent(envs.single_action_space.n).to(device)
    optimizer = optim.Adam(actor_network.parameters(), lr=args.lr, eps=1e-5)
    # critic_optim = optim.Adam(critic_network.parameters(), lr=args.lr, eps=1e-5)

    obs_storage = torch.zeros((args.max_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions_storage = torch.zeros((args.max_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    rewards_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    dones_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    values_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    
    global_step = 0
    start_time = time.time()
    num_updates = args.total_timesteps // args.batch_size
    
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs).to(device)

    for update in tqdm(range(1, num_updates + 1), desc="Training Updates"):
        
        frac = 1.0 - (update / num_updates)
        lr = args.lr * frac
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
      
        
        for step in range(0, args.max_steps):
            global_step += args.num_envs
            obs_storage[step] = next_obs
            dones_storage[step] = next_done

            with torch.no_grad():
                action, logprob, _ = actor_network.get_action(next_obs)
                value = actor_network.get_value(next_obs)
            
            values_storage[step] = value.flatten()
            actions_storage[step] = action
            logprobs_storage[step] = logprob

            next_obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())
            done = np.logical_or(terminated, truncated)
            
            rewards_storage[step] = torch.tensor(reward).to(device).view(-1)
            next_obs = torch.Tensor(next_obs).to(device)
            next_done = torch.Tensor(done).to(device)

            if "final_info" in info:
                for item in info["final_info"]:
                    if item and "episode" in item:
                        wandb.log({"charts/episodic_return": item['episode']['r'], "global_step": global_step})
                        wandb.log({"charts/episodic_length": item['episode']['l'], "global_step": global_step})

        # === Advantage Calculation & Returns (YOUR ORIGINAL LOGIC) ===
        with torch.no_grad():
            advantages = torch.zeros_like(rewards_storage).to(device)
            
            # 1. Bootstrap value: Get value of the state *after*
            bootstrap_value = actor_network.get_value(next_obs).squeeze()
            lastgae = 0.0

            for t in reversed(range(args.max_steps)):
                
                if t == args.max_steps - 1:
                    nextnonterminal = (1.0 - next_done)
                    gt_next_state = bootstrap_value * nextnonterminal
                else:
                    nextnonterminal = (1.0 - dones_storage[t + 1])
                    gt_next_state = values_storage[t + 1] * nextnonterminal # If done at t, the next gt is 0
                
                delta = (rewards_storage[t] +  args.gamma *  gt_next_state ) - values_storage[t]

                advantages[t] = lastgae = delta + args.GAE * lastgae * nextnonterminal * args.gamma

        
        # Calculate advantages using the computed returns and stored values
        returns = advantages + values_storage
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # === PPO Update Phase ===
        b_obs = obs_storage.reshape((-1,) +  envs.single_observation_space.shape)
        b_logprobs = logprobs_storage.reshape(-1)
        b_actions = actions_storage.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values_storage.reshape(-1)

        b_inds = np.arange(args.batch_size)
        for epoch in range(args.PPO_EPOCHS):
            np.random.shuffle(b_inds)
            
        
            
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                new_log_probs, entropy = actor_network.evaluate_get_action(b_obs[mb_inds], b_actions[mb_inds])
                ratio = torch.exp(new_log_probs - b_logprobs[mb_inds])
                logratio = new_log_probs - b_logprobs[mb_inds]
                with torch.no_grad():
                    approx_kl = ((ratio - 1) - logratio).mean()
                    wandb.log({"charts/approx_kl": approx_kl.item()})

                b_advantages[mb_inds] = (b_advantages[mb_inds] - b_advantages[mb_inds].mean()) / (b_advantages[mb_inds].std() + 1e-8)
                
                pg_loss1 = b_advantages[mb_inds] * ratio
                pg_loss2 = b_advantages[mb_inds] * torch.clamp(ratio, 1 - args.clip_value, 1 + args.clip_value)
                policy_loss = -torch.min(pg_loss1, pg_loss2).mean()

                current_values = actor_network.get_value(b_obs[mb_inds]).squeeze()
                
                # Value clipping
                v_loss_unclipped = (current_values - b_returns[mb_inds]) ** 2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    current_values - b_values[mb_inds], -args.clip_coeff, args.clip_coeff
                )
                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                critic_loss = args.VALUE_COEFF * 0.5 * v_loss_max.mean()
                
                entropy_loss = entropy.mean()
                loss = policy_loss - args.ENTROPY_COEFF * entropy_loss + critic_loss

                # actor_optim.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                
                grad_norm_dict = {}
                total_norm = 0
                for name, param in actor_network.named_parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2)
                        if 'actor' in name or 'critic' in name:
                            grad_norm_dict[f"gradients/norm_{name}"] = param_norm.item()
                        else:
                            grad_norm_dict[f"gradients/shared_norm_{name}"] = param_norm.item()
                        total_norm += param_norm.item() ** 2
                grad_norm_dict["gradients/total_norm"] = total_norm ** 0.5
                wandb.log(grad_norm_dict)
                
                nn.utils.clip_grad_norm_(actor_network.parameters(), 0.5)
                # nn.utils.clip_grad_norm_(actor_network.parameters(), 0.5)
                # actor_optim.step()
                optimizer.step()
        
        if args.use_wandb:
            wandb.log({ 
                "losses/total_loss": loss.item(),
                "losses/policy_loss": policy_loss.item(),
                "losses/value_loss": critic_loss.item(),
                "losses/entropy": entropy_loss.item(),
                "charts/learning_rate": optimizer.param_groups[0]['lr'],
                "charts/episodic_return": np.mean(rewards_storage.cpu().numpy()),
                "charts/advantages_mean": b_advantages.mean().item(),
                "charts/advantages_std": b_advantages.std().item(),
                "charts/returns_mean": b_returns.mean().item(),
                "global_step": global_step,
            })
            print(f"Update {update}, Global Step: {global_step}, Policy Loss: {policy_loss.item():.4f}, Value Loss: {critic_loss.item():.4f}")
    
        if update % 200 == 0:
            episodic_returns, _ = evaluate(actor_network, device, run_name, num_eval_eps=5, record=args.capture_video)
            # Log the average return from the evaluation
            avg_return = np.mean(episodic_returns)
            
            if args.use_wandb:
                wandb.log({
                    "eval/avg_return": avg_return,
                    "global_step": global_step,
                })
            print(f"Evaluation at step {global_step}: Average raw return = {avg_return:.2f}")

    if args.capture_video:
        print("Capturing final evaluation video...")
        episodic_returns, eval_frames = evaluate(actor_network, device, run_name, num_eval_eps=10, record=True)

        if len(eval_frames) > 0:
            video_path = f"videos/final_eval_{run_name}.mp4"
            os.makedirs(os.path.dirname(video_path), exist_ok=True)
            imageio.mimsave(video_path, eval_frames, fps=30, codec='libx264')
            if args.use_wandb:
                wandb.log({"eval/final_video": wandb.Video(video_path, fps=30, format="mp4")})
                print(f"Final evaluation video saved and uploaded to WandB.")

    envs.close()
    if args.use_wandb:
        wandb.finish()

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

In [None]:
import wandb
wandb.login(key=secret_value_0)

In [None]:
!pip install  SuperSuit

In [None]:
!pip install pettingzoo[atari]

In [None]:
!pip uninstall stable_baselines3 -y

In [None]:
pip install stable_baselines3==2.6.0

In [None]:
import os
import time
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import wandb
import imageio
import ale_py

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from wandb.integration.sb3 import WandbCallback

gym.register_envs(ale_py)

# ===== CONFIGURATION (Mirrors the custom script) =====
class Config:
    # Experiment settings
    exp_name = "PPO-SB3-Atari-Benchmark"
    seed = 42
    env_id = "BoxingNoFrameskip-v4"
    total_timesteps = 10_000_000

    # PPO & Agent settings
    lr = 2.5e-4
    gamma = 0.99
    num_envs = 8  # Number of parallel environments
    n_steps = 128  # Steps per rollout per environment (max_steps in custom)
    num_minibatches = 4
    n_epochs = 4   # PPO_EPOCHS in custom
    clip_range = 0.1 # clip_value in custom
    ent_coef = 0.01  # ENTROPY_COEFF
    vf_coef = 0.5    # VALUE_COEFF
    gae_lambda = 0.95 # GAE
    max_grad_norm = 0.5
    
    # Logging & Saving
    capture_video = True
    use_wandb = True
    wandb_project = "cleanRL"
    
    # Evaluation
    eval_freq_updates = 200 # Evaluate every 200 updates
    num_eval_episodes = 10

    # Derived values
    @property
    def batch_size(self):
        # In SB3, this is the minibatch size.
        return (self.num_envs * self.n_steps) // self.num_minibatches
    
    @property
    def eval_freq_steps(self):
        # Convert update frequency to step frequency
        return self.eval_freq_updates * self.n_steps

# --- Environment Creation (Mirrors the custom script) ---
def make_env(env_id, seed, idx):
    # def thunk():
    env = gym.make(env_id)
    # Use the all-in-one, official Atari wrapper from Gymnasium
    # This handles: No-op resets, frame skipping, resizing, grayscaling, life-based terminals, and reward clipping.
    env = gym.wrappers.AtariPreprocessing(
        env,
        noop_max=30,
        frame_skip=4,
        screen_size=84,
        terminal_on_life_loss=True,
        grayscale_obs=True,
        scale_obs=True # Keep as False to match custom script's uint8 storage, SB3 handles scaling
    )
    # Stack the preprocessed frames
    env = gym.wrappers.FrameStackObservation(env, 4)
    env.action_space.seed(seed + idx)
    env.observation_space.seed(seed + idx)
    return env
    # return thunk

# --- Custom Network Architecture (Mirrors the custom script) ---
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class CustomCNN(BaseFeaturesExtractor):
    """
    Custom CNN feature extractor to match the architecture of the custom PyTorch script.
    """
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # The observation space from FrameStack is (4, 84, 84)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            layer_init(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(
            layer_init(nn.Linear(n_flatten, features_dim)),
            nn.ReLU()
        )

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # SB3 automatically handles the normalization of images (dividing by 255)
        return self.linear(self.cnn(observations))

# ===== SCRIPT START =====
if __name__ == '__main__':
    # --- Setup ---
    args = Config()
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    
    if args.use_wandb:
        run = wandb.init(
            project=args.wandb_project,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )

    # --- Set seeds for reproducibility ---
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # --- Create Vectorized Environment ---
    env = DummyVecEnv([make_env(args.env_id, args.seed, i) for i in range(args.num_envs)])

    # --- Define Hyperparameters and Policy Architecture ---
    policy_kwargs = {
        "features_extractor_class": CustomCNN,
        "features_extractor_kwargs": {"features_dim": 512},
        "net_arch": [],  # No extra hidden layers between extractor and heads
        "activation_fn": nn.ReLU,
    }

    # --- Create PPO Model ---
    model = PPO(
        "CnnPolicy",
        env,
        policy_kwargs=policy_kwargs,
        learning_rate=args.lr,
        n_steps=args.n_steps,
        batch_size=args.batch_size,
        n_epochs=args.n_epochs,
        gamma=args.gamma,
        gae_lambda=args.gae_lambda,
        clip_range=args.clip_range,
        ent_coef=args.ent_coef,
        vf_coef=args.vf_coef,
        max_grad_norm=args.max_grad_norm,
        seed=args.seed,
        tensorboard_log=f"runs/{run.id}" if args.use_wandb else None,
        verbose=1,
    )

    # --- Setup Callbacks ---
    callbacks = []
    # 1. Evaluation Callback
    # Create a separate, non-vectorized env for evaluation
    eval_env = make_env(args.env_id, args.seed, 0)()
    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=f'models/{run.id}/' if args.use_wandb else None,
        log_path=f'models/{run.id}/' if args.use_wandb else None,
        eval_freq=max(args.eval_freq_steps // args.num_envs, 1),
        n_eval_episodes=args.num_eval_episodes,
        deterministic=True,
        render=False,
    )
    callbacks.append(eval_callback)

    # 2. W&B Callback
    if args.use_wandb:
        wandb_callback = WandbCallback(
            gradient_save_freq=100_000, # Log gradients periodically
            model_save_path=f"models/{run.id}",
            verbose=2,
        )
        callbacks.append(wandb_callback)

    # --- Train ---
    print("Policy Architecture:")
    print(model.policy)
    print("\nStarting training...")
    model.learn(
        total_timesteps=args.total_timesteps,
        callback=callbacks,
        progress_bar=True,
    )


    # --- Cleanup ---
    env.close()
    eval_env.close()
    if 'video_env' in locals():
        video_env.close()
    if args.use_wandb:
        run.finish()

In [None]:


import os
import random
import time
import gymnasium as gym
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import wandb
import cv2
import imageio
# import ale_py
from pettingzoo.atari import pong_v3
import importlib
import supersuit as ss
from pettingzoo.atari import pong_v3

# ===== CONFIGURATION =====
class Config:
    # Experiment settings
    exp_name = "PPO-PettingZoo-Pong"
    seed = 42
    env_id = "pong_v3"
    total_timesteps = 5_000_000  # Standard metric for vectorized training

    # PPO & Agent settings
    lr = 2.5e-4
    gamma = 0.99
    num_envs = 8  # Number of parallel environments
    max_steps = 128  # Steps per rollout per environment (aka num_steps)
    num_minibatches = 4
    PPO_EPOCHS = 4
    clip_value = 0.1 
    clip_coeff = 0.1  # Value clipping coefficient
    ENTROPY_COEFF = 0.01
    
    VALUE_COEFF = 0.5
    
    # Logging & Saving
    capture_video = False
    use_wandb = True
    wandb_project = "cleanRL"
    
    GAE = 0.95  # Generalized Advantage Estimation
    anneal_lr = True  # Whether to linearly decay the learning rate
    max_grad_norm = 0.5  # Gradient clipping value
    
    
    # Derived values
    @property
    def batch_size(self):
        return self.num_envs * self.max_steps

    @property
    def minibatch_size(self):
        return self.batch_size // self.num_minibatches

# --- Preprocessing ---
TARGET_HEIGHT = 64
TARGET_WIDTH = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# --- Networks ---
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, action_space):
        super(Agent, self).__init__()
        # Shared CNN feature extractor
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(6, 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)), # Adjusted for 64x64 input
            nn.ReLU(),
        )
        # Actor head
        self.actor = layer_init(nn.Linear(512, action_space), std=0.01)
        # Critic head
        self.critic = layer_init(nn.Linear(512, 1), std=1.0)

    def get_features(self, x):
        return self.network(x)

    def get_value(self, x):  
        x = x.clone()
        x = x.permute(0, 3, 1, 2)
        x[:, :4, :, :] /= 255.0
        return self.critic(self.get_features(x))

    def get_action(self, x, action=None, deterministic=False):
        # print("No eval: ", x.shape)
        x = x.clone()
        x = x.permute(0, 3, 1, 2)
        x[:, :4, :, :] /= 255.0
        
        features = self.get_features(x)
        logits = self.actor(features)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs=probs)
        if deterministic:
            probs = torch.softmax(logits, dim=-1)
            action = torch.argmax(probs, dim=-1)
        if action is None:
            action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return action, log_prob, entropy
    
    def evaluate_get_action(self, x, action):
        # print("Eval: ", x.shape)
        x = x.clone()
        x = x.permute(0, 3, 1, 2)
        x[:, :4, :, :] /= 255.0
        features = self.get_features(x)
        logits = self.actor(features)
        dist = torch.distributions.Categorical(logits=logits)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return log_prob, entropy

# --- Environment Creation ---
def make_env(env_id, seed, idx, run_name, eval_mode=False):
    
    env = importlib.import_module(f"pettingzoo.atari.{args.env_id}").parallel_env()
    # env = gym.wrappers.RecordEpisodeStatistics(env)
    # env.reset(seed=seed)  # <--- Required to initialize np_random
    env = ss.max_observation_v0(env, 2)
    # env = ss.frame_skip_v0(env, 4)
    env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.frame_stack_v1(env, 4)
    env = ss.agent_indicator_v0(env, type_only=False)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gymnasium")
    envs.single_observation_space = envs.observation_space
    envs.single_action_space = envs.action_space
    envs.is_vector_env = True
    # envs = gym.wrappers.RecordEpisodeStatistics(envs)
  
    return envs


# --- NEW AND SIMPLE evaluate function ---
def evaluate(model, device, run_name, num_eval_eps=10, record=False):
    # For evaluation, we create a single, NON-vectorized game.
    # We use the .aec_env() which is designed for turn-by-turn interaction.
    eval_env = pong_v3.env(render_mode="rgb_array" if record else None)
    
    # Apply the same wrappers as the training env, but without the vectorization ones.
    eval_env = ss.max_observation_v0(eval_env, 2)
    # No frame_skip for evaluation, we want to see all frames.
    eval_env = ss.clip_reward_v0(eval_env, lower_bound=-1, upper_bound=1)
    eval_env = ss.color_reduction_v0(eval_env, mode="B")
    eval_env = ss.resize_v1(eval_env, x_size=84, y_size=84)
    eval_env = ss.frame_stack_v1(eval_env, 4)
    eval_env = ss.agent_indicator_v0(eval_env, type_only=False)

    model.to(device)
    model.eval()
    
    all_returns = []
    frames = []

    for _ in tqdm(range(num_eval_eps), desc="Evaluating"):
        eval_env.reset(seed=Config.seed + len(all_returns))
        episode_return = 0.0
        
        # The main AEC loop
        for agent in eval_env.agent_iter():
            obs, reward, terminated, truncated, info = eval_env.last()
            done = terminated or truncated

            if done:
                # When an agent is done, we need to step one last time with a None action
                eval_env.step(None)
                continue

            # Accumulate reward for the current agent
            episode_return += reward

            # If it's our trained agent's turn
            if agent == 'first_0':
                with torch.no_grad():
                    # The observation needs a batch dimension for the network
                    obs_tensor = torch.Tensor(obs).to(device).unsqueeze(0)
                    action, _, _ = model.get_action(obs_tensor, deterministic=True)
                eval_env.step(action.cpu().item())
            else:
                # The opponent just takes a random action
                action = eval_env.action_space(agent).sample()
                eval_env.step(action)
            
            if record:
                frames.append(eval_env.render())

        all_returns.append(episode_return)

    eval_env.close()
    model.train()
    
    return all_returns, frames
# --- Main Execution ---
if __name__ == "__main__":
    args = Config()
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

    if args.use_wandb:
        wandb.init(
            project=args.wandb_project,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    # writer = SummaryWriter(f"runs/{run_name}")
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    device = "cuda"

    # 3. Create multiple parallel games
    envs = make_env(env_id=args.env_id, seed=args.seed, idx=0, run_name=run_name)
    # print(env.single_action_space)
    # envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gymnasium")
    
    actor_network = Agent(envs.action_space.n).to(device)
    
    optimizer = optim.Adam(actor_network.parameters(), lr=args.lr, eps=1e-5)
    # critic_optim = optim.Adam(critic_network.parameters(), lr=args.lr, eps=1e-5)

    obs_storage = torch.zeros((args.max_steps, args.num_envs) + envs.observation_space.shape).to(device)
    actions_storage = torch.zeros((args.max_steps, args.num_envs) + envs.action_space.shape).to(device)
    logprobs_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    rewards_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    dones_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    values_storage = torch.zeros((args.max_steps, args.num_envs)).to(device)
    
    # Episode tracking variables
    episodic_return = np.zeros(args.num_envs)
    episode_step_count = np.zeros(args.num_envs)
    
    global_step = 0
    start_time = time.time()
    num_updates = args.total_timesteps // args.batch_size
    
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs).to(device)

    for update in tqdm(range(1, num_updates + 1), desc="Training Updates"):
        
        frac = 1.0 - (update / num_updates)
        lr = args.lr * frac
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
      
        
        for step in range(0, args.max_steps):
            global_step += args.num_envs
            obs_storage[step] = next_obs
            dones_storage[step] = next_done

            # for agent in envs.agent_iter():
                
            
            # if done:
            #     envs.step(None)
            #     continue
            
            # elif agent == 'first_0':
            
            with torch.no_grad():
                action, logprob, _ = actor_network.get_action(next_obs)
                value = actor_network.get_value(next_obs)
            
            values_storage[step] = value.flatten()
            actions_storage[step] = action
            logprobs_storage[step] = logprob

            next_obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())
            done = np.logical_or(terminated, truncated)
            
            # Update episode tracking
            episodic_return += reward
            episode_step_count += 1
            
            rewards_storage[step] = torch.tensor(reward).to(device).view(-1)
            next_obs = torch.Tensor(next_obs).to(device)
            next_done = torch.Tensor(done).to(device)
            # else:
            #     # For other agents, just sample random actions
            #     action = envs.single_action_space(agent).sample()
            #     envs.step(action)
            #     rewards_storage[step] = torch.tensor(reward).to(device).view(-1)
                # next_obs = torch.Tensor(obs).to(device)
                # next_done = torch.Tensor(done).to(device)
                
                
            if "final_info" in info:
                for item in info["final_info"]:
                    # The item can be None if the env at that index is not done
                    if item and "episode" in item:
                        print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
                        wandb.log({
                            "charts/episodic_return": item['episode']['r'],
                            "charts/episodic_length": item['episode']['l'],
                            "global_step": global_step
                        })
        # === Advantage Calculation & Returns (YOUR ORIGINAL LOGIC) ===
        with torch.no_grad():
            advantages = torch.zeros_like(rewards_storage).to(device)
            
            # 1. Bootstrap value: Get value of the state *after*
            bootstrap_value = actor_network.get_value(next_obs).squeeze()
            lastgae = 0.0

            for t in reversed(range(args.max_steps)):
                
                if t == args.max_steps - 1:
                    nextnonterminal = (1.0 - next_done)
                    gt_next_state = bootstrap_value * nextnonterminal
                else:
                    nextnonterminal = (1.0 - dones_storage[t + 1])
                    gt_next_state = values_storage[t + 1] * nextnonterminal # If done at t, the next gt is 0
                
                delta = (rewards_storage[t] +  args.gamma *  gt_next_state ) - values_storage[t]

                advantages[t] = lastgae = delta + args.GAE * lastgae * nextnonterminal * args.gamma

        
        # Calculate advantages using the computed returns and stored values
        returns = advantages + values_storage
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # === PPO Update Phase ===
        b_obs = obs_storage.reshape((-1,) +  envs.observation_space.shape)
        b_logprobs = logprobs_storage.reshape(-1)
        b_actions = actions_storage.reshape((-1,) + envs.action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values_storage.reshape(-1)

        b_inds = np.arange(args.batch_size)
        for epoch in range(args.PPO_EPOCHS):
            np.random.shuffle(b_inds)
            
        
            
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                new_log_probs, entropy = actor_network.evaluate_get_action(b_obs[mb_inds], b_actions[mb_inds])
                ratio = torch.exp(new_log_probs - b_logprobs[mb_inds])
                logratio = new_log_probs - b_logprobs[mb_inds]
                with torch.no_grad():
                    approx_kl = ((ratio - 1) - logratio).mean()
                    wandb.log({"charts/approx_kl": approx_kl.item()})

                b_advantages[mb_inds] = (b_advantages[mb_inds] - b_advantages[mb_inds].mean()) / (b_advantages[mb_inds].std() + 1e-8)
                
                pg_loss1 = b_advantages[mb_inds] * ratio
                pg_loss2 = b_advantages[mb_inds] * torch.clamp(ratio, 1 - args.clip_value, 1 + args.clip_value)
                policy_loss = -torch.min(pg_loss1, pg_loss2).mean()

                current_values = actor_network.get_value(b_obs[mb_inds]).squeeze()
                
                # Value clipping
                v_loss_unclipped = (current_values - b_returns[mb_inds]) ** 2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    current_values - b_values[mb_inds], -args.clip_coeff, args.clip_coeff
                )
                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                critic_loss = args.VALUE_COEFF * 0.5 * v_loss_max.mean()
                
                entropy_loss = entropy.mean()
                loss = policy_loss - args.ENTROPY_COEFF * entropy_loss + critic_loss

                # actor_optim.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                
                grad_norm_dict = {}
                total_norm = 0
                for name, param in actor_network.named_parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2)
                        if 'actor' in name or 'critic' in name:
                            grad_norm_dict[f"gradients/norm_{name}"] = param_norm.item()
                        else:
                            grad_norm_dict[f"gradients/shared_norm_{name}"] = param_norm.item()
                        total_norm += param_norm.item() ** 2
                grad_norm_dict["gradients/total_norm"] = total_norm ** 0.5
                wandb.log(grad_norm_dict)
                
                nn.utils.clip_grad_norm_(actor_network.parameters(), 0.5)
                # nn.utils.clip_grad_norm_(actor_network.parameters(), 0.5)
                # actor_optim.step()
                optimizer.step()
        
        if args.use_wandb:
            wandb.log({ 
                "losses/total_loss": loss.item(),
                "losses/policy_loss": policy_loss.item(),
                "losses/value_loss": critic_loss.item(),
                "losses/entropy": entropy_loss.item(),
                "charts/learning_rate": optimizer.param_groups[0]['lr'],
                "charts/episodic_return": np.mean(rewards_storage.cpu().numpy()),
                "charts/advantages_mean": b_advantages.mean().item(),
                "charts/advantages_std": b_advantages.std().item(),
                "charts/returns_mean": b_returns.mean().item(),
                "global_step": global_step,
            })
            print(f"Update {update}, Global Step: {global_step}, Policy Loss: {policy_loss.item():.4f}, Value Loss: {critic_loss.item():.4f}")
    
        if update % 200 == 0:
            episodic_returns, _ = evaluate(actor_network, device, run_name, num_eval_eps=5, record=args.capture_video)
            # Log the average return from the evaluation
            avg_return = np.mean(episodic_returns)
            
            if args.use_wandb:
                wandb.log({
                    "eval/avg_return": avg_return,
                    "global_step": global_step,
                })
            print(f"Evaluation at step {global_step}: Average raw return = {avg_return:.2f}")

    if args.capture_video:
        print("Capturing final evaluation video...")
        episodic_returns, eval_frames = evaluate(actor_network, device, run_name, num_eval_eps=10, record=True)

        if len(eval_frames) > 0:
            video_path = f"videos/final_eval_{run_name}.mp4"
            os.makedirs(os.path.dirname(video_path), exist_ok=True)
            imageio.mimsave(video_path, eval_frames, fps=30, codec='libx264')
            if args.use_wandb:
                wandb.log({"eval/final_video": wandb.Video(video_path, fps=30, format="mp4")})
                print(f"Final evaluation video saved and uploaded to WandB.")

    envs.close()
    if args.use_wandb:
        wandb.finish()

In [None]:
!pip install autorom

In [None]:
pip install supersuit

In [None]:
!AutoROM --accept-license