<a href="https://colab.research.google.com/github/Vyshnavijulapelly/Reinforcement-Learning/blob/main/RL_Lab_05.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import argparse
import collections
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import cv2


# Replay buffer
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (
            np.stack(state),
            np.array(action),
            np.array(reward, dtype=np.float32),
            np.stack(next_state),
            np.array(done, dtype=np.uint8),
        )

    def __len__(self):
        return len(self.buffer)


# Q-networks
class DQN_MLP(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
        )

    def forward(self, x):
        return self.net(x)


class DQN_CNN(nn.Module):
    def __init__(self, input_shape, act_dim):
        super().__init__()
        c, h, w = input_shape
        assert h == 84 and w == 84, "CNN expects 84x84 input"
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, act_dim),
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        x = x / 255.0
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)


# Atari preprocessing wrappers
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._skip = skip
        self._obs_buffer = np.zeros(
            (2,) + env.observation_space.shape, dtype=np.uint8
        )

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        max_frame = self._obs_buffer.max(axis=0)
        return max_frame, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class FrameProcessor(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
        )

    def observation(self, obs):
        obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
        return np.expand_dims(obs, -1)


class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        super().__init__(env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.k):
            self.frames.append(obs)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self):
        return np.concatenate(list(self.frames), axis=-1)


# Environment builder
def make_env(env_id: str, seed: int, cnn: bool, frameskip: int, frame_stack: int):
    env = gym.make(env_id, frameskip=1) if "ALE/" in env_id else gym.make(env_id)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    if cnn:
        if "ALE/" not in env_id:
            raise ValueError("--cnn was set but env is not Atari. Use ALE/Breakout-v5 etc.")
        env = MaxAndSkipEnv(env, skip=frameskip)
        env = gym.wrappers.TransformObservation(
            env,
            lambda obs: obs[:, :, ::-1] if obs is not None and obs.ndim == 3 else obs,
            observation_space=env.observation_space,
        )
        env = FrameProcessor(env)
        env = FrameStack(env, k=frame_stack)
    else:
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(
            env,
            lambda x: x.astype(np.float32),
            observation_space=gym.spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=env.observation_space.shape,
                dtype=np.float32,
            ),
        )

    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env


