# Image to 3D

## Transformers Implementation

## Render to Mesh

In [None]:
import numpy as np

def depth_edges_mask(depth):
    """Returns a mask of edges in the depth map.
    Args:
    depth: 2D numpy array of shape (H, W) with dtype float32.
    Returns:
    mask: 2D numpy array of shape (H, W) with dtype bool.
    """
    # Compute the x and y gradients of the depth map.
    depth_dx, depth_dy = np.gradient(depth)
    # Compute the gradient magnitude.
    depth_grad = np.sqrt(depth_dx**2 + depth_dy**2)
    # Compute the edge mask.
    mask = depth_grad > 0.01
    return mask
    # # Save as glb
    # glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
    # glb_path = glb_file.name
    # mesh.export(glb_path)
    # return glb_path

In [None]:
import os
import json
import torch
import trimesh
import numpy as np
# import open3d as o3d
from tqdm import tqdm 
from PIL import Image
from typing import Callable, List, Optional, Tuple
from pytorch3d.io import load_objs_as_meshes, load_obj

import pytorch3d
from pytorch3d.structures import Meshes, Pointclouds
import pytorch3d.utils
from pytorch3d.renderer import (
    FoVPerspectiveCameras,
    PointLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    HardPhongShader,
    TexturesUV,
    TexturesVertex,
    Textures
)
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene

In [None]:
import torch
from pytorch3d import renderer
from pytorch3d.vis.plotly_vis import plot_scene, AxisArgs
from pytorch3d.renderer.camera_utils import join_cameras_as_batch

def get_cameras_looking_at_points(camera_locations, image_size, look_at_points=None, focal_length=0.7):
    # Extract device from input `camera_locaitons` tensor
    device = camera_locations.device

    number_of_cameras = camera_locations.shape[0]

    # `look_at_points` defaults to the center
    if look_at_points is None:
        look_at_points = torch.zeros_like(camera_locations)

    if not torch.is_tensor(focal_length):
        focal_length = torch.tensor([focal_length]*number_of_cameras).to(dtype=torch.float32, device=device)

    # Get camera rotation and translation
    R, T = pytorch3d.renderer.look_at_view_transform(at=look_at_points, eye=camera_locations)

    image_size = image_size.unsqueeze(0).expand(number_of_cameras, -1)

    # Define Camera
    cameras = renderer.cameras.PerspectiveCameras(
        focal_length=focal_length,
        principal_point=image_size/2,
        R=R,
        T=T,
        in_ndc=False, 
        image_size=image_size.flip(-1),
        device=device
    )

    return (
        cameras[0],
        join_cameras_as_batch([cameras[i] for i in range(1,number_of_cameras)]) if number_of_cameras > 1 else None
    )

In [None]:
imsz = torch.tensor(image.size)
imsz.unsqueeze(0).expand(3, -1).flip(-1)

In [None]:
def load_mesh(obj_file_path, device="cpu"):
    mesh = load_objs_as_meshes([obj_file_path], device=device)
    verts, faces = mesh.get_mesh_verts_faces(0)
    texture_rgb = torch.ones_like(verts, device=device)
    texture_rgb[:, 1:] *= 0.0  # red, by zeroing G and B
    mesh.textures = Textures(verts_rgb=texture_rgb[None])
    
    # Normalize mesh
    verts = verts - verts.mean(dim=0)
    verts /= verts.max()
    
    # This updates the pytorch3d mesh with the new vertex coordinates.
    mesh = mesh.update_padded(verts.unsqueeze(0))
    verts, faces = mesh.get_mesh_verts_faces(0)
    
    return mesh, verts, faces

In [None]:
def generate_faces_from_grid(H, W, device):
    # Source (modified): https://huggingface.co/spaces/shariqfarooq/ZoeDepth/blob/main/gradio_im_to_3d.py
    """Creates mesh triangle indices from a given pixel grid size.
        This function is not and need not be differentiable as triangle indices are
        fixed.

    Args:
        h: (int) denoting the height of the image.
        w: (int) denoting the width of the image.

    Returns:
        triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
    """
    '''
    00---01
    |    |
    10---11
    '''
    vertex_ids = torch.arange(H*W).reshape(H, W).to(depth.device)
    vertex_00 = vertex_ids[:H-1, :W-1]
    vertex_01 = vertex_00 + 1
    vertex_10 = vertex_00 + W
    vertex_11 = vertex_00 + W + 1

    return torch.cat(
        torch.stack(
            # counter-clockwise orientation
            # TODO: does the order matter?
            [
                vertex_00, vertex_10, vertex_01, # faces_upper_left_triangle
                 vertex_10, vertex_11, vertex_01 # faces_lower_right_triangle
            ]
        ).flatten(1).chunk(2),
        dim=1
    ).permute(1, 0)


