In [1]:
import torch
import platform

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from shap_e.util.image_util import load_image

In [2]:
def check_device():
    if platform.system() == 'Darwin':
        print(f"Torch MPS Available: {torch.backends.mps.is_available()}")
        print(f"Torch MPS Built: {torch.backends.mps.is_built()}")
    else:
        print(torch.cuda.is_available())
        print(f"CUDA Devides: {torch.cuda.device_count()}")
        print(f"Current CUDA Index: {torch.cuda.current_device()}")

    device = None

    if platform.system() == 'Darwin':
        device = torch.device("mps")
    else:
        device = torch.device("cuda" if (torch.cude.is_available()) else 'cpu')

    return device

In [3]:
device = check_device()

NameError: name 'platform' is not defined

In [None]:
xm = load_model('transmitter', device=device)
model = load_model('image300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [None]:
batch_size = 4
guidance_scale = 3.0

# To get the best result, you should remove the background and show only the object of interest to the model.
image = load_image("example_data/corgi.png")

latents = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(images=[image] * batch_size),
    progress=True,
    clip_denoised=True,
    use_fp16=True,
    use_karras=True,
    karras_steps=64,
    sigma_min=1e-3,
    sigma_max=160,
    s_churn=0,
)

In [None]:
render_mode = 'nerf' # you can change this to 'stf' for mesh rendering
size = 64 # this is the size of the renders; higher values take longer to render.

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(latents):
    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    display(gif_widget(images))