In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# MLP Networks
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, activation=nn.ReLU, output_activation=None):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(dims) - 1):
            layers.extend([nn.Linear(dims[i], dims[i+1]), activation()])
        layers.append(nn.Linear(dims[-1], output_dim))
        if output_activation:
            layers.append(output_activation())
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dims=[256, 256], log_std_min=-20, log_std_max=2):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.net = MLP(obs_dim, hidden_dims, action_dim * 2)
        self.action_dim = action_dim

    def forward(self, obs):
        output = self.net(obs)
        mean, log_std = output.chunk(2, dim=-1)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std

    def sample(self, obs, deterministic=False):
        mean, log_std = self.forward(obs)
        std = log_std.exp()
        
        if deterministic:
            return torch.tanh(mean), None
        
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        
        return action, log_prob

class Critic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dims=[256, 256]):
        super().__init__()
        self.net = MLP(obs_dim + action_dim, hidden_dims, 1)

    def forward(self, obs, action):
        return self.net(torch.cat([obs, action], dim=-1))

class TwinCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dims=[256, 256]):
        super().__init__()
        self.q1 = Critic(obs_dim, action_dim, hidden_dims)
        self.q2 = Critic(obs_dim, action_dim, hidden_dims)

    def forward(self, obs, action):
        return self.q1(obs, action), self.q2(obs, action)

    def min_q(self, obs, action):
        q1, q2 = self.forward(obs, action)
        return torch.min(q1, q2)

