In [2]:
# %% [1] Imports and Mock Configuration
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import einops
from dataclasses import dataclass
from typing import Tuple, Optional

# 检查是否安装了必要的库
try:
    import diffusers
except ImportError:
    print("请安装 diffusers: pip install diffusers")

# --- 模拟 Lerobot 的辅助函数 (为了让代码独立运行) ---
def get_output_shape(model, image_dim):
    return model(torch.rand(*(image_dim))).data.shape

# --- 模拟配置类 ---
@dataclass
class MockConfig:
    # 视觉相关
    vision_backbone: str = "resnet18"
    pretrained_backbone_weights: str = None
    crop_shape: Tuple[int, int] = (84, 84)
    crop_is_random: bool = False
    use_group_norm: bool = True
    spatial_softmax_num_keypoints: int = 32
    
    # 维度相关
    image_features: dict = None # 下面初始化
    action_feature: torch.Tensor = torch.zeros(14) # 假设动作维度 14
    n_obs_steps: int = 2  # 观察历史长度
    
    # U-Net 结构
    down_dims: Tuple[int, ...] = (128, 256, 512) # 简化一点以便测试
    kernel_size: int = 5
    n_groups: int = 8
    diffusion_step_embed_dim: int = 128
    use_film_scale_modulation: bool = True
    
    # 扩散参数
    num_train_timesteps: int = 100
    beta_start: float = 0.0001
    beta_end: float = 0.02
    beta_schedule: str = "squaredcos_cap_v2"
    prediction_type: str = "epsilon"
    clip_sample: bool = True
    clip_sample_range: float = 1.0

# 初始化配置
config = MockConfig()
# 模拟输入图片格式: (Channel, Height, Width)
config.image_features = {"camera_0": torch.zeros(3, 96, 96)} 

print("配置加载完成。动作维度:", config.action_feature.shape[0])

配置加载完成。动作维度: 14


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# %% [2] Spatial Softmax
class SpatialSoftmax(nn.Module):
    def __init__(self, input_shape, num_kp=None):
        super().__init__()
        self._in_c, self._in_h, self._in_w = input_shape
        
        if num_kp is not None:
            self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
            self._out_c = num_kp
        else:
            self.nets = None
            self._out_c = self._in_c

        pos_x, pos_y = torch.meshgrid(
            torch.linspace(-1.0, 1.0, self._in_w), 
            torch.linspace(-1.0, 1.0, self._in_h),
            indexing='xy'
        )
        pos_x = pos_x.reshape(self._in_h * self._in_w, 1)
        pos_y = pos_y.reshape(self._in_h * self._in_w, 1)
        self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))

    def forward(self, features):
        if self.nets is not None:
            features = self.nets(features)
        
        # [B, K, H, W] -> [B * K, H * W]
        features = features.reshape(-1, self._in_h * self._in_w)
        attention = F.softmax(features, dim=-1)
        # [B * K, H * W] x [H * W, 2] -> [B * K, 2]
        expected_xy = attention @ self.pos_grid
        feature_keypoints = expected_xy.view(-1, self._out_c, 2)
        return feature_keypoints

# --- 验证 ---
# 假设 ResNet 输出特征图大小为 (Batch=2, Channel=512, H=10, W=10)
dummy_feature_map = torch.randn(2, 512, 10, 10)
pooler = SpatialSoftmax(input_shape=(512, 10, 10), num_kp=32)
output_kp = pooler(dummy_feature_map)

print(f"输入特征图: {dummy_feature_map.shape}")
print(f"输出关键点: {output_kp.shape} (Batch, Keypoints, 2)")
assert output_kp.shape == (2, 32, 2), "Spatial Softmax 输出维度错误！"

输入特征图: torch.Size([2, 512, 10, 10])
输出关键点: torch.Size([2, 32, 2]) (Batch, Keypoints, 2)


