### 测试目标：

- 构建一个可泛化的随意构建自回归结构的actor
- actor由多个自回归块和末端动作网络组成
- 自回归块是单维离散动作分布，末端动作网络使用无共享backbone的多头网络结构，使用单维Cat、Cont和Bern构成
- 所有自回归结构在wrapper里面构建
- 未设计GRU或自注意力结构


关键设计点总结：
1. Wrapper 全权负责拓扑结构：
   - __init__ 中通过循环计算 current_input_dim += action_dim，实现了自动的维度对齐。你不需要手动计算第 3 层输入的维度是 69 还是 70。
2. Teacher Forcing 实现：
   - 在 evaluate_actions 中，代码使用 gt_idx (Ground Truth) 生成 One-Hot 向量拼接到 curr_input。这保证了在训练时，后续节点的梯度能正确传导回前面的节点（如果前面的节点输出也是可导的——但在 PPO 中我们通常只对 Logits 求导，这里主要是为了条件概率计算的准确性：$P(B|A_{true})$）。
3. 高度模块化：
    - layers 是一个 nn.ModuleList。
    - 每一层都是一个独立的 MLP（来自 MLP_heads）。
    - 没有共享的主干网络（Backbone）。在级联结构中，这通常更好，因为每个决策阶段关注的特征可能完全不同。如果需要共享特征，可以在 Wrapper 最开始加一个 Feature Extractor，然后把提取的特征作为 state 传入循环。

In [5]:
import os, sys

def get_current_file_dir():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return os.getcwd()
        else:
            return os.path.dirname(os.path.abspath(__file__))
    except NameError:
        return os.path.dirname(os.path.abspath(__file__))

cur_dir = get_current_file_dir()
project_root = os.path.dirname(os.path.dirname(cur_dir))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
print(project_root)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, Bernoulli, Normal
import numpy as np

# 假设 MLP_heads 在同级目录下，如果不是请调整 import
from Algorithms.MLP_heads import PolicyNetDiscrete, PolicyNetContinuous, PolicyNetBernouli, PolicyNetMultiDiscrete
from Algorithms.Utils import SquashedNormal # 假设 Utils 也在

d:\3_Machine_Learning_in_Python\project03_fire_and_dodge_missile


In [8]:



