In [1]:
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from point_e.models.fusion import TextImageFusionModule
from point_e.models.multimodal import SimpleMultimodalTransformer
from point_e.models.configs import MODEL_CONFIGS
from point_e.models.download import load_checkpoint

from point_e.diffusion.sampler import PointCloudSampler
from point_e.diffusion.configs import DIFFUSION_CONFIGS
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.models.download import load_checkpoint
from point_e.util.plotting import plot_point_cloud

from torchvision import transforms


In [2]:
# 测试融合模块
def test_fusion_module():
    print("Testing TextImageFusionModule...")
    fusion = TextImageFusionModule(text_dim=768, image_dim=768,fusion_dim=512)
    
    # 创建随机输入
    batch_size = 2
    text_emb = torch.randn(batch_size, 768)
    img_tokens = torch.randn(batch_size, 196, 768)  # 196 = 14x14 grid
    
    # 前向传递
    out = fusion(text_emb, img_tokens)
    
    
    print(f"Input shapes: text_emb {text_emb.shape}, img_tokens {img_tokens.shape}")
    print(f"Output shape: {out.shape}")
    print("Fusion module test passed!")
if __name__ == "__main__":
    test_fusion_module()

Testing TextImageFusionModule...
Input shapes: text_emb torch.Size([2, 768]), img_tokens torch.Size([2, 196, 768])
Output shape: torch.Size([2, 512])
Fusion module test passed!


In [3]:
import torch
import numpy as np
from PIL import Image
from point_e.models.multimodal import SimpleMultimodalTransformer

def test_simple_multimodal_transformer():
    print("Testing SimpleMultimodalTransformer…")
    device = torch.device('cpu')
    
    # 1) 实例化
    model = SimpleMultimodalTransformer(device=device, dtype=torch.float32)
    model.eval()
    print("✔ Model initialized.")

    # 2) 拿到期望的输入维度
    B     = 2
    C_in  = model.input_channels  # 默认 6
    T_ctx = model.n_ctx          # 1024（或加上 grid tokens，取决于你初始化时的写法）
    
    # 3) 随机造 x, t
    x = torch.randn(B, C_in, T_ctx)
    t = torch.randint(0, 1000, (B,), dtype=torch.long)
    
    # 4) 造一批 PIL Image
    dummy_imgs = [
        Image.fromarray((np.random.rand(224,224,3)*255).astype(np.uint8))
        for _ in range(B)
    ]
    dummy_texts = ["a red chair", "a blue table"]

    # 5) forward
    print("Running forward pass…")
    out = model(x, t, images=dummy_imgs, texts=dummy_texts)
    print(f"→ out.shape = {out.shape}")

    # 6) 简单断言
    assert isinstance(out, torch.Tensor), "输出必须是 Tensor"
    assert out.shape[0] == B, f"第一维（batch）应为 {B}"
    print("✔ Forward pass successful!")

if __name__ == "__main__":
    test_simple_multimodal_transformer()


Testing SimpleMultimodalTransformer…
✔ Model initialized.
Running forward pass…
→ out.shape = torch.Size([2, 12, 1024])
✔ Forward pass successful!


In [4]:
import torch
import numpy as np
from PIL import Image
import torch.optim as optim
from point_e.models.multimodal import SimpleMultimodalTransformer
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config

