In [4]:
# ✅ 基础导入
import torch
from torchvision.transforms import Resize
from torchvision.transforms.functional import InterpolationMode

import sys
sys.path.append(".")  # 添加当前路径，或指定完整路径

# ✅ 正确导入 get_model 和 DiffusionEngine
from generate_t2i_sr import get_args
from sat.model.base_model import get_model
from dit.model import DiffusionEngine

# ✅ 获取 args 并设置必要参数
args = get_args([])  # 传入空参数，模拟不使用 CLI
# args = get_args([])

# ✅ 手动补齐 DiffusionEngine 需要的字段
args.image_size = 256
args.input_size = 256
args.patch_size = 4
args.image_block_size = 128
args.in_channels = 6
args.out_channels = 3
args.scale_factor = 4
args.clip_img_dim = 1024
args.time_embed_dim = 1280
args.input_time = True
args.is_decoder = False
args.no_crossmask = False
args.use_fp16 = False

# ✅ 模型路径 & 模式
args.network = 'ckpt/mp_rank_00_model_states.pt'
args.inference_type = 'full'
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ✅ 构建模型 & 加载权重
net = get_model(args, DiffusionEngine).to(args.device).eval()
ckpt = torch.load(args.network, map_location=args.device)
net.load_state_dict(ckpt, strict=False)


[2025-03-29 17:09:18,099] [INFO] [RANK 0] using world size: 1 and model-parallel size: 1 
[2025-03-29 17:09:18,099] [INFO] [RANK 0] > padded vocab (size: 100) with 28 dummy tokens (new size: 128)


[2025-03-29 17:09:18,101] [INFO] [checkpointing.py:229:model_parallel_cuda_manual_seed] > initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234


[2025-03-29 17:09:18,102] [INFO] [RANK 0] building DiffusionEngine model ...


AttributeError: 'Namespace' object has no attribute 'stop_grad_patch_embed'

In [None]:
# 🧪 构造模拟输入：低分辨率图像 + 上采样图 + 噪声图
B, C, H, W = 1, 3, 64, 64
scale = 4
HR_H, HR_W = H * scale, W * scale
device = args.device

lr_img = torch.randn(B, C, H, W).to(torch.float32).to(device)
upsampled_img = Resize((HR_H, HR_W), interpolation=InterpolationMode.BICUBIC)(lr_img)
noisy_img = torch.randn(B, C, HR_H, HR_W).to(torch.float32).to(device)
concat_img = torch.cat([upsampled_img, noisy_img], dim=1)  # [B, 6, HR_H, HR_W]
sigmas = torch.ones(B).to(device) * 1.0


In [None]:
# 🔎 patchify 前向测试
with torch.no_grad():
    patches = net.patchify(concat_img)
    print(f"[patchify output shape] = {patches.shape}")


In [None]:
# 🧠 注册所有 attention 层 hook，打印输出 shape
def make_hook(name):
    def hook(module, input, output):
        print(f"[Hook] {name} output shape: {output.shape}")
    return hook

for i, layer in enumerate(net.transformer.layers):
    if hasattr(layer, 'attn'):
        layer.attn.register_forward_hook(make_hook(f"Layer-{i}-Attn"))


In [None]:
# 🚀 执行一次 model_forward（前向传播）
with torch.no_grad():
    out = net.model_forward(
        images=concat_img,
        lr_imgs=lr_img,
        sigmas=sigmas,
        input_ids=None, position_ids=None, attention_mask=None
    )
    print(f"[Final output shape] = {out.shape}")


In [None]:
# 🏗️ 打印 Transformer 主体结构（层级）
print(net.transformer)


In [None]:
# 🔍 选择第 0 层 Attention 结构展开看
print(net.transformer.layers[0].attn)
