[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abossenbroek/3d-from-2d-reconstruction/blob/main/labs/pytorch3d_intro/Deformable_mesh_rendering_on_point_cloud.ipynb)

# Fit a raw point cloud on a mesh

This tutorial shows how to:
- Load a point cloud from a `.ply` file and view it
- Fit a mesh to the point cloud
- Use loss functions on meshes and point clouds
- How to sample from point clouds and meshes

In [None]:
%%shell 
pip install open3d

In [None]:
import os
import sys
import torch

need_pytorch3d = False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d = True
if need_pytorch3d:
    if torch.__version__.startswith("1.13.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str = torch.__version__.split("+")[0].replace(".", "")
        version_str = "".join(
            [
                f"py3{sys.version_info.minor}_cu",
                torch.version.cuda.replace(".", ""),
                f"_pyt{pyt_version_str}",
            ]
        )
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import plotly.graph_objects as go

# Util function for loading meshes
from pytorch3d.io import load_obj, load_objs_as_meshes, load_ply, save_obj, save_ply
from pytorch3d.loss import (
    chamfer_distance,
    mesh_edge_loss,
    mesh_laplacian_smoothing,
    mesh_normal_consistency,
)
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.renderer import (
    DirectionalLights,
    FoVPerspectiveCameras,
    Materials,
    MeshRasterizer,
    MeshRenderer,
    PointLights,
    RasterizationSettings,
    SoftPhongShader,
    TexturesUV,
    TexturesVertex,
    look_at_view_transform,
)

# Data structures and functions for rendering
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils import ico_sphere
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from tqdm import tqdm

## 0. Install and Import modules

Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:

In [None]:
sys.path.append(os.path.abspath(""))

In [None]:
# Set the device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    print("WARNING: CPU only, this will be slow!")

### 1. Load a raw point cloud

Load an `.ply` file. 

- load with `open3d` to allow easy access to x, y, z for plotting with `plotly`
- plot with `plotly` as scatter plot
- load with `pytorch3d` using `load_ply()`
- scale and center the raw point cloud to ensure the mean is at 0 on all axis and the maximum on each axis is bounded between $[-1, 1]$

In [None]:
!wget https://raw.githubusercontent.com/PacktPublishing/3D-Deep-Learning-with-Python/main/chap3/pedestrian.ply

In [None]:
pcd = o3d.io.read_point_cloud("pedestrian.ply")

In [None]:
# We read the target 3D model using load_obj
verts, faces = load_ply("pedestrian.ply")

point_cloud = Pointclouds(points=[verts])

verts = verts.to(device)
faces = faces.to(device)

# Center around 0 and scale to [-1, 1]
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale
verts = verts[None, :, :]

In [None]:
print(np.asarray(pcd.points))

In [None]:
x = np.asarray(pcd.points)[:, 0]
y = np.asarray(pcd.points)[:, 1]
z = np.asarray(pcd.points)[:, 2]

In [None]:
fig = go.Figure(
    data=[
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode="markers",
            marker=dict(
                size=12,
                color=z,  # set color to an array/list of desired values
                colorscale="Viridis",  # choose a colorscale
                opacity=0.8,
            ),
        )
    ]
)
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

## 2. Mesh prediction via optimization
In the previous section, we created loaded our raw point cloud.

Next we want to create a mesh that minimizes our losses.

### 2.1 Initialization


In [None]:
# Plot losses as a function of optimization iteration
def plot_losses(losses):
    fig = plt.figure(figsize=(13, 5))
    ax = fig.gca()
    for k, l in losses.items():
        ax.plot(l["values"], label=k + " loss")
    ax.legend(fontsize="16")
    ax.set_xlabel("Iteration", fontsize="16")
    ax.set_ylabel("Loss", fontsize="16")
    ax.set_title("Loss vs iterations", fontsize="16")

In [None]:
%matplotlib inline

