# advanced.py

Auto-generated implementation from the Agentic RL PhD codebase.

### Original Implementations & References
The following links point to the official or high-quality reference implementations for the papers covered in this notebook:

- https://github.com/werner-duvaud/muzero-general (MuZero), https://github.com/danijar/dreamerv3 (DreamerV3)

*Note: The code below is a simplified pedagogical implementation.*

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Papers:
# 1. "Mastering Atari... by Planning with a Learned Model" (MuZero)
# 2. "Mastering Diverse Domains through World Models" (DreamerV3)

class MuZeroNetwork(nn.Module):
    """
    Paper: MuZero (Schrittwieser et al., 2019)
    Innovation: Planning without a ground-truth simulator.
    """
    def __init__(self, obs_dim, action_dim, latent_dim):
        super().__init__()
        # 1. Representation Network (h): Obs -> Latent State
        self.representation = nn.Sequential(nn.Conv2d(obs_dim, 64, 3), nn.ReLU(), nn.Flatten(), nn.Linear(1024, latent_dim))
        
        # 2. Dynamics Network (g): Latent + Action -> Next Latent + Reward
        self.dynamics_state = nn.Linear(latent_dim + action_dim, latent_dim)
        self.dynamics_reward = nn.Linear(latent_dim + action_dim, 1)
        
        # 3. Prediction Network (f): Latent -> Policy + Value
        self.prediction_policy = nn.Linear(latent_dim, action_dim)
        self.prediction_value = nn.Linear(latent_dim, 1)

    def initial_inference(self, observation):
        s = self.representation(observation)
        p = self.prediction_policy(s)
        v = self.prediction_value(s)
        return s, p, v

    def recurrent_inference(self, hidden_state, action):
        x = torch.cat([hidden_state, action], dim=1)
        next_s = self.dynamics_state(x)
        reward = self.dynamics_reward(x)
        p = self.prediction_policy(next_s)
        v = self.prediction_value(next_s)
        return next_s, reward, p, v

class DreamerV3(nn.Module):
    """
    Paper: DreamerV3 (Hafner et al., 2023)
    Innovation: Symlog, Discrete Latents, KL Balancing.
    """
    def __init__(self):
        super().__init__()
        # Placeholder for RSSM (Recurrent State Space Model)
        # Innovation: Uses categorical latents instead of Gaussian
    
    def symlog(self, x):
        """
        The magic scaling function from DreamerV3 that handles diverse reward scales.
        symlog(x) = sign(x) * ln(|x| + 1)
        """
        return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
    
    def symexp(self, x):
        return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