# Training loop
def train(args):
    env = make_env(args.env, args.seed, args.cnn, args.frameskip, args.frame_stack)
    obs_shape = env.observation_space.shape
    act_dim = env.action_space.n

    if args.cnn:
        obs_shape = (obs_shape[2], obs_shape[0], obs_shape[1])  # HWC->CHW
        q_net = DQN_CNN(obs_shape, act_dim)
        target_q_net = DQN_CNN(obs_shape, act_dim)
    else:
        q_net = DQN_MLP(obs_shape[0], act_dim)
        target_q_net = DQN_MLP(obs_shape[0], act_dim)

    target_q_net.load_state_dict(q_net.state_dict())
    optimizer = optim.Adam(q_net.parameters(), lr=args.lr)
    buffer = ReplayBuffer(args.buffer_size)

    epsilon = args.eps_start
    epsilon_decay = (args.eps_start - args.eps_end) / args.eps_decay_steps
    global_step = 0
    episode_rewards = []

    obs, _ = env.reset(seed=args.seed)
    if args.cnn:
        obs = np.transpose(obs, (2, 0, 1))

    while global_step < args.total_steps:
        # Epsilon-greedy
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
                if args.cnn:
                    obs_t = obs_t.to(next(q_net.parameters()).device)
                q_values = q_net(obs_t)
                action = q_values.argmax(dim=1).item()

        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        if args.cnn:
            next_obs_proc = np.transpose(next_obs, (2, 0, 1))
            buffer.push(obs, action, reward, next_obs_proc, done)
            obs = next_obs_proc
        else:
            buffer.push(obs, action, reward, next_obs, done)
            obs = next_obs

        if done:
            obs, _ = env.reset()
            if args.cnn:
                obs = np.transpose(obs, (2, 0, 1))
            if "episode" in info:
                ep_r = info["episode"]["r"]
                episode_rewards.append(ep_r)
                if len(episode_rewards) % args.log_interval == 0:
                    avg_r = np.mean(episode_rewards[-args.log_interval :])
                    print(f"Step {global_step}, AvgReward {avg_r:.2f}, Eps {epsilon:.3f}")

        # Training step
        if (
            len(buffer) >= args.learning_starts
            and global_step % args.train_freq == 0
        ):
            (
                batch_obs,
                batch_actions,
                batch_rewards,
                batch_next_obs,
                batch_dones,
            ) = buffer.sample(args.batch_size)

            batch_obs_t = torch.tensor(batch_obs, dtype=torch.float32)
            batch_actions_t = torch.tensor(batch_actions, dtype=torch.int64)
            batch_rewards_t = torch.tensor(batch_rewards, dtype=torch.float32)
            batch_next_obs_t = torch.tensor(batch_next_obs, dtype=torch.float32)
            batch_dones_t = torch.tensor(batch_dones, dtype=torch.float32)

            if args.cnn:
                batch_obs_t = batch_obs_t / 255.0
                batch_next_obs_t = batch_next_obs_t / 255.0

            q_values = q_net(batch_obs_t).gather(1, batch_actions_t.unsqueeze(1)).squeeze(1)

            with torch.no_grad():
                if args.double:
                    next_actions = q_net(batch_next_obs_t).argmax(dim=1)
                    next_q = target_q_net(batch_next_obs_t).gather(1, next_actions.unsqueeze(1)).squeeze(1)
                else:
                    next_q = target_q_net(batch_next_obs_t).max(1)[0]
                target = batch_rewards_t + args.gamma * (1 - batch_dones_t) * next_q

            loss = nn.SmoothL1Loss()(q_values, target)

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
            optimizer.step()

        if global_step % args.target_update == 0:
            target_q_net.load_state_dict(q_net.state_dict())

        epsilon = max(args.eps_end, epsilon - epsilon_decay)
        global_step += 1

    print("Training finished.")
    env.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="CartPole-v1")
    parser.add_argument("--total-steps", type=int, default=5000)  # shorter for Colab demo
    parser.add_argument("--buffer-size", type=int, default=10000)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--eps-start", type=float, default=1.0)
    parser.add_argument("--eps-end", type=float, default=0.1)
    parser.add_argument("--eps-decay-steps", type=int, default=1000)
    parser.add_argument("--target-update", type=int, default=1000)
    parser.add_argument("--learning-starts", type=int, default=1000)
    parser.add_argument("--train-freq", type=int, default=1)
    parser.add_argument("--cnn", action="store_true")
    parser.add_argument("--frameskip", type=int, default=4)
    parser.add_argument("--frame-stack", type=int, default=4)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--double", action="store_true")
    parser.add_argument("--log-interval", type=int, default=10)

    # âœ… Ignore Jupyter/Colab args
    args = parser.parse_args(args=[])

    train(args)


Step 200, AvgReward 20.10, Eps 0.820
Step 367, AvgReward 16.70, Eps 0.670
Step 525, AvgReward 15.80, Eps 0.527
Step 660, AvgReward 13.50, Eps 0.406
Step 789, AvgReward 12.90, Eps 0.290
Step 891, AvgReward 10.20, Eps 0.198
Step 1000, AvgReward 10.90, Eps 0.100
Step 1108, AvgReward 10.80, Eps 0.100
Step 1213, AvgReward 10.50, Eps 0.100
Step 1310, AvgReward 9.70, Eps 0.100
Step 1403, AvgReward 9.30, Eps 0.100
Step 1502, AvgReward 9.90, Eps 0.100
Step 1600, AvgReward 9.80, Eps 0.100
Step 1698, AvgReward 9.80, Eps 0.100
Step 1798, AvgReward 10.00, Eps 0.100
Step 1898, AvgReward 10.00, Eps 0.100
Step 1991, AvgReward 9.30, Eps 0.100
Step 2110, AvgReward 11.90, Eps 0.100
Step 2210, AvgReward 10.00, Eps 0.100
Step 2328, AvgReward 11.80, Eps 0.100
Step 2495, AvgReward 16.70, Eps 0.100
Step 2654, AvgReward 15.90, Eps 0.100
Step 2766, AvgReward 11.20, Eps 0.100
Step 2903, AvgReward 13.70, Eps 0.100
Step 3060, AvgReward 15.70, Eps 0.100
Step 3270, AvgReward 21.00, Eps 0.100
Step 3468, AvgReward 19.

In [10]:
#Implementing a policy gradient algorithm
from __future__ import annotations
import argparse
import math
import os
import random
from dataclasses import dataclass
from typing import List, Tuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

