In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# Full imports
import torch
from blockgen.configs import VoxelConfig, DiffusionConfig
from blockgen.utils import create_model_and_trainer, create_dataloaders
from blockgen.models import DiffusionModel3D
from blockgen.inference import DiffusionInference3D
from diffusers import UNet3DConditionModel

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cpu


In [None]:
# Configs
voxel_config = VoxelConfig(
    use_rgb=False,
    default_color=[0.5, 0.5, 0.5],
    alpha_weight=1.0,
    rgb_weight=1.0
)

diffusion_config = DiffusionConfig(
    num_timesteps=1000,
    use_ema=True,
    ema_decay=0.9999,
    ema_update_after_step=0,
    ema_device=device,
    seed=42
)

In [5]:
# Create model and trainer
trainer, model = create_model_and_trainer(
    voxel_config=voxel_config,
    diffusion_config=diffusion_config,
    resolution=32,
    device=device
)

Total parameters: 120096449


In [None]:
train_loader, test_loader = create_dataloaders(
    voxel_dir="objaverse_data_voxelized",
    annotation_file="objaverse_data/annotations.json",
    config=diffusion_config,
    config_voxel=voxel_config,
    batch_size=2
)

In [None]:
metrics = trainer.train(
    train_loader,
    test_loader,
    total_steps=95_000,
    save_every=5_000,
    eval_every=10_000,
    save_dir='runs/experiment_4'
)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(metrics['training_losses'])
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(metrics['test_steps'], metrics['test_losses'])
plt.title('Test Loss')
plt.xlabel('Step')
plt.ylabel('Loss')

plt.tight_layout()
plt.show()

In [None]:
# Inference
inferencer = DiffusionInference3D(
    model=model,
    noise_scheduler=model.noise_scheduler,
    config=voxel_config,
    device=device
)

In [None]:
samples = inferencer.sample(
    prompt="Stone",
    num_samples=2,
    image_size=(32, 32, 32),
    show_intermediate=False,
    guidance_scale=7
)

In [None]:
inferencer.visualize_samples(samples)