# offline_rl.py

Auto-generated implementation from the Agentic RL PhD codebase.

### Original Implementations & References
The following links point to the official or high-quality reference implementations for the papers covered in this notebook:

- https://github.com/aviralkumar2907/CQL (CQL), https://github.com/kzl/decision-transformer (DT)

*Note: The code below is a simplified pedagogical implementation.*

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Papers:
# 1. "Conservative Q-Learning for Offline Reinforcement Learning" (CQL)
# 2. "Decision Transformer: Reinforcement Learning via Sequence Modeling"

def cql_loss(q_values, current_action_q, q_function):
    """
    Paper: Conservative Q-Learning (Kumar et al., 2020)
    Innovation: Penalize Q-values for out-of-distribution actions.
    Loss = MSE(Bellman) + alpha * (logsumexp(Q(s, all_a)) - Q(s, pi(a)))
    """
    # 1. Push down Q-values of random/unseen actions (logsumexp)
    # 2. Push up Q-values of actions actually in the dataset (current_action_q)
    cql_diff = torch.logsumexp(q_values, dim=1) - current_action_q
    return cql_diff.mean()

class DecisionTransformer(nn.Module):
    """
    Paper: Decision Transformer (Chen et al., 2021)
    Innovation: RL as Sequence Modeling (RTG, State, Action) -> Action
    """
    def __init__(self, state_dim, act_dim, hidden_size, max_length=20):
        super().__init__()
        self.hidden_size = hidden_size
        self.embed_t = nn.Embedding(1000, hidden_size) # Timesteps
        self.embed_s = nn.Linear(state_dim, hidden_size)
        self.embed_a = nn.Linear(act_dim, hidden_size)
        self.embed_R = nn.Linear(1, hidden_size) # Returns-to-go

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4), num_layers=3
        )
        self.predict_action = nn.Linear(hidden_size, act_dim)

    def forward(self, states, actions, returns_to_go, timesteps):
        # Input: (Batch, Seq_Len, Dim)
        # We interleave: R_t, s_t, a_t
        batch_size, seq_len = states.shape[0], states.shape[1]
        
        embed_s = self.embed_s(states)
        embed_a = self.embed_a(actions)
        embed_R = self.embed_R(returns_to_go)
        embed_t = self.embed_t(timesteps)

        # Add time embeddings
        embed_s = embed_s + embed_t
        embed_a = embed_a + embed_t
        embed_R = embed_R + embed_t

        # Stack inputs: R1, s1, a1, R2, s2, a2...
        stacked_inputs = torch.stack((embed_R, embed_s, embed_a), dim=2)
        stacked_inputs = stacked_inputs.reshape(batch_size, 3 * seq_len, self.hidden_size)
        
        # Causal Mask (Autoregressive)
        mask = nn.Transformer.generate_square_subsequent_mask(3 * seq_len)
        
        # Process
        output = self.transformer(stacked_inputs, mask=mask)
        
        # Predict action (based on R and s, so we take the middle embedding of the triplet)
        # Reshape back to (Batch, Seq, 3, Hidden)
        output = output.reshape(batch_size, seq_len, 3, self.hidden_size)
        
        # We want to predict action given R and s. 
        # In the stack (R, s, a), 's' is at index 1.
        action_preds = self.predict_action(output[:, :, 1]) 
        
        return action_preds
