In [None]:
import subprocess
import sys

# Upgrade gymnasium to latest version
subprocess.check_call([sys.executable, "-m", "pip", "install", "gymnasium", "--upgrade"])
!pip install swig
!pip install "gymnasium[box2d]"

In [None]:
import wandb
WANDB_API_KEY = "2c421b79022678408a8dec66cb629dc5d1708474" 
wandb.login(key=WANDB_API_KEY)

In [None]:
# ============================================================
# Cell 2: Imports, Device, Seeding
# ============================================================

import os
import random
from collections import deque, namedtuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim

# ------------------------------------------------------------
# Device
# ------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# ------------------------------------------------------------
# Seeding utilities
# ------------------------------------------------------------
def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


class FrameSkip(gym.Wrapper):
    """
    Repeat the same action for n frames and accumulate reward.
    """
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        terminated = False
        truncated = False
        info = {}

        for _ in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if terminated or truncated:
                break

        return obs, total_reward, terminated, truncated, info


def make_env(env_name: str, seed: int, render_mode=None):
    env = gym.make(env_name, render_mode=render_mode, continuous=True)
    env = FrameSkip(env, skip=4)

    env.reset(seed=seed)
    env.action_space.seed(seed)
    return env


In [None]:
# ============================================================
# Cell 3: Replay Buffer & Utilities
# ============================================================

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class ReplayBuffer:
    """Standard Replay Buffer for Off-Policy Algorithms (SAC, TD3)."""
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        # state, next_state: np arrays (vector or image)
        self.buffer.append(Transition(state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        transitions = random.sample(self.buffer, batch_size)
        
        state_batch = torch.as_tensor(
            np.array([t.state for t in transitions]), dtype=torch.float32, device=DEVICE
        )
        action_batch = torch.as_tensor(
            np.array([t.action for t in transitions]), dtype=torch.float32, device=DEVICE
        )
        reward_batch = torch.as_tensor(
            np.array([t.reward for t in transitions]), dtype=torch.float32, device=DEVICE
        ).unsqueeze(-1)
        next_state_batch = torch.as_tensor(
            np.array([t.next_state for t in transitions]), dtype=torch.float32, device=DEVICE
        )
        done_batch = torch.as_tensor(
            np.array([t.done for t in transitions]), dtype=torch.float32, device=DEVICE
        ).unsqueeze(-1)
        
        return state_batch, action_batch, reward_batch, next_state_batch, done_batch
    
    def __len__(self):
        return len(self.buffer)

def get_action_bounds(env):
    """Returns low and high action bounds of a continuous action space."""
    return env.action_space.low, env.action_space.high

def scale_action(action_unscaled, min_action, max_action):
    """Scale action from [-1, 1] to [min_action, max_action]."""
    return min_action + (max_action - min_action) * (action_unscaled + 1.0) / 2.0

def unscale_action(action_scaled, min_action, max_action):
    """Inverse of scale_action: map [min_action, max_action] -> [-1, 1]."""
    return (action_scaled - min_action) / (max_action - min_action) * 2.0 - 1.0


In [None]:
# ============================================================
# Cell 4: Shared CNN/MLP Feature Extractor
# ============================================================

LOG_STD_MAX = 2
LOG_STD_MIN = -20

class FeatureExtractor(nn.Module):
    """
    Handles both vector observations (e.g., LunarLander)
    and image observations (e.g., CarRacing).
    - For vectors: simple MLP to 256 features.
    - For images: CNN -> FC to 512 features.
    """
    def __init__(self, obs_shape):
        super().__init__()
        self.obs_shape = obs_shape
        
        if len(obs_shape) == 1:
            # Vector observation
            self.is_image = False
            input_dim = obs_shape[0]
            self.mlp = nn.Sequential(
                nn.Linear(input_dim, 256),
                nn.Tanh(),
                nn.Linear(256, 256),
                nn.Tanh()
            )
            self.output_dim = 256
        else:
            # Image observation (H, W, C)
            self.is_image = True
            h, w, c = obs_shape
            conv_layers = [
                nn.Conv2d(c,   32, kernel_size=8, stride=4), nn.ReLU(),
                nn.Conv2d(32,  64, kernel_size=4, stride=2), nn.ReLU(),
                nn.Conv2d(64,  64, kernel_size=3, stride=1), nn.ReLU(),
            ]
            self.cnn = nn.Sequential(*conv_layers)
            
            # Compute the size after convs
            with torch.no_grad():
                dummy = torch.zeros(1, c, h, w)
                n_flatten = self.cnn(dummy).view(1, -1).shape[1]
            self.fc = nn.Sequential(
                nn.Linear(n_flatten, 512),
                nn.ReLU(),
            )
            self.output_dim = 512

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: tensor of shape
           - (B, obs_dim) for vector
           - (B, H, W, C) or (H, W, C) for image
        """
        if not self.is_image:
            if x.dim() == 1:
                x = x.unsqueeze(0)
            return self.mlp(x)

        # Image path
        if x.dim() == 3:
            x = x.unsqueeze(0)  # (1, H, W, C)
        if x.shape[1] != self.obs_shape[2]:
            # currently (B, H, W, C); convert to (B, C, H, W)
            x = x.permute(0, 3, 1, 2)
        x = x / 255.0
        x = self.cnn(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x


In [None]:
# ============================================================
# Cell 5: Policy & Value Networks (PPO/SAC/TD3)
# ============================================================

class PPOActorCritic(nn.Module):
    """
    Actor-Critic network for PPO.
    Uses FeatureExtractor internally to handle both LL-v3 and CarRacing-v3.
    """
    def __init__(self, obs_shape, action_dim):
        super().__init__()
        self.feature = FeatureExtractor(obs_shape)
        feat_dim = self.feature.output_dim
        
        self.actor_mean = nn.Linear(feat_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(1, action_dim))
        self.critic = nn.Linear(feat_dim, 1)

    def forward(self, state):
        features = self.feature(state)
        mean = torch.tanh(self.actor_mean(features))  # [-1, 1]
        log_std = self.actor_log_std.expand_as(mean)
        std = torch.exp(log_std)
        value = self.critic(features)
        return mean, std, value


class StochasticActor(nn.Module):
    """
    Stochastic actor for SAC.
    Outputs mu and log_std.
    """
    def __init__(self, obs_shape, action_dim):
        super().__init__()
        self.feature = FeatureExtractor(obs_shape)
        feat_dim = self.feature.output_dim
        
        self.mu_layer = nn.Linear(feat_dim, action_dim)
        self.log_std_layer = nn.Linear(feat_dim, action_dim)

    def forward(self, state):
        feat = self.feature(state)
        mu = self.mu_layer(feat)
        log_std = self.log_std_layer(feat)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        return mu, log_std


class DeterministicActor(nn.Module):
    """
    Deterministic actor for TD3.
    """
    def __init__(self, obs_shape, action_dim):
        super().__init__()
        self.feature = FeatureExtractor(obs_shape)
        feat_dim = self.feature.output_dim
        
        self.mu_layer = nn.Linear(feat_dim, action_dim)

    def forward(self, state):
        feat = self.feature(state)
        mu = self.mu_layer(feat)
        return torch.tanh(mu)


class QFunction(nn.Module):
    """
    Critic network Q(s,a) for SAC/TD3.
    Uses a FeatureExtractor (shared) for state.
    """
    def __init__(self, obs_shape, action_dim):
        super().__init__()
        self.feature = FeatureExtractor(obs_shape)
        feat_dim = self.feature.output_dim
        self.q_net = nn.Sequential(
            nn.Linear(feat_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, state, action):
        feat = self.feature(state)
        if action.dim() == 1:
            action = action.unsqueeze(0)
        x = torch.cat([feat, action], dim=-1)
        return self.q_net(x)


In [None]:
# ============================================================
# Cell 6: PPO Agent
# ============================================================

class PPOAgent:
    def __init__(self, obs_shape, action_dim, hyperparams, action_bounds):
        self.GAMMA = hyperparams['gamma']
        self.LR = hyperparams['learning_rate']
        self.CLIP_EPSILON = hyperparams['clip_epsilon']
        self.PPO_EPOCHS = hyperparams['ppo_epochs']
        self.MINIBATCH_SIZE = hyperparams['minibatch_size']
        self.ENTROPY_COEFF = hyperparams['entropy_coeff']
        self.ACTION_MIN, self.ACTION_MAX = action_bounds
        self.obs_shape = obs_shape

        self.model = PPOActorCritic(obs_shape, action_dim).to(DEVICE)
        self.old_model = PPOActorCritic(obs_shape, action_dim).to(DEVICE)
        self.old_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.LR)

        self.buffer = {'s': [], 'a': [], 'logp': [], 'v': [], 'r': [], 'm': []}

    def _state_to_tensor(self, state):
        t = torch.as_tensor(state, dtype=torch.float32, device=DEVICE)
        return t

    def select_action(self, state, deterministic=False):
        state_tensor = self._state_to_tensor(state)
        mean, std, value = self.model(state_tensor)

        if deterministic:
            action_unscaled = mean
        else:
            dist = Normal(mean, std)
            action_unscaled = dist.rsample()

        dist = Normal(mean, std)
        log_prob = dist.log_prob(action_unscaled).sum(dim=-1, keepdim=True)

        # scale action to env bounds
        action_unscaled_np = action_unscaled.detach().cpu().numpy().flatten()
        action_scaled = scale_action(action_unscaled_np, self.ACTION_MIN, self.ACTION_MAX)

        # store transition data for PPO
        self.buffer['s'].append(state)  # store raw state (np)
        self.buffer['a'].append(action_unscaled_np)
        self.buffer['logp'].append(log_prob.detach().cpu().numpy().flatten())
        self.buffer['v'].append(value.detach().cpu().numpy().flatten())

        return action_scaled

    def store_reward(self, reward, done):
        self.buffer['r'].append(reward)
        self.buffer['m'].append(1.0 - float(done))

    def clear_buffer(self):
        self.buffer = {'s': [], 'a': [], 'logp': [], 'v': [], 'r': [], 'm': []}

    def _compute_gae_and_returns(self, last_value):
        rewards = self.buffer['r']
        masks = self.buffer['m']
        values = np.concatenate(self.buffer['v'])

        # discounted returns
        returns = []
        R = last_value
        for reward, mask in zip(reversed(rewards), reversed(masks)):
            R = reward + self.GAMMA * R * mask
            returns.append(R)
        returns = np.array(list(reversed(returns))).flatten()

        advantage = returns - values
        advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
        return returns, advantage

    def learn(self, last_state):
        if not self.buffer['r']:
            return None

        if last_state is None:
            last_value = 0.0
        else:
            with torch.no_grad():
                state_tensor = self._state_to_tensor(last_state)
                _, _, last_value_tensor = self.model(state_tensor)
                last_value = last_value_tensor.item()

        returns, advantages = self._compute_gae_and_returns(last_value)

        state_batch = torch.as_tensor(
            np.array(self.buffer['s']), dtype=torch.float32, device=DEVICE
        )
        action_batch = torch.as_tensor(
            np.array(self.buffer['a']), dtype=torch.float32, device=DEVICE
        )
        old_logp_batch = torch.as_tensor(
            np.array(self.buffer['logp']), dtype=torch.float32, device=DEVICE
        ).unsqueeze(-1)
        return_batch = torch.as_tensor(
            returns, dtype=torch.float32, device=DEVICE
        ).unsqueeze(-1)
        advantage_batch = torch.as_tensor(
            advantages, dtype=torch.float32, device=DEVICE
        ).unsqueeze(-1)

        self.old_model.load_state_dict(self.model.state_dict())
        total_loss = 0.0

        data_size = len(state_batch)
        indices = np.arange(data_size)

        for _ in range(self.PPO_EPOCHS):
            np.random.shuffle(indices)
            for start in range(0, data_size, self.MINIBATCH_SIZE):
                end = start + self.MINIBATCH_SIZE
                batch_idx = indices[start:end]

                s = state_batch[batch_idx]
                a = action_batch[batch_idx]
                old_logp = old_logp_batch[batch_idx]
                ret = return_batch[batch_idx]
                adv = advantage_batch[batch_idx]

                mean, std, value = self.model(s)
                dist = Normal(mean, std)
                logp = dist.log_prob(a).sum(dim=-1, keepdim=True)
                entropy = dist.entropy().mean()

                ratio = torch.exp(logp - old_logp)
                surr1 = ratio * adv
                surr2 = torch.clamp(ratio, 1.0 - self.CLIP_EPSILON, 1.0 + self.CLIP_EPSILON) * adv
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = F.mse_loss(value, ret)

                loss = actor_loss + critic_loss - self.ENTROPY_COEFF * entropy

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()

        self.clear_buffer()
        return total_loss / max(self.PPO_EPOCHS, 1)


In [None]:
# ============================================================
# Cell 7: SAC Agent
# ============================================================

class SACAgent:
    def __init__(self, obs_shape, action_dim, hyperparams, action_bounds):
        self.GAMMA = hyperparams['gamma']
        self.LR_ACTOR = hyperparams['learning_rate']
        self.LR_CRITIC = hyperparams.get('lr_critic', 3e-4)
        self.LR_ALPHA = hyperparams.get('lr_alpha', 1e-4)
        self.TAU = hyperparams['tau']
        self.BATCH_SIZE = hyperparams['batch_size']
        self.MEMORY_CAPACITY = hyperparams['memory_size']
        self.ACTION_MIN, self.ACTION_MAX = action_bounds
        self.obs_shape = obs_shape

        # Entropy regularization
        alpha_start = hyperparams.get('alpha_start', 0.1)
        self.log_alpha = torch.tensor(
            np.log(alpha_start),
            dtype=torch.float32,
            requires_grad=True,
            device=DEVICE,
        )
        self.alpha = self.log_alpha.exp()
        self.target_entropy = -torch.tensor(action_dim, dtype=torch.float32, device=DEVICE)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.LR_ALPHA)

        # Networks
        self.actor = StochasticActor(obs_shape, action_dim).to(DEVICE)
        self.q1 = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q2 = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q1_target = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q2_target = QFunction(obs_shape, action_dim).to(DEVICE)

        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.LR_ACTOR)
        self.q_params = list(self.q1.parameters()) + list(self.q2.parameters())
        self.q_optimizer = optim.Adam(self.q_params, lr=self.LR_CRITIC)

        self.memory = ReplayBuffer(self.MEMORY_CAPACITY)

    def _state_to_tensor(self, state):
        return torch.as_tensor(state, dtype=torch.float32, device=DEVICE)

    def _get_action_and_logp(self, state_tensor):
        mu, log_std = self.actor(state_tensor)
        std = log_std.exp()
        dist = Normal(mu, std)
        z = dist.rsample()
        action = torch.tanh(z)
        log_prob = dist.log_prob(z).sum(dim=-1, keepdim=True)
        log_prob -= torch.sum(torch.log(torch.clamp(1 - action.pow(2), 1e-6, 1.0)), dim=-1, keepdim=True)
        return action, log_prob

    def select_action(self, state, deterministic=False):
        state_tensor = self._state_to_tensor(state)
        with torch.no_grad():
            mu, log_std = self.actor(state_tensor)
            if deterministic:
                action_unscaled = torch.tanh(mu)
            else:
                action_unscaled, _ = self._get_action_and_logp(state_tensor)

            action_unscaled_np = action_unscaled.cpu().numpy().flatten()
            action_scaled = scale_action(action_unscaled_np, self.ACTION_MIN, self.ACTION_MAX)
            return action_scaled

    def store_transition(self, s, a, r, s_prime, done):
        # store *unscaled* action in buffer (-1,1)
        unscaled_a = unscale_action(a, self.ACTION_MIN, self.ACTION_MAX)
        self.memory.push(s, unscaled_a, r, s_prime, done)

    def learn(self):
        if len(self.memory) < self.BATCH_SIZE:
            return None

        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.BATCH_SIZE)

        # Critic update
        with torch.no_grad():
            next_action, next_log_prob = self._get_action_and_logp(next_state_batch)
            q1_target_val = self.q1_target(next_state_batch, next_action)
            q2_target_val = self.q2_target(next_state_batch, next_action)
            min_q_target = torch.min(q1_target_val, q2_target_val)
            target_v = min_q_target - self.alpha * next_log_prob
            q_target = reward_batch + self.GAMMA * (1 - done_batch) * target_v

        q1_val = self.q1(state_batch, action_batch)
        q2_val = self.q2(state_batch, action_batch)

        q1_loss = F.mse_loss(q1_val, q_target)
        q2_loss = F.mse_loss(q2_val, q_target)
        critic_loss = q1_loss + q2_loss

        self.q_optimizer.zero_grad()
        critic_loss.backward()
        self.q_optimizer.step()

        # Actor update
        for p in self.q_params:
            p.requires_grad = False

        new_action, log_prob = self._get_action_and_logp(state_batch)
        q1_new = self.q1(state_batch, new_action)
        q2_new = self.q2(state_batch, new_action)
        min_q_new = torch.min(q1_new, q2_new)

        actor_loss = (self.alpha * log_prob - min_q_new).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        for p in self.q_params:
            p.requires_grad = True

        # Alpha update
        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp().item()

        # Soft update of target networks
        for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
            target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)
        for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
            target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)

        return critic_loss.item() + actor_loss.item()


In [None]:
# ============================================================
# Cell 8: TD3 Agent
# ============================================================

class TD3Agent:
    def __init__(self, obs_shape, action_dim, hyperparams, action_bounds):
        self.GAMMA = hyperparams['gamma']
        self.LR_ACTOR = hyperparams['learning_rate']
        self.LR_CRITIC = hyperparams.get('lr_critic', 1e-3)
        self.TAU = hyperparams['tau']
        self.BATCH_SIZE = hyperparams['batch_size']
        self.MEMORY_CAPACITY = hyperparams['memory_size']
        self.POLICY_DELAY = hyperparams.get('policy_delay', 2)
        self.POLICY_NOISE = hyperparams.get('policy_noise', 0.2)
        self.NOISE_CLIP = hyperparams.get('noise_clip', 0.5)

        self.ACTION_MIN, self.ACTION_MAX = action_bounds
        self.obs_shape = obs_shape
        self.critic_updates = 0

        self.actor = DeterministicActor(obs_shape, action_dim).to(DEVICE)
        self.actor_target = DeterministicActor(obs_shape, action_dim).to(DEVICE)
        self.actor_target.load_state_dict(self.actor.state_dict())

        self.q1 = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q2 = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q1_target = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q2_target = QFunction(obs_shape, action_dim).to(DEVICE)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.LR_ACTOR)
        self.q_params = list(self.q1.parameters()) + list(self.q2.parameters())
        self.q_optimizer = optim.Adam(self.q_params, lr=self.LR_CRITIC)

        self.memory = ReplayBuffer(self.MEMORY_CAPACITY)

    def _state_to_tensor(self, state):
        return torch.as_tensor(state, dtype=torch.float32, device=DEVICE)

    def select_action(self, state, deterministic=False):
        state_tensor = self._state_to_tensor(state)
        with torch.no_grad():
            action_unscaled = self.actor(state_tensor)
            if not deterministic:
                noise = torch.randn_like(action_unscaled) * 0.1
                action_unscaled = torch.clamp(action_unscaled + noise, -1.0, 1.0)
            action_unscaled_np = action_unscaled.cpu().numpy().flatten()
            action_scaled = scale_action(action_unscaled_np, self.ACTION_MIN, self.ACTION_MAX)
            return action_scaled

    def store_transition(self, s, a, r, s_prime, done):
        unscaled_a = unscale_action(a, self.ACTION_MIN, self.ACTION_MAX)
        self.memory.push(s, unscaled_a, r, s_prime, done)

    def learn(self):
        if len(self.memory) < self.BATCH_SIZE:
            return None

        self.critic_updates += 1
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.BATCH_SIZE)

        with torch.no_grad():
            noise = (torch.randn_like(action_batch) * self.POLICY_NOISE).clamp(-self.NOISE_CLIP, self.NOISE_CLIP)
            next_action_unscaled = self.actor_target(next_state_batch)
            next_action = torch.clamp(next_action_unscaled + noise, -1.0, 1.0)

            q1_target_val = self.q1_target(next_state_batch, next_action)
            q2_target_val = self.q2_target(next_state_batch, next_action)
            min_q_target = torch.min(q1_target_val, q2_target_val)

            q_target = reward_batch + self.GAMMA * (1 - done_batch) * min_q_target

        q1_val = self.q1(state_batch, action_batch)
        q2_val = self.q2(state_batch, action_batch)

        q1_loss = F.mse_loss(q1_val, q_target)
        q2_loss = F.mse_loss(q2_val, q_target)
        critic_loss = q1_loss + q2_loss

        self.q_optimizer.zero_grad()
        critic_loss.backward()
        self.q_optimizer.step()

        actor_loss = None
        if self.critic_updates % self.POLICY_DELAY == 0:
            q_actor = self.q1(state_batch, self.actor(state_batch))
            actor_loss = -q_actor.mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Soft updates
            for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
                target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)
            for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
                target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)
            for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
                target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)

        return critic_loss.item() + (actor_loss.item() if actor_loss is not None else 0.0)


In [None]:
# ============================================================
# Cell 9: Training & Testing Loops
# ============================================================

def train_agent(env_name, agent, algo_name, hyperparams, num_timesteps, seed=0, log_wandb=True):
    env = make_env(env_name, seed=seed, render_mode=None)
    total_steps = 0
    episodes = 0

    ppo_steps_per_update = hyperparams.get('trajectory_size', num_timesteps + 1)

    while total_steps < num_timesteps:
        state, info = env.reset()
        episode_reward = 0.0
        episode_steps = 0
        done = False

        last_state_for_ppo = state  # for GAE bootstrapping

        while not done:
            action_scaled = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action_scaled)
            done = terminated or truncated

            if algo_name == "PPO":
                agent.store_reward(reward, done)
                if len(agent.buffer['s']) >= ppo_steps_per_update:
                    loss = agent.learn(next_state if not done else None)
                    if log_wandb and loss is not None:
                        wandb.log({"train/ppo_loss": loss}, step=total_steps)
            else:
                agent.store_transition(state, action_scaled, reward, next_state, done)
                if len(agent.memory) >= hyperparams['batch_size']:
                    loss = agent.learn()
                    if log_wandb and loss is not None:
                        wandb.log({"train/offpolicy_loss": loss}, step=total_steps)

            state = next_state
            episode_reward += reward
            episode_steps += 1
            total_steps += 1

            if total_steps >= num_timesteps:
                break

        episodes += 1
        if log_wandb:
            wandb.log(
                {
                    "train/episode_reward": episode_reward,
                    "train/episode_steps": episode_steps,
                    "train/episodes": episodes,
                },
                step=total_steps,
            )

        print(
            f"Env: {env_name} | Algo: {algo_name} | "
            f"Total Steps: {total_steps}/{num_timesteps}, "
            f"Episode {episodes}, Steps: {episode_steps}, Reward: {episode_reward:.2f}"
        )

    env.close()
    return agent


def test_agent(env_name, agent, algo_name, num_tests=100, seed=123, record_video=False):
    video_dir = f"./videos/{env_name}_{algo_name}"
    if record_video:
        env = gym.wrappers.RecordVideo(
            make_env(env_name, seed=seed, render_mode="rgb_array"),
            video_folder=video_dir,
            episode_trigger=lambda i: i == 0,
            name_prefix=f"{algo_name}_test",
        )
    else:
        env = make_env(env_name, seed=seed, render_mode=None)

    rewards = []
    for ep in range(num_tests):
        state, info = env.reset()
        done = False
        ep_reward = 0.0
        while not done:
            action = agent.select_action(state, deterministic=True)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            state = next_state
            ep_reward += reward
        rewards.append(ep_reward)

    env.close()
    avg_r = float(np.mean(rewards))
    std_r = float(np.std(rewards))
    return avg_r, std_r, rewards


In [None]:
# ============================================================
# Cell 10: Hyperparameters, ENV Config, Agent Factory
# ============================================================


# Total timesteps per environment
ENV_TIMESTEPS = {
    "LunarLander-v3": 300_000,
    "CarRacing-v3":   100_000,  
}

ENV_SEEDS = {
    "LunarLander-v3": 42,
    "CarRacing-v3":   100,
}

# Base hyperparams for the 3 algorithms
BASE_HYPERPARAMS = {
    "PPO": {
        "gamma": 0.99,
        "learning_rate": 3e-4,
        "trajectory_size": 4096,
        "minibatch_size": 256,
        "clip_epsilon": 0.2,
        "ppo_epochs": 10,
        "entropy_coeff": 0.01,
        "memory_size": 0,     # unused
        "batch_size": 0,       # unused
        "tau": 0.0,            # unused
        "alpha_start": 0.0,    # unused
    },
    "SAC": {
        "gamma": 0.99,
        "learning_rate": 3e-4,
        "lr_critic": 3e-4,
        "lr_alpha": 1e-4,
        "memory_size": 1_000_000,
        "batch_size": 256,
        "tau": 0.005,
        "alpha_start": 0.1,
    },
    "TD3": {
        "gamma": 0.99,
        "learning_rate": 1e-3,
        "lr_critic": 1e-3,
        "memory_size": 1_000_000,
        "batch_size": 256,
        "tau": 0.005,
        "policy_delay": 2,
        "policy_noise": 0.2,
        "noise_clip": 0.5,
    },
}

# Per-env overrides 
ENV_ALGO_OVERRIDES = {
    "LunarLander-v3": {
        "PPO": {
            "learning_rate": 3e-4,
            "trajectory_size": 4096,
            "minibatch_size": 256,
            "ppo_epochs": 10,
            "entropy_coeff": 0.01,
        },
        "SAC": {
            "learning_rate": 3e-4,
            "lr_critic": 3e-4,
            "batch_size": 256,
            "tau": 0.01,
            "alpha_start": 0.1,
        },
        "TD3": {
            "learning_rate": 1e-3,
            "lr_critic": 1e-3,
            "batch_size": 256,
            "tau": 0.01,
        },
    },
    "CarRacing-v3": {
        "PPO": {
            "learning_rate": 1e-4,
            "trajectory_size": 8192,
            "minibatch_size": 512,
            "ppo_epochs": 10,
            "entropy_coeff": 0.0,   
        },
        "SAC": {
            "learning_rate": 3e-4,
            "lr_critic": 3e-4,
            "batch_size": 256,
            "tau": 0.01,
            "alpha_start": 0.1,
        },
        "TD3": {
            "learning_rate": 1e-4,
            "lr_critic": 1e-3,
            "batch_size": 512,
            "tau": 0.005,
            "policy_noise": 0.2,
            "noise_clip": 0.5,
        },
    },
}

ENVIRONMENTS = {
    # "LunarLander-v3": {},
    "CarRacing-v3":   {},
}

SWEEP_VARIATIONS = {
    # "PPO_BASE": {"algo": "PPO"},
    "SAC_BASE": {"algo": "SAC"},
    # "TD3_BASE": {"algo": "TD3"},
}

def create_agent(env_name, algo_name, config):
    # Temporary env to read shapes and action bounds
    tmp_env = make_env(env_name, seed=0, render_mode=None)
    obs_shape = tmp_env.observation_space.shape
    action_dim = tmp_env.action_space.shape[0]
    action_bounds = get_action_bounds(tmp_env)
    tmp_env.close()

    if algo_name == "PPO":
        return PPOAgent(obs_shape, action_dim, config, action_bounds)
    elif algo_name == "SAC":
        return SACAgent(obs_shape, action_dim, config, action_bounds)
    elif algo_name == "TD3":
        return TD3Agent(obs_shape, action_dim, config, action_bounds)
    else:
        raise ValueError(f"Unknown algo: {algo_name}")


In [None]:
# ============================================================
# Cell 11: W&B + Experiment Runner
# ============================================================


WANDB_PROJECT = "CMPS458_SOTA_Continuous_Assignment"

def run_experiment(env_name, algo_name, config, total_timesteps, run_name_suffix=""):
    seed = config.get("seed", ENV_SEEDS[env_name])
    set_global_seed(seed)

    run_name = f"{env_name}_{algo_name}_{run_name_suffix}"

    wandb.init(
        project=WANDB_PROJECT,
        config=config,
        name=run_name,
        reinit=True,
    )

    # Create agent
    agent = create_agent(env_name, algo_name, config)

    # Training
    print(f"\n--- Training {algo_name} on {env_name} for {total_timesteps} steps (seed={seed}) ---")
    trained_agent = train_agent(
        env_name,
        agent,
        algo_name,
        config,
        num_timesteps=total_timesteps,
        seed=seed,
        log_wandb=True,
    )

    # Testing
    print(f"\n--- Testing {algo_name} on {env_name} (100 episodes) ---")
    avg_r, std_r, rewards = test_agent(
        env_name,
        trained_agent,
        algo_name,
        num_tests=100,
        seed=seed + 1,
        record_video=True, 
    )

    rewards_table = wandb.Table(data=[[r] for r in rewards], columns=["episode_reward"])

    wandb.log(
        {
            "test/avg_reward": avg_r,
            "test/std_reward": std_r,
            "test/rewards_table": rewards_table,
        }
    )

    print(f"RESULTS {run_name}")
    print(f"  Avg Reward: {avg_r:.2f}")
    print(f"  Std Reward: {std_r:.2f}")

    wandb.finish()


In [None]:
# ============================================================
# Cell 12: Main - Run All 3 Algorithms on Both Environments
# ============================================================

for env_name in ENVIRONMENTS.keys():
    total_steps = ENV_TIMESTEPS[env_name]
    env_seed = ENV_SEEDS[env_name]

    for sweep_name, sweep_conf in SWEEP_VARIATIONS.items():
        algo_name = sweep_conf["algo"]

        base_hp = BASE_HYPERPARAMS[algo_name].copy()
        env_overrides = ENV_ALGO_OVERRIDES[env_name][algo_name]
        base_hp.update(env_overrides)

        final_config = base_hp.copy()
        final_config["algo"] = algo_name
        final_config["env_name"] = env_name
        final_config["seed"] = env_seed
        final_config["name"] = sweep_name

        run_experiment(
            env_name=env_name,
            algo_name=algo_name,
            config=final_config,
            total_timesteps=total_steps,
            run_name_suffix=sweep_name,
        )