def edge_threshold_filter(vertices, faces, edge_threshold=0.1, test=1):
    """
    Only keep faces where all edges are smaller than edge_threshold.
    Will remove stretch artifacts that are caused by inconsistent depth at object borders

    :param vertices: (N, 3) torch.Tensor of type torch.float32
    :param faces: (M, 3) torch.Tensor of type torch.long
    :param edge_threshold: maximum length per edge (otherwise removes that face).

    :return: filtered faces
    """
    # if test==2:
        

    edge_distances = torch.linalg.vector_norm(vertices[faces] - vertices[faces].roll(shifts=-1, dims=1), dim=2)
    if test==1:
        edge_distances = (vertices[faces] - vertices[faces].roll(shifts=-1, dims=1))[:, :, 2]
        
    mask_small_edge = (edge_distances < edge_threshold).all(1)

    return faces[mask_small_edge, :], faces[(edge_distances >= edge_threshold).any(1), :]

@torch.inference_mode
def predict_depth(image, image_processor, model):
    inputs = image_processor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    outputs_flip = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))

    return post_process_depth_estimation_zoedepth(outputs, [image.size[::-1]], outputs_flip=outputs_flip)[0][
        "predicted_depth"
    ]

In [None]:
image_processor(images=[image, image], return_tensors="pt")["pixel_values"].shape

In [None]:
# num_points = {}
# for test in [True, False]:
#     num_points[test] = {}
#     for focal_length in np.arange(0.1, 1.8, 0.1):
#         num_points[test][focal_length.item()] = {}
#         for threshold in np.arange(0.001, 0.04, 0.001):
#             cameras = get_cameras_looking_at_points(torch.tensor(
#                 [
#                     0.0,
#                     0.0,
#                     predicted_depth[predicted_depth.shape[0]//2, predicted_depth.shape[1]//2]
#                 ]).unsqueeze(0), focal_length=focal_length.item())[0]
#             world_points = cameras.unproject_points(xyz)
#             num_points[test][focal_length.item()][threshold.item()] = edge_threshold_filter(
#                 world_points, faces_new, threshold, test=test
#             )[0].shape[0]

# import plotly.graph_objects as go
# import numpy as np
# from plotly.subplots import make_subplots

# # Create a figure object
# fig = make_subplots(rows=1, cols=2)

# # Add a surface plot for each combination of test, focal_length, and threshold
# for test in [True]:#, False]:
#     for threshold in np.arange(0.001, 0.04, 0.001):
#         num_points_val = [
#             num_points[True][focal_length.item()][threshold.item()]
#             for focal_length in np.arange(0.1, 1.8, 0.1)
#         ]
#         fig.add_trace(
#             go.Scatter(
#                 x=np.arange(0.1, 1.8, 0.1), y=num_points_val,
#                 mode='lines', name=f"Test=True, Threshold={threshold:.4f}"
#             ),
#             row=1, col=1
#         )
#         num_points_val = [
#             num_points[False][focal_length.item()][threshold.item()]
#             for focal_length in np.arange(0.1, 1.8, 0.1)
#         ]
#         fig.add_trace(
#             go.Scatter(
#                 x=np.arange(0.1, 1.8, 0.1), y=num_points_val,
#                 mode='lines', name=f"Test=False, Threshold={threshold:.4f}"
#             ),
#             row=1, col=2
#         )

# # Set the layout and show the plot
# # fig.update_layout(title='Number of Points',
# #                    scene=dict(xaxis_title='Test',
# #                                yaxis_title='Focal Length',
# #                                zaxis_title='Threshold'))
# fig.show()

In [None]:
def get_pixel_coordinates_pt3d(
    height: int,
    width: int,
    device: torch.device = torch.device('cpu')
):
    """For an image with y_resolution and x_resolution, return a tensor of pixel coordinates
    normalized to lie in [0, 1], with the origin (0, 0) in the bottom left corner,
    the x-axis pointing right, and the y-axis pointing up. The top right corner
    being at (1, 1).

    Returns:
        xy_pix: a meshgrid of values from [0, 1] of shape 
                (y_resolution, x_resolution, 2)
    """
    xs = torch.arange(width-1, -1, -1)  # Inverted the order for x-coordinates
    ys = torch.arange(height-1, -1, -1)  # Inverted the order for y-coordinates
    x, y = torch.meshgrid(xs, ys, indexing='xy')

    return torch.cat([x.unsqueeze(dim=2), y.unsqueeze(dim=2)], dim=2).to(device)

In [None]:
from PIL import Image
import requests

