In [None]:
# Pay attention, in order to run this you would need to install all dependencies in requirements.txt even commented ones (check files) or use the docker image in the readme

%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
from blockgen.configs import VoxelConfig, DiffusionConfig
from blockgen.inference import DiffusionInference3D
from scripts.generate import load_model_for_inference

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Load configs to be used, see results.ipynb for loading combined configs and using other models for inference

In [2]:
shape_voxel_config = VoxelConfig(
    mode='two_stage',
    stage='shape',
    default_color=[0.5, 0.5, 0.5],
    alpha_weight=1.0,
    rgb_weight=1.0
)

color_voxel_config = VoxelConfig(
    mode='two_stage',
    stage='color',
    default_color=[0.5, 0.5, 0.5],
    alpha_weight=1.0,
    rgb_weight=1.0
)

diffusion_config = DiffusionConfig(
    num_timesteps=1000,
    use_ema=True,
    ema_decay=0.9999,
    ema_update_after_step=0,
    ema_device=device,
    use_ddim=False,
    seed=42
)

# Load models can we loaded with DDIM need to choose steps check results.ipynb ddim (need also to put use_ddim=True) section in ablation study. Can also load in EMA mode

In [None]:
shape_model = load_model_for_inference(
    model_path="runs/experiment_two_stage/shape/best_model/model",
    voxel_config=shape_voxel_config,
    diffusion_config=diffusion_config,
    device=device,
    ema=False
)

color_model = load_model_for_inference(
    model_path="runs/experiment_two_stage/color/best_model/model",
    voxel_config=color_voxel_config,
    diffusion_config=diffusion_config,
    device=device,
    ema=False
)

# Load inference object to call all visualization/inference pipelines

In [None]:
inferencer = DiffusionInference3D(
    model=shape_model,
    noise_scheduler=shape_model.noise_scheduler,
    color_model=color_model,
    color_noise_scheduler=color_model.noise_scheduler,
    device=device
)

# Check report/readme/code for all choices in inference

In [None]:
samples = inferencer.sample_two_stage(
    prompt="A tree",
    num_samples=1,
    image_size=(32, 32, 32),
    guidance_scale=20.0,
    color_guidance_scale=20.0,
    show_after_shape=True,
    use_rotations=False,
    use_mean_init=False
)

In [None]:
inferencer.visualize_samples(samples, prompt="A tree", threshold=0.5, save_path="output/tree.png")