# -----------------------
# Utilities
# -----------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def discounted_returns(rewards: List[float], gamma: float, reward_to_go: bool) -> List[float]:
    """Compute (reward-to-go) discounted returns for a single episode."""
    if reward_to_go:
        returns = np.zeros(len(rewards), dtype=np.float32)
        running = 0.0
        for t in reversed(range(len(rewards))):
            running = rewards[t] + gamma * running
            returns[t] = running
        return returns.tolist()
    else:
        # full-episode return repeated
        total = 0.0
        for r in rewards:
            total = total + r * (gamma ** 0)  # just sum (no discount structure needed here)
        # but we want discounted sum for the episode: compute properly:
        running = 0.0
        for t in reversed(range(len(rewards))):
            running = rewards[t] + gamma * running
        return [float(running)] * len(rewards)


# -----------------------
# Networks
# -----------------------
class MLPPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(nn.ReLU())
            last = h
        layers.append(nn.Linear(last, act_dim))
        self.logits_net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.logits_net(x)  # raw logits

    def get_action_and_logp(self, obs: np.ndarray, device: torch.device):
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
        if obs_t.dim() == 1:
            obs_t = obs_t.unsqueeze(0)
        logits = self.forward(obs_t)
        dist = Categorical(logits=logits)
        action = dist.sample()
        logp = dist.log_prob(action)
        return int(action.item()), float(logp.item())


class MLPValue(nn.Module):
    def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(nn.ReLU())
            last = h
        layers.append(nn.Linear(last, 1))
        self.v_net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.v_net(x).squeeze(-1)  # returns shape (batch,)


# -----------------------
# Training loops
# -----------------------
@dataclass
class PGConfig:
    env: str = "CartPole-v1"
    algo: str = "reinforce"  # "reinforce" or "actor_critic"
    total_timesteps: int = 200_000
    batch_size: int = 5000  # for REINFORCE: number of timesteps per policy update
    max_episode_len: int = 1000
    gamma: float = 0.99
    lr: float = 1e-3
    hidden_sizes: Tuple[int, int] = (64, 64)
    seed: int = 0
    reward_to_go: bool = True
    normalize_adv: bool = True
    entropy_coef: float = 0.0
    value_lr: float = 1e-3  # for actor-critic
    value_iters: int = 1  # how many value updates per actor update
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    save_path: str = "pg_checkpoint.pt"
    log_interval: int = 10


def train_reinforce(cfg: PGConfig):
    env = gym.make(cfg.env)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=cfg.hidden_sizes).to(cfg.device)
    optimizer = optim.Adam(policy.parameters(), lr=cfg.lr)

    total_steps = 0
    ep_returns = []
    ep_lens = []
    logs = []

    while total_steps < cfg.total_timesteps:
        # collect a batch of trajectories (timesteps >= batch_size)
        batch_obs = []
        batch_acts = []
        batch_logps = []
        batch_rets = []  # discounted returns for each timestep
        batch_lens = []

        steps_collected = 0
        while steps_collected < cfg.batch_size:
            obs, _ = env.reset()
            done = False
            ep_rewards = []
            ep_obs = []
            ep_acts = []
            ep_logps = []
            for t in range(cfg.max_episode_len):
                a, logp = policy.get_action_and_logp(obs, device=cfg.device)
                next_obs, r, terminated, truncated, info = env.step(a)
                done = terminated or truncated
                ep_obs.append(obs.copy())
                ep_acts.append(a)
                ep_logps.append(logp)
                ep_rewards.append(r)
                obs = next_obs
                if done:
                    break

            # compute discounted returns
            rets = discounted_returns(ep_rewards, cfg.gamma, reward_to_go=cfg.reward_to_go)
            batch_obs.extend(ep_obs)
            batch_acts.extend(ep_acts)
            batch_logps.extend(ep_logps)
            batch_rets.extend(rets)
            batch_lens.append(len(ep_rewards))
            ep_returns.append(sum(ep_rewards))
            ep_lens.append(len(ep_rewards))
            steps_collected += len(ep_rewards)
            total_steps += len(ep_rewards)

        # convert to tensors
        obs_t = torch.as_tensor(np.array(batch_obs, dtype=np.float32), device=cfg.device)
        acts_t = torch.as_tensor(np.array(batch_acts, dtype=np.int64), device=cfg.device)
        logps_old_t = torch.as_tensor(np.array(batch_logps, dtype=np.float32), device=cfg.device)
        rets_t = torch.as_tensor(np.array(batch_rets, dtype=np.float32), device=cfg.device)

        # optionally normalize returns (baseline of zero)
        if cfg.normalize_adv:
            rets_t = (rets_t - rets_t.mean()) / (rets_t.std() + 1e-8)

        # policy loss = - sum log pi(a|s) * G_t  (averaged)
        logits = policy(obs_t)
        dist = Categorical(logits=logits)
        logp_all = dist.log_prob(acts_t)
        entropy = dist.entropy().mean()
        loss = -(logp_all * rets_t).mean() - cfg.entropy_coef * entropy

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if len(ep_returns) >= cfg.log_interval:
            avg_return = np.mean(ep_returns[-cfg.log_interval:])
            avg_len = np.mean(ep_lens[-cfg.log_interval:])
            print(f"[REINFORCE] steps={total_steps} avg_return={avg_return:.2f} avg_len={avg_len:.2f} loss={loss.item():.4f}")
            logs.append((total_steps, avg_return, avg_len))
    env.close()
    # save
    torch.save({"policy_state": policy.state_dict()}, cfg.save_path)
    return logs