# CNN Networks
class NatureCNN(nn.Module):
    def __init__(self, obs_shape, feature_dim=512):
        super().__init__()
        # obs_shape from gym: (H, W, C) -> convert to (C, H, W) for PyTorch conv
        if len(obs_shape) == 3:
            # Assume gym format (H, W, C) if last dim is small (channels)
            if obs_shape[2] <= 4:
                self.input_shape = (obs_shape[2], obs_shape[0], obs_shape[1])  # (C, H, W)
            else:
                self.input_shape = obs_shape  # Already (C, H, W)
        else:
            self.input_shape = obs_shape
            
        self.conv1 = nn.Conv2d(self.input_shape[0], 32, 8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
        
        with torch.no_grad():
            dummy = torch.zeros(1, *self.input_shape)
            dummy = self.conv3(self.conv2(self.conv1(dummy)))
            conv_out_size = dummy.numel()
        
        self.fc = nn.Linear(conv_out_size, feature_dim)
        self.feature_dim = feature_dim

    def forward(self, x):
        if len(x.shape) == 4:  # (B, H, W, C) -> (B, C, H, W)
            x = x.permute(0, 3, 1, 2).contiguous()
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.reshape(x.size(0), -1)
        return F.relu(self.fc(x))

class ConvGaussianPolicy(nn.Module):
    def __init__(self, obs_shape, action_dim, feature_dim=512, hidden_dims=[256, 256], log_std_min=-20, log_std_max=2):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.cnn = NatureCNN(obs_shape, feature_dim)
        self.mean_net = MLP(feature_dim, hidden_dims, action_dim)
        self.log_std_net = MLP(feature_dim, hidden_dims, action_dim)
        self.action_dim = action_dim

    def forward(self, obs):
        features = self.cnn(obs)
        mean = self.mean_net(features)
        log_std = torch.clamp(self.log_std_net(features), self.log_std_min, self.log_std_max)
        return mean, log_std

    def sample(self, obs, deterministic=False):
        mean, log_std = self.forward(obs)
        std = log_std.exp()
        
        if deterministic:
            return torch.tanh(mean), None
        
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        
        return action, log_prob

class ConvCritic(nn.Module):
    def __init__(self, obs_shape, action_dim, feature_dim=512, hidden_dims=[256, 256]):
        super().__init__()
        self.cnn = NatureCNN(obs_shape, feature_dim)
        self.net = MLP(feature_dim + action_dim, hidden_dims, 1)

    def forward(self, obs, action):
        features = self.cnn(obs)
        return self.net(torch.cat([features, action], dim=-1))

class ConvTwinCritic(nn.Module):
    def __init__(self, obs_shape, action_dim, feature_dim=512, hidden_dims=[256, 256]):
        super().__init__()
        self.q1 = ConvCritic(obs_shape, action_dim, feature_dim, hidden_dims)
        self.q2 = ConvCritic(obs_shape, action_dim, feature_dim, hidden_dims)

    def forward(self, obs, action):
        return self.q1(obs, action), self.q2(obs, action)

    def min_q(self, obs, action):
        q1, q2 = self.forward(obs, action)
        return torch.min(q1, q2)

In [None]:
import numpy as np
import torch

class ReplayBuffer:
    def __init__(self, capacity, obs_shape, action_dim):
        self.capacity = capacity
        self.obs = np.zeros((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.next_obs = np.zeros((capacity, *obs_shape), dtype=np.float32)
        self.dones = np.zeros(capacity, dtype=np.bool_)
        self.ptr = 0
        self.size = 0
    
    def __len__(self):
        return self.size

    def add(self, obs, action, reward, next_obs, done):
        self.obs[self.ptr] = obs
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_obs[self.ptr] = next_obs
        self.dones[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size, device):
        idxs = np.random.randint(0, self.size, batch_size)
        batch = {
            'obs': torch.as_tensor(self.obs[idxs], dtype=torch.float32, device=device),
            'actions': torch.as_tensor(self.actions[idxs], dtype=torch.float32, device=device),
            'rewards': torch.as_tensor(self.rewards[idxs], dtype=torch.float32, device=device),
            'next_obs': torch.as_tensor(self.next_obs[idxs], dtype=torch.float32, device=device),
            'dones': torch.as_tensor(self.dones[idxs], dtype=torch.float32, device=device)
        }
        return batch

class PPORolloutBuffer:
    def __init__(self):
        self.obs = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.logps = []
        self.dones = []

    def store(self, obs, action, reward, value, logp, done):
        self.obs.append(obs)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.logps.append(logp)
        self.dones.append(done)

    def get(self):
        data = {
            'obs': np.array(self.obs, dtype=np.float32),
            'actions': np.array(self.actions, dtype=np.float32),
            'rewards': np.array(self.rewards, dtype=np.float32),
            'values': np.array(self.values, dtype=np.float32),
            'logps': np.array(self.logps, dtype=np.float32),
            'dones': np.array(self.dones, dtype=np.bool_)
        }
        self.__init__()
        return data

In [None]:
import cv2
import gymnasium as gym
from gymnasium import Wrapper, ObservationWrapper
from collections import deque
import numpy as np

class PreprocessCarRacing(ObservationWrapper):
    def __init__(self, env, resize=(84, 84)):
        super().__init__(env)
        self.resize = resize
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0,
            shape=(resize[0], resize[1], 1),
            dtype=np.float32
        )

    def observation(self, obs):
        gray = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        cropped = gray[12:, :]  # Remove score bar
        resized = cv2.resize(cropped, self.resize, interpolation=cv2.INTER_AREA)
        return resized[:, :, np.newaxis].astype(np.float32) / 255.0

class FrameStack(Wrapper):
    def __init__(self, env, num_stack=4):
        super().__init__(env)
        self.num_stack = num_stack
        self.frames = deque(maxlen=num_stack)
        obs_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0,
            shape=(obs_shape[0], obs_shape[1], obs_shape[2] * num_stack),
            dtype=np.float32
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.num_stack):
            self.frames.append(obs)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, term, trunc, info

    def _get_obs(self):
        return np.concatenate(self.frames, axis=-1)

In [None]:
from abc import ABC, abstractmethod

class BaseAgent(ABC):
    @abstractmethod
    def select_action(self, obs, deterministic=False):
        pass
    
    @abstractmethod
    def train_step(self):
        pass

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils.networks import GaussianPolicy, MLP
from utils.buffers import PPORolloutBuffer
import numpy as np

class PPOAgent(BaseAgent):
    def __init__(self, obs_dim, action_dim, lr=3e-4, gamma=0.99, clip_ratio=0.2, 
                 lam=0.95, train_pi_iters=80, train_v_iters=80, target_kl=0.01,
                 hidden_dims=[64, 64], max_ep_len=1000, ent_coef=0.0, batch_size=64, device='cuda'):
        self.device = device
        self.gamma = gamma
        self.clip_ratio = clip_ratio
        self.lam = lam
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.target_kl = target_kl
        self.max_ep_len = max_ep_len
        self.ent_coef = ent_coef
        self.batch_size = batch_size
        
        self.actor = GaussianPolicy(obs_dim, action_dim, hidden_dims).to(device)
        self.critic = MLP(obs_dim, hidden_dims, 1).to(device)
        
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr)
        
        self.buffer = PPORolloutBuffer()

    def select_action(self, obs, deterministic=False):
        obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            action, log_prob, _ = self.actor.sample(obs, deterministic)
            value = self.critic(obs)
        # log_prob is None when deterministic=True
        if log_prob is None:
            return action.cpu().numpy()[0], 0.0, value.cpu().numpy()[0].item()
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0].item(), value.cpu().numpy()[0].item()

    def compute_gae(self, rewards, values, dones):
        returns = np.zeros_like(rewards)
        advantages = np.zeros_like(rewards)
        last_gae = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_non_terminal = 1.0 - dones[t]
                next_value = 0
            else:
                next_non_terminal = 1.0 - dones[t]
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
            advantages[t] = last_gae = delta + self.gamma * self.lam * next_non_terminal * last_gae
            returns[t] = advantages[t] + values[t]
        
        return returns, advantages

    def train_step(self):
        # Train when called - train.py handles the timing based on rollout_length
        if len(self.buffer.obs) == 0:
            return None
            
        data = self.buffer.get()
        returns, advantages = self.compute_gae(data['rewards'], data['values'], data['dones'])
        
        old_obs = torch.as_tensor(data['obs'], dtype=torch.float32, device=self.device)
        old_actions = torch.as_tensor(data['actions'], dtype=torch.float32, device=self.device)
        old_logps = torch.as_tensor(data['logps'], dtype=torch.float32, device=self.device)
        returns = torch.as_tensor(returns, dtype=torch.float32, device=self.device)
        advantages = torch.as_tensor(advantages, dtype=torch.float32, device=self.device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        actor_losses = []
        critic_losses = []
        kl_divs = []
        entropies = []
        
        dataset_size = len(data['obs'])
        indices = np.arange(dataset_size)
        
        for i in range(self.train_pi_iters):
            np.random.shuffle(indices)
            
            for start in range(0, dataset_size, self.batch_size):
                end = start + self.batch_size
                idx = indices[start:end]
                
                batch_obs = old_obs[idx]
                batch_act = old_actions[idx]
                batch_logp = old_logps[idx]
                batch_adv = advantages[idx]
                batch_ret = returns[idx]
                
                _, new_logps = self.actor.sample(batch_obs)
                mean, log_std = self.actor.forward(batch_obs)
                entropy = (0.5 + 0.5 * np.log(2 * np.pi) + log_std).sum(dim=-1)
                ratio = torch.exp(new_logps - batch_logp.unsqueeze(1))
                surr1 = ratio * batch_adv.unsqueeze(1)
                surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_adv.unsqueeze(1)
                
                # Entropy bonus
                # entropy_loss = -entropy.mean()
                actor_loss = -torch.min(surr1, surr2).mean() #+ self.ent_coef * entropy_loss
                
                values = self.critic(batch_obs).squeeze()
                critic_loss = F.mse_loss(values, batch_ret)
                
                self.actor_optim.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
                self.actor_optim.step()
                
                self.critic_optim.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
                self.critic_optim.step()
                
                kl = (batch_logp.unsqueeze(1) - new_logps).mean()
                kl_divs.append(kl.item())
                actor_losses.append(actor_loss.item())
                critic_losses.append(critic_loss.item())
                entropies.append(entropy.mean().item())
            
            # Early stopping based on KL divergence (averaged over epoch)
            if np.mean(kl_divs[-dataset_size//self.batch_size:]) > self.target_kl * 1.5:
                break
        
        return {
            'actor_loss': np.mean(actor_losses),
            'critic_loss': np.mean(critic_losses),
            'kl_divergence': np.mean(kl_divs),
            'training_epochs': i + 1,
            'entropy': np.mean(entropies)
        }

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from utils.networks import ConvGaussianPolicy, NatureCNN, MLP
from utils.buffers import PPORolloutBuffer

class PPOAgentCNN(PPOAgent):
    def __init__(self, obs_shape, action_dim, feature_dim=512, **kwargs):
        self.device = kwargs['device']
        self.gamma = kwargs['gamma']
        self.clip_ratio = kwargs['clip_ratio']
        self.lam = kwargs['lam']
        self.train_pi_iters = kwargs['train_pi_iters']
        self.train_v_iters = kwargs['train_v_iters']
        self.target_kl = kwargs['target_kl']
        self.max_ep_len = kwargs['max_ep_len']
        self.ent_coef = kwargs.get('ent_coef', 0.0)
        self.batch_size = kwargs.get('batch_size', 64)
        
        # Extract only the hidden_dims for networks
        hidden_dims = kwargs.get('hidden_dims', [256, 256])
        
        # CNN actor
        self.actor = ConvGaussianPolicy(obs_shape, action_dim, feature_dim, hidden_dims=hidden_dims).to(self.device)
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=kwargs['lr'])
        
        # CNN critic
        class CriticCNN(torch.nn.Module):
            def __init__(self, obs_shape, feature_dim, hidden_dims):
                super().__init__()
                self.cnn = NatureCNN(obs_shape, feature_dim)
                self.fc = MLP(feature_dim, hidden_dims, 1)
            
            def forward(self, obs):
                return self.fc(self.cnn(obs))
        
        self.critic = CriticCNN(obs_shape, feature_dim, hidden_dims).to(self.device)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=kwargs['lr'])
        
        self.buffer = PPORolloutBuffer()

    def select_action(self, obs, deterministic=False):
        obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            action, log_prob = self.actor.sample(obs_tensor, deterministic)
            value = self.critic(obs_tensor)
        # log_prob is None when deterministic=True
        if log_prob is None:
            return action.cpu().numpy()[0], 0.0, value.cpu().numpy()[0].item()
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0].item(), value.cpu().numpy()[0].item()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils.networks import GaussianPolicy, TwinCritic
from utils.buffers import ReplayBuffer

