# Dit
ref：
- https://github.com/haidog-yaqub/MeanFlow
- https://github.com/facebookresearch/DiT/blob/main/models.py


## 2D

### embedding

In [None]:
import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed

# --- 依赖的辅助模块 (从原代码中提取) ---

class TimestepEmbedder(nn.Module):
    def __init__(self, dim, nfreq=256):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(nfreq, dim), nn.SiLU(), nn.Linear(dim, dim))
        self.nfreq = nfreq

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):  # TODO ROPE
        half_dim = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half_dim, dtype=torch.float32)
            / half_dim
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t = t * 1000
        t_freq = self.timestep_embedding(t, self.nfreq)
        t_emb = self.mlp(t_freq)
        return t_emb

class LabelEmbedder(nn.Module):
    def __init__(self, num_classes, dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes + 1, dim)
        self.num_classes = num_classes

    def forward(self, labels):
        embeddings = self.embedding(labels)
        return embeddings

# --- 主要的 Embedding 模块 ---

class MfditEmbedding(nn.Module):
    """
    将输入图像 x, 时间步 t, r, 和可选的类别标签 y 转换为
    token 序列 x 和条件向量 c。
    """
    def __init__(self, input_size=32, patch_size=2, in_channels=4, dim=1152, num_classes=1000):
        super().__init__()
        self.use_cond = num_classes is not None

        # 图像和位置嵌入
        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, dim)
        num_patches = self.x_embedder.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim), requires_grad=False) # 通常在初始化后固定

        # 时间步和标签嵌入
        self.t_embedder = TimestepEmbedder(dim)
        self.r_embedder = TimestepEmbedder(dim)
        self.y_embedder = LabelEmbedder(num_classes, dim) if self.use_cond else None

    def forward(self, x, t, r, y=None):
        # 1. 图像和位置嵌入
        x = self.x_embedder(x) + self.pos_embed # init?

        # 2. 条件嵌入
        t_emb = self.t_embedder(t)
        r_emb = self.r_embedder(r)
        c = t_emb + r_emb # 合并时间步嵌入

        if self.use_cond and y is not None:
            y_emb = self.y_embedder(y)
            c = c + y_emb # 添加类别嵌入

        return x, c

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Embedding Demo ---")
    # 模型参数
    B, C, H, W = 4, 4, 32, 32
    DIM = 1152
    NUM_CLASSES = 1000

    # 创建 Embedding 模块实例
    embedding_layer = MfditEmbedding(
        input_size=H,
        in_channels=C,
        dim=DIM,
        num_classes=NUM_CLASSES
    )

    # 创建模拟输入数据
    x_in = torch.randn(B, C, H, W)
    t_in = torch.rand(B)
    r_in = torch.rand(B)
    y_in = torch.randint(0, NUM_CLASSES, (B,))

    # 前向传播
    x_tokens, c_vector = embedding_layer(x_in, t_in, r_in, y_in)

    # 打印输出形状
    print(f"输入图像形状: {x_in.shape}")
    print(f"输出 Token 序列形状: {x_tokens.shape}")
    print(f"输出条件向量 c 形状: {c_vector.shape}")
    print("-" * 20)

--- Embedding Demo ---
输入图像形状: torch.Size([4, 4, 32, 32])
输出 Token 序列形状: torch.Size([4, 256, 1152])
输出条件向量 c 形状: torch.Size([4, 1152])
--------------------


  from .autonotebook import tqdm as notebook_tqdm


### backbone

In [1]:
import torch
import torch.nn as nn
from timm.models.vision_transformer import Mlp, Attention
import torch.nn.functional as F

# --- 依赖的辅助模块 (从原代码中提取) ---
def modulate(x, scale, shift):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(1))

    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

class DiTBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=True, qk_norm=True, norm_layer=RMSNorm)
        self.attn.fused_attn = False # 禁用融合注意力 (flash attention)，因为它不能与 JVP 一起使用
        self.norm2 = RMSNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim))

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
            self.adaLN_modulation(c).chunk(6, dim=-1)
        )
        x = x + gate_msa.unsqueeze(1) * self.attn(
            modulate(self.norm1(x), scale_msa, shift_msa)
        )
        x = x + gate_mlp.unsqueeze(1) * self.mlp(
            modulate(self.norm2(x), scale_mlp, shift_mlp)
        )
        return x

# --- 主要的 Backbone 模块 ---

class MfditBackbone(nn.Module):
    """
    接收 token 序列 x 和条件向量 c，通过一系列 DiTBlock 进行处理。
    """
    def __init__(self, dim=1152, depth=28, num_heads=16, mlp_ratio=4.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])

    def forward(self, x, c):
        for block in self.blocks:
            x = block(x, c)
        return x

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Backbone Demo ---")
    # 模型参数
    B = 4
    NUM_TOKENS = 256 # (32/2) * (32/2)
    DIM = 1152

    # 创建 Backbone 模块实例
    backbone = MfditBackbone(dim=DIM, depth=28, num_heads=16)

    # 创建模拟输入数据 (来自 Embedding 层的输出)
    x_tokens_in = torch.randn(B, NUM_TOKENS, DIM)
    c_vector_in = torch.randn(B, DIM)

    # 前向传播
    x_features = backbone(x_tokens_in, c_vector_in)

    # 打印输出形状
    print(f"输入 Token 序列形状: {x_tokens_in.shape}")
    print(f"输入条件向量 c 形状: {c_vector_in.shape}")
    print(f"输出特征序列形状: {x_features.shape}")
    print("-" * 20)

  from .autonotebook import tqdm as notebook_tqdm


--- Backbone Demo ---
输入 Token 序列形状: torch.Size([4, 256, 1152])
输入条件向量 c 形状: torch.Size([4, 1152])
输出特征序列形状: torch.Size([4, 256, 1152])
--------------------


### head

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- 依赖的辅助模块 (从原代码中提取) ---
def modulate(x, scale, shift):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(1))

    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

class FinalLayer(nn.Module):
    def __init__(self, dim, patch_size, out_dim):
        super().__init__()
        self.norm_final = RMSNorm(dim)
        self.linear = nn.Linear(dim, patch_size * patch_size * out_dim)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

# --- 主要的 Task Head 模块 ---