class GeneralCascadeWrapper(nn.Module):
    """
    通用自回归级联 Wrapper
    
    结构定义逻辑：
    - chain_config: 一个列表，定义级联的每一层。
    - 每一层根据上一层的输出自动扩充输入维度。
    - 最后一层 (Tail) 可以包含并行的多个头。
    """
    def __init__(self, state_dim, chain_config, hidden_dims=[64, 64], device='cpu'):
        super().__init__()
        self.device = device
        self.state_dim = state_dim
        self.chain_config = chain_config
        self.hidden_dims = hidden_dims
        
        self.layers = nn.ModuleList()
        self.layer_info = [] # 存储每层的元数据（类型、动作维度等）
        
        current_input_dim = state_dim
        
        print(f"Build Cascade Actor with State Dim: {state_dim}")
        
        for i, config in enumerate(chain_config):
            layer_name = config['name']
            layer_type = config['type']
            
            info = {
                'name': layer_name,
                'type': layer_type,
                'input_dim': current_input_dim
            }
            
            # --- 构建中间自回归层 (Regression Block) ---
            # 这里的设定是：中间层只能是 Categorical (Cat)，用于做决策分支
            if layer_type == 'cat':
                action_dim = config['dim']
                # 使用 PolicyNetDiscrete
                net = PolicyNetDiscrete(current_input_dim, hidden_dims, action_dim)
                self.layers.append(net)
                
                info['action_dim'] = action_dim
                # 更新下一层的输入维度：State + OneHot(Action)
                current_input_dim += action_dim 
                print(f"  [Layer {i}] {layer_name} (Cat): In={info['input_dim']}, Out={action_dim} -> Next In={current_input_dim}")
                
            # --- 构建尾部层 (Tail Block) ---
            # 尾部不再作为下一层的输入，且可以包含混合分布
            elif layer_type == 'tail':
                # Tail 是一个容器，包含多个并行的头
                tail_heads = nn.ModuleDict()
                tail_dims = config['dims'] # dict, e.g., {'bern': 1, 'cont': 2}
                
                info['sub_heads'] = {}
                
                # 1. Tail - Bern
                if 'bern' in tail_dims:
                    b_dim = tail_dims['bern']
                    tail_heads['bern'] = PolicyNetBernouli(current_input_dim, hidden_dims, b_dim)
                    info['sub_heads']['bern'] = b_dim
                    print(f"  [Layer {i}] {layer_name} (Tail-Bern): In={current_input_dim}, Out={b_dim}")

                # 2. Tail - Cont
                if 'cont' in tail_dims:
                    c_dim = tail_dims['cont']
                    tail_heads['cont'] = PolicyNetContinuous(current_input_dim, hidden_dims, c_dim)
                    info['sub_heads']['cont'] = c_dim
                    print(f"  [Layer {i}] {layer_name} (Tail-Cont): In={current_input_dim}, Out={c_dim}")

                # 3. Tail - Cat (如果尾部也有 Cat，比如雷达模式)
                if 'cat' in tail_dims:
                    # 这里假设尾部的 cat 可能是 MultiDiscrete 或者 simple Discrete
                    # 简化起见，这里假设是 list of dims (MultiDiscrete)
                    cat_dims = tail_dims['cat'] # [3, 4]
                    if isinstance(cat_dims, int): cat_dims = [cat_dims]
                    
                    # 你的 MLP_heads 里有 PolicyNetMultiDiscrete，这里也可以用
                    # 或者复用 PolicyNetDiscrete 循环创建。这里为了演示混合头，使用 ModuleList
                    cat_heads = nn.ModuleList()
                    for c_dim in cat_dims:
                        cat_heads.append(PolicyNetDiscrete(current_input_dim, hidden_dims, c_dim))
                    tail_heads['cat'] = cat_heads
                    info['sub_heads']['cat'] = cat_dims
                    print(f"  [Layer {i}] {layer_name} (Tail-Cat): In={current_input_dim}, Out={cat_dims}")

                self.layers.append(tail_heads)
            
            self.layer_info.append(info)
            
    def _to_tensor(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float, device=self.device)
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return x

    def get_action(self, state, explore=True):
        """
        前向推理：State -> Action0 -> State+Action0 -> Action1 ...
        """
        # 推荐：虽然手动detach也可以，但在推理入口加上 no_grad 是更好的习惯，能省显存
        # 但为了保证 action_raw 里的 tensor 在某些特殊需求下（如重参数化trick）的逻辑清晰，
        # 这里演示显式 detach 的写法。
        
        state = self._to_tensor(state)
        batch_size = state.size(0)
        
        curr_input = state
        
        actions_exec = {} # 存放最终用于执行的动作（扁平化/Numpy）
        actions_raw = {}  # 存放用于存 Buffer 的 Tensor/Index
        
        # 遍历每一层
        for i, layer in enumerate(self.layers):
            info = self.layer_info[i]
            name = info['name']
            l_type = info['type']
            
            if l_type == 'cat':
                # 1. 计算 Logits
                logits = layer(curr_input, logits=True) # (B, ActionDim)
                dist = Categorical(logits=logits)
                
                # 2. 采样
                if explore:
                    idx = dist.sample() # (B, )
                else:
                    idx = torch.argmax(logits, dim=-1) # (B, )
                
                # 3. 存储 [FIX: add .detach()]
                # 保持维度 (B, )，Wrapper 外部可能需要 stack
                actions_raw[name] = idx.detach().cpu().numpy() 
                actions_exec[name] = idx.detach().cpu().numpy()
                
                # 4. 拼接 One-Hot 到输入，供下一层使用
                # idx: (B,) -> (B, 1)
                # 注意：用于拼接的 one_hot 不需要 detach，因为下一层的输入需要依赖上一层的选择（如果是可导的）
                # 但对于离散动作，索引本身阻断了梯度，所以这里 detach 与否不影响梯度流，只影响代码安全性
                idx_view = idx.view(-1, 1)
                one_hot = torch.zeros(batch_size, info['action_dim'], device=self.device)
                one_hot.scatter_(1, idx_view, 1)
                
                curr_input = torch.cat([curr_input, one_hot], dim=-1)
                
            elif l_type == 'tail':
                # Tail 层包含多个并行的头
                heads = layer # ModuleDict
                sub_info = info['sub_heads']
                
                # --- Bern ---
                if 'bern' in heads:
                    b_logits = heads['bern'](curr_input)
                    dist = Bernoulli(logits=b_logits)
                    b_val = dist.sample() if explore else (b_logits > 0).float()
                    
                    # [FIX: add .detach()]
                    actions_exec[name + '_bern'] = b_val.detach().cpu().numpy()
                    actions_raw[name + '_bern'] = b_val.detach().cpu().numpy()
                    
                # --- Cont ---
                if 'cont' in heads:
                    mu, std = heads['cont'](curr_input)
                    dist = SquashedNormal(mu, std)
                    if explore:
                        a_norm, u = dist.sample()
                    else:
                        u = mu
                        a_norm = torch.tanh(u)
                    
                    # [FIX: add .detach()]
                    # 报错就发生在这里，因为 a_norm 和 u 都带着梯度
                    actions_exec[name + '_cont'] = a_norm.detach().cpu().numpy()
                    actions_raw[name + '_cont'] = u.detach().cpu().numpy() # 存 pre-tanh
                
                # --- Cat (Multi) ---
                if 'cat' in heads:
                    cat_nets = heads['cat'] # ModuleList
                    cat_res = []
                    for net in cat_nets:
                        logits = net(curr_input, logits=True)
                        dist = Categorical(logits=logits)
                        idx = dist.sample() if explore else torch.argmax(logits, dim=-1)
                        # [FIX: add .detach()]
                        cat_res.append(idx.detach().cpu().numpy())
                    
                    # Stack results: (B, N_Heads)
                    if len(cat_res) > 0:
                        stacked = np.stack(cat_res, axis=-1)
                        actions_exec[name + '_cat'] = stacked
                        actions_raw[name + '_cat'] = stacked

        return actions_exec, actions_raw

    def evaluate_actions(self, state, actions_raw):
        """
        训练：State + GT_Action0 -> LogProb(GT_Action1) ...
        """
        state = self._to_tensor(state)
        curr_input = state
        
        total_log_prob = 0
        entropy_dict = {}
        
        for i, layer in enumerate(self.layers):
            info = self.layer_info[i]
            name = info['name']
            l_type = info['type']
            
            if l_type == 'cat':
                # 获取 GT 动作
                gt_idx = torch.tensor(actions_raw[name], device=self.device).long()
                if gt_idx.dim() > 1: gt_idx = gt_idx.squeeze(-1)
                
                # Forward
                logits = layer(curr_input, logits=True)
                dist = Categorical(logits=logits)
                
                # Calc LogProb & Entropy
                log_prob = dist.log_prob(gt_idx).unsqueeze(-1) # (B, 1)
                total_log_prob += log_prob
                entropy_dict[name] = dist.entropy().mean().item()
                
                # 拼接 GT One-Hot (Teacher Forcing)
                one_hot = F.one_hot(gt_idx, num_classes=info['action_dim']).float()
                curr_input = torch.cat([curr_input, one_hot], dim=-1)
                
            elif l_type == 'tail':
                heads = layer
                
                if 'bern' in heads:
                    gt = torch.tensor(actions_raw[name + '_bern'], device=self.device).float()
                    logits = heads['bern'](curr_input)
                    dist = Bernoulli(logits=logits)
                    total_log_prob += dist.log_prob(gt).sum(-1, keepdim=True)
                    entropy_dict[name + '_bern'] = dist.entropy().mean().item()

                if 'cont' in heads:
                    gt_u = torch.tensor(actions_raw[name + '_cont'], device=self.device).float()
                    mu, std = heads['cont'](curr_input)
                    dist = SquashedNormal(mu, std)
                    # SquashedNormal log_prob
                    total_log_prob += dist.log_prob(0, gt_u).sum(-1, keepdim=True)
                    entropy_dict[name + '_cont'] = dist.entropy().mean().item() # 近似
                    
                if 'cat' in heads:
                    gt_cats = torch.tensor(actions_raw[name + '_cat'], device=self.device).long()
                    # gt_cats: (B, N_Heads)
                    cat_nets = heads['cat']
                    avg_ent = 0
                    for hi, net in enumerate(cat_nets):
                        logits = net(curr_input, logits=True)
                        dist = Categorical(logits=logits)
                        gt_h = gt_cats[:, hi]
                        total_log_prob += dist.log_prob(gt_h).unsqueeze(-1)
                        avg_ent += dist.entropy().mean().item()
                    entropy_dict[name + '_cat'] = avg_ent / len(cat_nets)

        return total_log_prob, entropy_dict