def train_actor_critic(cfg: PGConfig):
    env = gym.make(cfg.env)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    policy = MLPPolicy(obs_dim, act_dim, hidden_sizes=cfg.hidden_sizes).to(cfg.device)
    value = MLPValue(obs_dim, hidden_sizes=cfg.hidden_sizes).to(cfg.device)

    opt_policy = optim.Adam(policy.parameters(), lr=cfg.lr)
    opt_value = optim.Adam(value.parameters(), lr=cfg.value_lr)

    total_steps = 0
    ep_returns = []
    ep_lens = []
    logs = []

    while total_steps < cfg.total_timesteps:
        # collect batch of transitions (here we collect episodes until steps >= batch_size)
        batch_obs = []
        batch_acts = []
        batch_logps = []
        batch_rets = []
        batch_vals = []
        batch_lens = []

        steps_collected = 0
        while steps_collected < cfg.batch_size:
            obs, _ = env.reset()
            done = False
            ep_rewards = []
            ep_obs = []
            ep_acts = []
            ep_logps = []
            ep_vals = []
            for t in range(cfg.max_episode_len):
                obs_t = torch.as_tensor(obs[None, :], dtype=torch.float32, device=cfg.device)
                logits = policy(obs_t)
                dist = Categorical(logits=logits)
                action = dist.sample().item()
                logp = float(dist.log_prob(torch.tensor(action, device=cfg.device)).item())
                val = float(value(obs_t).item())

                next_obs, r, terminated, truncated, info = env.step(action)
                done = terminated or truncated

                ep_obs.append(obs.copy())
                ep_acts.append(action)
                ep_logps.append(logp)
                ep_vals.append(val)
                ep_rewards.append(r)

                obs = next_obs
                if done:
                    break

            returns = discounted_returns(ep_rewards, cfg.gamma, reward_to_go=cfg.reward_to_go)
            batch_obs.extend(ep_obs)
            batch_acts.extend(ep_acts)
            batch_logps.extend(ep_logps)
            batch_rets.extend(returns)
            batch_vals.extend(ep_vals)
            batch_lens.append(len(ep_rewards))
            ep_returns.append(sum(ep_rewards))
            ep_lens.append(len(ep_rewards))
            steps_collected += len(ep_rewards)
            total_steps += len(ep_rewards)

        # tensors
        obs_t = torch.as_tensor(np.array(batch_obs, dtype=np.float32), device=cfg.device)
        acts_t = torch.as_tensor(np.array(batch_acts, dtype=np.int64), device=cfg.device)
        rets_t = torch.as_tensor(np.array(batch_rets, dtype=np.float32), device=cfg.device)
        vals_t = torch.as_tensor(np.array(batch_vals, dtype=np.float32), device=cfg.device)

        # advantages
        adv_t = rets_t - vals_t
        if cfg.normalize_adv:
            adv_t = (adv_t - adv_t.mean()) / (adv_t.std() + 1e-8)

        # policy loss (with entropy bonus)
        logits = policy(obs_t)
        dist = Categorical(logits=logits)
        logp_all = dist.log_prob(acts_t)
        entropy = dist.entropy().mean()
        policy_loss = -(logp_all * adv_t).mean() - cfg.entropy_coef * entropy

        opt_policy.zero_grad()
        policy_loss.backward()
        opt_policy.step()

        # value loss (MSE between returns and value)
        for _ in range(cfg.value_iters):
            val_pred = value(obs_t)
            value_loss = nn.MSELoss()(val_pred, rets_t)
            opt_value.zero_grad()
            value_loss.backward()
            opt_value.step()

        if len(ep_returns) >= cfg.log_interval:
            avg_return = np.mean(ep_returns[-cfg.log_interval:])
            avg_len = np.mean(ep_lens[-cfg.log_interval:])
            print(f"[A2C] steps={total_steps} avg_return={avg_return:.2f} avg_len={avg_len:.2f} policy_loss={policy_loss.item():.4f} value_loss={value_loss.item():.4f}")
            logs.append((total_steps, avg_return, avg_len))

    env.close()
    torch.save({"policy_state": policy.state_dict(), "value_state": value.state_dict()}, cfg.save_path)
    return logs