In [4]:
# %% [3] Vision Encoder
class DiffusionRgbEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 1. 图像预处理 (Crop)
        if config.crop_shape is not None:
            self.do_crop = True
            self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
        else:
            self.do_crop = False

        # 2. 加载 Backbone (ResNet18)
        backbone_model = getattr(torchvision.models, config.vision_backbone)(weights=None) # 这里为了速度不下载权重
        # 取掉最后的全连接层和Pooling层
        self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
        
        # 3. 将 BatchNorm 替换为 GroupNorm (这对 Diffusion 训练稳定性很重要)
        # (简化版：省略了递归替换函数 _replace_submodules 的完整实现，仅作演示)
        # 在实际代码中，这里会遍历网络替换 BN 为 GN

        # 4. 计算 Backbone 输出形状
        dummy_input = torch.zeros(1, 3, *config.crop_shape)
        with torch.no_grad():
            feature_map_shape = self.backbone(dummy_input).shape[1:] # [C, H, W]

        # 5. Pooling 和 投影
        self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
        self.feature_dim = config.spatial_softmax_num_keypoints * 2
        self.out = nn.Linear(self.feature_dim, self.feature_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        if self.do_crop:
            x = self.center_crop(x) # 简化：只用 center crop
        
        x = self.backbone(x)      # [B, 512, H', W']
        x = self.pool(x)          # [B, K, 2]
        x = torch.flatten(x, start_dim=1) # [B, K*2]
        x = self.relu(self.out(x))
        return x

# --- 验证 ---
encoder = DiffusionRgbEncoder(config)
# 假设输入 Batch=2 的图片
dummy_img = torch.randn(2, 3, 96, 96) 
encoded_vec = encoder(dummy_img)

print(f"输入图片: {dummy_img.shape}")
print(f"编码向量: {encoded_vec.shape}")
assert encoded_vec.shape == (2, 32*2), "Visual Encoder 输出维度错误！"

输入图片: torch.Size([2, 3, 96, 96])
编码向量: torch.Size([2, 64])


In [5]:
# %% [4] U-Net Building Blocks
class DiffusionConv1dBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )
    def forward(self, x): return self.block(x)

class DiffusionConditionalResidualBlock1d(nn.Module):
    def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, use_film_scale_modulation=True):
        super().__init__()
        self.use_film_scale_modulation = use_film_scale_modulation
        self.out_channels = out_channels
        
        self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
        
        # FiLM: 将条件向量映射为 Scale 和 Bias
        cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
        self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
        
        self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        out = self.conv1(x)
        
        # FiLM 调制逻辑
        # cond: [Batch, Cond_Dim] -> [Batch, Out_Channels * 2] -> [Batch, Out_Channels * 2, 1]
        cond_embed = self.cond_encoder(cond).unsqueeze(-1) 
        
        if self.use_film_scale_modulation:
            scale, bias = torch.split(cond_embed, self.out_channels, dim=1)
            out = scale * out + bias # 核心：特征调制
        else:
            out = out + cond_embed
            
        out = self.conv2(out)
        out = out + self.residual_conv(x)
        return out

# --- 验证 ---
# 假设：Batch=2, 特征通道=64, 序列长度(Horizon)=16
dummy_feat = torch.randn(2, 64, 16)
# 假设：条件向量维度=128 (来自 ResNet + TimeEmbedding)
dummy_cond = torch.randn(2, 128)

block = DiffusionConditionalResidualBlock1d(in_channels=64, out_channels=128, cond_dim=128)
out_block = block(dummy_feat, dummy_cond)

print(f"ResBlock 输入: {dummy_feat.shape}")
print(f"ResBlock 输出: {out_block.shape}")
assert out_block.shape == (2, 128, 16), "ResBlock 维度变换错误！"

ResBlock 输入: torch.Size([2, 64, 16])
ResBlock 输出: torch.Size([2, 128, 16])


In [6]:
# %% [5] Conditional U-Net (Refined & Annotated)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops

# -----------------------------------------------------------------------------
# 辅助模块：位置编码
# -----------------------------------------------------------------------------
class DiffusionSinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

