In [None]:
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

from point_e.diffusion.sampler import PointCloudSampler
from point_e.diffusion.configs import DIFFUSION_CONFIGS
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.models.download import load_checkpoint
from point_e.models.multimodal import MultimodalPointDiffusionTransformer
from point_e.util.plotting import plot_point_cloud

from torchvision import transforms

# CLIP preprocessing transform
def get_clip_transform():
    return transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                             (0.26862954, 0.26130258, 0.27577711))
    ])

def main():
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create the base multimodal model
    model_config = MODEL_CONFIGS['base40M'].copy()
    base_model = MultimodalPointDiffusionTransformer(
        device=device,
        **model_config
    )
    base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['base40M'])
    
    # Create the upsampler model (original Point-E)
    upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
    upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
    
    # Load checkpoints
    base_model.load_state_dict(torch.load("multimodal_point_e_final.pt", map_location=device))
    upsampler_model.load_state_dict(load_checkpoint('upsample', device))
    
    # Set models to eval mode
    base_model.eval()
    upsampler_model.eval()
    
    # Create sampler
    sampler = PointCloudSampler(
        device=device,
        models=[base_model, upsampler_model],
        diffusions=[base_diffusion, upsampler_diffusion],
        num_points=[1024, 4096 - 1024],
        aux_channels=['R', 'G', 'B'],
        guidance_scale=[3.0, 3.0],
    )
    
    # Prepare inputs
    transform = get_clip_transform()
    
    # Load and preprocess image
    image_path = "example_image.jpg"
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    # Text prompt
    text_prompt = "a red motorcycle with chrome details"
    
    # Generate point cloud
    print("Generating point cloud...")
    samples = None
    with torch.no_grad():
        for x in tqdm(sampler.sample_batch_progressive(
            batch_size=1, 
            model_kwargs=dict(images=image, texts=[text_prompt])
        )):
            samples = x
    
    # Convert to point cloud
    pc = sampler.output_to_point_clouds(samples)[0]
    
    # Visualize
    fig = plot_point_cloud(pc, grid_size=1)
    fig.savefig("multimodal_output.png")
    
    # Save point cloud
    pc.save("multimodal_output.npz")
    
    print(f"Generated point cloud saved to multimodal_output.npz")
    print(f"Visualization saved to multimodal_output.png")


if __name__ == "__main__":
    main()