<a href="https://colab.research.google.com/github/MarkoMile/shap-e-minecraft/blob/main/working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !git clone https://github.com/MarkoMile/shap-e-minecraft

In [None]:
# %cd shap-e-minecraft
# !pip install -e .

In [None]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [None]:
# Example of saving the latents as meshes.
from shap_e.util.notebooks import decode_latent_mesh
from shap_e.rendering.mc import marching_cubes
from shap_e.models.query import Query
from shap_e.models.renderer import Renderer, get_camera_from_batch
from shap_e.models.volume import BoundingBoxVolume, Volume
from shap_e.util.collections import AttrDict
from functools import partial
from shap_e.models.nn.meta import subdict

def volume_query_points(
    volume: Volume,
    grid_size: int,
):
    assert isinstance(volume, BoundingBoxVolume)
    indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
    zs = indices % grid_size
    ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
    xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
    combined = torch.stack([xs, ys, zs], dim=1)
    return (combined.float() / (grid_size - 1)) * (
        volume.bbox_max - volume.bbox_min
    ) + volume.bbox_min

In [None]:
from typing import Union
from shap_e.models.transmitter.base import Transmitter, VectorDecoder

@torch.no_grad()
def voxelGrid(
    xm: Union[Transmitter, VectorDecoder],
    latent: torch.Tensor,
):
  batch = AttrDict(cameras=create_pan_cameras(2, latent.device))
  query_batch_size = batch.get("query_batch_size", batch.get("ray_batch_size", 4096))
  parameters=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
            latent[None]
        )
  options=AttrDict(rendering_mode="stf", render_with_direction=False)
  volume = xm.renderer.volume;
  grid_size = xm.renderer.grid_size

  sdf_fn = tf_fn = nerstf_fn = None
  if xm.renderer.nerstf is not None:
    nerstf_fn = partial(
        xm.renderer.nerstf.forward_batched,
        params=subdict(parameters, "nerstf"),
        options=options,
    )
  else:
      sdf_fn = partial(
          xm.renderer.sdf.forward_batched,
          params=subdict(parameters, "sdf"),
          options=options,
      )
      tf_fn = partial(
          xm.renderer.tf.forward_batched,
          params=subdict(parameters, "tf"),
          options=options,
      )

  query_points= volume_query_points(volume,grid_size)

  fn = nerstf_fn
  sdf_out = fn(
      query=Query(position=query_points[None].repeat(batch_size, 1, 1)),
      query_batch_size=query_batch_size,
      options=options,
  )
  # print(sdf_out)

  fields = sdf_out.signed_distance.float()
  raw_signed_distance = sdf_out.signed_distance
  assert (
      len(fields.shape) == 3 and fields.shape[-1] == 1
  ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
  fields = fields.reshape(batch_size, *([grid_size] * 3))

  ######### DOWNSAMPLING OF FIELD TO ACHIEVE BETTER ACCURACY FOR LOWER RESOLUTIONS #########
  # pool of square window
  m = torch.nn.AvgPool3d(3, stride=3)

  downsampled_fields = m(fields)
  downsampled_fields = downsampled_fields.round()

  # textures

  query_points= volume_query_points(volume,grid_size//3)

  fn = nerstf_fn
  sdf_out = fn(
      query=Query(position=query_points[None].repeat(batch_size, 1, 1)),
      query_batch_size=query_batch_size,
      options=options,
  )

  texture_grid = sdf_out.channels.float().view(grid_size//3,grid_size//3,grid_size//3,3)

  out = AttrDict(
            fields = fields,
            downsampled_fields = downsampled_fields,
            texture_grid = texture_grid,
            raw_signed_distance = raw_signed_distance,
            volume = volume
        )

  return out

In [None]:
def marching_cubes_voxel(
    field: torch.Tensor,
    min_point: torch.Tensor,
    size: torch.Tensor,
):
    assert len(field.shape) == 3, "input must be a 3D scalar field"
    dev = field.device

    grid_size = field.shape
    grid_size_tensor = torch.tensor(grid_size).to(size)

    # Create bitmasks between 0 and 255 (inclusive) indicating the state
    # of the eight corners of each cube.
    bitmasks = (field >= 0).to(torch.uint8)
    # Compute corner coordinates across the entire grid.
    corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
    corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(
        grid_size[0], device=dev, dtype=field.dtype
    )[:, None, None]
    corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(
        grid_size[1], device=dev, dtype=field.dtype
    )[:, None]
    corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(
        grid_size[2], device=dev, dtype=field.dtype
    )
    # Create a flat array of [X, Y, Z] indices for each cube.
    cube_indices = torch.zeros(
        grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
    )
    cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[
        :, None, None
    ]
    cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[
        :, None
    ]
    cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
    flat_cube_indices = cube_indices.reshape(-1, 3)

    flat_bitmasks = bitmasks.reshape(
        -1
    ).long()  # must cast to long for indexing to believe this not a mask

    used_indices = torch.nonzero(bitmasks)

    out = AttrDict(
            bitmasks = bitmasks,
            flat_bitmasks = flat_bitmasks,
            used_indices = used_indices,
            corner_coords = corner_coords,
            flat_cube_indices = flat_cube_indices
        )
    return out


In [None]:
batch_size = 1 # this is the size of the models, higher values take longer to generate.
guidance_scale = 10.0 # this is the scale of the guidance, higher values make the model look more like the prompt.
prompt = "a rubber duck" # this is the prompt, you can change this to anything you want.

# Generate the latents

latents = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(texts=[prompt] * batch_size),
    progress=True,
    clip_denoised=True,
    use_fp16=True,
    use_karras=True,
    karras_steps=64,
    sigma_min=1E-3,
    sigma_max=160,
    s_churn=0,
)

In [None]:
# Render the latents as images. (gif)
# (disabled by default, uncomment to enable)

# render_mode = 'nerf' # you can change this to 'stf'
# size = 64 # this is the size of the renders, higher values take longer to render.

# cameras = create_pan_cameras(size, device)
# for i, latent in enumerate(latents):
#     images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
#     display(gif_widget(images))

In [None]:
# final grid size value can be changed to any value, higher values take longer to render.
final_grid_size = 24

# this is used for downsampling to get better results for lower resolutions.
xm.renderer.grid_size = final_grid_size * 3

for i, latent in enumerate(latents):
  t = voxelGrid(xm, latent)
  mc = (marching_cubes_voxel(t.downsampled_fields[0], t.volume.bbox_min, t.volume.bbox_max - t.volume.bbox_min))


In [None]:
# Plotting the voxel grid to visualize the model.

from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

ax = plt.figure().add_subplot(projection='3d')
ax.voxels(mc.bitmasks.cpu().numpy(),facecolors=t.texture_grid.cpu().numpy())
ax.set_aspect('equal')


plt.show()


In [None]:
#prepare the save directory
prompt_dirname = prompt.replace(" ", "_")

#save color and coords data for use in minecraft script
torch.save(mc.used_indices, f'{prompt_dirname}/coords.pt')
torch.save(t.texture_grid, f'{prompt_dirname}/colors.pt')