import torch
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation

from zoedepth_post_processing import post_process_depth_estimation_zoedepth

image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu")

# prepare image for the model
url = "https://shariqfarooq-zoedepth.hf.space/file=/home/user/app/examples/person_1.jpeg"
url = "https://shariqfarooq-zoedepth.hf.space/file=/home/user/app/examples/mountains.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
image.thumbnail((512,512))

In [None]:
image

In [None]:
predicted_depth = predict_depth(image, image_processor, model)

In [None]:
Image.fromarray(((predicted_depth/predicted_depth.max())*255).astype("uint8"))

In [None]:
def sph2cart(az, el, r):
    # source https://github.com/numpy/numpy/issues/5228#issue-46746558
    rcos_theta = r * np.cos(np.radians(el))
    x = rcos_theta * np.cos(np.radians(az))
    y = rcos_theta * np.sin(np.radians(az))
    z = r * np.sin(np.radians(el))
    return np.round(x, 5), np.round(y, 5), np.round(z, 5)

In [None]:
torch.tensor(np.stack(sph2cart([45*i for i in range(8)], [80]*8, [3]*8)).T)

In [None]:
# z_0 = predicted_depth[predicted_depth.shape[0]//2, predicted_depth.shape[1]//2]
z_0 = (predicted_depth.max() + predicted_depth.min())/2
# main_camera, other_cameras = get_cameras_looking_at_points(
#     camera_locations = torch.tensor([[0, 0, z_0]] + [
#         [x, y, z_0]
#         for x in [-z_0, 0, z_0]
#         for y in [-z_0, 0, z_0]
#         if x+y != x*y
#     ]).to(dtype=torch.float32, device=device),
#     image_size=torch.tensor(image.size).to(dtype=torch.float32, device=device),
#     focal_length=6 * image.size[0]
# )
device="cpu"
main_camera, other_cameras = get_cameras_looking_at_points(
    camera_locations = torch.tensor(
        [[0, 0, z_0]] + torch.tensor(np.stack(
            sph2cart([45*i for i in range(8)], [80]*8, [z_0]*8)
        ).T).tolist()
    ).to(dtype=torch.float32, device=device),
    image_size=torch.tensor(image.size).to(dtype=torch.float32, device=device),
    focal_length=0.5 * image.size[0]
)

In [None]:
device="cpu"
img_resolution = image.size[::-1]
xy_pix = get_pixel_coordinates_pt3d(img_resolution[0], img_resolution[1], device=device)
xy_pix = xy_pix.flatten(0, -2)
depth = torch.tensor(predicted_depth).unsqueeze(2).flatten(0, -2)
xyz = torch.cat((xy_pix, depth), dim=1)
world_points = main_camera.unproject_points(xyz)    

In [None]:
image.size

In [None]:
faces_new = generate_faces_from_grid(image.size[1], image.size[0], "cpu")
faces_filtered, faces_removed = edge_threshold_filter(world_points, faces_new, 0.01)

In [None]:
print(len(faces_removed))
print(len(faces_filtered))

In [None]:
from pytorch3d.renderer import TexturesVertex

colors = torch.tensor(np.array(image).reshape(-1, 3)/255)
colors[faces_removed.unique()] = torch.zeros(3).to(colors)
textures = TexturesVertex(verts_features=colors.unsqueeze(0))
# textures_white = TexturesVertex(verts_features=torch.tensor(np.array(image).reshape(-1, 3)*0.0 + 1).unsqueeze(0))
# textures_white = TexturesAtlas(atlas=torch.ones(size=(1, faces_removed.shape[0], 1, 1, 3)))

trg_mesh = Meshes(verts=[world_points], faces=[faces_new], textures=textures)
# bad_mesh = Meshes(verts=[world_points], faces=[faces_removed], textures=textures_white)
verts, faces = trg_mesh.get_mesh_verts_faces(0)
print(
    f"\nVertices: {verts.shape}"
    f"\nFaces: {faces.shape}"
)

In [None]:
point_cloud = Pointclouds(points=[world_points], features=[colors])

In [None]:
from rich import inspect
inspect(plot_scene)

In [None]:
fig = plot_scene(
    {
        "mesh": {
            "mesh": trg_mesh,
            # "mesh_bad": bad_mesh,
            "main_camera": main_camera,
            "other_cameras": other_cameras,
        },
        "mesh2": {
            "pointcloud": point_cloud,
            "main_camera": main_camera,
            "other_cameras": other_cameras,
        }
    },
    axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)", showgrid=True, showticklabels=True),
    ncols=1,
    viewpoint_cameras=main_camera
)

fig.update_layout(
    autosize=False,
    width=1200,
    height=1500,
)

