In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm
import nopdb

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util import plotting

In [2]:
torch.cuda.empty_cache()

# 1. Build models

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('creating base model...')
base_name = 'base1B' # use base300M or base1B for better results
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

print('creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

print('downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

creating base model...


  warn(


creating upsample model...
downloading base checkpoint...
downloading upsampler checkpoint...


<All keys matched successfully>

In [4]:
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0],
)

In [5]:
# Load an image to condition on.
img = Image.open('example_data/cube_stack.jpg')

def sample_from_model(breakpoint):
    samples = None
    k = 0
    with nopdb.capture_call(base_model.backbone.resblocks[-1].attn.attention.forward) as attn_call:
        for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
            if x.shape[2] == 1024:
                samples = x
                if k == breakpoint:
                    break
                k += 1
            else:
                break
    
    return samples

In [6]:
frames = []
breakpoints = [30,-1]
time = [0, 1]

for k in breakpoints:
    samples = sample_from_model(k)

    pc = sampler.output_to_point_clouds(samples)[0]
    ax = plotting.plot_point_cloud(pc, grid_size=1, fixed_bounds=None, angle=0.5)
    plt.title('iterations = ' + str(k))
    plt.savefig('Figures/viz/fig' + str(k) + '.png', bbox_inches='tight', dpi=300)
    plt.close()

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [7]:
# img = Image.open('example_data/cube_stack.jpg')

# samples = None
# for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
#     samples = x

In [8]:
# import importlib
# importlib.reload(plotting)

# pc = sampler.output_to_point_clouds(samples)[0]
# ax = plotting.plot_point_cloud(pc, grid_size=1, fixed_bounds=None, angle=1.5)
# plt.savefig('Figures/3dplots/final.png', bbox_inches='tight', dpi=300)
# plt.close()

: 