# -----------------------------------------------------------------------------
# 核心模块：1D U-Net
# -----------------------------------------------------------------------------
class DiffusionConditionalUnet1d(nn.Module):
    def __init__(self, config, global_cond_dim):
        super().__init__()
        self.config = config
        
        # 1. 时间步编码 (Timestep Embedding)
        # 将一个数字 t (如 50) 映射为高维向量
        self.diffusion_step_encoder = nn.Sequential(
            DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
            nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
            nn.Mish(),
            nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
        )
        
        # FiLM 条件维度 = 时间嵌入 + 视觉/状态条件
        cond_dim = config.diffusion_step_embed_dim + global_cond_dim
        
        # ---------------------------------------------------------------------
        # A. 构建 Encoder (下采样路径)
        # ---------------------------------------------------------------------
        # 结构定义: [(14, 128), (128, 256), (256, 512)]
        in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
            zip(config.down_dims[:-1], config.down_dims[1:])
        )

        self.down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            # 判断是否是 Encoder 的最后一层 (Bottleneck 前夕)
            is_last = (ind >= len(in_out) - 1)
            
            self.down_modules.append(nn.ModuleList([
                DiffusionConditionalResidualBlock1d(dim_in, dim_out, cond_dim),
                DiffusionConditionalResidualBlock1d(dim_out, dim_out, cond_dim),
                # 关键逻辑：除了最后一层，其他层都做 stride=2 的下采样
                nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity()
            ]))

        # ---------------------------------------------------------------------
        # B. 构建 Mid (瓶颈层)
        # ---------------------------------------------------------------------
        mid_dim = config.down_dims[-1]
        self.mid_modules = nn.ModuleList([
            DiffusionConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim),
            DiffusionConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim),
        ])

        # ---------------------------------------------------------------------
        # C. 构建 Decoder (上采样路径)
        # ---------------------------------------------------------------------
        self.up_modules = nn.ModuleList([])
        
        # 倒序遍历 in_out: [(256, 512), (128, 256), (14, 128)]
        # 注意：这里的 dim_orig_in/out 是指 Encoder 时的输入输出
        for ind, (dim_orig_in, dim_orig_out) in enumerate(reversed(in_out)):
            
            # Decoder 输入维度 = Encoder 输出维度
            dim_in = dim_orig_out
            
            # Decoder 输出维度 = Encoder 输入维度
            # 特殊情况：Decoder 最后一层不应该输出 14 (动作维)，而是保持高维 (128)，
            # 留给 Final Conv 去压缩。
            is_last_layer = (ind == len(in_out) - 1)
            dim_out = dim_orig_in if not is_last_layer else dim_orig_out
            
            # 关键逻辑：Decoder 除了最后一层，都需要上采样 (4->8, 8->16)
            should_upsample = not is_last_layer
            
            self.up_modules.append(nn.ModuleList([
                # dim_in * 2 是因为我们要拼接来自 Encoder 的 Skip Connection
                DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, cond_dim),
                DiffusionConditionalResidualBlock1d(dim_out, dim_out, cond_dim),
                # 上采样层
                nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if should_upsample else nn.Identity()
            ]))

        # ---------------------------------------------------------------------
        # D. 输出层
        # ---------------------------------------------------------------------
        self.final_conv = nn.Sequential(
            DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=5),
            nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
        )

    def forward(self, x, timestep, global_cond):
        """
        x: [Batch, Horizon, ActionDim] (例如 2, 16, 14)
        """
        # 1. 维度调整: (B, T, D) -> (B, D, T) 以适应 Conv1d
        x = einops.rearrange(x, "b t d -> b d t")
        
        # 2. 准备条件向量
        timesteps_embed = self.diffusion_step_encoder(timestep)
        # 拼接时间信息和视觉信息
        global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
        
        # 3. --- Encoder Forward ---
        encoder_skip_features = []
        for resnet1, resnet2, downsample in self.down_modules:
            x = resnet1(x, global_feature)
            x = resnet2(x, global_feature)
            # 重要：先保存 Skip Connection，再下采样
            # 举例 Layer 0: 输入长度 16，这里保存长度 16 的特征
            encoder_skip_features.append(x) 
            x = downsample(x) # 长度变为 8
            
        # 4. --- Mid Forward ---
        for mid_mod in self.mid_modules:
            x = mid_mod(x, global_feature)
            
        # 5. --- Decoder Forward ---
        for resnet1, resnet2, upsample in self.up_modules:
            # 取出对应的 Skip Connection
            skip = encoder_skip_features.pop()
            
            # 安全检查：防止因 padding 导致 1 像素的误差
            if x.shape[-1] != skip.shape[-1]:
                x = F.interpolate(x, size=skip.shape[-1], mode='nearest')
                
            # 拼接: [Batch, C_dec, T] + [Batch, C_skip, T] -> [Batch, C_dec*2, T]
            x = torch.cat((x, skip), dim=1) 
            
            x = resnet1(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x) # 长度翻倍 (4->8, 8->16)
            
        # 6. Final Output
        x = self.final_conv(x)
        # 维度还原: (B, D, T) -> (B, T, D)
        x = einops.rearrange(x, "b d t -> b t d")
        return x

# -----------------------------------------------------------------------------
# 验证代码
# -----------------------------------------------------------------------------
print("--- 开始验证 U-Net ---")
config_test = MockConfig()
config_test.action_feature = torch.zeros(14) 
config_test.down_dims = (128, 256, 512)
horizon = 16 

unet = DiffusionConditionalUnet1d(config_test, global_cond_dim=64)

# 模拟输入
dummy_x = torch.randn(2, horizon, 14)
dummy_t = torch.tensor([10, 20])
dummy_c = torch.randn(2, 64)

# 运行
out = unet(dummy_x, dummy_t, dummy_c)

print(f"输入形状: {dummy_x.shape}")
print(f"输出形状: {out.shape}")

if out.shape == dummy_x.shape:
    print("✅ 验证成功：输入输出形状完全一致 (16 -> 16)！")
else:
    print("❌ 验证失败：形状不匹配。")

--- 开始验证 U-Net ---
输入形状: torch.Size([2, 16, 14])
输出形状: torch.Size([2, 16, 14])
✅ 验证成功：输入输出形状完全一致 (16 -> 16)！