In [None]:
losses = {
    "chamfer": {"weight": 1.0, "values": []},  # Weight for the chamfer loss
    "edge": {"weight": 1.0, "values": []},  # Weight of mesh edge loss
    "normal": {"weight": 0.01, "values": []},  # Weight of normal consistency
    "laplacian": {"weight": 1.0, "values": []},  # Weight of mesh laplacian smoothing
}

# Number of optimization steps
Niter = 2000
# Plot period for the losses
plot_period = 250

# SGD learning ratge
SGD_lr = 1.0
# SGD momentum
SGD_momentum = 0.9

In [None]:
# Start with a sphere
src_mesh = ico_sphere(4, device)
src_vert = src_mesh.verts_list()
deform_verts = torch.full(src_vert[0].shape, 0.0, device=device, requires_grad=True)

# The optimizer
optimizer = torch.optim.SGD([deform_verts], lr=SGD_lr, momentum=SGD_momentum)

### 2.2 Define update rules

In [None]:
meshes = {"iter": [], "mesh": []}


# Losses to smooth / regularize the mesh shape
def update_mesh_shape_prior_losses(mesh, original_points, sampled_points, loss):

    loss["chamfer"], _ = chamfer_distance(sampled_points, original_points)

    # and (b) the edge length of the predicted mesh
    loss["edge"] = mesh_edge_loss(mesh)

    # mesh normal consistency
    loss["normal"] = mesh_normal_consistency(mesh)

    # mesh laplacian smoothing
    loss["laplacian"] = mesh_laplacian_smoothing(mesh, method="uniform")

### 2.3 Optimization loop

In [None]:
loop = tqdm(range(Niter))

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()

    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)

    # We sample 5k points from the surface of each mesh
    sample_src = sample_points_from_meshes(new_src_mesh, verts.shape[1])

    # Losses to smooth /regularize the mesh shape
    loss = {k: torch.tensor(0.0, device=device) for k in losses}
    update_mesh_shape_prior_losses(new_src_mesh, verts, sample_src, loss)

    # Weighted sum of the losses
    sum_loss = torch.tensor(0.0, device=device)
    for k, l in loss.items():
        sum_loss += l * losses[k]["weight"]
        losses[k]["values"].append(float(l.detach().cpu()))

    # Print the losses
    loop.set_description("total_loss = %.6f" % sum_loss)

    # Optimization step
    sum_loss.backward()
    optimizer.step()

    # Store rendered meshes
    if i % plot_period:
        tmp_verts, temp_face = new_src_mesh.get_mesh_verts_faces(0)
        tmp_verts_transformed = tmp_verts * scale + center
        temp_mesh = Meshes(
            verts=[tmp_verts_transformed.to(device)],
            faces=[tmp_verts_transformed.to(device)],
        )

        meshes["iter"].append(i)
        meshes["mesh"].append(temp_mesh)


# Fetch the verts and faces of the final predicted mesh
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)

# Scale normalize back to the original target size
final_verts = final_verts * scale + center

In [None]:
mesh = Meshes(
    verts=[final_verts.to(device)],
    faces=[final_faces.to(device)],
)

### 2.4 Show intermediate and final result

In [None]:
# Render the plotly figure
plots = {}

for i, m in zip(meshes["iter"], meshes["mesh"]):
    plots[f"Iteration {i}"] = {"point cloud to fit": point_cloud, "mesh_progression": m}


plots["Final mesh"] = {"mesh_progression": mesh}

fig = plot_scene(plots, ncols=4)
fig.update_layout(width=1600, height=2400, margin=dict(l=80, r=80, t=20, b=20))
fig.show()

#### 2.4.2 Show losses

In [None]:
plot_losses(losses)

## 3. Conclusion
In this tutorial, we learned how to load a raw point cloud. We center the point cloud to ensure that we can easily use it with a deep learning optimizer. Then we deformed an initial icosphere onto the raw point cloud. We optimized over the verteces by reducing various losses defined for point clouds and meshes.