# classic_rl.py

Auto-generated implementation from the Agentic RL PhD codebase.

### Original Implementations & References
The following links point to the official or high-quality reference implementations for the papers covered in this notebook:

- https://github.com/haarnoja/sac (SAC), https://github.com/deepmind/dqn (DQN)

*Note: The code below is a simplified pedagogical implementation.*

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# Papers: 
# 1. "Playing Atari with Deep Reinforcement Learning" (DQN)
# 2. "Continuous control with deep reinforcement learning" (DDPG)
# 3. "Soft Actor-Critic" (SAC)

class DQN(nn.Module):
    """
    Paper: Playing Atari with Deep Reinforcement Learning (Mnih et al., 2013)
    Innovation: Q-Learning with Neural Networks + Experience Replay
    """
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, action_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class DDPG(nn.Module):
    """
    Paper: Continuous control with deep reinforcement learning (Lillicrap et al., 2015)
    Innovation: DQN for continuous action spaces (Actor-Critic)
    """
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, action_dim), nn.Tanh() # Action range [-1, 1]
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

class SAC(nn.Module):
    """
    Paper: Soft Actor-Critic (Haarnoja et al., 2018)
    Innovation: Entropy Regularization for exploration/stability.
    """
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.log_alpha = nn.Parameter(torch.zeros(1)) # Learnable temperature
        
        # Double Q-Learning (Two Critics) to reduce overestimation
        self.critic1 = nn.Sequential(nn.Linear(state_dim + action_dim, 256), nn.ReLU(), nn.Linear(256, 1))
        self.critic2 = nn.Sequential(nn.Linear(state_dim + action_dim, 256), nn.ReLU(), nn.Linear(256, 1))
        
        # Stochastic Policy
        self.actor_fc = nn.Sequential(nn.Linear(state_dim, 256), nn.ReLU())
        self.mu = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)

    def get_action(self, state):
        x = self.actor_fc(state)
        mu = self.mu(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)
        std = log_std.exp()
        dist = torch.distributions.Normal(mu, std)
        
        # Reparameterization trick (tanh squash)
        u = dist.rsample()
        action = torch.tanh(u)
        
        # Enforce entropy correction for tanh
        log_prob = dist.log_prob(u) - torch.log(1 - action.pow(2) + 1e-6)
        return action, log_prob.sum(dim=-1, keepdim=True)