class SACAgent(BaseAgent):
    def __init__(self, obs_dim, action_dim, lr=3e-4, gamma=0.99, tau=0.005, 
                 alpha=0.2, buffer_size=500000, batch_size=256, 
                 hidden_dims=[256, 256], automatic_entropy_tuning=True, 
                 target_entropy=None, device='cuda'):
        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.batch_size = batch_size
        
        self.actor = GaussianPolicy(obs_dim, action_dim, hidden_dims).to(device)
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr)
        
        self.critic = TwinCritic(obs_dim, action_dim, hidden_dims).to(device)
        self.critic_target = TwinCritic(obs_dim, action_dim, hidden_dims).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr)
        
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            self.target_entropy = target_entropy if target_entropy is not None else -action_dim
            self.log_alpha = torch.tensor(0.0, requires_grad=True, device=device)
            self.alpha_optim = optim.Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.exp()
        
        self.replay_buffer = ReplayBuffer(buffer_size, (obs_dim,), action_dim)

    def select_action(self, obs, deterministic=False):
        obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            action, _ = self.actor.sample(obs, deterministic)
        return action.cpu().numpy()[0]

    def train_step(self):
        if self.replay_buffer.size < self.batch_size:
            return None
            
        batch = self.replay_buffer.sample(self.batch_size, self.device)
        
        with torch.no_grad():
            next_action, next_log_prob = self.actor.sample(batch['next_obs'])
            min_q_next = self.critic_target.min_q(batch['next_obs'], next_action)
            q_target = batch['rewards'].unsqueeze(1) + (1 - batch['dones'].unsqueeze(1)) * self.gamma * (min_q_next - self.alpha * next_log_prob)
        
        q1, q2 = self.critic(batch['obs'], batch['actions'])
        critic_loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)
        
        self.critic_optim.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
        self.critic_optim.step()
        
        action_pred, log_prob = self.actor.sample(batch['obs'])
        q_new = self.critic.min_q(batch['obs'], action_pred)
        actor_loss = (self.alpha * log_prob - q_new).mean()
        
        self.actor_optim.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
        self.actor_optim.step()
        
        alpha_loss = torch.tensor(0.0, device=self.device)
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_prob.detach() + self.target_entropy)).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp()
        
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': self.alpha.item(),
            'entropy': -log_prob.mean().item()
        }

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils.networks import ConvGaussianPolicy, ConvTwinCritic
from utils.buffers import ReplayBuffer

class SACAgentCNN(SACAgent):
    def __init__(self, obs_shape, action_dim, feature_dim=512, **kwargs):
        # Initialize parameters exactly like parent
        self.device = kwargs['device']
        self.gamma = kwargs['gamma']
        self.tau = kwargs['tau']
        self.alpha = kwargs.get('alpha', 0.2)
        self.batch_size = kwargs['batch_size']
        
        hidden_dims = kwargs.get('hidden_dims', [256, 256])
        
        # --- CNN Networks ---
        self.actor = ConvGaussianPolicy(obs_shape, action_dim, feature_dim, hidden_dims=hidden_dims).to(self.device)
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=kwargs['lr'])
        
        self.critic = ConvTwinCritic(obs_shape, action_dim, feature_dim, hidden_dims=hidden_dims).to(self.device)
        self.critic_target = ConvTwinCritic(obs_shape, action_dim, feature_dim, hidden_dims=hidden_dims).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=kwargs['lr'])
        
        self.automatic_entropy_tuning = kwargs.get('automatic_entropy_tuning', True)
        if self.automatic_entropy_tuning:
            self.target_entropy = kwargs.get('target_entropy', -action_dim)
            self.log_alpha = torch.tensor(0.0, requires_grad=True, device=self.device)
            self.alpha_optim = optim.Adam([self.log_alpha], lr=kwargs['lr'])
            self.alpha = self.log_alpha.exp()
        
        # Buffer initialization
        self.replay_buffer = ReplayBuffer(kwargs['buffer_size'], obs_shape, action_dim)

    def select_action(self, obs, deterministic=False):
        # Convert obs to tensor and add batch dimension
        # Obs comes from FrameStack wrapper in (H, W, C) format, already normalized [0, 1]
        obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
        
        # Add batch dimension if needed
        if len(obs_tensor.shape) == 3:
            obs_tensor = obs_tensor.unsqueeze(0)
        
        # Note: No permutation here - the CNN network handles HWC‚ÜíCHW conversion internally

        with torch.no_grad():
            action, _ = self.actor.sample(obs_tensor, deterministic)
        return action.cpu().numpy()[0]

    def train_step(self):
        # Standard check
        if self.replay_buffer.size < self.batch_size:
            return None
            
        batch = self.replay_buffer.sample(self.batch_size, self.device)
        
        # --- CNN SPECIFIC PREPROCESSING ---
        # Observations are already normalized [0, 1] from FrameStack wrapper
        obs = batch['obs'].float()
        next_obs = batch['next_obs'].float()

        # Note: No permutation here - the CNN network handles HWC‚ÜíCHW conversion internally

        # Scale Rewards: CarRacing rewards are huge (~900), scale to prevent value explosion
        rewards = batch['rewards'] / 20.0 
        # ----------------------------------

        # --- STANDARD SAC LOGIC (using processed tensors) ---
        with torch.no_grad():
            next_action, next_log_prob = self.actor.sample(next_obs)
            # Use 'min_q' if ConvTwinCritic supports it, otherwise manually compute min(q1, q2)
            if hasattr(self.critic_target, 'min_q'):
                min_q_next = self.critic_target.min_q(next_obs, next_action)
            else:
                q1_next, q2_next = self.critic_target(next_obs, next_action)
                min_q_next = torch.min(q1_next, q2_next)
                
            q_target = rewards.unsqueeze(1) + (1 - batch['dones'].unsqueeze(1)) * self.gamma * (min_q_next - self.alpha * next_log_prob)
        
        # Critic Update
        q1, q2 = self.critic(obs, batch['actions'])
        critic_loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)
        
        self.critic_optim.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
        self.critic_optim.step()
        
        # Actor Update
        action_pred, log_prob = self.actor.sample(obs)
        
        if hasattr(self.critic, 'min_q'):
            q_new = self.critic.min_q(obs, action_pred)
        else:
            q1_new, q2_new = self.critic(obs, action_pred)
            q_new = torch.min(q1_new, q2_new)
            
        actor_loss = (self.alpha * log_prob - q_new).mean()
        
        self.actor_optim.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
        self.actor_optim.step()
        
        # Entropy Update
        alpha_loss = torch.tensor(0.0, device=self.device)
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_prob.detach() + self.target_entropy)).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp()
        
        # Soft Update
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': self.alpha.item(),
            'entropy': -log_prob.mean().item()
        }

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils.networks import MLP, TwinCritic
from utils.buffers import ReplayBuffer
import numpy as np