fig

In [None]:
from rich import inspect
inspect(trg_mesh, all=True)

In [None]:
import os
import torch
import random
import numpy as np
# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, load_obj
import time
# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    OpenGLPerspectiveCameras,
    PointLights,
    diffuse,
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    TexturesAtlas,
    BlendParams,
    AmbientLights
)

BlendParams.background_color = (1.,1.,1.)
BlendParams.gamma = 0.001

# Settings for rasterization
raster_settings = RasterizationSettings(
    image_size=(512, 341),
    blur_radius=0.0,
    faces_per_pixel=1,
    max_faces_per_bin = 50000,
)

# class Lighting(NamedTuple):  # pragma: no cover
#     ambient: float = 0.8
#     diffuse: float = 1.0
#     fresnel: float = 0.0
#     specular: float = 0.0
#     roughness: float = 0.5
#     facenormalsepsilon: float = 1e-6
#     vertexnormalsepsilon: float = 1e-12

# # Setting the lights
lights = PointLights(
    ambient_color=((0.7, 0.7, 0.7), ),
    diffuse_color=((0.1, 0.1, 0.1), ),
    specular_color=((0.1, 0.1, 0.1), ),
    device=device,
    location=(
        torch.cat([other_camera.T for other_camera in other_cameras]).tolist()
    )
)

# Setting the renderer
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=other_cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=other_cameras,
        lights=lights
    )
)
from pytorch3d.structures import join_meshes_as_batch
# Rendering the image
images = renderer(join_meshes_as_batch([trg_mesh]*len(other_cameras)))

In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

images_p = images.clone()
images_p = (images_p.numpy()*255).astype("uint8")
images_p = [Image.fromarray(image_p) for image_p in images_p]
images_p = images_p[:4] + [image] + images_p[4:]

images_grid = image_grid([images_p[i] for i in [3, 2, 1, 5, 4, 0, 6, 7, 8]], 3, 3)
images_grid.save("out.png")
images_grid

In [None]:
# # Setting the lights
lights = PointLights(
    device=device,
    location=(
        torch.cat([main_camera.T]).tolist()
    )
)

# Setting the renderer
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=main_camera, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=main_camera,
        lights=lights
    )
)
from pytorch3d.structures import join_meshes_as_batch
# Rendering the image
main_render = renderer(trg_mesh)[0]

In [None]:
Image.fromarray((main_render.numpy()*255).astype("uint8"))

In [None]:
image

In [None]:
!pip uninstall diffusers -y
!pip install git+https://github.com/huggingface/diffusers

In [None]:
import torch
from diffusers.utils import load_image, check_min_version
from diffusers.pipelines import StableDiffusion3ControlNetInpaintingPipeline
from diffusers.models.controlnet_sd3 import SD3ControlNetModel

controlnet = SD3ControlNetModel.from_pretrained(
    "alimama-creative/SD3-Controlnet-Inpainting", use_safetensors=True, extra_conditioning_channels=1
)
pipe = StableDiffusion3ControlNetInpaintingPipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    controlnet=controlnet,
    # torch_dtype=torch.float16,
)
# pipe.text_encoder.to(torch.float16)
# pipe.controlnet.to(torch.float16)
# pipe.to("cuda")

image = load_image(
    "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/resolve/main/images/dog.png"
)
mask = load_image(
    "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/resolve/main/images/dog_mask.png"
)
width = 1024
height = 1024
prompt = "A cat is sitting next to a puppy."
generator = torch.Generator(device="cpu").manual_seed(24)
# res_image = pipe(
#     negative_prompt="deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
#     prompt=prompt,
#     height=height,
#     width=width,
#     control_image=image,
#     control_mask=mask,
#     num_inference_steps=28,
#     generator=generator,
#     controlnet_conditioning_scale=0.95,
#     guidance_scale=7,
# ).images[0]
# res_image.save(f"sd3.png")

In [None]:
image

In [None]:
from PIL import Image
import numpy as np
image = Image.open("./out.png")
image
mask = Image.fromarray(np.where(
    np.logical_or((np.array(image)==255).all(2), (np.array(image)==0).all(2))[:, :, None],
    np.ones_like(image)*255,
    np.zeros_like(image),
))
mask

In [None]:
from diffusers import StableDiffusionInpaintPipeline
import torch

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    revision="fp16",
    torch_dtype=torch.float16,
)
prompt = "realistic, high quality"
#image and mask_image should be PIL images.
#The mask structure is white for inpainting and black for keeping as is
image = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
# image.save("./yellow_cat_on_park_bench.png")
image

In [None]:
image.crop((0, 0, 314, 512)).save("test.png")