In [3]:
import torch
import torch.nn as nn
from functools import partial

class MaskedAutoencoderViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
                 num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=8,
                 mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        num_patches = (img_size // patch_size) ** 2
        self.num_patches = num_patches

        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Encoder blocks
        self.encoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim * mlp_ratio))
            for _ in range(depth)
        ])
        self.encoder_norm = norm_layer(embed_dim)

        # Decoder embedding
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        # Decoder blocks
        self.decoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(decoder_embed_dim, decoder_num_heads, int(decoder_embed_dim * mlp_ratio))
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = norm_layer(decoder_embed_dim)

        # Output prediction layer
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size * patch_size * in_chans, bias=True)

        self.norm_pix_loss = norm_pix_loss

    def patchify(self, imgs):
        p = self.patch_embed.kernel_size[0]
        _, c, t, h, w = imgs.shape
        assert h % p == 0 and w % p == 0, "Image dimensions must be divisible by the patch size."
        return imgs.reshape(1, c, t, h // p, p, w // p, p).permute(0, 3, 5, 2, 4, 6, 1).reshape(1, -1, p * p * c)

    def unpatchify(self, patches, t, h, w):
        p = int((patches.shape[-1] // 3) ** 0.5)  # 计算 patch 的空间尺寸
        patches = patches.reshape(1, t, h // p, w // p, p, p, 3)  # 恢复到分 patch 的结构
        patches = patches.permute(0, 6, 1, 3, 4, 2, 5).reshape(1, 3, t, h, w)  # 调整通道顺序并恢复完整图像
        return patches

    def forward(self, imgs):
        # Patchify input
        x = self.patchify(imgs)
        x += self.pos_embed

        # Encoder
        for block in self.encoder_blocks:
            x = block(x)
        x = self.encoder_norm(x)

        # Decoder
        x = self.decoder_embed(x)
        for block in self.decoder_blocks:
            x = block(x)
        x = self.decoder_norm(x)

        # Prediction
        pred = self.decoder_pred(x)

        # Reshape back to image
        t, h, w = imgs.shape[-3:]
        pred = self.unpatchify(pred, t, h, w)
        return pred


if __name__ == "__main__":
    # 模型初始化
    model = MaskedAutoencoderViT(
        img_size=1440,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=8,
        mlp_ratio=4.0,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        norm_pix_loss=False
    )
    model.eval()

    # 测试输入
    input_tensor = torch.rand(1, 3, 2, 720, 1440)  # [N, C, T, H, W]
    print(f"Input shape: {input_tensor.shape}")

    # 模型前向传播
    with torch.no_grad():
        output = model(input_tensor)
    print(f"Output shape: {output.shape}")


Input shape: torch.Size([1, 3, 2, 720, 1440])
Output shape: torch.Size([1, 3, 2, 720, 1440])