class TD3Agent(BaseAgent):
    def __init__(self, obs_dim, action_dim, lr=3e-4, gamma=0.99, tau=0.005, 
                 policy_noise=0.2, noise_clip=0.5, policy_delay=2,
                 buffer_size=500000, batch_size=256, hidden_dims=[256, 256], 
                 exploration_noise=0.1, device='cuda'):
        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_delay = policy_delay
        self.batch_size = batch_size
        
        self.actor = MLP(obs_dim, hidden_dims, action_dim, output_activation=nn.Tanh).to(device)
        self.actor_target = MLP(obs_dim, hidden_dims, action_dim, output_activation=nn.Tanh).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr)
        
        self.critic = TwinCritic(obs_dim, action_dim, hidden_dims).to(device)
        self.critic_target = TwinCritic(obs_dim, action_dim, hidden_dims).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr)
        
        self.replay_buffer = ReplayBuffer(buffer_size, (obs_dim,), action_dim)
        self.exploration_noise = exploration_noise
        self.train_steps = 0

    def select_action(self, obs, deterministic=False):
        obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            action = self.actor(obs).cpu().numpy()[0]
        if not deterministic:
            action += np.random.normal(0, self.exploration_noise, size=action.shape)
        return np.clip(action, -1.0, 1.0)

    def train_step(self):
        if self.replay_buffer.size < self.batch_size:
            return None
            
        batch = self.replay_buffer.sample(self.batch_size, self.device)
        
        with torch.no_grad():
            noise = torch.randn_like(batch['actions']) * self.policy_noise
            noise = torch.clamp(noise, -self.noise_clip, self.noise_clip)
            
            next_action = self.actor_target(batch['next_obs']) + noise
            next_action = torch.clamp(next_action, -1.0, 1.0)
            
            q1_target, q2_target = self.critic_target(batch['next_obs'], next_action)
            min_q_target = torch.min(q1_target, q2_target)
            q_target = batch['rewards'].unsqueeze(1) + (1 - batch['dones'].unsqueeze(1)) * self.gamma * min_q_target
        
        q1, q2 = self.critic(batch['obs'], batch['actions'])
        critic_loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)
        
        self.critic_optim.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5)
        self.critic_optim.step()
        
        actor_loss = torch.tensor(0.0, device=self.device)
        if self.train_steps % self.policy_delay == 0:
            actor_loss = -self.critic.q1(batch['obs'], self.actor(batch['obs'])).mean()
            
            self.actor_optim.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
            self.actor_optim.step()
            
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        self.train_steps += 1
        
        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item()
        }

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from utils.networks import NatureCNN, MLP, ConvTwinCritic
from utils.buffers import ReplayBuffer

class ConvDeterministicActor(nn.Module):
    """Deterministic actor for TD3 with CNN encoder"""
    def __init__(self, obs_shape, action_dim, feature_dim=512, hidden_dims=[256, 256]):
        super().__init__()
        # Ensure channel dimension is passed correctly to NatureCNN
        # obs_shape is likely (84, 84, 4) or (4, 84, 84). We handle the channel count.
        c = obs_shape[0] if obs_shape[0] < obs_shape[2] else obs_shape[2]
        self.cnn = NatureCNN(c, feature_dim) 
        self.net = MLP(feature_dim, hidden_dims, action_dim, output_activation=nn.Tanh)
    
    def forward(self, obs):
        features = self.cnn(obs)
        return self.net(features)