class MfditTaskHead(nn.Module):
    """
    接收来自 Backbone 的特征 x 和条件向量 c，并将其转换回图像格式。
    """
    def __init__(self, dim=1152, patch_size=2, out_channels=4):
        super().__init__()
        self.final_layer = FinalLayer(dim, patch_size, out_channels)
        self.out_channels = out_channels
        self.patch_size = patch_size

    def unpatchify(self, x):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, C, H, W)
        """
        c = self.out_channels
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
        return imgs

    def forward(self, x, c):
        x = self.final_layer(x, c)
        x = self.unpatchify(x)
        return x

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Task Head Demo ---")
    # 模型参数
    B = 4
    NUM_TOKENS = 256
    DIM = 1152
    PATCH_SIZE = 2
    OUT_CHANNELS = 4

    # 创建 Task Head 模块实例
    task_head = MfditTaskHead(dim=DIM, patch_size=PATCH_SIZE, out_channels=OUT_CHANNELS)

    # 创建模拟输入数据 (来自 Backbone 层的输出)
    x_features_in = torch.randn(B, NUM_TOKENS, DIM)
    c_vector_in = torch.randn(B, DIM)

    # 前向传播
    output_image = task_head(x_features_in, c_vector_in)

    # 打印输出形状
    print(f"输入特征序列形状: {x_features_in.shape}")
    print(f"输入条件向量 c 形状: {c_vector_in.shape}")
    print(f"输出图像形状: {output_image.shape}")
    print("-" * 20)

--- Task Head Demo ---
输入特征序列形状: torch.Size([4, 256, 1152])
输入条件向量 c 形状: torch.Size([4, 1152])
输出图像形状: torch.Size([4, 4, 32, 32])
--------------------


## 1D

### embedding

In [None]:
import torch
import torch.nn as nn
import numpy as np
import math
from typing import Optional, Callable, Tuple, Union

# --- 辅助模块 (从原代码中提取并保持不变) ---
class TimestepEmbedder(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, dim, nfreq=256):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(nfreq, dim), nn.SiLU(), nn.Linear(dim, dim))
        self.nfreq = nfreq
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half_dim = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    def forward(self, t):
        t = t * 1000
        t_freq = self.timestep_embedding(t, self.nfreq)
        t_emb = self.mlp(t_freq)
        return t_emb

class LabelEmbedder(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, num_classes, dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes + 1, dim)
        self.num_classes = num_classes
    def forward(self, labels):
        return self.embedding(labels)

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): # TO-DO + sample_rate 
    # ... (代码与上一回答完全相同)
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

# --- 关键修改: 1D Patch Embedding ---
class PatchEmbed1d(nn.Module):
    """ 1D Time-Series to Patch Embedding """
    def __init__(
        self,
        seq_len: int = 2048,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        norm_layer: Optional[Callable] = None,
        flatten: bool = True,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        # 使用 Conv1d 进行 Patching
        self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        self.flatten = flatten

        self.num_patches = (seq_len // patch_size)

    def forward(self, x):
        # 输入 x 形状: B, C, L
        B, C, L = x.shape
        assert L == self.seq_len, f"Input sequence length ({L}) doesn't match model ({self.seq_len})."

        x = self.proj(x) # B, E, L' (L' = L / patch_size)
        if self.flatten:
            x = x.transpose(1, 2)  # B, L', E
        x = self.norm(x)
        return x

# --- 主要的 1D Embedding 模块 ---
class MfditEmbedding1d(nn.Module):
    def __init__(self, seq_len=2048, patch_size=16, in_channels=3, dim=768, num_classes=1000):
        super().__init__()
        self.use_cond = num_classes is not None

        # 1D 图像和位置嵌入
        self.x_embedder = PatchEmbed1d(seq_len, patch_size, in_channels, dim)
        self.num_patches = self.x_embedder.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, dim), requires_grad=False)

        # 条件嵌入 (保持不变)
        self.t_embedder = TimestepEmbedder(dim)
        self.r_embedder = TimestepEmbedder(dim)
        self.y_embedder = LabelEmbedder(num_classes, dim) if self.use_cond else None

        self.initialize_weights()

    def initialize_weights(self):
        # 初始化 1D 位置嵌入
        pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], np.arange(self.num_patches))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # 其他权重初始化... (可以从原代码中复制)


    def forward(self, x, t, r, y=None):
        # 1. 1D 时间序列嵌入
        # 输入 x 形状: B, C, L
        x = self.x_embedder(x) + self.pos_embed # x 输出: B, Num_Patches, Dim # TO-DO 输入的 self.num_patches 和 patch_size 是确定的， 可能会涉及到unpatchify

        # 2. 条件嵌入
        t_emb = self.t_embedder(t)
        r_emb = self.r_embedder(r)
        c = t_emb + r_emb

        if self.use_cond and y is not None:
            y_emb = self.y_embedder(y)
            c = c + y_emb

        return x, c

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- 1D Embedding Demo ---")
    # 模型参数
    B, C, L = 4, 3, 2048  # (Batch, Channels, Length)
    PATCH_SIZE = 16
    DIM = 768
    NUM_CLASSES = 10

    # 创建 Embedding 模块实例
    embedding_layer_1d = MfditEmbedding1d(
        seq_len=L,
        patch_size=PATCH_SIZE,
        in_channels=C,
        dim=DIM,
        num_classes=NUM_CLASSES
    )

    # 创建模拟输入数据
    x_in_1d = torch.randn(B, C, L) # 注意输入格式为 B,C,L
    t_in = torch.rand(B)
    r_in = torch.rand(B)
    y_in = torch.randint(0, NUM_CLASSES, (B,))

    # 前向传播
    x_tokens, c_vector = embedding_layer_1d(x_in_1d, t_in, r_in, y_in)

    # 打印输出形状
    print(f"输入时间序列形状: {x_in_1d.shape}")
    print(f"输出 Token 序列形状: {x_tokens.shape}")
    print(f"输出条件向量 c 形状: {c_vector.shape}")
    print(f"预期 Token 数量: {L // PATCH_SIZE}, 实际 Token 数量: {x_tokens.shape[1]}")
    print("-" * 20)

--- 1D Embedding Demo ---
输入时间序列形状: torch.Size([4, 3, 2048])
输出 Token 序列形状: torch.Size([4, 128, 768])
输出条件向量 c 形状: torch.Size([4, 768])
预期 Token 数量: 128, 实际 Token 数量: 128
--------------------


### backbone

In [5]:
import torch
import torch.nn as nn
from timm.models.vision_transformer import Mlp, Attention
import torch.nn.functional as F

# --- 依赖的辅助模块 (从原代码中提取，无任何修改) ---
def modulate(x, scale, shift):
    # ... (代码与上一回答完全相同)
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class RMSNorm(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(1))
    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

class DiTBlock(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=True, qk_norm=True, norm_layer=RMSNorm)
        self.attn.fused_attn = False
        self.norm2 = RMSNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim))
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.adaLN_modulation(c).chunk(6, dim=-1))
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), scale_msa, shift_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), scale_mlp, shift_mlp))
        return x

# --- 主要的 1D Backbone 模块 (仅重命名) ---
class MfditBackbone1d(nn.Module):
    def __init__(self, dim=768, depth=12, num_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])

    def forward(self, x, c):
        for block in self.blocks:
            x = block(x, c)
        return x

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- 1D Backbone Demo ---")
    # 模型参数
    B = 4
    NUM_TOKENS = 2048 // 16 # L / patch_size
    DIM = 768

    # 创建 Backbone 模块实例
    backbone_1d = MfditBackbone1d(dim=DIM, depth=12, num_heads=12)

    # 创建模拟输入数据 (来自 1D Embedding 层的输出)
    x_tokens_in = torch.randn(B, NUM_TOKENS, DIM)
    c_vector_in = torch.randn(B, DIM)

    # 前向传播
    x_features = backbone_1d(x_tokens_in, c_vector_in)

    # 打印输出形状
    print(f"输入 Token 序列形状: {x_tokens_in.shape}")
    print(f"输入条件向量 c 形状: {c_vector_in.shape}")
    print(f"输出特征序列形状: {x_features.shape}")
    print("-" * 20)

--- 1D Backbone Demo ---
输入 Token 序列形状: torch.Size([4, 128, 768])
输入条件向量 c 形状: torch.Size([4, 768])
输出特征序列形状: torch.Size([4, 128, 768])
--------------------


### head

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- 依赖的辅助模块 (从原代码中提取) ---
def modulate(x, scale, shift):
    # ... (代码与上一回答完全相同)
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class RMSNorm(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(1))
    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

class FinalLayer(nn.Module):
    # 将输出维度修改为 1D patch 所需的
    def __init__(self, dim, patch_size, out_channels):
        super().__init__()
        self.norm_final = RMSNorm(dim)
        # 线性层输出每个 patch 的所有通道的值
        self.linear = nn.Linear(dim, patch_size * out_channels)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

# --- 主要的 1D Task Head 模块 ---
class MfditTaskHead1d(nn.Module):
    def __init__(self, dim=768, patch_size=16, out_channels=3):
        super().__init__()
        self.final_layer = FinalLayer(dim, patch_size, out_channels)
        self.out_channels = out_channels
        self.patch_size = patch_size

    def unpatchify_1d(self, x):
        """
        x: (N, T, patch_size * C)
        ts: (N, C, L)
        """
        N, T, _ = x.shape
        C = self.out_channels
        P = self.patch_size
        L = T * P

        # (N, T, P * C) -> (N, T, P, C)
        x = x.view(N, T, P, C)
        # (N, T, P, C) -> (N, C, T, P)
        x = x.permute(0, 3, 1, 2)
        # (N, C, T, P) -> (N, C, L)
        ts = x.reshape(N, C, L)
        return ts

    def forward(self, x, c):
        x = self.final_layer(x, c) # x: (N, T, patch_size * out_channels)
        x = self.unpatchify_1d(x)  # x: (N, C, L)
        return x

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- 1D Task Head Demo ---")
    # 模型参数
    B, C, L = 4, 3, 2048
    PATCH_SIZE = 16
    NUM_TOKENS = L // PATCH_SIZE
    DIM = 768

    # 创建 Task Head 模块实例
    task_head_1d = MfditTaskHead1d(dim=DIM, patch_size=PATCH_SIZE, out_channels=C)

    # 创建模拟输入数据 (来自 Backbone 层的输出)
    x_features_in = torch.randn(B, NUM_TOKENS, DIM)
    c_vector_in = torch.randn(B, DIM)

    # 前向传播
    output_ts = task_head_1d(x_features_in, c_vector_in)

    # 打印输出形状
    print(f"输入特征序列形状: {x_features_in.shape}")
    print(f"输入条件向量 c 形状: {c_vector_in.shape}")
    print(f"输出时间序列形状: {output_ts.shape}")
    print(f"预期输出长度: {L}, 实际输出长度: {output_ts.shape[2]}")
    print("-" * 20)

--- 1D Task Head Demo ---
输入特征序列形状: torch.Size([4, 128, 768])
输入条件向量 c 形状: torch.Size([4, 768])
输出时间序列形状: torch.Size([4, 3, 2048])
预期输出长度: 2048, 实际输出长度: 2048
--------------------


## 1D v2

### embedding

In [None]:
import torch
import torch.nn as nn
import numpy as np
import math
from typing import Optional, Callable

# --- 辅助模块 (保持不变) ---
class TimestepEmbedder(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, dim, nfreq=256):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(nfreq, dim), nn.SiLU(), nn.Linear(dim, dim))
        self.nfreq = nfreq
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half_dim = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    def forward(self, t):
        t = t * 1000
        t_freq = self.timestep_embedding(t, self.nfreq)
        t_emb = self.mlp(t_freq)
        return t_emb

class LabelEmbedder(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, num_classes, dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes + 1, dim)
        self.num_classes = num_classes
    def forward(self, labels):
        return self.embedding(labels)

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    # ... (代码与上一回答完全相同)
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

# --- 1D Patch Embedding (保持不变, 但实例化时in_chans会+1) ---
class PatchEmbed1d(nn.Module):
    def __init__(self, seq_len, patch_size, in_chans, embed_dim, norm_layer=None, flatten=True):
        super().__init__()
        self.seq_len = seq_len
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        self.flatten = flatten
        self.num_patches = (seq_len // patch_size)

    def forward(self, x):
        B, C, L = x.shape
        assert L == self.seq_len, f"Input sequence length ({L}) doesn't match model ({self.seq_len})."
        x = self.proj(x)
        if self.flatten:
            x = x.transpose(1, 2)
        x = self.norm(x)
        return x

# --- 主要的 1D Embedding 模块 (已修改) ---
class MfditEmbedding1d(nn.Module):
    def __init__(self, num_patches=128, patch_size=16, in_channels=3, dim=768, num_classes=1000):
        super().__init__()
        self.use_cond = num_classes is not None
        self.in_channels = in_channels
        self.seq_len = num_patches * patch_size # TO-DO 目前会fix PatchEmbedding 应该 可以overleaping

        # 1. 修改: in_channels+1 用于拼接 sample_T
        # 2. 修改: seq_len 是计算得出的
        self.x_embedder = PatchEmbed1d(
            seq_len=self.seq_len,
            patch_size=patch_size,
            in_chans=self.in_channels + 1, # +1 for sample_T
            embed_dim=dim
        )
        self.num_patches = self.x_embedder.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, dim), requires_grad=False)

        # 条件嵌入 (保持不变)
        self.t_embedder = TimestepEmbedder(dim)
        self.r_embedder = TimestepEmbedder(dim)
        self.y_embedder = LabelEmbedder(num_classes, dim) if self.use_cond else None

        self.initialize_weights()

    def initialize_weights(self):
        pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], np.arange(self.num_patches))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    def forward(self, x, sample_T, t, r, y=None):
        # 输入 x 形状: B, L, C
        # 输入 sample_T 形状: B, L, 1
        assert x.shape[1] == sample_T.shape[1], "Sequence length of x and sample_T must match."
        assert x.shape[1] == self.seq_len, f"Input sequence length ({x.shape[1]}) doesn't match model ({self.seq_len})."

        # 1. 修改: 拼接x和sample_T
        x_with_time = torch.cat([x, sample_T], dim=-1) # -> B, L, C+1

        # 2. 修改: 调整维度以匹配Conv1d
        x_with_time = x_with_time.permute(0, 2, 1) # -> B, C+1, L

        # 3. 1D 时间序列嵌入
        x_tokens = self.x_embedder(x_with_time) + self.pos_embed # -> B, Num_Patches, Dim

        # 4. 条件嵌入
        t_emb = self.t_embedder(t)
        r_emb = self.r_embedder(r)
        c = t_emb + r_emb

        if self.use_cond and y is not None:
            y_emb = self.y_embedder(y)
            c = c + y_emb

        return x_tokens, c

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- (修改后) 1D Embedding Demo ---")
    B, C = 4, 3
    NUM_PATCHES = 128
    PATCH_SIZE = 16
    L = NUM_PATCHES * PATCH_SIZE # 2048
    DIM = 768
    NUM_CLASSES = 10

    embedding_layer_1d = MfditEmbedding1d(
        num_patches=NUM_PATCHES,
        patch_size=PATCH_SIZE,
        in_channels=C,
        dim=DIM,
        num_classes=NUM_CLASSES
    )

    x_in_1d = torch.randn(B, L, C)
    sample_T_in = torch.randn(B, L, 1) # 新增输入
    t_in = torch.rand(B)
    r_in = torch.rand(B)
    y_in = torch.randint(0, NUM_CLASSES, (B,))

    x_tokens, c_vector = embedding_layer_1d(x_in_1d, sample_T_in, t_in, r_in, y_in)

    print(f"输入时间序列形状 (x): {x_in_1d.shape}")
    print(f"输入采样时间形状 (sample_T): {sample_T_in.shape}")
    print(f"输出 Token 序列形状: {x_tokens.shape}")
    print(f"输出条件向量 c 形状: {c_vector.shape}")
    print(f"预期 Token 数量: {NUM_PATCHES}, 实际 Token 数量: {x_tokens.shape[1]}")
    print("-" * 20)

--- (修改后) 1D Embedding Demo ---
输入时间序列形状 (x): torch.Size([4, 2048, 3])
输入采样时间形状 (sample_T): torch.Size([4, 2048, 1])
输出 Token 序列形状: torch.Size([4, 128, 768])
输出条件向量 c 形状: torch.Size([4, 768])
预期 Token 数量: 128, 实际 Token 数量: 128
--------------------


### embedding evenly spaced

In [17]:
import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
from einops import rearrange # 导入 rearrange

# --- 辅助模块 (保持不变) ---
class TimestepEmbedder(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, dim, nfreq=256):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(nfreq, dim), nn.SiLU(), nn.Linear(dim, dim))
        self.nfreq = nfreq
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half_dim = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    def forward(self, t):
        t = t * 1000
        t_freq = self.timestep_embedding(t, self.nfreq)
        t_emb = self.mlp(t_freq)
        return t_emb

class LabelEmbedder(nn.Module):
    # ... (代码与上一回答完全相同)
    def __init__(self, num_classes, dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes + 1, dim)
        self.num_classes = num_classes
    def forward(self, labels):
        return self.embedding(labels)

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    # ... (代码与上一回答完全相同)
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

# --- 使用 einops 重构的 1D 动态补丁嵌入 ---
class DynamicPatchEmbed1d(nn.Module):
    def __init__(self, num_patches, patch_size, in_chans, embed_dim, norm_layer=None):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.proj = nn.Linear(in_chans * patch_size, embed_dim)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, L = x.shape
        x_4d = rearrange(x, 'b c l -> b c 1 l')

        if L == self.patch_size:
            stride = 1
        else:
            stride = max(1, (L - self.patch_size) // (self.num_patches - 1))

        patches = F.unfold(x_4d, kernel_size=(1, self.patch_size), stride=(1, stride))

        if patches.shape[-1] < self.num_patches:
            patches = F.pad(patches, (0, self.num_patches - patches.shape[-1]))
        elif patches.shape[-1] > self.num_patches:
            patches = patches[:, :, :self.num_patches]

        # --- 关键修改: 使用 rearrange ---
        # (B, C * P, T) -> (B, T, C * P)
        patches = rearrange(patches, 'b d t -> b t d')
        
        x = self.proj(patches)
        x = self.norm(x)
        return x

# --- 主要的 1D Embedding 模块 (使用 einops 重构) ---
class MfditEmbedding1d(nn.Module):
    def __init__(self, num_patches, patch_size, in_channels, dim, num_classes=1000):
        super().__init__()
        self.use_cond = num_classes is not None
        self.num_patches = num_patches

        self.x_embedder = DynamicPatchEmbed1d(num_patches, patch_size, in_channels + 1, dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim), requires_grad=False)
        self.t_embedder = TimestepEmbedder(dim)
        self.r_embedder = TimestepEmbedder(dim)
        self.y_embedder = LabelEmbedder(num_classes, dim) if self.use_cond else None
        self.initialize_weights()

    def initialize_weights(self):
        pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], np.arange(self.num_patches))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    def forward(self, x, sample_T, t, r, y=None):
        # --- 关键修改: 使用 rearrange ---
        x_with_time = torch.cat([x, sample_T], dim=-1)
        # (B, L, C+1) -> (B, C+1, L)
        x_with_time = rearrange(x_with_time, 'b l c -> b c l')

        x_tokens = self.x_embedder(x_with_time) + self.pos_embed
        
        t_emb, r_emb = self.t_embedder(t), self.r_embedder(r)
        c = t_emb + r_emb
        if self.use_cond and y is not None:
            c = c + self.y_embedder(y)
        return x_tokens
        
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- (einops 重构) 动态 1D Embedding Demo ---")
    B, C, L_variable = 4, 3, 3000
    NUM_PATCHES, PATCH_SIZE, DIM = 128, 16, 768

    embedding_layer_1d = MfditEmbedding1d(
        num_patches=NUM_PATCHES, patch_size=PATCH_SIZE, in_channels=C, dim=DIM
    )
    x_in_1d = torch.randn(B, L_variable, C)
    sample_T_in = torch.randn(B, L_variable, 1)
    t_in, r_in = torch.rand(B), torch.rand(B)
    x_tokens = embedding_layer_1d(x_in_1d, sample_T_in, t_in, r_in)
    print(f"输入序列形状: {x_in_1d.shape}")
    print(f"输出 Token 形状: {x_tokens.shape}")
    print("-" * 20)

--- (einops 重构) 动态 1D Embedding Demo ---
输入序列形状: torch.Size([4, 3000, 3])
输出 Token 形状: torch.Size([4, 128, 768])
--------------------


### backbone

### task_head

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- 依赖的辅助模块 (从原代码中提取) ---
def modulate(x, scale, shift):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(1))
    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g
class FinalLayer(nn.Module):
    def __init__(self, dim, patch_size, out_channels):
        super().__init__()
        self.norm_final = RMSNorm(dim)
        self.linear = nn.Linear(dim, patch_size * out_channels)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

# --- 主要的 1D Task Head 模块 (已修改) ---
class MfditTaskHead1d(nn.Module):
    def __init__(self, num_patches, patch_size, out_channels, dim=768):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.out_channels = out_channels
        self.final_layer = FinalLayer(dim, patch_size, out_channels)

    def unpatchify_1d(self, x):
        """
        x: (N, T, patch_size * C_out)
        ts: (N, C_out, L)
        """
        N, T, _ = x.shape
        # 断言token数量与初始化时一致
        assert T == self.num_patches, "Input token count must match num_patches."
        C = self.out_channels
        P = self.patch_size
        L = T * P

        x = x.view(N, T, P, C)
        x = x.permute(0, 3, 1, 2)
        ts = x.reshape(N, C, L)
        return ts

    def forward(self, x, c):
        x = self.final_layer(x, c)
        x = self.unpatchify_1d(x)  # -> (N, C, L)
        # 修改: 调整输出维度以匹配 B,L,C
        x = x.permute(0, 2, 1)   # -> (N, L, C)
        return x

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- (修改后) 1D Task Head Demo ---")
    B, C = 4, 3
    NUM_PATCHES = 128
    PATCH_SIZE = 16
    L = NUM_PATCHES * PATCH_SIZE
    DIM = 768

    task_head_1d = MfditTaskHead1d(
        num_patches=NUM_PATCHES,
        patch_size=PATCH_SIZE,
        out_channels=C,
        dim=DIM
    )

    x_features_in = torch.randn(B, NUM_PATCHES, DIM)
    c_vector_in = torch.randn(B, DIM)

    output_ts = task_head_1d(x_features_in, c_vector_in)

    print(f"输入特征序列形状: {x_features_in.shape}")
    print(f"输出时间序列形状: {output_ts.shape}")
    print(f"预期输出形状: ({B}, {L}, {C})")
    print("-" * 20)

--- (修改后) 1D Task Head Demo ---
输入特征序列形状: torch.Size([4, 128, 768])
输出时间序列形状: torch.Size([4, 2048, 3])
预期输出形状: (4, 2048, 3)
--------------------


## 1D v3 patcher

### 1

In [30]:
import torch
import torch.nn as nn
import numpy as np
import math
from einops import rearrange

# 核心: 精确的、无损的序列补丁化与重建
class SequencePatcher(nn.Module):
    def __init__(self, num_patches, patch_size):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size

    def patch(self, x):
        B, C, L = x.shape
        if L < self.patch_size:
            raise ValueError(f"输入序列长度 ({L}) 必须大于或等于补丁大小 ({self.patch_size})")

        start_indices = torch.linspace(
            0, L - self.patch_size, steps=self.num_patches, device=x.device
        ).round().long()

        patch_indices = torch.arange(self.patch_size, device=x.device)
        absolute_indices = start_indices.unsqueeze(1) + patch_indices
        absolute_indices = rearrange(absolute_indices, 't p -> 1 1 t p').expand(B, C, -1, -1)
        
        # 使用 unsqueeze 和 expand 替代 gather, 在某些情况下性能更好
        # (B, C, L) -> (B, C, 1, L) -> (B, C, T, L)
        expanded_x = x.unsqueeze(2).expand(-1, -1, self.num_patches, -1) # (B, C, T, L)
        patches = torch.gather(expanded_x, 3, absolute_indices) # (B, C, T, P)
        
        patches = rearrange(patches, 'b c t p -> b t (c p)')
        return patches
    def reconstruct(self, patches, L_original):
        """
        输入: patches (B, T, C * P)
             L_original (int) - 原始序列长度
        输出: (B, C, L_original)
        """
        B, T, _ = patches.shape
        assert T == self.num_patches, "输入补丁数量与初始化不符"
        
        C = patches.shape[-1] // self.patch_size
        
        patches = rearrange(patches, 'b t (c p) -> b t c p', p=self.patch_size)

        start_indices = torch.linspace(
            0, L_original - self.patch_size, steps=self.num_patches, device=patches.device
        ).round().long()
        
        output = torch.zeros(B, C, L_original, device=patches.device)
        overlap_count = torch.zeros(B, C, L_original, device=patches.device)
        
        # 使用 scatter_add_ 高效并行地放置补丁
        # 创建索引
        patch_pos_indices = torch.arange(self.patch_size, device=patches.device).unsqueeze(0)
        absolute_indices = start_indices.unsqueeze(1) + patch_pos_indices
        
        # 扩展维度以匹配 scatter_add_ 的要求
        absolute_indices_expanded = rearrange(absolute_indices, 't p -> 1 1 t p').expand(B, C, -1, -1)
        patches_for_scatter = rearrange(patches, 'b t c p -> b c t p')
        
        output.scatter_add_(2, absolute_indices_expanded.flatten(2), patches_for_scatter.flatten(2))
        overlap_count.scatter_add_(2, absolute_indices_expanded.flatten(2), torch.ones_like(patches_for_scatter).flatten(2))
            
        output = output / torch.clamp(overlap_count, min=1.0)
        
        return output
# 辅助模块
class TimestepEmbedder(nn.Module):
    def __init__(self, dim, nfreq=256):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(nfreq, dim), nn.SiLU(), nn.Linear(dim, dim))
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half_dim = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    def forward(self, t):
        t = t * 1000
        t_freq = self.timestep_embedding(t, self.mlp[0].in_features)
        return self.mlp(t_freq)

class LabelEmbedder(nn.Module):
    def __init__(self, num_classes, dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes + 1, dim)
    def forward(self, labels):
        return self.embedding(labels)
        
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64) / (embed_dim / 2.)
    omega = 1. / 10000**omega
    out = np.einsum('m,d->md', pos.reshape(-1), omega)
    return np.concatenate([np.sin(out), np.cos(out)], axis=1)

# 主要 Embedding 模块
class MfditEmbedding1d(nn.Module):
    def __init__(self, num_patches, patch_size, in_channels, dim, num_classes=1000):
        super().__init__()
        self.use_cond = num_classes is not None
        self.num_patches = num_patches
        self.patcher = SequencePatcher(num_patches, patch_size)
        self.proj = nn.Linear((in_channels + 1) * patch_size, dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim), requires_grad=False)
        self.t_embedder = TimestepEmbedder(dim)
        self.r_embedder = TimestepEmbedder(dim)
        self.y_embedder = LabelEmbedder(num_classes, dim) if self.use_cond else None
        self.initialize_weights()

    def initialize_weights(self):
        pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], np.arange(self.num_patches))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    def forward(self, x, sample_T, t, r, y=None):
        x_with_time = torch.cat([x, sample_T], dim=-1)
        x_with_time = rearrange(x_with_time, 'b l c -> b c l')
        patches = self.patcher.patch(x_with_time)
        x_tokens = self.proj(patches) + self.pos_embed
        
        t_emb, r_emb = self.t_embedder(t), self.r_embedder(r)
        c = t_emb + r_emb
        if self.use_cond and y is not None:
            c = c + self.y_embedder(y)
        
        # --- Bug修复: 同时返回 x_tokens 和 c ---
        return x_tokens, c
# --- Demo for Embedding Module ---
print("--- 模块一: Embedding 独立测试 ---")
# 超参数
B, C_in, L_variable = 4, 3, 3000
NUM_PATCHES, PATCH_SIZE, DIM = 128, 16, 768
NUM_CLASSES = 10

# 实例化模块
embedding_layer = MfditEmbedding1d(
    num_patches=NUM_PATCHES, 
    patch_size=PATCH_SIZE, 
    in_channels=C_in, 
    dim=DIM,
    num_classes=NUM_CLASSES
)

# 创建模拟输入
x_in = torch.randn(B, L_variable, C_in)
sample_T_in = torch.randn(B, L_variable, 1)
t_in, r_in = torch.rand(B), torch.rand(B)
y_in = torch.randint(0, NUM_CLASSES, (B,))

# 前向传播
x_tokens, c_vector = embedding_layer(x_in, sample_T_in, t_in, r_in, y_in)

# 验证输出
print(f"输入 x shape: {x_in.shape}")
print(f"输出 tokens shape: {x_tokens.shape}")
print(f"输出 条件c shape: {c_vector.shape}")
assert x_tokens.shape == (B, NUM_PATCHES, DIM)
assert c_vector.shape == (B, DIM)
print("Embedding 模块测试通过！")
print("-" * 40)

--- 模块一: Embedding 独立测试 ---
输入 x shape: torch.Size([4, 3000, 3])
输出 tokens shape: torch.Size([4, 128, 768])
输出 条件c shape: torch.Size([4, 768])
Embedding 模块测试通过！
----------------------------------------


### 2

### 3

In [29]:
# 注意: 这个模块依赖于上面定义的 SequencePatcher 类
class MfditTaskHead1d(nn.Module):
    def __init__(self, num_patches, patch_size, out_channels, dim=768):
        super().__init__()
        self.patcher = SequencePatcher(num_patches, patch_size)
        self.final_norm = RMSNorm(dim)
        self.final_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))
        self.proj_out = nn.Linear(dim, out_channels * patch_size)

    def forward(self, x_tokens, c, L_original):
        shift, scale = self.final_mod(c).chunk(2, dim=-1)
        x = modulate(self.final_norm(x_tokens), shift, scale)
        patches_out = self.proj_out(x)
        reconstructed_ts = self.patcher.reconstruct(patches_out, L_original)
        output = rearrange(reconstructed_ts, 'b c l -> b l c')
        return output
# --- Demo for Task Head Module ---
print("--- 模块三: Task Head 独立测试 ---")
# 超参数
B, L_variable, C_out = 4, 3000, 3
NUM_PATCHES, PATCH_SIZE, DIM = 128, 16, 768

# 实例化模块
task_head = MfditTaskHead1d(
    num_patches=NUM_PATCHES, 
    patch_size=PATCH_SIZE, 
    out_channels=C_out, 
    dim=DIM
)

# 模拟输入 (来自Backbone层的输出)
features_in = torch.randn(B, NUM_PATCHES, DIM)
c_vector_in = torch.randn(B, DIM)

# 前向传播
output_ts = task_head(features_in, c_vector_in, L_original=L_variable)

# 验证输出
print(f"输入 features shape: {features_in.shape}")
print(f"重建后 output shape: {output_ts.shape}")
assert output_ts.shape == (B, L_variable, C_out)
print("Task Head 模块测试通过！")
print("-" * 40)

--- 模块三: Task Head 独立测试 ---
输入 features shape: torch.Size([4, 128, 768])
重建后 output shape: torch.Size([4, 3000, 3])
Task Head 模块测试通过！
----------------------------------------


## 1D v + channel

### 1

In [1]:
import torch
import torch.nn as nn
import numpy as np
import math
from einops import rearrange
import torch.nn.functional as F

# --------------------------------------------------------------------------
# 模块一(A): 最终版 - 独立的、无状态的 SequencePatcher
# --------------------------------------------------------------------------
class SequencePatcher(nn.Module):
    def __init__(self, num_patches, patch_size):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size

    def patch(self, x):
        """ 输入 (B, C, L), 输出 (B, T, C, P) 和 start_indices (T,) """
        B, C, L = x.shape
        if L < self.patch_size:
            raise ValueError(f"序列长度 ({L}) 不能小于补丁大小 ({self.patch_size})")

        start_indices = torch.linspace(
            0, L - self.patch_size, steps=self.num_patches, device=x.device
        ).round().long()

        patch_indices = torch.arange(self.patch_size, device=x.device)
        absolute_indices = rearrange(start_indices, 't -> t 1') + rearrange(patch_indices, 'p -> 1 p')
        absolute_indices_for_gather = rearrange(absolute_indices, 't p -> 1 1 t p').expand(B, C, -1, -1)
        
        patches = torch.gather(x.unsqueeze(2).expand(-1, -1, self.num_patches, -1), 3, absolute_indices_for_gather)
        return rearrange(patches, 'b c t p -> b t c p'), start_indices

    def unpatch(self, patches, start_indices, L_original):
        """ 输入 (B, T, C, P), start_indices, L_original, 输出 (B, L, C) """
        B, T, C, P = patches.shape
        assert T == self.num_patches and P == self.patch_size
        
        output = torch.zeros(B, C, L_original, device=patches.device)
        overlap_count = torch.zeros(B, C, L_original, device=patches.device)
        
        patch_pos_indices = torch.arange(P, device=patches.device).unsqueeze(0)
        absolute_indices = start_indices.unsqueeze(1) + patch_pos_indices
        absolute_indices_expanded = rearrange(absolute_indices, 't p -> 1 1 t p').expand(B, C, -1, -1)
        patches_for_scatter = rearrange(patches, 'b t c p -> b c t p')
        
        output.scatter_add_(2, absolute_indices_expanded.flatten(2), patches_for_scatter.flatten(2))
        overlap_count.scatter_add_(2, absolute_indices_expanded.flatten(2), torch.ones_like(patches_for_scatter).flatten(2))
        
        reconstructed_ts = output / torch.clamp(overlap_count, min=1.0)
        return rearrange(reconstructed_ts, 'b c l -> b l c')

# --------------------------------------------------------------------------
# 模块一(B): 主要的 Embedding 模块
# --------------------------------------------------------------------------
class TimestepEmbedder(nn.Module):
    def __init__(self, dim, nfreq=256):
        super().__init__(); self.mlp = nn.Sequential(nn.Linear(nfreq, dim), nn.SiLU(), nn.Linear(dim, dim))
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2; freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32) / half).to(t.device)
        embedding = torch.cat([torch.cos(t[:, None] * freqs), torch.sin(t[:, None] * freqs)], dim=-1)
        if dim % 2: embedding = F.pad(embedding, (0, 1))
        return embedding
    def forward(self, t): return self.mlp(self.timestep_embedding(t * 1000, self.mlp[0].in_features))

class LabelEmbedder(nn.Module):
    def __init__(self, num_classes, dim): super().__init__(); self.embedding = nn.Embedding(num_classes + 1, dim)
    def forward(self, labels): return self.embedding(labels)

def get_1d_sincos_pos_embed(embed_dim, num_patches):
    pos = torch.arange(num_patches)
    omega = torch.arange(embed_dim // 2, dtype=torch.float64) / (embed_dim / 2.)
    omega = 1. / 10000**omega
    out = torch.einsum('m,d->md', pos, omega)
    return torch.cat([torch.sin(out), torch.cos(out)], dim=1).float()

class MfditEmbedding1d(nn.Module):
    def __init__(self, num_patches, patch_size, channel_map, c_dim, dim, num_classes=1000):
        super().__init__()
        self.use_cond = num_classes is not None
        # 核心修改: 使用统一的补丁配置
        self.patcher = SequencePatcher(num_patches, patch_size)
        
        self.channel_embedders = nn.ModuleDict({
            name: nn.Sequential(nn.Linear(num_channels + 1, c_dim * 2), nn.GELU(), nn.Linear(c_dim * 2, c_dim))
            for name, num_channels in channel_map.items()
        })
        self.proj_patch = nn.Linear(patch_size * c_dim, dim)
        self.pos_embed = nn.Parameter(get_1d_sincos_pos_embed(dim, num_patches).unsqueeze(0), requires_grad=False)
        self.t_embedder, self.r_embedder = TimestepEmbedder(dim), TimestepEmbedder(dim)
        self.y_embedder = LabelEmbedder(num_classes, dim) if self.use_cond else None

    def forward(self, x, name, sample_T, t, r, y=None):
        x_with_time = torch.cat([x, sample_T], dim=-1)
        
        patches, start_indices = self.patcher.patch(rearrange(x_with_time, 'b l c -> b c l'))
        
        patches_for_mlp = rearrange(patches, 'b t c p -> b t p c')
        channel_embedded_patches = self.channel_embedders[name](patches_for_mlp)
        flattened_patches = rearrange(channel_embedded_patches, 'b t p c_dim -> b t (p c_dim)')
        x_tokens = self.proj_patch(flattened_patches)
        x_tokens = x_tokens + self.pos_embed
        
        t_emb, r_emb = self.t_embedder(t), self.r_embedder(r)
        c = t_emb + r_emb
        if self.use_cond and y is not None: c = c + self.y_embedder(y)
            
        # 核心修改: 返回重建所需的 start_indices
        return x_tokens, c, start_indices
# --- 模块一: 最终版 Embedding 独立测试 ---
print("--- 模块一: 统一补丁配置的 Embedding 独立测试 ---")
B, L_variable = 4, 3000
# 统一的补丁配置
NUM_PATCHES, PATCH_SIZE = 128, 16
# 其他超参数
C_DIM, DIM = 16, 768

CHANNEL_MAP = {'vibration': 3, 'temperature': 1}

# 实例化模块
embedding_layer = MfditEmbedding1d(NUM_PATCHES, PATCH_SIZE, CHANNEL_MAP, C_DIM, DIM)

# --- 测试'vibration'信号 ---
x_vib = torch.randn(B, L_variable, 3)
sample_T_in = torch.randn(B, L_variable, 1)
t_in, r_in = torch.rand(B), torch.rand(B)

# 接收三个返回值
x_tokens_vib, c_vib, indices_vib = embedding_layer(x_vib, 'vibration', sample_T_in, t_in, r_in)

print(f"输入 vibration (3通道), L={L_variable}")
print(f"  -> 输出 tokens shape: {x_tokens_vib.shape}")
print(f"  -> 输出 condition shape: {c_vib.shape}")
print(f"  -> 输出 indices shape: {indices_vib.shape}")
assert x_tokens_vib.shape == (B, NUM_PATCHES, DIM)
assert indices_vib.shape == (NUM_PATCHES,)
print("  'vibration'信号处理成功!")
print("-" * 40)

# --- 测试'temperature'信号 ---
x_temp = torch.randn(B, L_variable, 1)
# 接收三个返回值
x_tokens_temp, c_temp, indices_temp = embedding_layer(x_temp, 'temperature', sample_T_in, t_in, r_in)
print(f"输入 temperature (1通道), L={L_variable}")
print(f"  -> 输出 tokens shape: {x_tokens_temp.shape}")
print(f"  -> 输出 condition shape: {c_temp.shape}")
print(f"  -> 输出 indices shape: {indices_temp.shape}")
assert x_tokens_temp.shape == (B, NUM_PATCHES, DIM)
assert indices_temp.shape == (NUM_PATCHES,)
print("  'temperature'信号处理成功!")
print("-" * 40)

--- 模块一: 统一补丁配置的 Embedding 独立测试 ---
输入 vibration (3通道), L=3000
  -> 输出 tokens shape: torch.Size([4, 128, 768])
  -> 输出 condition shape: torch.Size([4, 768])
  -> 输出 indices shape: torch.Size([128])
  'vibration'信号处理成功!
----------------------------------------
输入 temperature (1通道), L=3000
  -> 输出 tokens shape: torch.Size([4, 128, 768])
  -> 输出 condition shape: torch.Size([4, 768])
  -> 输出 indices shape: torch.Size([128])
  'temperature'信号处理成功!
----------------------------------------


### 2

In [None]:
import torch
import torch.nn as nn
from timm.models.vision_transformer import Mlp, Attention
import torch.nn.functional as F


def modulate(x, scale, shift):

    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class RMSNorm(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(1))
    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

class DiTBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=True, qk_norm=True, norm_layer=RMSNorm)
        self.attn.fused_attn = False
        self.norm2 = RMSNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim))
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.adaLN_modulation(c).chunk(6, dim=-1))
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), scale_msa, shift_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), scale_mlp, shift_mlp))
        return x

# --- 主要的 1D Backbone 模块 (仅重命名) ---
class MfditBackbone1d(nn.Module):
    def __init__(self, dim=768, depth=12, num_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])

    def forward(self, x, c):
        for block in self.blocks:
            x = block(x, c)
        return x

# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- 1D Backbone Demo ---")
    # 模型参数
    B = 4
    NUM_TOKENS = 2048 // 16 # L / patch_size
    DIM = 768

    # 创建 Backbone 模块实例
    backbone_1d = MfditBackbone1d(dim=DIM, depth=12, num_heads=12)

    # 创建模拟输入数据 (来自 1D Embedding 层的输出)
    x_tokens_in = torch.randn(B, NUM_TOKENS, DIM)
    c_vector_in = torch.randn(B, DIM)

    # 前向传播
    x_features = backbone_1d(x_tokens_in, c_vector_in)

    # 打印输出形状
    print(f"输入 Token 序列形状: {x_tokens_in.shape}")
    print(f"输入条件向量 c 形状: {c_vector_in.shape}")
    print(f"输出特征序列形状: {x_features.shape}")
    print("-" * 20)

  from .autonotebook import tqdm as notebook_tqdm


--- 1D Backbone Demo ---
输入 Token 序列形状: torch.Size([4, 128, 768])
输入条件向量 c 形状: torch.Size([4, 768])
输出特征序列形状: torch.Size([4, 128, 768])
--------------------


### 3

In [None]:
class MfditTaskHead1d(nn.Module):
    def __init__(self, num_patches, patch_size, channel_map, c_dim, dim):
        super().__init__()
        # 核心修改: 使用统一的补丁配置
        self.patcher = SequencePatcher(num_patches, patch_size)
        self.channel_map = channel_map
        self.final_norm = RMSNorm(dim)
        self.final_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))
        self.proj_out = nn.Linear(dim, patch_size * c_dim)
        self.recon_heads = nn.ModuleDict({
            name: nn.Sequential(nn.Linear(c_dim, c_dim * 2), nn.GELU(), nn.Linear(c_dim * 2, num_channels + 1))
            for name, num_channels in channel_map.items()
        })
        
    def forward(self, x_tokens, name, c, L_original, start_indices):
        shift, scale = self.final_mod(c).chunk(2, dim=-1)
        x = modulate(self.final_norm(x_tokens), shift, scale)
        x = self.proj_out(x)
        
        patches_for_mlp = rearrange(x, 'b t (p c_dim) -> b t p c_dim', p=self.patcher.patch_size)
        reconstructed_channels = self.recon_heads[name](patches_for_mlp)
        patches_to_unpatch = rearrange(reconstructed_channels, 'b t p c -> b t c p')
        
        # 核心修改: 使用传入的 start_indices 进行重建
        reconstructed_ts_with_time = self.patcher.unpatch(patches_to_unpatch, start_indices, L_original)
        
        num_original_channels = self.channel_map[name]
        reconstructed_ts = reconstructed_ts_with_time[:, :, :num_original_channels]
        
        return reconstructed_ts
# --- 模块三: 最终版 Task Head 独立测试 ---
print("--- 模块三: 最终版 Task Head 独立测试 ---")
B, L_variable = 4, 3000
# 统一的补丁配置
NUM_PATCHES, PATCH_SIZE = 128, 16
C_DIM, DIM = 16, 768

CHANNEL_MAP_HEAD = {'vibration': 3, 'temperature': 1}
task_head = MfditTaskHead1d(NUM_PATCHES, PATCH_SIZE, CHANNEL_MAP_HEAD, C_DIM, DIM)

# --- 测试重建 'vibration' (3通道) 信号 ---
output_name = 'vibration'
C_out = CHANNEL_MAP_HEAD[output_name]
print(f"请求重建 '{output_name}' 信号 ({C_out}个通道)")

# 模拟输入
features_in = torch.randn(B, NUM_PATCHES, DIM)
c_vector_in = torch.randn(B, DIM)
# 模拟从Embedding层得到的重建索引
start_indices_in = torch.linspace(0, L_variable - PATCH_SIZE, steps=NUM_PATCHES).round().long()

output_ts = task_head(features_in, output_name, c_vector_in, L_variable, start_indices_in)

print(f"输入 features shape: {features_in.shape}")
print(f"重建后 output shape: {output_ts.shape}")
assert output_ts.shape == (B, L_variable, C_out)
print(f"  '{output_name}' 信号重建成功!")
print("-" * 40)



--- 模块三: 最终版 Task Head 独立测试 ---
请求重建 'vibration' 信号 (3个通道)
输入 features shape: torch.Size([4, 128, 768])
重建后 output shape: torch.Size([4, 3000, 3])
  'vibration' 信号重建成功!
----------------------------------------


# DOPT

## 2D

### 1

In [8]:
# --- Embedding (嵌入层) ---
import numpy as np
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F

ACTIVATION = {'gelu':nn.GELU(),'tanh':nn.Tanh(),'sigmoid':nn.Sigmoid(),'relu':nn.ReLU(),'leaky_relu':nn.LeakyReLU(0.1),'softplus':nn.Softplus(),'ELU':nn.ELU(),'silu':nn.SiLU()}

import math
import logging
from torch.nn.modules.container import Sequential

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, out_dim=128,act='gelu'):
        super().__init__()
        # img_size = to_2tuple(img_size)
        # patch_size = to_2tuple(patch_size)
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.out_dim = out_dim
        self.act = ACTIVATION[act]

        self.proj = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size),
            self.act,
            nn.Conv2d(embed_dim, out_dim, kernel_size=1, stride=1)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x

class TimeAggregator(nn.Module):
    def __init__(self, n_channels, n_timesteps, out_channels, type='mlp'):
        super(TimeAggregator, self).__init__()
        self.n_channels = n_channels
        self.n_timesteps = n_timesteps
        self.out_channels = out_channels
        self.type = type
        if self.type == 'mlp':
            self.w = nn.Parameter(1/(n_timesteps * out_channels**0.5) *torch.randn(n_timesteps, out_channels, out_channels),requires_grad=True)   # initialization could be tuned
        elif self.type == 'exp_mlp':
            self.w = nn.Parameter(1/(n_timesteps * out_channels**0.5) *torch.randn(n_timesteps, out_channels, out_channels),requires_grad=True)   # initialization could be tuned
            self.gamma = nn.Parameter(2**torch.linspace(-10,10, out_channels).unsqueeze(0),requires_grad=True)  # 1, C
    ##  B, X, Y, T, C
    def forward(self, x):
        if self.type == 'mlp':
            x = torch.einsum('tij, ...ti->...j', self.w, x)
        elif self.type == 'exp_mlp':
            t = torch.linspace(0, 1, x.shape[-2]).unsqueeze(-1).to(x.device) # T, 1
            t_embed = torch.cos(t @ self.gamma)
            x = torch.einsum('tij,...ti->...j', self.w, x * t_embed)

        return x

class DPOTNetEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=1, out_channels=4, in_timesteps=1, embed_dim=768,
                 act='gelu', time_agg='exp_mlp', normalize=False):
        """
        初始化嵌入层。
        :param img_size: 输入图像的尺寸。
        :param patch_size: 每个分块的尺寸。
        :param in_channels: 输入数据的通道数。
        :param out_channels: 输出数据的通道数。
        :param in_timesteps: 输入数据的时间步长。
        :param embed_dim: 嵌入向量的维度。
        :param act: 激活函数类型。
        :param time_agg: 时间聚合层类型。
        :param normalize: 是否进行数据规范化。
        """
        super().__init__()
        self.normalize = normalize
        self.in_channels = in_channels
        # 初始化分块嵌入模块
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_channels + 3,
                                      embed_dim=out_channels * patch_size + 3, out_dim=embed_dim, act=act)
        # 初始化可学习的位置编码
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, self.patch_embed.out_size[0], self.patch_embed.out_size[1]))
        # 如果启用规范化，则初始化用于学习缩放参数的线性层
        if self.normalize:
            self.scale_feats_mu = nn.Linear(2 * in_channels, embed_dim)
            self.scale_feats_sigma = nn.Linear(2 * in_channels, embed_dim)
        # 初始化时间聚合层
        self.time_agg_layer = TimeAggregator(in_channels, in_timesteps, embed_dim, time_agg)
        # 使用截断正态分布初始化位置编码
        torch.nn.init.trunc_normal_(self.pos_embed, std=.02)

    def get_grid_3d(self, x):
        """生成3D坐标网格"""
        batchsize, size_x, size_y, size_z = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).to(x.device).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).to(x.device).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).to(x.device).repeat([batchsize, size_x, size_y, 1, 1])
        grid = torch.cat((gridx, gridy, gridz), dim=-1)
        return grid

    def forward(self, x):
        """
        前向传播函数。
        输入 x 的形状: (B, H, W, T, C)
        """
        B, H, W, T, C = x.shape
        mu, sigma = None, None
        
        # 1. (可选) 规范化
        if self.normalize:
            mu, sigma = x.mean(dim=(1, 2, 3), keepdim=True), x.std(dim=(1, 2, 3), keepdim=True) + 1e-6
            x = (x - mu) / sigma
            # 学习用于仿射变换的参数，实现类似AdaIN的功能
            scale_mu = self.scale_feats_mu(torch.cat([mu, sigma], dim=-1)).squeeze(-2).permute(0, 3, 1, 2)
            scale_sigma = self.scale_feats_sigma(torch.cat([mu, sigma], dim=-1)).squeeze(-2).permute(0, 3, 1, 2)

        # 2. 添加坐标网格信息
        grid = self.get_grid_3d(x)
        x = torch.cat((x, grid), dim=-1).contiguous()  # (B, H, W, T, C+3)
        
        # 调整维度以适应卷积操作
        x = rearrange(x, 'b x y t c -> (b t) c x y')
        
        # 3. 分块与投影
        x = self.patch_embed(x)
        
        # 4. 添加位置编码
        x = x + self.pos_embed
        
        # 调整维度以进行时间聚合
        x = rearrange(x, '(b t) c x y -> b x y t c', b=B, t=T)
        
        # 5. 时间聚合
        x = self.time_agg_layer(x)
        
        # 调整为骨干网络期望的输入形状
        x = rearrange(x, 'b x y c -> b c x y')
        
        # (可选) 应用学习到的仿射变换
        if self.normalize:
            x = scale_sigma * x + scale_mu
            
        return x, mu, sigma
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Embedding Demo ---")
    
    # --- 模型参数 ---
    IMG_SIZE = 64       # 图像尺寸
    PATCH_SIZE = 5      # 分块大小
    IN_CHANNELS = 3     # 输入通道数
    OUT_CHANNELS = 3    # 输出通道数
    IN_TIMESTEPS = 6    # 输入时间步
    EMBED_DIM = 32      # 嵌入维度
    B = 4               # 批量大小

    # --- 创建 Embedding 模块实例 ---
    embedding_layer = DPOTNetEmbedding(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=IN_CHANNELS,
        out_channels=OUT_CHANNELS,
        in_timesteps=IN_TIMESTEPS,
        embed_dim=EMBED_DIM,
        normalize=True  # 启用规范化
    )

    # --- 创建模拟输入数据 ---
    # 输入形状: (批量, 高, 宽, 时间步, 通道)
    x_in_embed = torch.randn(B, IMG_SIZE, IMG_SIZE, IN_TIMESTEPS, IN_CHANNELS)

    # --- 前向传播 ---
    embedded_x, mu, sigma = embedding_layer(x_in_embed)

    # --- 打印输出形状 ---
    print(f"输入形状: {x_in_embed.shape}")
    # 预期输出形状: (批量, 嵌入维度, 高/分块大小, 宽/分块大小)
    print(f"嵌入后特征形状: {embedded_x.shape}")
    print(f"计算出的均值形状: {mu.shape}")
    print(f"计算出的标准差形状: {sigma.shape}")
    print("-" * 20)

--- Embedding Demo ---
输入形状: torch.Size([4, 64, 64, 6, 3])
嵌入后特征形状: torch.Size([4, 32, 12, 12])
计算出的均值形状: torch.Size([4, 1, 1, 1, 3])
计算出的标准差形状: torch.Size([4, 1, 1, 1, 3])
--------------------


### 2

In [5]:
# --- Backbone (骨干网络) ---

class AFNO2D(nn.Module):
    """
    hidden_size: channel dimension size
    num_blocks: how many blocks to use in the block diagonal weight matrices (higher => less complexity but less parameters)
    """
    def __init__(self, width = 32, num_blocks=8, channel_first = False,sparsity_threshold=0.01, modes = 32,hard_thresholding_fraction=1, hidden_size_factor=1, act='gelu'):
        super().__init__()
        assert width % num_blocks == 0, f"hidden_size {width} should be divisble by num_blocks {num_blocks}"



        self.hidden_size = width
        self.sparsity_threshold = sparsity_threshold
        self.num_blocks = num_blocks
        self.block_size = self.hidden_size // self.num_blocks
        self.channel_first = channel_first
        self.modes = modes
        self.hidden_size_factor = hidden_size_factor
        # self.scale = 0.02
        self.scale = 1 / (self.block_size * self.block_size * self.hidden_size_factor)

        self.act = ACTIVATION[act]

        self.w1 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
        self.b2 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size))

    ### N, C, X, Y
    def forward(self, x, spatial_size=None):
        if self.channel_first:
            B, C, H, W = x.shape
            x = x.permute(0, 2, 3, 1)  ### ->N, X, Y, C
        else:
            B, H, W, C = x.shape
        x_orig = x

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        # x = torch.fft.rfft2(x, dim=(1, 2))

        x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size)

        o1_real = torch.zeros([B, x.shape[1], x.shape[2], self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o1_imag = torch.zeros([B, x.shape[1], x.shape[2], self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o2_real = torch.zeros(x.shape, device=x.device)
        o2_imag = torch.zeros(x.shape, device=x.device)

        # total_modes = H*W // 2 + 1
        kept_modes = self.modes

        o1_real[:, :kept_modes, :kept_modes] = self.act(
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes, :kept_modes].real, self.w1[0]) - \
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes, :kept_modes].imag, self.w1[1]) + \
            self.b1[0]
        )

        o1_imag[:, :kept_modes, :kept_modes] = self.act(
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes, :kept_modes].imag, self.w1[0]) + \
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes, :kept_modes].real, self.w1[1]) + \
            self.b1[1]
        )

        o2_real[:, :kept_modes, :kept_modes] = (
            torch.einsum('...bi,bio->...bo', o1_real[:, :kept_modes, :kept_modes], self.w2[0]) - \
            torch.einsum('...bi,bio->...bo', o1_imag[:, :kept_modes, :kept_modes], self.w2[1]) + \
            self.b2[0]
        )

        o2_imag[:, :kept_modes, :kept_modes] = (
            torch.einsum('...bi,bio->...bo', o1_imag[:, :kept_modes, :kept_modes], self.w2[0]) + \
            torch.einsum('...bi,bio->...bo', o1_real[:, :kept_modes, :kept_modes], self.w2[1]) + \
            self.b2[1]
        )

        x = torch.stack([o2_real, o2_imag], dim=-1)
        ## for ab study
        # x = F.softshrink(x, lambd=self.sparsity_threshold)

        x = torch.view_as_complex(x)
        x = x.reshape(B, x.shape[1], x.shape[2], C)
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho")



        x = x + x_orig
        if self.channel_first:
            x = x.permute(0, 3, 1, 2)     ### N, C, X, Y

        return x
class Block(nn.Module):
    def __init__(self, mixing_type = 'afno', double_skip = True, width = 32, n_blocks = 4, mlp_ratio=1., channel_first = True, modes = 32, drop=0., drop_path=0., act='gelu', h=14, w=8,):
        super().__init__()
        # self.norm1 = norm_layer(width)
        # self.norm1 = torch.nn.LayerNorm([width])
        self.norm1 = torch.nn.GroupNorm(8, width)
        # self.norm1 = torch.nn.InstanceNorm2d(width,affine=True,track_running_stats=False)
        self.width = width
        self.modes = modes
        self.act = ACTIVATION[act]

        if mixing_type == "afno":
            self.filter = AFNO2D(width = width, num_blocks=n_blocks, sparsity_threshold=0.01, channel_first = channel_first, modes = modes,
                                 hard_thresholding_fraction=1, hidden_size_factor=1, act=act)

        self.norm2 = torch.nn.GroupNorm(8, width)



        mlp_hidden_dim = int(width * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Conv2d(in_channels=width, out_channels=mlp_hidden_dim, kernel_size=1, stride=1),
            self.act,
            nn.Conv2d(in_channels=mlp_hidden_dim, out_channels=width, kernel_size=1, stride=1),
        )

        self.double_skip = double_skip

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.filter(x)


        if self.double_skip:
            x = x + residual
            residual = x

        x = self.norm2(x)
        x = self.mlp(x)

        x = x + residual

        return x
    
class DPOTNetBackbone(nn.Module):
    def __init__(self, embed_dim=768, depth=12, mixing_type='afno', modes=32, mlp_ratio=1., n_blocks=4, act='gelu',
                 img_size=224, patch_size=16):
        """
        初始化骨干网络。
        :param embed_dim: 嵌入维度。
        :param depth: 骨干网络的层数（Block的数量）。
        :param mixing_type: 混合器类型, 此处为 'afno'。
        :param modes: 傅里叶变换中保留的模式数。
        :param mlp_ratio: MLP层的隐藏维度与嵌入维度的比率。
        :param n_blocks: AFNO中块对角矩阵的数量。
        :param act: 激活函数类型。
        :param img_size: 输入图像尺寸。
        :param patch_size: 分块尺寸。
        """
        super().__init__()
        # 计算AFNO层需要的空间频率维度
        h = img_size // patch_size
        w = h // 2 + 1
        # 创建一个由多个Block组成的模块列表
        self.blocks = nn.ModuleList([
            Block(mixing_type=mixing_type, modes=modes,
                  width=embed_dim, mlp_ratio=mlp_ratio, channel_first=True, n_blocks=n_blocks, double_skip=False,
                  h=h, w=w, act=act)
            for _ in range(depth)])

    def forward(self, x):
        """
        前向传播函数。
        输入 x 的形状: (B, C, H', W')
        """
        for blk in self.blocks:
            x = blk(x)
        return x
# --- Demo 代码 ---
if __name__ == '__main__':
    # (接上一个Demo)
    print("--- Backbone Demo ---")
    
    # --- 模型参数 ---
    DEPTH = 4  # 骨干网络深度

    # --- 创建 Backbone 模块实例 ---
    backbone_layer = DPOTNetBackbone(
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE
    )

    # --- 前向传播 ---
    # 使用上一阶段的输出 `embedded_x` 作为输入
    features = backbone_layer(embedded_x)

    # --- 打印输出形状 ---
    print(f"输入特征形状: {embedded_x.shape}")
    # 骨干网络不改变特征图的形状
    print(f"输出特征形状: {features.shape}")
    print("-" * 20)

--- Backbone Demo ---


NameError: name 'EMBED_DIM' is not defined

### 3

In [None]:
# --- Task Head (任务头) ---
class DPOTNetTaskHead(nn.Module):
    def __init__(self, embed_dim=768, out_channels=4, out_timesteps=1, n_cls=12, out_layer_dim=32, patch_size=16,
                 act='gelu', normalize=False):
        """
        初始化任务头。
        :param embed_dim: 嵌入维度。
        :param out_channels: 最终输出的通道数。
        :param out_timesteps: 最终输出的时间步长。
        :param n_cls: 分类任务的类别数。
        :param out_layer_dim: 输出层中间卷积的维度。
        :param patch_size: 分块大小，用于转置卷积。
        :param act: 激活函数类型。
        :param normalize: 是否需要进行反规范化。
        """
        super().__init__()
        self.out_channels = out_channels
        self.out_timesteps = out_timesteps
        self.normalize = normalize
        self.act = ACTIVATION[act]

        # 分类头
        self.cls_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            self.act,
            nn.Linear(embed_dim, embed_dim),
            self.act,
            nn.Linear(embed_dim, n_cls)
        )
        
        # 输出生成层
        self.out_layer = nn.Sequential(
            # 转置卷积，用于上采样恢复分辨率
            nn.ConvTranspose2d(in_channels=embed_dim, out_channels=out_layer_dim, kernel_size=patch_size,
                               stride=patch_size),
            self.act,
            nn.Conv2d(in_channels=out_layer_dim, out_channels=out_layer_dim, kernel_size=1, stride=1),
            self.act,
            nn.Conv2d(in_channels=out_layer_dim, out_channels=self.out_channels * self.out_timesteps,
                      kernel_size=1, stride=1)
        )

    def forward(self, x, mu=None, sigma=None):
        """
        前向传播函数。
        输入 x 的形状: (B, C, H', W')
        """
        # 1. 分类任务
        # 全局平均池化
        cls_token = x.mean(dim=(2, 3), keepdim=False)
        cls_pred = self.cls_head(cls_token)
        
        # 2. 生成任务
        x = self.out_layer(x).permute(0, 2, 3, 1)
        
        # 调整形状以匹配 (B, H, W, T_out, C_out)
        x = x.reshape(*x.shape[:3], self.out_timesteps, self.out_channels).contiguous()
        
        # (可选) 反规范化
        if self.normalize and mu is not None and sigma is not None:
            x = x * sigma + mu
            
        return x, cls_pred
# --- Demo 代码 ---
if __name__ == '__main__':
    # (接上一个Demo)
    print("--- Task Head Demo ---")

    # --- 模型参数 ---
    OUT_TIMESTEPS = 1   # 输出时间步
    N_CLS = 10          # 分类类别数

    # --- 创建 Task Head 模块实例 ---
    task_head_layer = DPOTNetTaskHead(
        embed_dim=EMBED_DIM,
        out_channels=OUT_CHANNELS,
        out_timesteps=OUT_TIMESTEPS,
        n_cls=N_CLS,
        patch_size=PATCH_SIZE,
        normalize=True # 启用反规范化
    )

    # --- 前向传播 ---
    # 使用Backbone的输出 `features` 和 Embedding的 `mu`, `sigma` 作为输入
    output, cls_pred = task_head_layer(features, mu, sigma)

    # --- 打印输出形状 ---
    print(f"输入特征形状: {features.shape}")
    # 预期输出形状: (批量, 高, 宽, 输出时间步, 输出通道)
    print(f"最终生成输出形状: {output.shape}")
    # 预期分类输出形状: (批量, 类别数)
    print(f"分类预测形状: {cls_pred.shape}")
    print("-" * 20)

### all

In [None]:
class DPOTNet(nn.Module):
    def __init__(self, img_size=224, patch_size=16, mixing_type='afno', in_channels=1, out_channels=4,
                 in_timesteps=1, out_timesteps=1, n_blocks=4, embed_dim=768, out_layer_dim=32, depth=12,
                 modes=32, mlp_ratio=1., n_cls=12, normalize=False, act='gelu', time_agg='exp_mlp'):
        super(DPOTNet, self).__init__()
        # 实例化嵌入层
        self.embedding = DPOTNetEmbedding(img_size=img_size, patch_size=patch_size, in_channels=in_channels,
                                          out_channels=out_channels, in_timesteps=in_timesteps, embed_dim=embed_dim,
                                          act=act, time_agg=time_agg, normalize=normalize)
        # 实例化骨干网络
        self.backbone = DPOTNetBackbone(embed_dim=embed_dim, depth=depth, mixing_type=mixing_type, modes=modes,
                                        mlp_ratio=mlp_ratio, n_blocks=n_blocks, act=act, img_size=img_size,
                                        patch_size=patch_size)
        # 实例化任务头
        self.task_head = DPOTNetTaskHead(embed_dim=embed_dim, out_channels=out_channels,
                                         out_timesteps=out_timesteps, n_cls=n_cls,
                                         out_layer_dim=out_layer_dim, patch_size=patch_size, act=act,
                                         normalize=normalize)

    def forward(self, x):
        # 依次通过三个模块
        x, mu, sigma = self.embedding(x)
        x = self.backbone(x)
        x, cls_pred = self.task_head(x, mu, sigma)
        return x, cls_pred
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- 完整模型 Demo ---")

    # --- 模型参数 ---
    IMG_SIZE = 20
    PATCH_SIZE = 5
    IN_CHANNELS = 3
    OUT_CHANNELS = 3
    IN_TIMESTEPS = 6
    OUT_TIMESTEPS = 1
    EMBED_DIM = 32
    DEPTH = 4
    N_CLS = 10
    B = 4

    # --- 创建完整DPOTNet实例 ---
    net = DPOTNet(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=IN_CHANNELS,
        out_channels=OUT_CHANNELS,
        in_timesteps=IN_TIMESTEPS,
        out_timesteps=OUT_TIMESTEPS,
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        n_cls=N_CLS,
        normalize=True
    )
    
    # --- 创建模拟输入数据 ---
    x_in = torch.randn(B, IMG_SIZE, IMG_SIZE, IN_TIMESTEPS, IN_CHANNELS)
    
    # --- 前向传播 ---
    y, _ = net(x_in)
    
    # --- 打印输出形状 ---
    print(f"输入形状: {x_in.shape}")
    print(f"最终输出形状: {y.shape}")
    print("-" * 20)

## 1D

### 1

In [23]:
import numpy as np
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# 激活函数字典
ACTIVATION = {
    'gelu': nn.GELU(), 'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid(), 
    'relu': nn.ReLU(), 'leaky_relu': nn.LeakyReLU(0.1), 
    'softplus': nn.Softplus(), 'ELU': nn.ELU(), 'silu': nn.SiLU()
}

class PatchEmbed1D(nn.Module):
    """
    将1D序列分块并进行线性嵌入的模块。
    """
    def __init__(self, seq_len=1024, patch_len=16, in_chans=3, embed_dim=768, out_dim=128, act='gelu'):
        super().__init__()
        self.seq_len = seq_len
        self.patch_len = patch_len
        self.in_chans = in_chans
        self.out_dim = out_dim
        self.num_patches = seq_len // patch_len
        self.act = ACTIVATION[act]

        # 使用1D卷积实现分块和嵌入
        self.proj = nn.Sequential(
            nn.Conv1d(in_chans, embed_dim, kernel_size=patch_len, stride=patch_len),
            self.act,
            nn.Conv1d(embed_dim, out_dim, kernel_size=1, stride=1)
        )

    def forward(self, x):
        # 输入 x 形状: (B, C, L)
        B, C, L = x.shape
        assert L == self.seq_len, f"Input sequence length ({L}) doesn't match model ({self.seq_len})."
        # 输出 x 形状: (B, out_dim, num_patches)
        x = self.proj(x)
        return x

class DPOTNetEmbedding1D(nn.Module):
    """
    DPOTNet 的1D版本嵌入层。
    """
    def __init__(self, seq_len=1024, patch_len=16, in_channels=1, embed_dim=768, act='gelu', normalize=False):
        super().__init__()
        self.normalize = normalize
        self.in_channels = in_channels
        
        # 使用1D分块嵌入模块
        self.patch_embed = PatchEmbed1D(seq_len=seq_len, patch_len=patch_len, in_chans=in_channels + 1,
                                        embed_dim=embed_dim, out_dim=embed_dim, act=act)
        
        # 1D位置编码，对应每个patch
        self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.patch_embed.num_patches))
        
        if self.normalize:
            self.scale_feats_mu = nn.Linear(2 * in_channels, embed_dim)
            self.scale_feats_sigma = nn.Linear(2 * in_channels, embed_dim)
            
        torch.nn.init.trunc_normal_(self.pos_embed, std=.02)

    def get_grid_1d(self, x):
        """生成1D坐标网格"""
        batchsize, seq_len, n_feats = x.shape
        grid = torch.linspace(0, 1, seq_len, dtype=torch.float, device=x.device)
        grid = grid.reshape(1, seq_len, 1).repeat(batchsize, 1, 1)
        return grid

    def forward(self, x):
        # 输入 x 形状: (B, L, C)
        B, L, C = x.shape
        mu, sigma = None, None
        
        if self.normalize:
            mu, sigma = x.mean(dim=1, keepdim=True), x.std(dim=1, keepdim=True) + 1e-6
            x = (x - mu) / sigma
            # 调整mu和sigma的形状以用于线性层
            scale_mu = self.scale_feats_mu(torch.cat([mu, sigma], dim=-1).squeeze(1)) # -> B, embed_dim
            scale_sigma = self.scale_feats_sigma(torch.cat([mu, sigma], dim=-1).squeeze(1)) # -> B, embed_dim

        # 添加位置坐标
        grid = self.get_grid_1d(x)
        x = torch.cat((x, grid), dim=-1) # (B, L, C+1)
        
        # 调整维度以适应1D卷积: (B, C, L)
        x = x.permute(0, 2, 1)
        
        # 分块与投影
        x = self.patch_embed(x) # (B, embed_dim, num_patches)
        
        # 添加位置编码
        x = x + self.pos_embed
        
        if self.normalize:
            # 调整 scale_mu/sigma 的形状以匹配 x
            x = scale_sigma.unsqueeze(-1) * x + scale_mu.unsqueeze(-1)
            
        return x, mu, sigma
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- 1D Embedding Demo ---")
    
    # --- 模型参数 ---
    SEQ_LEN = 1024      # 输入序列长度
    PATCH_LEN = 16      # 分块长度
    IN_CHANNELS = 3     # 输入特征/通道数
    EMBED_DIM = 64      # 嵌入维度
    B = 4               # 批量大小

    # --- 创建 1D Embedding 模块实例 ---
    embedding_layer_1d = DPOTNetEmbedding1D(
        seq_len=SEQ_LEN,
        patch_len=PATCH_LEN,
        in_channels=IN_CHANNELS,
        embed_dim=EMBED_DIM,
        normalize=True
    )

    # --- 创建模拟输入数据 ---
    # 输入形状: (批量, 序列长度, 特征数)
    x_in_1d = torch.randn(B, SEQ_LEN, IN_CHANNELS)

    # --- 前向传播 ---
    embedded_x_1d, mu, sigma = embedding_layer_1d(x_in_1d)

    # --- 打印输出形状 ---
    print(f"输入序列形状 (B, L, C): {x_in_1d.shape}")
    print(f"嵌入后特征形状 (B, embed_dim, num_patches): {embedded_x_1d.shape}")
    print(f"计算出的均值形状: {mu.shape}")
    print(f"计算出的标准差形状: {sigma.shape}")
    print("-" * 20)

--- 1D Embedding Demo ---


输入序列形状 (B, L, C): torch.Size([4, 1024, 3])
嵌入后特征形状 (B, embed_dim, num_patches): torch.Size([4, 64, 64])
计算出的均值形状: torch.Size([4, 1, 3])
计算出的标准差形状: torch.Size([4, 1, 3])
--------------------


### 2

In [24]:
import torch.fft

class AFNO1D(nn.Module):
    """
    1D版本的AFNO，用于处理序列数据。
    """
    def __init__(self, width=64, num_blocks=8, modes=16, hidden_size_factor=2, act='gelu'):
        super().__init__()
        assert width % num_blocks == 0, f"hidden_size {width} should be divisible by num_blocks {num_blocks}"

        self.hidden_size = width
        self.num_blocks = num_blocks
        self.block_size = self.hidden_size // self.num_blocks
        self.modes = modes
        self.hidden_size_factor = hidden_size_factor
        self.scale = 1 / (self.block_size * self.block_size * self.hidden_size_factor)
        self.act = ACTIVATION[act]

        self.w1 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
        self.b2 = nn.Parameter(self.scale * torch.rand(2, self.num_blocks, self.block_size))

    def forward(self, x):
        # 输入 x 形状: (B, C, num_patches)
        B, C, N = x.shape
        x = x.permute(0, 2, 1) # (B, num_patches, C)
        x_orig = x

        # 1D 傅里叶变换
        x = torch.fft.rfft(x, dim=1, norm="ortho")
        x = x.reshape(B, x.shape[1], self.num_blocks, self.block_size)

        o1_real = torch.zeros([B, x.shape[1], self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o1_imag = torch.zeros([B, x.shape[1], self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o2_real = torch.zeros(x.shape, device=x.device)
        o2_imag = torch.zeros(x.shape, device=x.device)

        kept_modes = self.modes
        o1_real[:, :kept_modes] = self.act(
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].real, self.w1[0]) - \
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].imag, self.w1[1]) + \
            self.b1[0]
        )
        o1_imag[:, :kept_modes] = self.act(
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].imag, self.w1[0]) + \
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].real, self.w1[1]) + \
            self.b1[1]
        )
        o2_real[:, :kept_modes] = (
            torch.einsum('...bi,bio->...bo', o1_real[:, :kept_modes], self.w2[0]) - \
            torch.einsum('...bi,bio->...bo', o1_imag[:, :kept_modes], self.w2[1]) + \
            self.b2[0]
        )
        o2_imag[:, :kept_modes] = (
            torch.einsum('...bi,bio->...bo', o1_imag[:, :kept_modes], self.w2[0]) + \
            torch.einsum('...bi,bio->...bo', o1_real[:, :kept_modes], self.w2[1]) + \
            self.b2[1]
        )

        x = torch.stack([o2_real, o2_imag], dim=-1)
        x = torch.view_as_complex(x)
        x = x.reshape(B, x.shape[1], C)
        
        # 1D 逆傅里叶变换
        x = torch.fft.irfft(x, n=N, dim=1, norm="ortho")
        
        x = x + x_orig
        x = x.permute(0, 2, 1) # (B, C, num_patches)
        return x

class Block1D(nn.Module):
    """
    1D版本的Block模块。
    """
    def __init__(self, width=64, mlp_ratio=4., n_blocks=8, modes=16, act='gelu'):
        super().__init__()
        self.norm1 = torch.nn.GroupNorm(n_blocks, width)
        self.filter = AFNO1D(width=width, num_blocks=n_blocks, modes=modes, act=act)
        self.norm2 = torch.nn.GroupNorm(n_blocks, width)
        
        mlp_hidden_dim = int(width * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Conv1d(width, mlp_hidden_dim, 1),
            ACTIVATION[act],
            nn.Conv1d(mlp_hidden_dim, width, 1),
        )

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.filter(x)
        x = x + residual
        
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual
        
        return x

class DPOTNetBackbone1D(nn.Module):
    """
    DPOTNet 的1D版本骨干网络。
    """
    def __init__(self, width=64, depth=4, mlp_ratio=4., n_blocks=8, modes=16, act='gelu'):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block1D(width=width, mlp_ratio=mlp_ratio, n_blocks=n_blocks, modes=modes, act=act)
            for _ in range(depth)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x
# --- Demo 代码 (续) ---
if __name__ == '__main__':
    # (前面是1D Embedding Demo)
    print("--- 1D Backbone Demo ---")

    # --- 模型参数 ---
    DEPTH = 4  # 骨干网络深度
    
    # --- 创建 1D Backbone 模块实例 ---
    backbone_layer_1d = DPOTNetBackbone1D(
        width=EMBED_DIM,
        depth=DEPTH,
        modes=EMBED_DIM // 4 # 通常modes设为嵌入维度的一个分数
    )
    
    # --- 前向传播 ---
    # 使用上一阶段的输出 `embedded_x_1d` 作为输入
    features_1d = backbone_layer_1d(embedded_x_1d)

    # --- 打印输出形状 ---
    print(f"输入特征形状 (B, embed_dim, num_patches): {embedded_x_1d.shape}")
    # Backbone不改变特征序列的形状
    print(f"输出特征形状 (B, embed_dim, num_patches): {features_1d.shape}")
    print("-" * 20)

--- 1D Backbone Demo ---
输入特征形状 (B, embed_dim, num_patches): torch.Size([4, 64, 64])
输出特征形状 (B, embed_dim, num_patches): torch.Size([4, 64, 64])
--------------------


### 3

In [25]:
class DPOTNetTaskHead1D(nn.Module):
    """
    DPOTNet 的1D版本任务头。
    """
    def __init__(self, embed_dim=64, out_channels=3, n_cls=10, patch_len=16, act='gelu', normalize=False):
        super().__init__()
        self.out_channels = out_channels
        self.normalize = normalize
        self.act = ACTIVATION[act]

        self.cls_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            self.act,
            nn.Linear(embed_dim, n_cls)
        )
        
        # 1D 输出层
        self.out_layer = nn.Sequential(
            # 1D转置卷积，用于上采样恢复序列长度
            nn.ConvTranspose1d(embed_dim, embed_dim, kernel_size=patch_len, stride=patch_len),
            self.act,
            nn.Conv1d(embed_dim, self.out_channels, 1)
        )

    def forward(self, x, mu=None, sigma=None):
        # 输入 x 形状: (B, embed_dim, num_patches)
        
        # 1. 分类任务 (在 num_patches 维度上池化)
        cls_token = x.mean(dim=2) # (B, embed_dim)
        cls_pred = self.cls_head(cls_token) # (B, n_cls)
        
        # 2. 生成任务
        x_out = self.out_layer(x) # (B, out_channels, L)
        
        # 调整为 (B, L, C)
        x_out = x_out.permute(0, 2, 1)
        
        if self.normalize and mu is not None and sigma is not None:
            x_out = x_out * sigma + mu
            
        return x_out, cls_pred
# --- Demo 代码 (续) ---
if __name__ == '__main__':
    # (前面是 1D Embedding 和 1D Backbone Demo)
    print("--- 1D Task Head Demo ---")

    # --- 模型参数 ---
    N_CLS = 10  # 分类任务的类别数

    # --- 创建 1D Task Head 模块实例 ---
    task_head_layer_1d = DPOTNetTaskHead1D(
        embed_dim=EMBED_DIM,
        out_channels=IN_CHANNELS, # 输出通道数通常与输入通道数一致
        n_cls=N_CLS,
        patch_len=PATCH_LEN,
        normalize=True
    )

    # --- 前向传播 ---
    # 使用 Backbone 的输出 `features_1d` 和 Embedding 的 `mu`, `sigma` 作为输入
    output_1d, cls_pred_1d = task_head_layer_1d(features_1d, mu, sigma)

    # --- 打印输出形状 ---
    print(f"输入特征形状 (B, embed_dim, num_patches): {features_1d.shape}")
    print(f"最终生成序列形状 (B, L, C): {output_1d.shape}")
    print(f"分类预测形状 (B, n_cls): {cls_pred_1d.shape}")
    print("-" * 20)
    
    # --- 演示端到端完整模型 ---
    print("--- 1D Full Model Demo ---")
    embedding_layer = DPOTNetEmbedding1D(SEQ_LEN, PATCH_LEN, IN_CHANNELS, EMBED_DIM, normalize=True)
    backbone_layer = DPOTNetBackbone1D(EMBED_DIM, DEPTH)
    task_head_layer = DPOTNetTaskHead1D(EMBED_DIM, IN_CHANNELS, N_CLS, PATCH_LEN, normalize=True)
    
    x_in = torch.randn(B, SEQ_LEN, IN_CHANNELS)
    
    x_embed, mu_s, sigma_s = embedding_layer(x_in)
    x_feat = backbone_layer(x_embed)
    x_final, cls_final = task_head_layer(x_feat, mu_s, sigma_s)
    
    print(f"端到端输入形状: {x_in.shape}")
    print(f"端到端输出形状: {x_final.shape}")
    print(f"端到端分类形状: {cls_final.shape}")

--- 1D Task Head Demo ---
输入特征形状 (B, embed_dim, num_patches): torch.Size([4, 64, 64])
最终生成序列形状 (B, L, C): torch.Size([4, 1024, 3])
分类预测形状 (B, n_cls): torch.Size([4, 10])
--------------------
--- 1D Full Model Demo ---
端到端输入形状: torch.Size([4, 1024, 3])
端到端输出形状: torch.Size([4, 1024, 3])
端到端分类形状: torch.Size([4, 10])


# Sundial

## flow_loss

### 1

In [12]:
import torch
import torch.nn as nn
import math

class TimestepEmbedder(nn.Module):
    """将标量时间步（t）嵌入为矢量表示。"""
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

class FlowEmbedding(nn.Module):
    """
    FlowLoss的嵌入模块。
    职责：将时间步t和条件z处理成统一的条件信号y。
    """
    def __init__(self, model_channels, z_channels):
        super().__init__()
        self.time_embed = TimestepEmbedder(model_channels)
        self.cond_embed = nn.Linear(z_channels, model_channels)

    def forward(self, t, z):
        """
        :param t: 时间步张量 [batch_size]
        :param z: 来自上游Transformer的条件向量 [batch_size, z_channels]
        :return y: 统一的条件信号 [batch_size, model_channels]
        """
        t_emb = self.time_embed(t)
        z_emb = self.cond_embed(z)
        # 最终的条件信号是时间嵌入和条件嵌入的和
        y = t_emb + z_emb
        return y
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Flow Embedding Demo ---")
    
    # --- 模型参数 ---
    BATCH_SIZE = 4
    MODEL_CHANNELS = 64
    Z_CHANNELS = 32
    MAX_PERIOD = 10000
    
    # --- 创建 FlowEmbedding 模块实例 ---
    flow_embedding_layer = FlowEmbedding(model_channels=MODEL_CHANNELS, z_channels=Z_CHANNELS)
    
    # --- 创建模拟输入数据 ---
    t_in = torch.randint(0, 1000, (BATCH_SIZE,))  # 随机时间步
    z_in = torch.randn(BATCH_SIZE, Z_CHANNELS)   # 随机条件向量
    
    # --- 前向传播 ---
    y_out = flow_embedding_layer(t_in, z_in)
    
    # --- 打印输出形状 ---
    print(f"输入时间步形状: {t_in.shape}")
    print(f"输入条件向量形状: {z_in.shape}")
    print(f"输出统一条件信号形状: {y_out.shape}")
    print("-" * 20)

--- Flow Embedding Demo ---
输入时间步形状: torch.Size([4])
输入条件向量形状: torch.Size([4, 32])
输出统一条件信号形状: torch.Size([4, 64])
--------------------


### 2

In [17]:
def modulate(x, shift, scale):
    """辅助函数，用于通过移位(shift)和缩放(scale)来调整张量x。"""
    return x * (1 + scale) + shift

class ResBlock(nn.Module):
    """一个由条件信号y调制的残差块。"""
    def __init__(self, channels):
        super().__init__()
        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(channels, 3 * channels, bias=True) # 输出 shift, scale, gate
        )

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h

class FlowBackbone(nn.Module):
    """
    FlowLoss的核心处理骨干。
    职责：在条件y的指导下，处理加噪输入x。
    """
    def __init__(self, in_channels, model_channels, num_res_blocks):
        super().__init__()
        self.input_proj = nn.Linear(in_channels, model_channels)
        
        res_blocks = []
        for _ in range(num_res_blocks):
            res_blocks.append(ResBlock(model_channels))
        self.res_blocks = nn.ModuleList(res_blocks)

    def forward(self, x, y):
        """
        :param x: 加噪输入 [batch_size, in_channels]
        :param y: 来自Embedding层的条件信号 [batch_size, model_channels]
        :return: 处理后的特征 [batch_size, model_channels]
        """
        x = self.input_proj(x)
        for block in self.res_blocks:
            x = block(x, y)
        return x
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Flow Backbone Demo ---")

    # --- 模型参数 ---
    BATCH_SIZE = 4
    IN_CHANNELS = 64
    MODEL_CHANNELS = 128
    NUM_RES_BLOCKS = 6
    LENGTH = 1024  # 输入特征长度

    # --- 创建 FlowBackbone 模块实例 ---
    flow_backbone_layer = FlowBackbone(
        in_channels=IN_CHANNELS,
        model_channels=MODEL_CHANNELS,
        num_res_blocks=NUM_RES_BLOCKS
    )

    # --- 创建模拟输入数据 ---
    x_in = torch.randn(BATCH_SIZE, LENGTH, IN_CHANNELS)  # 随机加噪输入
    y_in = torch.randn(BATCH_SIZE, LENGTH, MODEL_CHANNELS)  # 随机条件信号

    # --- 前向传播 ---
    features_out = flow_backbone_layer(x_in, y_in)

    # --- 打印输出形状 ---
    print(f"输入加噪特征形状: {x_in.shape}")
    print(f"输入条件信号形状: {y_in.shape}")
    print(f"输出特征形状: {features_out.shape}")
    print("-" * 20)

--- Flow Backbone Demo ---
输入加噪特征形状: torch.Size([4, 1024, 64])
输入条件信号形状: torch.Size([4, 1024, 128])
输出特征形状: torch.Size([4, 1024, 128])
--------------------


### 3

In [15]:
class FlowTaskHead(nn.Module):
    """
    FlowLoss的任务头。
    职责：接收Backbone处理后的特征，生成最终预测。
    """
    def __init__(self, model_channels, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(model_channels, 2 * model_channels, bias=True) # 输出 shift, scale
        )

    def forward(self, x, y):
        """
        :param x: 来自Backbone的特征 [batch_size, model_channels]
        :param y: 来自Embedding层的条件信号 [batch_size, model_channels]
        :return: 最终预测 [batch_size, out_channels]
        """
        shift, scale = self.adaLN_modulation(y).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x
# --- Demo 代码 ---
if __name__ == '__main__':
    print("--- Flow Task Head Demo ---")

    # --- 模型参数 ---
    BATCH_SIZE = 4
    MODEL_CHANNELS = 128
    OUT_CHANNELS = 10

    # --- 创建 FlowTaskHead 模块实例 ---
    flow_task_head_layer = FlowTaskHead(
        model_channels=MODEL_CHANNELS,
        out_channels=OUT_CHANNELS
    )

    # --- 创建模拟输入数据 ---
    features_in = torch.randn(BATCH_SIZE, MODEL_CHANNELS)  # 来自Backbone的特征
    y_in = torch.randn(BATCH_SIZE, MODEL_CHANNELS)  # 来自Embedding层的条件信号

    # --- 前向传播 ---
    output_out = flow_task_head_layer(features_in, y_in)

    # --- 打印输出形状 ---
    print(f"输入特征形状: {features_in.shape}")
    print(f"输入条件信号形状: {y_in.shape}")
    print(f"最终预测形状: {output_out.shape}")
    print("-" * 20)

--- Flow Task Head Demo ---
输入特征形状: torch.Size([4, 128])
输入条件信号形状: torch.Size([4, 128])
最终预测形状: torch.Size([4, 10])
--------------------


### all

In [16]:
class SimpleMLP_Deconstructed(nn.Module):
    """ 
    将解耦后的三个部分组合成完整的网络。
    """
    def __init__(self, in_channels, model_channels, out_channels, z_channels, num_res_blocks):
        super().__init__()
        self.embedding = FlowEmbedding(model_channels, z_channels)
        self.backbone = FlowBackbone(in_channels, model_channels, num_res_blocks)
        self.task_head = FlowTaskHead(model_channels, out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        # 可以在这里放置权重初始化逻辑
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)


    def forward(self, x, t, z):
        # 1. Embedding: 创建条件信号y
        y = self.embedding(t, z)
        # 2. Backbone: 在y的指导下处理x
        x_processed = self.backbone(x, y)
        # 3. Task Head: 生成最终预测
        output = self.task_head(x_processed, y)
        return output

class FlowLossDeconstructed(nn.Module):
    """ 
    使用解耦模型的总封装器，包含训练和采样逻辑。
    """
    def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps):
        super().__init__()
        self.in_channels = target_channels
        # 使用我们解耦后的网络
        self.net = SimpleMLP_Deconstructed(
            in_channels=target_channels,
            model_channels=width,
            out_channels=target_channels,
            z_channels=z_channels,
            num_res_blocks=depth
        )
        self.num_sampling_steps = num_sampling_steps

    def forward(self, target, z):
        # 训练逻辑 (与之前相同)
        noise = torch.randn_like(target)
        t = torch.rand(target.shape[0], device=target.device)
        noised_target = t[:, None] * target + (1 - t[:, None]) * noise
        predict_v = self.net(noised_target, t * 1000, z)
        
        # --- 损失计算部分 ---
        weights = 1.0 / torch.arange(1, self.in_channels + 1, dtype=torch.float32, device=target.device)
        loss = (weights * (predict_v - target) ** 2).sum(dim=-1) #? 这里的损失计算是否合理？
        return loss.mean()

    def sample(self, z, num_samples=1):
        # 采样逻辑 (与之前相同)
        z = z.repeat(num_samples, 1)
        noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
        x = noise
        dt = 1.0 / self.num_sampling_steps
        for i in range(self.num_sampling_steps):
            t = (torch.ones((x.shape[0])) * i / self.num_sampling_steps).to(x.device)
            pred = self.net(x, t * 1000, z)
            x = x + (pred - noise) * dt
        x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1)
        return x


if __name__ == '__main__':
    print("--- 解耦后的 FlowLoss 模块 Demo ---")
    
    # 1. 定义模型参数
    TARGET_CHANNELS = 32
    Z_CHANNELS = 128
    DEPTH = 4
    WIDTH = 256
    NUM_SAMPLING_STEPS = 20
    BATCH_SIZE = 2
    NUM_SAMPLES = 3

    # 2. 实例化解耦后的FlowLoss模块
    flow_loss_module = FlowLossDeconstructed(
        target_channels=TARGET_CHANNELS,
        z_channels=Z_CHANNELS,
        depth=DEPTH,
        width=WIDTH,
        num_sampling_steps=NUM_SAMPLING_STEPS
    )

    # 3. 准备输入数据
    z_condition = torch.randn(BATCH_SIZE, Z_CHANNELS)
    
    # 4. 演示推理过程
    print("\n--- 1. 推理 (Sampling) ---")
    flow_loss_module.eval()
    with torch.no_grad():
        generated_samples = flow_loss_module.sample(z_condition, num_samples=NUM_SAMPLES)
    print(f"输入条件 z 的形状: {z_condition.shape}")
    print(f"生成样本的形状: {generated_samples.shape}") # (batch_size, num_samples, target_channels)

    # 5. 演示训练过程
    print("\n--- 2. 训练 (Loss Calculation) ---")
    flow_loss_module.train()
    true_target = torch.randn(BATCH_SIZE, TARGET_CHANNELS)
    loss = flow_loss_module(target=true_target, z=z_condition)
    print(f"真实目标 target 的形状: {true_target.shape}")
    print(f"计算得到的损失值: {loss.item():.4f}")
    
    print("\n解耦后的 FlowLoss Demo 运行完毕!")

--- 解耦后的 FlowLoss 模块 Demo ---

--- 1. 推理 (Sampling) ---
输入条件 z 的形状: torch.Size([2, 128])
生成样本的形状: torch.Size([2, 3, 32])

--- 2. 训练 (Loss Calculation) ---
真实目标 target 的形状: torch.Size([2, 32])
计算得到的损失值: 21.9188

解耦后的 FlowLoss Demo 运行完毕!


##  1D

### 1

In [20]:
from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers import PretrainedConfig

# 为了演示，我们创建一个最小化的配置类
class SundialConfig(PretrainedConfig):
    model_type = "sundial"
    
    def __init__(self, 
                 input_token_len=32, 
                 hidden_size=128, 
                 intermediate_size=256,
                 dropout_rate=0.1,
                 hidden_act="silu",
                 **kwargs):
        self.input_token_len = input_token_len
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.dropout_rate = dropout_rate
        self.hidden_act = hidden_act
        super().__init__(**kwargs)

class SundialPatchEmbedding(nn.Module):
    """
    将一维时间序列转换为一系列嵌入向量(patch embedding)。
    这个过程类似于NLP中的分词和词嵌入。
    """
    def __init__(self, config: SundialConfig):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout_rate)
        # 输入维度是数据patch和mask patch拼接后的大小 (input_token_len * 2)
        self.hidden_layer = nn.Linear(
            config.input_token_len * 2, config.intermediate_size)
        self.act = ACT2FN[config.hidden_act]
        self.output_layer = nn.Linear(
            config.intermediate_size, config.hidden_size)
        self.residual_layer = nn.Linear(
            config.input_token_len * 2, config.hidden_size)
        self.input_token_len = config.input_token_len

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: 输入的原始时间序列, 形状为 [batch_size, sequence_length]

        # 创建一个与输入x形状相同,值为1的掩码,用于标识真实数据位置
        mask = torch.ones_like(x, dtype=torch.float32)
        input_length = x.shape[-1]

        # 计算需要填充的长度,以确保序列总长度是 input_token_len 的整数倍
        padding_length = (self.input_token_len - (input_length %
                                         self.input_token_len)) % self.input_token_len
        
        # 在序列的左侧(过去)进行填充,填充值为0
        x = F.pad(x, (padding_length, 0))
        mask = F.pad(mask, (padding_length, 0))
        
        # 使用 unfold 将时间序列分割成多个patch(块)
        x_patched = x.unfold(dimension=-1, size=self.input_token_len,
                             step=self.input_token_len)
        mask_patched = mask.unfold(
            dimension=-1, size=self.input_token_len, step=self.input_token_len)

        # 将数据patch和mask patch沿着最后一个维度拼接
        combined = torch.cat([x_patched, mask_patched], dim=-1)
        
        # 通过一个MLP(多层感知机)将patch投影到高维空间(hidden_size)
        hid = self.act(self.hidden_layer(combined))
        out = self.dropout(self.output_layer(hid))
        res = self.residual_layer(combined)
        
        # 添加残差连接
        out = out + res
        return out

# --- Embedding层 使用示例 ---
if __name__ == '__main__':
    print("--- 1. Embedding层 Demo ---")
    
    # 1. 配置模型
    config = SundialConfig(
        input_token_len=32,
        hidden_size=128
    )

    # 2. 实例化Embedding层
    embedding_layer = SundialPatchEmbedding(config)
    embedding_layer.eval()

    # 3. 准备输入数据
    BATCH_SIZE = 2
    SEQUENCE_LENGTH = 512
    CHANNEL = 3
    # 模拟两段原始时间序列数据
    input_timeseries = torch.randn(BATCH_SIZE, SEQUENCE_LENGTH, CHANNEL)
    print(f"原始输入数据形状: (batch_size, sequence_length, channel) = {input_timeseries.shape}")

    # 4. 前向传播
    with torch.no_grad():
        embeddings = embedding_layer(input_timeseries)

    # 5. 查看输出
    # 原始序列被分成了 512 / 32 = 16 个patch
    # 输出形状应为: (batch_size, num_patches, hidden_size)
    print(f"生成的嵌入向量形状: (batch_size, num_patches, hidden_size) = {embeddings.shape}")
    print("Embedding层 Demo 运行完毕!\n")

--- 1. Embedding层 Demo ---
原始输入数据形状: (batch_size, sequence_length, channel) = torch.Size([2, 512, 3])
生成的嵌入向量形状: (batch_size, num_patches, hidden_size) = torch.Size([2, 512, 1, 128])
Embedding层 Demo 运行完毕!



### 2

In [21]:
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, Cache, DynamicCache
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast

# 为了演示，我们创建一个更完整的SundialConfig
class SundialConfig(PretrainedConfig):
    model_type = "sundial"

    def __init__(self,
                 input_token_len=32,
                 hidden_size=128,
                 intermediate_size=256,
                 num_hidden_layers=4,
                 num_attention_heads=4,
                 max_position_embeddings=10000,
                 initializer_range=0.02,
                 dropout_rate=0.1,
                 hidden_act="silu",
                 use_cache=True,
                 **kwargs):
        self.input_token_len = input_token_len
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.dropout_rate = dropout_rate
        self.hidden_act = hidden_act
        self.use_cache = use_cache
        super().__init__(**kwargs)

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class SundialRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return (self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype))

class SundialAttention(nn.Module):
    def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.attention_dropout = config.dropout_rate
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.rotary_emb = SundialRotaryEmbedding(self.head_dim, max_position_embeddings=config.max_position_embeddings)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None,
                output_attentions: bool = False, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()
        query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

        attn_output = F.scaled_dot_product_attention(
            query_states, key_states, value_states, attn_mask=attention_mask,
            dropout_p=(self.attention_dropout if self.training else 0.0)
        )
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        
        return attn_output, None, past_key_value

class SundialMLP(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, hidden_state):
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

class SundialDecoderLayer(nn.Module):
    def __init__(self, config: SundialConfig, layer_idx: int):
        super().__init__()
        self.self_attn = SundialAttention(config, layer_idx)
        self.ffn_layer = SundialMLP(config.hidden_size, config.intermediate_size, config.hidden_act)
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.norm2 = nn.LayerNorm(config.hidden_size)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
                **kwargs) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value)
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.ffn_layer(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states, self_attn_weights, present_key_value

class SundialPreTrainedModel(PreTrainedModel):
    config_class = SundialConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["SundialDecoderLayer"]
    _supports_cache_class = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

class SundialBackbone(SundialPreTrainedModel):
    def __init__(self, config: SundialConfig):
        super().__init__(config)
        self.layers = nn.ModuleList([SundialDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
        self.norm = nn.LayerNorm(config.hidden_size)
        self.gradient_checkpointing = False

    def forward(self, inputs_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
                use_cache: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutputWithPast]:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length, _ = inputs_embeds.shape
        past_key_values_length = 0
        if use_cache and not isinstance(past_key_values, Cache):
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
        if past_key_values is not None:
             past_key_values_length = past_key_values.get_usable_length(seq_length)

        if position_ids is None:
            position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=inputs_embeds.device)
            position_ids = position_ids.unsqueeze(0)

        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)
        hidden_states = inputs_embeds
        next_decoder_cache = None

        for decoder_layer in self.layers:
            layer_outputs = decoder_layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values)
            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache = layer_outputs[2]

        hidden_states = self.norm(hidden_states)
        
        if not return_dict:
            return (hidden_states, next_decoder_cache)
        return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=None, attentions=None)


if __name__ == '__main__':
    # --- Backbone层 使用示例 ---
    print("--- 2. Backbone层 Demo ---")

    # 1. 准备与Embedding层一致的配置和数据
    config = SundialConfig(
        input_token_len=32,
        hidden_size=128,
        intermediate_size=256,
        num_hidden_layers=4,
        num_attention_heads=4
    )
    # 假设这是从Embedding层得到的输出
    batch_size = 2
    num_patches = 16 
    hidden_size = 128
    # (batch_size, num_patches, hidden_size)
    embeddings_from_previous_step = torch.randn(batch_size, num_patches, hidden_size)
    print(f"来自Embedding层的输入形状: {embeddings_from_previous_step.shape}")

    # 2. 实例化Backbone
    backbone = SundialBackbone(config)
    backbone.eval()

    # 3. 前向传播
    with torch.no_grad():
        # Backbone直接接收嵌入向量
        outputs = backbone(inputs_embeds=embeddings_from_previous_step)
        
    # 4. 查看输出
    # Backbone的输出是处理后的隐藏状态
    last_hidden_states = outputs.last_hidden_state
    print(f"Backbone输出的隐藏状态形状: {last_hidden_states.shape}")
    print("Backbone层 Demo 运行完毕!\n")

--- 2. Backbone层 Demo ---
来自Embedding层的输入形状: torch.Size([2, 16, 128])
Backbone输出的隐藏状态形状: torch.Size([2, 16, 128])
Backbone层 Demo 运行完毕!



### 3

In [22]:
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
from torch.distributions.normal import Normal

# 同样，为了演示创建一个最小化的Config
class SundialConfig(PretrainedConfig):
    model_type = "sundial"
    def __init__(self, hidden_size=128, output_token_len=32, 
                 flow_loss_depth=2, diffusion_batch_mul=1, **kwargs):
        self.hidden_size = hidden_size
        self.output_token_lens = [output_token_len]
        self.flow_loss_depth = flow_loss_depth
        self.diffusion_batch_mul = diffusion_batch_mul
        super().__init__(**kwargs)

# FlowLoss是任务头的核心计算组件
class FlowLoss(nn.Module):
    def __init__(self, n_dim, n_hidden, n_layer=2):
        super().__init__()
        self.n_dim = n_dim
        layers = []
        for _ in range(n_layer):
            layers.append(nn.Linear(n_hidden, n_hidden))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(n_hidden, 2 * n_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, y, c, mask, mask_y):
        # 预测高斯分布的均值和标准差
        params = self.layers(c)
        mu_base, log_std_base = params.chunk(2, -1)
        # 应用掩码
        mu_base = mu_base * mask_y
        log_std_base = log_std_base * mask_y
        dist = Normal(loc=mu_base, scale=torch.exp(log_std_base))
        # 计算负对数似然损失
        loss = -dist.log_prob(y)
        loss = (loss * mask_y).mean(dim=-1)
        loss = (loss * mask).mean()
        return loss

    def sample(self, c, n_sample):
        # 从预测的分布中采样
        params = self.layers(c)
        mu_base, log_std_base = params.chunk(2, -1)
        dist = Normal(loc=mu_base, scale=torch.exp(log_std_base))
        # .sample((n_sample,)) 会增加一个新的维度在最前面
        return dist.sample((n_sample,)).transpose(0, 1)

class SundialPredictionHead(nn.Module):
    """
    用于时间序列预测的任务头。
    """
    def __init__(self, config: SundialConfig):
        super().__init__()
        self.config = config
        self.flow_loss = FlowLoss(
            n_dim=config.output_token_lens[-1],
            n_hidden=config.hidden_size,
            n_layer=config.flow_loss_depth,
        )

    def forward(self, hidden_states: torch.Tensor, labels: Optional[torch.Tensor] = None, 
                loss_masks: Optional[torch.Tensor] = None, mask_y: Optional[torch.Tensor] = None,
                num_samples: int = 1):
        """
        根据是否有labels，执行训练或推理。
        :param hidden_states: 来自Backbone的输出, 形状 [batch_size, seq_len, hidden_size]
        """
        # --- 训练模式 ---
        if labels is not None:
            bsz, L, _ = hidden_states.shape
            
            # Reshape Tensors for Loss Calculation
            # 维度需要匹配flow_loss的输入
            hidden_states = hidden_states.reshape(bsz * L, -1)
            labels = labels.reshape(bsz * L, -1)
            loss_masks = loss_masks.reshape(bsz * L)
            mask_y = mask_y.repeat(L, 1)
            
            # 如果配置了diffusion_batch_mul，则复制数据以增加训练稳定性
            if self.config.diffusion_batch_mul > 1:
                hidden_states = hidden_states.repeat(self.config.diffusion_batch_mul, 1)
                labels = labels.repeat(self.config.diffusion_batch_mul, 1)
                loss_masks = loss_masks.repeat(self.config.diffusion_batch_mul)
                mask_y = mask_y.repeat(self.config.diffusion_batch_mul, 1)

            loss = self.flow_loss(labels, hidden_states, loss_masks, mask_y)
            return {"loss": loss}
            
        # --- 推理模式 ---
        else:
            # 推理时通常只使用最后一个时间步的隐藏状态进行预测
            last_hidden_state = hidden_states[:, -1, :]
            predictions = self.flow_loss.sample(last_hidden_state, num_samples)
            return {"predictions": predictions}

# --- Task Head层 使用示例 ---
if __name__ == '__main__':
    print("--- 3. Task Head Demo ---")

    # 1. 准备配置和来自Backbone的模拟输出
    config = SundialConfig(
        hidden_size=128,
        output_token_len=32, # 预测patch的长度
        flow_loss_depth=2
    )
    BATCH_SIZE = 2
    NUM_PATCHES = 16 
    HIDDEN_SIZE = 128
    
    # 模拟来自Backbone层的输出
    backbone_output = torch.randn(BATCH_SIZE, NUM_PATCHES, HIDDEN_SIZE)
    print(f"来自Backbone层的输入形状: {backbone_output.shape}")

    # 2. 实例化Task Head
    prediction_head = SundialPredictionHead(config)

    # --- 演示推理模式 ---
    print("\n--- 3a. 推理模式 (Sampling) ---")
    prediction_head.eval()
    with torch.no_grad():
        NUM_SAMPLES = 3
        outputs = prediction_head(hidden_states=backbone_output, num_samples=NUM_SAMPLES)
        predictions = outputs["predictions"]

    # 预测输出形状应为: (batch_size, num_samples, prediction_length)
    print(f"生成的预测形状: {predictions.shape}")

    # --- 演示训练模式 ---
    print("\n--- 3b. 训练模式 (Loss Calculation) ---")
    prediction_head.train()
    
    # 准备模拟的标签和掩码
    # 标签形状应与hidden_states的 (bsz*seq_len, patch_len) 对应
    labels_unfolded = torch.randn(BATCH_SIZE, NUM_PATCHES, config.output_token_lens[-1])
    loss_masks_unfolded = torch.ones(BATCH_SIZE, NUM_PATCHES)
    mask_y = torch.ones(BATCH_SIZE, config.output_token_lens[-1])
    
    # 计算损失
    loss_output = prediction_head(
        hidden_states=backbone_output, 
        labels=labels_unfolded,
        loss_masks=loss_masks_unfolded,
        mask_y=mask_y
    )
    print(f"计算得到的损失值: {loss_output['loss'].item():.4f}")
    
    print("\nTask Head Demo 运行完毕!")

--- 3. Task Head Demo ---
来自Backbone层的输入形状: torch.Size([2, 16, 128])

--- 3a. 推理模式 (Sampling) ---
生成的预测形状: torch.Size([2, 3, 32])

--- 3b. 训练模式 (Loss Calculation) ---
计算得到的损失值: 1.4341

Task Head Demo 运行完毕!


# flow loss 对比

## mean_flow in DIT

## flow loss in DIT

## flow loss in sundial