# -----------------------
# CLI / Entry
# -----------------------
def parse_args_colab_safe():
    p = argparse.ArgumentParser()
    p.add_argument("--env", type=str, default="CartPole-v1")
    p.add_argument("--algo", type=str, default="reinforce", choices=["reinforce", "actor_critic"])
    p.add_argument("--total-timesteps", type=int, default=200000)
    p.add_argument("--batch-size", type=int, default=5000)
    p.add_argument("--max-episode-len", type=int, default=1000)
    p.add_argument("--gamma", type=float, default=0.99)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--value-lr", type=float, default=1e-3)
    p.add_argument("--hidden-sizes", nargs="+", type=int, default=[64, 64])
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--reward-to-go", action="store_true")
    p.add_argument("--no-normalize-adv", dest="normalize_adv", action="store_false")
    p.add_argument("--entropy-coef", type=float, default=0.0)
    p.add_argument("--value-iters", type=int, default=1)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--save-path", type=str, default="pg_checkpoint.pt")
    p.add_argument("--log-interval", type=int, default=10)

    # ignore Jupyter / Colab args if present
    args = p.parse_args(args=[])
    return args


def main():
    args = parse_args_colab_safe()
    cfg = PGConfig(
        env=args.env,
        algo=args.algo,
        total_timesteps=args.total_timesteps,
        batch_size=args.batch_size,
        max_episode_len=args.max_episode_len,
        gamma=args.gamma,
        lr=args.lr,
        hidden_sizes=tuple(args.hidden_sizes),
        seed=args.seed,
        reward_to_go=args.reward_to_go,
        normalize_adv=args.normalize_adv,
        entropy_coef=args.entropy_coef,
        value_lr=args.value_lr,
        value_iters=args.value_iters,
        device=args.device,
        save_path=args.save_path,
        log_interval=args.log_interval,
    )

    set_seed(cfg.seed)

    print(f"Running {cfg.algo} on {cfg.env} for {cfg.total_timesteps} timesteps (device={cfg.device})")

    if cfg.algo == "reinforce":
        train_reinforce(cfg)
    else:
        train_actor_critic(cfg)


if __name__ == "__main__":
    main()


Running reinforce on CartPole-v1 for 200000 timesteps (device=cpu)
[REINFORCE] steps=5027 avg_return=26.70 avg_len=26.70 loss=0.0023
[REINFORCE] steps=10028 avg_return=19.40 avg_len=19.40 loss=0.0002
[REINFORCE] steps=15038 avg_return=24.30 avg_len=24.30 loss=0.0008
[REINFORCE] steps=20046 avg_return=22.20 avg_len=22.20 loss=-0.0017
[REINFORCE] steps=25087 avg_return=22.60 avg_len=22.60 loss=-0.0062
[REINFORCE] steps=30110 avg_return=26.50 avg_len=26.50 loss=-0.0056
[REINFORCE] steps=35124 avg_return=32.00 avg_len=32.00 loss=-0.0049
[REINFORCE] steps=40163 avg_return=24.20 avg_len=24.20 loss=-0.0050
[REINFORCE] steps=45194 avg_return=24.70 avg_len=24.70 loss=-0.0072
[REINFORCE] steps=50195 avg_return=28.00 avg_len=28.00 loss=-0.0095
[REINFORCE] steps=55203 avg_return=24.10 avg_len=24.10 loss=-0.0088
[REINFORCE] steps=60235 avg_return=36.90 avg_len=36.90 loss=-0.0083
[REINFORCE] steps=65247 avg_return=29.40 avg_len=29.40 loss=-0.0131
[REINFORCE] steps=70272 avg_return=23.20 avg_len=23.2