class TD3AgentCNN(TD3Agent):
    def __init__(self, obs_shape, action_dim, feature_dim=512, **kwargs):
        self.device = kwargs['device']
        self.gamma = kwargs['gamma']
        self.tau = kwargs['tau']
        self.policy_noise = kwargs['policy_noise']
        self.noise_clip = kwargs['noise_clip']
        self.policy_delay = kwargs['policy_delay']
        self.batch_size = kwargs['batch_size']
        
        hidden_dims = kwargs.get('hidden_dims', [256, 256])
        
        # CNN networks
        self.actor = ConvDeterministicActor(obs_shape, action_dim, feature_dim, hidden_dims).to(self.device)
        self.actor_target = ConvDeterministicActor(obs_shape, action_dim, feature_dim, hidden_dims).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=kwargs['lr'])
        
        self.critic = ConvTwinCritic(obs_shape, action_dim, feature_dim, hidden_dims=hidden_dims).to(self.device)
        self.critic_target = ConvTwinCritic(obs_shape, action_dim, feature_dim, hidden_dims=hidden_dims).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=kwargs['lr'])
        
        self.replay_buffer = ReplayBuffer(kwargs['buffer_size'], obs_shape, action_dim)
        self.exploration_noise = kwargs.get('exploration_noise', 0.1)
        self.train_steps = 0

    def select_action(self, obs, deterministic=False):
        # Convert obs to tensor and add batch dimension
        # Obs comes from FrameStack wrapper in (H, W, C) format, already normalized [0, 1]
        obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
        
        # Add batch dimension if needed
        if len(obs_tensor.shape) == 3:
            obs_tensor = obs_tensor.unsqueeze(0)
        
        # Note: No permutation here - the CNN network handles HWC‚ÜíCHW conversion internally

        with torch.no_grad():
            action = self.actor(obs_tensor).cpu().numpy()[0]
            
        if not deterministic:
            # Add noise for exploration
            noise = np.random.normal(0, self.exploration_noise, size=action.shape)
            action = action + noise
            
        return np.clip(action, -1.0, 1.0)

    def train_step(self):
        if self.replay_buffer.size < self.batch_size:
            return None
        
        self.train_steps += 1
        batch = self.replay_buffer.sample(self.batch_size, self.device)
        
        # --- CNN SPECIFIC PREPROCESSING ---
        # Observations are already normalized [0, 1] from FrameStack wrapper
        obs = batch['obs'].float()
        next_obs = batch['next_obs'].float()
        
        # Note: No permutation here - the CNN network handles HWC‚ÜíCHW conversion internally
        
        # Scale Rewards: CarRacing rewards are huge (~900), scale to prevent value explosion
        rewards = batch['rewards'] / 20.0
        # ----------------------------------

        # Critic Update
        with torch.no_grad():
            # Select action according to target actor and add clipped noise
            noise = (torch.randn_like(batch['actions']) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_actions = (self.actor_target(next_obs) + noise).clamp(-1, 1)

            # Compute the target Q value
            target_q1, target_q2 = self.critic_target(next_obs, next_actions)
            target_q = torch.min(target_q1, target_q2)
            target_q = rewards.unsqueeze(1) + (1 - batch['dones'].unsqueeze(1)) * self.gamma * target_q

        # Get current Q estimates
        current_q1, current_q2 = self.critic(obs, batch['actions'])

        # Compute critic loss
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        # Optimize the critic
        self.critic_optim.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
        self.critic_optim.step()

        actor_loss = None
        
        # Delayed policy updates
        if self.train_steps % self.policy_delay == 0:
            # Compute actor loss
            actor_loss = -self.critic.q1(obs, self.actor(obs)).mean()
            
            # Optimize the actor
            self.actor_optim.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
            self.actor_optim.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # Return metrics (check if actor_loss was computed)
        metrics = {'critic_loss': critic_loss.item()}
        if actor_loss is not None:
            metrics['actor_loss'] = actor_loss.item()
            
        return metrics

# üéÆ RL Assignment 4 - Complete Training Pipeline
## Kaggle-Ready Implementation

This notebook contains a complete implementation of:
- **SAC** (Soft Actor-Critic) - MLP & CNN variants
- **TD3** (Twin Delayed DDPG) - MLP & CNN variants  
- **PPO** (Proximal Policy Optimization) - MLP & CNN variants

**Environments:**
- LunarLander-v3 (MLP agents)
- CarRacing-v3 (CNN agents)

**Features:**
- All code self-contained (no external files needed)
- Easy configuration switching
- WandB integration for tracking
- Video recording of evaluations
- Checkpoint saving/loading

## üì¶ Installation & Setup

In [None]:
# Install required packages
!pip install gymnasium[box2d] -q
!pip install wandb -q
!pip install opencv-python -q
!pip install swig -q

print("‚úì All packages installed successfully!")

In [None]:
import os
import sys
import torch
import numpy as np
import wandb
from datetime import datetime

# Setup
os.makedirs('models', exist_ok=True)
os.makedirs('videos', exist_ok=True)
os.makedirs('logs', exist_ok=True)

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üñ•Ô∏è  Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## ‚öôÔ∏è Configuration System

Select which experiment to run by changing the `EXPERIMENT` variable below.

In [None]:
# ========================================
# üéØ SELECT EXPERIMENT HERE
# ========================================
EXPERIMENT = "sac_carracing"  # Options: sac_carracing, td3_carracing, ppo_carracing, 
                              #          sac_lunarlander, td3_lunarlander, ppo_lunarlander

# WandB Configuration (Set to None to disable)
WANDB_ENTITY = "ziadhf-cairo-university"  # Your WandB username/entity
WANDB_PROJECT = "cmps458-assignment4_2"   # Your project name
WANDB_ENABLED = True                      # Set to False to disable WandB

# Training Configuration
CHECKPOINT_FREQ = 100000  # Save checkpoint every N steps
ENABLE_VIDEO = True       # Record evaluation videos

print(f"üéØ Selected Experiment: {EXPERIMENT}")
print(f"üìä WandB: {'Enabled' if WANDB_ENABLED else 'Disabled'}")
print(f"üé• Video Recording: {'Enabled' if ENABLE_VIDEO else 'Disabled'}")

In [None]:
# WandB Authentication (for Kaggle)
# Add your WandB API key as a Kaggle Secret named 'WANDB_API_KEY'
# Or uncomment and paste your API key directly (not recommended for public notebooks)

if WANDB_ENABLED:
    try:
        # Try to get from Kaggle secrets
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        wandb_key = user_secrets.get_secret("WANDB_API_KEY")
        os.environ['WANDB_API_KEY'] = wandb_key
        print("‚úì WandB API key loaded from Kaggle secrets")
    except:
        # If not on Kaggle or secret not set, try environment variable
        if 'WANDB_API_KEY' not in os.environ:
            print("‚ö†Ô∏è  WandB API key not found!")
            print("   Option 1: Add 'WANDB_API_KEY' to Kaggle Secrets")
            print("   Option 2: Set WANDB_ENABLED = False to disable logging")
            # Uncomment and paste your key here (not recommended for public notebooks):
            # os.environ['WANDB_API_KEY'] = 'your-api-key-here'
        else:
            print("‚úì WandB API key found in environment")
else:
    print("‚ÑπÔ∏è  WandB logging disabled")

In [None]:
# Configuration Dictionary
CONFIGS = {
    "sac_carracing": {
        'algo': 'sac_cnn',
        'env_id': 'CarRacing-v3',
        'use_cnn': True,
        'feature_dim': 512,
        'seed': 42,
        'device': device,
        'total_steps': 2000000,
        'learning_starts': 35000,
        'train_freq': 1,
        'gradient_steps': 1,
        'max_ep_len': 1000,
        'eval_interval': 25000,
        'eval_episodes': 5,
        'checkpoint_freq': CHECKPOINT_FREQ,
        'agent_params': {
            'lr': 0.0001,
            'gamma': 0.99,
            'tau': 0.005,
            'alpha': 0.2,
            'automatic_entropy_tuning': True,
            'buffer_size': 200000,
            'batch_size': 256,
            'hidden_dims': [256, 256]
        }
    },
    "td3_carracing": {
        'algo': 'td3_cnn',
        'env_id': 'CarRacing-v3',
        'use_cnn': True,
        'feature_dim': 512,
        'seed': 42,
        'device': device,
        'total_steps': 2000000,
        'learning_starts': 35000,
        'train_freq': 1,
        'gradient_steps': 1,
        'max_ep_len': 1000,
        'eval_interval': 25000,
        'eval_episodes': 3,
        'checkpoint_freq': CHECKPOINT_FREQ,
        'agent_params': {
            'lr': 0.0001,
            'gamma': 0.99,
            'tau': 0.005,
            'policy_noise': 0.2,
            'noise_clip': 0.5,
            'policy_delay': 2,
            'buffer_size': 200000,
            'batch_size': 256,
            'hidden_dims': [256, 256],
            'exploration_noise': 0.3
        }
    },
    "ppo_carracing": {
        'algo': 'ppo_cnn',
        'env_id': 'CarRacing-v3',
        'use_cnn': True,
        'feature_dim': 512,
        'seed': 42,
        'device': device,
        'total_steps': 3000000,
        'learning_starts': 0,
        'rollout_length': 4096,
        'train_freq': 4096,
        'gradient_steps': 1,
        'max_ep_len': 2000,
        'eval_interval': 50000,
        'eval_episodes': 3,
        'checkpoint_freq': CHECKPOINT_FREQ,
        'agent_params': {
            'lr': 0.0003,
            'gamma': 0.99,
            'clip_ratio': 0.2,
            'lam': 0.95,
            'train_pi_iters': 15,
            'train_v_iters': 15,
            'target_kl': 0.02,
            'ent_coef': 0.01,
            'batch_size': 128,
            'hidden_dims': [256, 256],
            'max_ep_len': 1000
        }
    },
    "sac_lunarlander": {
        'algo': 'sac',
        'env_id': 'LunarLander-v3',
        'use_cnn': False,
        'seed': 42,
        'device': device,
        'total_steps': 300000,
        'learning_starts': 10000,
        'train_freq': 1,
        'gradient_steps': 1,
        'max_ep_len': 1000,
        'eval_interval': 10000,
        'eval_episodes': 10,
        'checkpoint_freq': CHECKPOINT_FREQ,
        'agent_params': {
            'lr': 0.0003,
            'gamma': 0.99,
            'tau': 0.005,
            'alpha': 0.2,
            'automatic_entropy_tuning': True,
            'buffer_size': 1000000,
            'batch_size': 256,
            'hidden_dims': [256, 256]
        }
    },
    "td3_lunarlander": {
        'algo': 'td3',
        'env_id': 'LunarLander-v3',
        'use_cnn': False,
        'seed': 42,
        'device': device,
        'total_steps': 300000,
        'learning_starts': 10000,
        'train_freq': 1,
        'gradient_steps': 1,
        'max_ep_len': 1000,
        'eval_interval': 10000,
        'eval_episodes': 10,
        'checkpoint_freq': CHECKPOINT_FREQ,
        'agent_params': {
            'lr': 0.0003,
            'gamma': 0.99,
            'tau': 0.005,
            'policy_noise': 0.2,
            'noise_clip': 0.5,
            'policy_delay': 2,
            'buffer_size': 1000000,
            'batch_size': 256,
            'hidden_dims': [256, 256],
            'exploration_noise': 0.1
        }
    },
    "ppo_lunarlander": {
        'algo': 'ppo',
        'env_id': 'LunarLander-v3',
        'use_cnn': False,
        'seed': 42,
        'device': device,
        'total_steps': 1000000,
        'learning_starts': 0,
        'rollout_length': 2048,
        'train_freq': 2048,
        'gradient_steps': 1,
        'max_ep_len': 1000,
        'eval_interval': 20000,
        'eval_episodes': 10,
        'checkpoint_freq': CHECKPOINT_FREQ,
        'agent_params': {
            'lr': 0.0003,
            'gamma': 0.99,
            'clip_ratio': 0.2,
            'lam': 0.95,
            'train_pi_iters': 80,
            'train_v_iters': 80,
            'target_kl': 0.01,
            'ent_coef': 0.0,
            'batch_size': 64,
            'hidden_dims': [64, 64],
            'max_ep_len': 1000
        }
    }
}

# Load selected config
config = CONFIGS[EXPERIMENT]
config['run_name'] = f"{config['algo']}-{config['env_id']}-{datetime.now().strftime('%Y%m%d_%H%M%S')}"

print(f"\nüìã Configuration Loaded:")
print(f"   Algorithm: {config['algo']}")
print(f"   Environment: {config['env_id']}")
print(f"   Total Steps: {config['total_steps']:,}")
print(f"   Device: {config['device']}")

## üèãÔ∏è Training & Evaluation Functions

In [None]:
import gymnasium as gym
from gymnasium.wrappers import RecordVideo

def make_env(env_id, seed=42, use_cnn=False, capture_video=False, run_name=None):
    """Create environment with optional CNN preprocessing and video recording"""
    render_mode = 'rgb_array' if capture_video else None
    env = gym.make(env_id, continuous=True, render_mode=render_mode)
    
    if capture_video and run_name and ENABLE_VIDEO:
        video_folder = f"videos/{run_name}"
        env = RecordVideo(
            env, 
            video_folder=video_folder,
            episode_trigger=lambda x: True,
            disable_logger=True
        )
    
    if use_cnn:
        env = PreprocessCarRacing(env, resize=(84, 84))
        env = FrameStack(env, num_stack=4)
    
    env.reset(seed=seed)
    return env

def evaluate(agent, env, n_episodes=5, max_ep_len=1000):
    """Evaluate agent across multiple episodes"""
    rewards = []
    for ep in range(n_episodes):
        obs, _ = env.reset()
        done = False
        ep_reward = 0
        steps = 0
        
        while not done and steps < max_ep_len:
            if isinstance(agent, (PPOAgent, PPOAgentCNN)):
                action, _, _ = agent.select_action(obs, deterministic=True)
            else:
                action = agent.select_action(obs, deterministic=True)
            
            obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc
            ep_reward += reward
            steps += 1
        
        rewards.append(ep_reward)
    
    return {
        'mean': np.mean(rewards),
        'std': np.std(rewards),
        'min': np.min(rewards),
        'max': np.max(rewards)
    }

def save_checkpoint(agent, path, config, eval_score, step):
    """Save model checkpoint"""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    save_dict = {
        'actor_state_dict': agent.actor.state_dict(),
        'config': config,
        'eval_score': eval_score,
        'step': step
    }
    if hasattr(agent, 'critic'):
        save_dict['critic_state_dict'] = agent.critic.state_dict()
    
    torch.save(save_dict, path)
    print(f"üíæ Checkpoint saved: {path}")

def load_checkpoint(agent, checkpoint_path):
    """Load checkpoint and return step number and best eval score"""
    checkpoint = torch.load(checkpoint_path, map_location=agent.device, weights_only=False)
    agent.actor.load_state_dict(checkpoint['actor_state_dict'])
    if hasattr(agent, 'critic') and 'critic_state_dict' in checkpoint:
        agent.critic.load_state_dict(checkpoint['critic_state_dict'])
    print(f"‚úÖ Loaded checkpoint from step {checkpoint['step']}, score: {checkpoint['eval_score']:.2f}")
    return checkpoint['step'], checkpoint['eval_score']

print("‚úì Training functions defined")

In [None]:
def create_agent(config, env):
    """Factory to create the correct agent based on config"""
    algo = config['algo']
    use_cnn = config.get('use_cnn', False)
    
    # Parameters to filter based on algorithm type
    ppo_only = ['max_ep_len', 'lam', 'clip_ratio', 'train_pi_iters', 'train_v_iters', 'target_kl']
    sac_td3_only = ['tau', 'alpha', 'automatic_entropy_tuning', 'buffer_size', 
                    'policy_noise', 'noise_clip', 'policy_delay', 'exploration_noise']
    
    is_ppo = 'ppo' in algo.lower()
    filter_params = sac_td3_only + ['obs_shape', 'feature_dim'] if is_ppo else ppo_only + ['obs_shape', 'feature_dim']
    
    base_params = {
        'action_dim': env.action_space.shape[0],
        'device': config['device'],
        **{k: v for k, v in config['agent_params'].items() if k not in filter_params}
    }
    
    if use_cnn:
        base_params['obs_shape'] = env.observation_space.shape
        base_params['feature_dim'] = config.get('feature_dim', 512)
        
        agent_map = {
            'sac': SACAgentCNN, 'td3': TD3AgentCNN, 'ppo': PPOAgentCNN,
            'sac_cnn': SACAgentCNN, 'td3_cnn': TD3AgentCNN, 'ppo_cnn': PPOAgentCNN
        }
    else:
        base_params['obs_dim'] = env.observation_space.shape[0]
        agent_map = {'sac': SACAgent, 'td3': TD3Agent, 'ppo': PPOAgent}
    
    agent_class = agent_map.get(algo.lower())
    if agent_class is None:
        raise ValueError(f"Unknown algorithm: {algo}")
    
    return agent_class(**base_params)

print("‚úì Agent factory defined")

## üöÄ Main Training Loop

In [None]:
def train(config):
    """Main training function"""
    run_name = config['run_name']
    
    # Initialize WandB
    if WANDB_ENABLED and WANDB_ENTITY:
        wandb.init(
            project=WANDB_PROJECT,
            entity=WANDB_ENTITY,
            config=config,
            name=run_name,
            tags=[config['algo'], config['env_id']],
            monitor_gym=True,
            save_code=True
        )
    
    print(f"\n{'='*70}")
    print(f"üéØ Starting Training: {run_name}")
    print(f"{'='*70}\n")
    
    # Create environments
    use_cnn = config.get('use_cnn', False)
    env = make_env(config['env_id'], seed=config['seed'], use_cnn=use_cnn, capture_video=False)
    eval_env = make_env(config['env_id'], seed=config['seed']+100, use_cnn=use_cnn, 
                        capture_video=True, run_name=run_name)
    
    print(f"üåç Environment: {config['env_id']}")
    print(f"   Observation space: {env.observation_space}")
    print(f"   Action space: {env.action_space}")
    print(f"   CNN Mode: {'ON' if use_cnn else 'OFF'}\n")
    
    # Create agent
    agent = create_agent(config, env)
    print(f"ü§ñ Agent: {agent.__class__.__name__}\n")
    
    # Training variables
    obs, _ = env.reset()
    episode_reward = 0
    episode_length = 0
    best_eval_score = -np.inf
    episode_count = 0
    episode_rewards_history = []
    last_log_step = 0
    log_freq = 1000
    
    # Prefill replay buffer for off-policy agents
    prefill_steps = config.get('learning_starts', 0)
    if prefill_steps > 0 and not isinstance(agent, (PPOAgent, PPOAgentCNN)):
        print(f"üîÑ Prefilling replay buffer ({prefill_steps} steps)...")
        while len(agent.replay_buffer) < prefill_steps:
            action = env.action_space.sample()
            next_obs, reward, term, trunc, _ = env.step(action)
            agent.replay_buffer.add(obs, action, reward, next_obs, term or trunc)
            obs = next_obs
            if term or trunc:
                obs, _ = env.reset()
        print("‚úÖ Buffer prefilled\n")
        obs, _ = env.reset()
    
    # Main training loop
    print("üèÉ Training started...\n")
    for step in range(config['total_steps']):
        # Select action
        if isinstance(agent, (PPOAgent, PPOAgentCNN)):
            action, logp, val = agent.select_action(obs, deterministic=False)
            next_obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc
            agent.buffer.store(obs, action, reward, val, logp, done)
        else:
            action = agent.select_action(obs, deterministic=False)
            next_obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc
            agent.replay_buffer.add(obs, action, reward, next_obs, done)
        
        obs = next_obs
        episode_reward += reward
        episode_length += 1
        
        # Training
        metrics = None
        should_train = False
        
        if isinstance(agent, (PPOAgent, PPOAgentCNN)):
            should_train = len(agent.buffer.obs) >= config.get('rollout_length', 2048)
        else:
            should_train = step > prefill_steps and step % config['train_freq'] == 0
        
        if should_train:
            if isinstance(agent, (PPOAgent, PPOAgentCNN)):
                metrics = agent.train_step()
            else:
                for _ in range(config['gradient_steps']):
                    metrics = agent.train_step()
            
            if metrics and WANDB_ENABLED:
                wandb.log({f'train/{k}': v for k, v in metrics.items()}, step=step)
        
        # Checkpoint saving
        if step > 0 and step % config['checkpoint_freq'] == 0:
            ckpt_path = f"models/{run_name}/checkpoint_{step}.pth"
            save_checkpoint(agent, ckpt_path, config, best_eval_score, step)
        
        # Episode end
        if done or episode_length >= config.get('max_ep_len', 1000):
            episode_count += 1
            episode_rewards_history.append(episode_reward)
            if len(episode_rewards_history) > 100:
                episode_rewards_history.pop(0)
            
            success = False
            if 'CarRacing' in config['env_id']:
                success = episode_reward >= 900
            
            if WANDB_ENABLED:
                wandb.log({
                    'train/episode_reward': episode_reward,
                    'train/episode_length': episode_length,
                    'train/episode_count': episode_count,
                    'train/success': int(success),
                    'train/rolling_mean_reward': np.mean(episode_rewards_history),
                }, step=step)
            
            obs, _ = env.reset()
            episode_reward = 0
            episode_length = 0
            
            # Console progress
            if step - last_log_step >= log_freq:
                progress = 100 * step / config['total_steps']
                recent_mean = np.mean(episode_rewards_history[-10:]) if episode_rewards_history else 0
                print(f"üìä Step {step:,}/{config['total_steps']:,} ({progress:.1f}%) | "
                      f"Ep: {episode_count} | R_avg: {recent_mean:.1f} | Best: {best_eval_score:.1f}")
                last_log_step = step
        
        # Evaluation
        if step % config['eval_interval'] == 0 and step > 0:
            print(f"\nüéØ Evaluating at step {step:,}...")
            eval_results = evaluate(agent, eval_env, n_episodes=config['eval_episodes'])
            
            if WANDB_ENABLED:
                wandb.log({f'eval/{k}': v for k, v in eval_results.items()}, step=step)
            
            print(f"   Mean: {eval_results['mean']:.2f} ¬± {eval_results['std']:.2f}")
            print(f"   Range: [{eval_results['min']:.2f}, {eval_results['max']:.2f}]")
            
            if eval_results['mean'] > best_eval_score:
                best_eval_score = eval_results['mean']
                print(f"   üåü New best score!")
                best_path = f"models/{run_name}/best_model.pth"
                save_checkpoint(agent, best_path, config, best_eval_score, step)
            print()
            
            obs, _ = env.reset()
    
    # Final save
    final_path = f"models/{run_name}/final_model.pth"
    save_checkpoint(agent, final_path, config, best_eval_score, config['total_steps'])
    
    print(f"\n{'='*70}")
    print(f"‚úÖ Training Complete!")
    print(f"   Best Score: {best_eval_score:.2f}")
    print(f"   Models saved in: models/{run_name}/")
    print(f"{'='*70}\n")
    
    if WANDB_ENABLED:
        wandb.finish()
    
    env.close()
    eval_env.close()
    return best_eval_score

print("‚úì Training function defined")

## ‚ñ∂Ô∏è RUN TRAINING

Execute the cell below to start training!

In [None]:
# üöÄ START TRAINING
try:
    best_score = train(config)
    print(f"\nüéâ Training finished successfully!")
    print(f"üèÜ Best evaluation score: {best_score:.2f}")
except Exception as e:
    print(f"\n‚ùå Error during training: {e}")
    import traceback
    traceback.print_exc()

## üìä Evaluation & Visualization

Use these cells to load a trained model and evaluate it, or to view recorded videos.

In [None]:
# Load and evaluate a trained model
def evaluate_model(model_path, n_episodes=10):
    """Load a checkpoint and evaluate it"""
    # Load config from checkpoint
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    eval_config = checkpoint['config']
    
    # Create environment and agent
    use_cnn = eval_config.get('use_cnn', False)
    env = make_env(eval_config['env_id'], seed=42, use_cnn=use_cnn, capture_video=False)
    agent = create_agent(eval_config, env)
    
    # Load weights
    agent.actor.load_state_dict(checkpoint['actor_state_dict'])
    if hasattr(agent, 'critic') and 'critic_state_dict' in checkpoint:
        agent.critic.load_state_dict(checkpoint['critic_state_dict'])
    
    print(f"üìÇ Loaded model from: {model_path}")
    print(f"   Step: {checkpoint['step']}")
    print(f"   Saved eval score: {checkpoint['eval_score']:.2f}\n")
    
    # Evaluate
    print(f"üéØ Evaluating for {n_episodes} episodes...")
    results = evaluate(agent, env, n_episodes=n_episodes)
    
    print(f"\nüìä Results:")
    print(f"   Mean: {results['mean']:.2f} ¬± {results['std']:.2f}")
    print(f"   Range: [{results['min']:.2f}, {results['max']:.2f}]")
    
    env.close()
    return results

# Example usage (uncomment to use):
# model_path = "models/sac_cnn-CarRacing-v3-20251210_035927/best_model.pth"
# evaluate_model(model_path, n_episodes=10)

In [None]:
# List all saved models
import glob

print("üíæ Available Models:\n")
models = glob.glob("models/*/best_model.pth")
if models:
    for i, model in enumerate(models, 1):
        try:
            checkpoint = torch.load(model, map_location='cpu', weights_only=False)
            print(f"{i}. {model}")
            print(f"   Step: {checkpoint.get('step', 'N/A')}, Score: {checkpoint.get('eval_score', 'N/A'):.2f}\n")
        except:
            print(f"{i}. {model} (could not load info)\n")
else:
    print("No models found. Train a model first!")

In [None]:
# Display recorded videos (for Jupyter/Kaggle)
from IPython.display import Video, display
import glob

def show_videos(run_name=None, max_videos=3):
    """Display evaluation videos"""
    if run_name:
        video_pattern = f"videos/{run_name}/*.mp4"
    else:
        video_pattern = "videos/**/*.mp4"
    
    videos = glob.glob(video_pattern, recursive=True)
    
    if not videos:
        print("No videos found. Enable video recording and run evaluation.")
        return
    
    print(f"üé• Found {len(videos)} videos\n")
    
    for i, video_path in enumerate(videos[:max_videos], 1):
        print(f"Video {i}: {video_path}")
        try:
            display(Video(video_path, width=600))
        except:
            print(f"Could not display {video_path}\n")

# Example usage (uncomment to use):
# show_videos(run_name="sac_cnn-CarRacing-v3-20251210_035927", max_videos=3)

## üéì Usage Instructions

### To run on Kaggle:

1. **Upload this notebook** to Kaggle
2. **Enable GPU**: Settings ‚Üí Accelerator ‚Üí GPU T4 x2
3. **Set experiment**: Change `EXPERIMENT` variable in the configuration cell
4. **Configure WandB** (optional): 
   - Set `WANDB_ENTITY` to your username
   - Set `WANDB_ENABLED = True`
   - Add WandB API key in Kaggle Secrets as `WANDB_API_KEY`
5. **Run all cells** from top to bottom

### Available Experiments:
- `sac_carracing` - SAC on CarRacing (2M steps, ~6-8 hours)
- `td3_carracing` - TD3 on CarRacing (2M steps, ~6-8 hours)
- `ppo_carracing` - PPO on CarRacing (3M steps, ~8-10 hours)
- `sac_lunarlander` - SAC on LunarLander (300K steps, ~1 hour)
- `td3_lunarlander` - TD3 on LunarLander (300K steps, ~1 hour)
- `ppo_lunarlander` - PPO on LunarLander (1M steps, ~2-3 hours)

### Tips:
- Use **GPU** for faster training (5-10x speedup)
- Enable **Internet** in Kaggle settings for WandB logging
- Models are saved in `models/` directory
- Videos are saved in `videos/` directory
- Checkpoints are saved every 100K steps by default