# ==========================================
# 测试代码
# ==========================================
if __name__ == "__main__":
    # 1. 定义结构
    # 场景：先选目标(5个)，再选机动(14种)，最后决定开火(1个开关)和舵偏(2维)
    config = [
        # Layer 0: 自回归块 1 (目标选择)
        {'name': 'target_sel', 'type': 'cat', 'dim': 5}, 
        
        # Layer 1: 自回归块 2 (机动选择，依赖目标)
        {'name': 'maneuver',   'type': 'cat', 'dim': 14},
        
        # Layer 2: 尾部 (依赖目标 + 机动)
        {'name': 'execution',  'type': 'tail', 'dims': {
            'bern': 1,      # 开火
            'cont': 2,      # 舵偏
            'cat': [2, 2]   # 雷达开关，干扰开关 (MultiDiscrete)
        }}
    ]
    
    # 2. 初始化 Wrapper
    state_dim = 50
    actor = GeneralCascadeWrapper(state_dim, config, hidden_dims=[32])
    print("\n--- Network Constructed ---")
    
    # 3. 模拟 Batch 数据
    batch_size = 4
    dummy_state = torch.randn(batch_size, state_dim)
    
    # 4. 测试 Get Action (推理)
    print("\n--- Testing Get Action (Inference) ---")
    actions_exec, actions_raw = actor.get_action(dummy_state, explore=True)
    
    for k, v in actions_exec.items():
        print(f"Action '{k}': Shape {v.shape}")
        # print(v)

    # 5. 测试 Evaluate Actions (训练)
    print("\n--- Testing Evaluate Actions (Training) ---")
    log_probs, entropies = actor.evaluate_actions(dummy_state, actions_raw)
    
    print(f"Total LogProb Shape: {log_probs.shape}") # Should be (B, 1)
    print("Entropies:", entropies)
    
    # 6. 验证维度拼接逻辑
    # 理论输入维度：
    # L0 In: 50
    # L1 In: 50 + 5 = 55
    # L2 In: 55 + 14 = 69
    print("\n--- Verifying Layer Inputs ---")
    # 这里通过打印 layer_info 来验证
    for i, info in enumerate(actor.layer_info):
        print(f"Layer {i} ({info['name']}) Input Dim: {info['input_dim']}")

Build Cascade Actor with State Dim: 50
  [Layer 0] target_sel (Cat): In=50, Out=5 -> Next In=55
  [Layer 1] maneuver (Cat): In=55, Out=14 -> Next In=69
  [Layer 2] execution (Tail-Bern): In=69, Out=1
  [Layer 2] execution (Tail-Cont): In=69, Out=2
  [Layer 2] execution (Tail-Cat): In=69, Out=[2, 2]

--- Network Constructed ---

--- Testing Get Action (Inference) ---
Action 'target_sel': Shape (4,)
Action 'maneuver': Shape (4,)
Action 'execution_bern': Shape (4, 1)
Action 'execution_cont': Shape (4, 2)
Action 'execution_cat': Shape (4, 2)

--- Testing Evaluate Actions (Training) ---
Total LogProb Shape: torch.Size([4, 1])
Entropies: {'target_sel': 1.519482135772705, 'maneuver': 2.608701467514038, 'execution_bern': 0.6931433081626892, 'execution_cont': 1.0052956342697144, 'execution_cat': 0.6742390990257263}

--- Verifying Layer Inputs ---
Layer 0 (target_sel) Input Dim: 50
Layer 1 (maneuver) Input Dim: 55
Layer 2 (execution) Input Dim: 69
