In [1]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from shap_e.util.image_util import load_image

from PIL import Image
import torchvision.transforms as transforms
import os
import numpy as np

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
xm = load_model('transmitter', device=device)
model = load_model('image300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [8]:
# default values:4
batch_size = 1
# default values:3.0
guidance_scale = 3.0

# To get the best result, you should remove the background and show only the object of interest to the model.
# image = load_image("example_data/dra_1.png")
image_name = "sh9"

# Define your maximum allowed size
max_width = 800
max_height = 600

image = load_image(f"img/{image_name}.png")
# Check if the image is larger than the maximum dimensions
if image.width > max_width or image.height > max_height:
    # Resize the image, maintaining the aspect ratio
    image.thumbnail((max_width, max_height))

latents = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(images=[image] * batch_size),
    progress=True,
    clip_denoised=True,
    use_fp16=True,
    use_karras=True,
    karras_steps=64,
    sigma_min=1e-3,
    sigma_max=160,
    s_churn=0,
)

  0%|          | 0/64 [00:00<?, ?it/s]

In [9]:
render_mode = 'nerf' # you can change this to 'stf' / 'nerf' (Neural Radiance Fields)
size = 128 # this is the size of the renders; higher values take longer to render.


to_pil = transforms.ToPILImage()
save_dir = 'output_gif'
os.makedirs(save_dir, exist_ok=True)

print('creating cameras...')
cameras = create_pan_cameras(size, device)


for i, latent in enumerate(latents):
    print(f'creating images...{len(latents)-i}left')

    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    print(f'images done, saving...{len(latents)-i}left')
    print('saving...')
    # save
    pil_images = []
    for j, img in enumerate(images):
        if isinstance(img, Image.Image):
            # If img is already a PIL Image, no need to convert
            pil_img = img
        elif hasattr(img, 'cpu') and callable(getattr(img, 'cpu')):
            # If it has a 'cpu' attribute, it might be a PyTorch tensor
            pil_img = Image.fromarray(img.cpu().detach().numpy().astype('uint8'))
        elif isinstance(img, np.ndarray):
            # If it's a NumPy array, convert directly to PIL
            pil_img = Image.fromarray(img)
        else:
            raise TypeError(f"Unsupported image type: {type(img)}")

        # Save the image
        pil_images.append(pil_img)
        print(f'{j} saved')
    
    gif_path = os.path.join(save_dir, f"{image_name}_{i}.gif")
    pil_images[0].save(
        gif_path,
        save_all=True,
        append_images=pil_images[1:],
        duration=100,  # Duration between frames in milliseconds (adjust as needed)
        loop=0  # Number of times the GIF should loop (0 means infinite)
    )

    display(gif_widget(images))

creating cameras...
creating images...1left
images done, saving...1left
saving...
0 saved
1 saved
2 saved
3 saved
4 saved
5 saved
6 saved
7 saved
8 saved
9 saved
10 saved
11 saved
12 saved
13 saved
14 saved
15 saved
16 saved
17 saved
18 saved
19 saved


HTML(value='<img src="data:image/gif;base64,R0lGODlhgACAAIcAAMHDwrW2t6mqqqCgqpSXpoySi4WGmnuEfX5+knZ8i3R3inBzhm…

In [6]:
# render_mode = 'nerf' # you can change this to 'stf' for mesh rendering
# size = 64 # this is the size of the renders; higher values take longer to render.

# cameras = create_pan_cameras(size, device)
# for i, latent in enumerate(latents):
#     images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
#     display(gif_widget(images))


# Example of saving the latents as meshes.
##############################################

# from shap_e.util.notebooks import decode_latent_mesh

# file_name = "mesh_dra_1"

# for i, latent in enumerate(latents):
#     t = decode_latent_mesh(xm, latent).tri_mesh()
#     with open(f'{file_name}_{i}.ply', 'wb') as f:
#         t.write_ply(f)
#     with open(f'{file_name}_{i}.obj', 'w') as f:
#         t.write_obj(f)

#################################################

In [7]:
###### loop folder########
# batch_size = 1
# guidance_scale = 3.0

# max_width = 800
# max_height = 600

# source_dir = 'img'
# save_dir = 'output_gif'
# os.makedirs(save_dir, exist_ok=True)

# # Loop through all files in the source directory
# for image_name in os.listdir(source_dir):
#     if image_name.endswith(('.png', '.jpg', '.jpeg')):  # Check if the file is an image
#         image_path = os.path.join(source_dir, image_name)
#         print(f"Processing {image_path}...")
        
#         image = load_image(image_path)
#         if image.width > max_width or image.height > max_height:
#             image.thumbnail((max_width, max_height))
        
#         latents = sample_latents(
#             batch_size=batch_size,
#             model=model,
#             diffusion=diffusion,
#             guidance_scale=guidance_scale,
#             model_kwargs=dict(images=[image] * batch_size),
#             progress=True,
#             clip_denoised=True,
#             use_fp16=True,
#             use_karras=True,
#             karras_steps=64,
#             sigma_min=1e-3,
#             sigma_max=160,
#             s_churn=0,
#         )

#         render_mode = 'nerf'
#         size = 128
#         cameras = create_pan_cameras(size, device)

#         pil_images = []
#         for i, latent in enumerate(latents):
#             images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
#             for img in images:
#                 if isinstance(img, Image.Image):
#                     pil_img = img
#                 elif hasattr(img, 'cpu') and callable(getattr(img, 'cpu')):
#                     pil_img = Image.fromarray(img.cpu().detach().numpy().astype('uint8'))
#                 elif isinstance(img, np.ndarray):
#                     pil_img = Image.fromarray(img)
#                 else:
#                     continue  # Skip if the image type is not supported
#                 pil_images.append(pil_img)

#         gif_filename = os.path.splitext(image_name)[0] + ".gif"
#         gif_path = os.path.join(save_dir, gif_filename)
#         pil_images[0].save(
#             gif_path,
#             save_all=True,
#             append_images=pil_images[1:],
#             duration=100,
#             loop=0
#         )
#         print(f"Generated {gif_path}")