In [1]:
from pathlib import Path
import open3d as o3d
import os

from pytorch_lightning import seed_everything

from src.dataset_utils import (
    get_singleview_data,
    get_multiview_data,
    get_voxel_data_json,
    get_image_transform_latent_model,
    get_pointcloud_data,
    get_mv_dm_data,
    get_sv_dm_data,
    get_sketch_data
)
from src.model_utils import Model
from src.mvdream_utils import load_mvdream_model
import argparse
from PIL import Image


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:

def simplify_mesh(obj_path, target_num_faces=1000):
    mesh = o3d.io.read_triangle_mesh(obj_path)
    simplified_mesh = mesh.simplify_quadric_decimation(target_num_faces)
    o3d.io.write_triangle_mesh(obj_path, simplified_mesh)


def generate_3d_object(
    model,
    data,
    data_idx,
    scale,
    diffusion_rescale_timestep,
    save_dir="examples",
    output_format="obj",
    target_num_faces=None,
    seed=42,
):
    # Set seed
    seed_everything(seed, workers=True)

    save_dir.mkdir(parents=True, exist_ok=True)
    model.set_inference_fusion_params(scale, diffusion_rescale_timestep)
    output_path = model.test_inference(
        data, data_idx, save_dir=save_dir, output_format=output_format
    )

    if output_format == "obj" and target_num_faces:
        simplify_mesh(output_path, target_num_faces=target_num_faces)


In [3]:
model_name = 'ADSKAILab/WaLa-SV-1B'
images = ['examples/single_view/table.png']
output_dir = 'examples'
output_format = 'obj'
target_num_faces = None
scale = 1.8
seed = 42
diffusion_rescale_timestep = 5



In [4]:
print(f"Loading model")

model = Model.from_pretrained(pretrained_model_name_or_path=model_name)
image_transform = get_image_transform_latent_model()

for image_path in images:
    print(f"Processing image: {image_path}")
    data = get_singleview_data(
        image_file=Path(image_path),
        image_transform=image_transform,
        device=model.device,
        image_over_white=False,
    )
    data_idx = 0
    save_dir = Path(output_dir) / Path(image_path).stem

    model.set_inference_fusion_params(
        scale, diffusion_rescale_timestep
    )

    generate_3d_object(
        model,
        data,
        data_idx,
        scale,
        diffusion_rescale_timestep,
        save_dir,
        output_format,
        target_num_faces,
        seed,
    )


    

Loading model


/opt/miniconda/envs/wala/lib/python3.10/site-packages/pytorch_lightning/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.3.3, which is newer than your current Lightning version: v2.1.0


'DotDict' object has no attribute 'dataset_path'
'DotDict' object has no attribute 'low_avg'
'DotDict' object has no attribute 'low_avg'
Low avg used : None high value: 63


Using cache found in /home/ray/.cache/torch/hub/facebookresearch_dinov2_main


cond_emb_dim: 1024
Input resolution: 224
Vocab size: N/A
'DotDict' object has no attribute 'use_multiple_views_inferences'
'DotDict' object has no attribute 'use_multiple_views_inferences'
'DotDict' object has no attribute 'use_multiple_views_inferences'
'DotDict' object has no attribute 'use_multiple_views_inferences'


Seed set to 42


Processing image: examples/single_view/table.png
'DotDict' object has no attribute 'use_multiple_views_inferences'
'DotDict' object has no attribute 'use_multiple_views_inferences'


  0%|          | 0/5 [00:00<?, ?it/s]

In [None]:

for image_path in images:
    print(f"Processing image: {image_path}")
    data = get_singleview_data(
        image_file=Path(image_path),
        image_transform=image_transform,
        device=model.device,
        image_over_white=False,
    )
    data_idx = 0
    save_dir = Path(output_dir) / Path(image_path).stem

    model.set_inference_fusion_params(
        scale, diffusion_rescale_timestep
    )

    generate_3d_object(
        model,
        data,
        data_idx,
        scale,
        diffusion_rescale_timestep,
        save_dir,
        output_format,
        target_num_faces,
        seed,
    )

Seed set to 42


Processing image: examples/single_view/table.png
'DotDict' object has no attribute 'use_multiple_views_inferences'
'DotDict' object has no attribute 'use_multiple_views_inferences'


  0%|          | 0/5 [00:00<?, ?it/s]