def test_training_loop():
    print("Testing training loop…")
    device = torch.device('cpu')

    # 1) 实例化模型 & diffusion
    model = SimpleMultimodalTransformer(device=device, dtype=torch.float32)
    diffusion = diffusion_from_config(DIFFUSION_CONFIGS['base40M'])
    model.train()

    # 2) 冻结 backbone，解冻 fusion & CLIP embed
    for name, param in model.named_parameters():
        if name.startswith("fusion") or name.startswith("clip.model"):
            param.requires_grad = True
        else:
            param.requires_grad = False

    optimizer = optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-4,
    )

    # 3) 构造 dummy batch
    B = 2
    # point_clouds: [B, C_in, T_ctx]
    pc = torch.randn(B, model.input_channels, model.n_ctx, device=device)

    # images: list of PIL Image
    imgs = [
        Image.fromarray((np.random.rand(224,224,3)*255).astype(np.uint8))
        for _ in range(B)
    ]
    texts = ["a red chair", "a blue table"]

    # 4) 训练一步
    t = torch.randint(0, diffusion.num_timesteps, (B,), device=device)
    model_kwargs = {"images": imgs, "texts": texts}

    try:
        losses = diffusion.training_losses(
            model=model,
            x_start=pc,
            t=t,
            model_kwargs=model_kwargs
        )
        loss = losses["loss"].mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Training loss: {loss.item():.4f}")
        print("Training loop test passed!")
    except Exception as e:
        print(f"Error in training loop test: {e}")
        import traceback; traceback.print_exc()

if __name__ == "__main__":
    test_training_loop()


Testing training loop…
Training loss: 1.0100
Training loop test passed!


In [10]:
# import torch
# import numpy as np
# from PIL import Image

# from point_e.models.multimodal import SimpleMultimodalTransformer
# from point_e.models.configs import MODEL_CONFIGS, model_from_config
# from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
# from point_e.diffusion.sampler import PointCloudSampler

# def test_generation_two_stage():
#     print("Testing two-stage point cloud generation…")
#     device = torch.device("cpu")
#     B = 1

#     # ——— 1) Base Stage: 1024 pts ———
#     base_model = SimpleMultimodalTransformer(device=device, dtype=torch.float32).eval()
#     base_diff = diffusion_from_config(DIFFUSION_CONFIGS["base40M"])
#     base_sampler = PointCloudSampler(
#         device=device,
#         models=[base_model],
#         diffusions=[base_diff],
#         num_points=[1024],
#         aux_channels=['R', 'G', 'B'],
#         guidance_scale=[3.0],
#         use_karras=[True],
#         karras_steps=[2],
#         sigma_min=[None],  # Fix: Ensure length matches n=1
#         sigma_max=[None],
#         s_churn=[0.0],
#     )

#     # Dummy image + prompt
#     img = Image.fromarray((np.random.rand(224, 224, 3) * 255).astype(np.uint8))
#     prompt = "a red chair"

#     with torch.no_grad():
#         low_res = base_sampler.sample_batch(
#             batch_size=B,
#             model_kwargs={'images': [img], 'texts': [prompt]},
#         )
#     print("✔ Base stage output:", low_res.shape)
#     assert low_res.shape[2] == 1024

#     # ——— 2) Upsample Stage: 1024→4096 pts ———
#     up_model = model_from_config(MODEL_CONFIGS["upsample"], device=device).eval()
#     up_diff = diffusion_from_config(DIFFUSION_CONFIGS["upsample"])
#     up_sampler = PointCloudSampler(
#         device=device,
#         models=[up_model],
#         diffusions=[up_diff],
#         num_points=[4096 - 1024],
#         aux_channels=['R', 'G', 'B'],
#         guidance_scale=[3.0],
#         use_karras=[True],
#         karras_steps=[2],
#         sigma_min=[None],  # Fix: Ensure length matches n=1
#         sigma_max=[None],
#         s_churn=[0.0],
#     )

#     with torch.no_grad():
#         high_res = up_sampler.sample_batch(
#             batch_size=B,
#             model_kwargs={'low_res': low_res},
#         )
#     print("✔ Upsample stage output:", high_res.shape)
#     assert high_res.shape[2] == 4096 - 1024

#     # ——— 3) 拼接最终点云 ———
#     final_pc = torch.cat([low_res, high_res], dim=2)
#     assert final_pc.shape[2] == 4096
#     print("✔ Two-stage generation successful, final shape:", final_pc.shape)

# if __name__ == "__main__":
#     test_generation_two_stage()