# Fit terrain mesh from rover images


In [None]:
import os
import cv2
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

from pytorch3d.vis.plotly_vis import plot_scene

from pytorch3d.renderer import (
    PointLights,
    BlendParams,
    DirectionalLights,
    FoVPerspectiveCameras,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesVertex,
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

from pytorch3d_utils import structured_grid_to_pytorch3d_mesh
from lac.utils.plotting import plot_poses, plot_surface, plot_mesh
from lac.utils.frames import get_cam_pose_rover, CAMERA_TO_OPENCV_PASSIVE, invert_transform_mat
from lac.util import load_data
from lac.params import IMG_FOV_RAD

%load_ext autoreload
%autoreload 2

In [None]:
map = np.load("../../../data/heightmaps/competition/Moon_Map_01_preset_0.dat", allow_pickle=True)

mesh = structured_grid_to_pytorch3d_mesh(map[..., :3])
mesh = mesh.to(device)

In [None]:
fig = plot_mesh(mesh)
fig.update_layout(width=1200, height=700, scene=dict(aspectmode="data"))
fig.show()

In [None]:
# data_path = "../../../output/LocalizationAgent/map1_preset0_4m_spiral"
data_path = "../../../output/DataCollectionAgent/map1_preset0_nolight_allcams"
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)

In [None]:
def rover_pose_to_cam_pose(rover_pose, cam_name="FrontLeft"):
    camera_pose = get_cam_pose_rover(cam_name)
    camera_pose[:3, :3] = CAMERA_TO_OPENCV_PASSIVE
    return rover_pose @ camera_pose

In [None]:
# don't ask me why this works
def cam_pose_to_p3d_cam(cam_pose):
    R_p3d = cam_pose[:3, :3].T.copy()
    # R_p3d[:, 0] *= -1
    # R_p3d[:, 1] *= -1
    T_p3d = -R_p3d @ cam_pose[:3, 3]
    return R_p3d.T, T_p3d

In [None]:
from scipy.spatial.transform import Rotation

r = Rotation.from_euler("XYZ", (0, -90, 0), degrees=True)

In [None]:
R = []
T = []
target_rgb = []
cam_poses = []

start_idx = 100
end_idx = 4000
increment = 10
idxs = np.arange(start_idx, end_idx, increment)
num_views = len(idxs)
print("num_views: ", num_views)

for i in idxs:
    cam_pose = rover_pose_to_cam_pose(poses[i])
    # Convert to pytorch3d convention
    cam_pose[:, 0] *= -1
    cam_pose[:, 1] *= -1
    cam_poses.append(cam_pose)

    R_p3d, T_p3d = cam_pose_to_p3d_cam(cam_pose)

    R.append(torch.tensor(R_p3d, device=device).float())
    T.append(torch.tensor(T_p3d, device=device).float())
    img_np = cv2.imread(os.path.join(data_path, "FrontLeft", f"{i}.png")) / 255.0
    new_size = (img_np.shape[1] // 2, img_np.shape[0] // 2)
    downscaled_img_np = cv2.resize(img_np, new_size, interpolation=cv2.INTER_AREA)
    target_rgb.append(torch.tensor(downscaled_img_np, device=device).float())

R = torch.stack(R)
T = torch.stack(T)
target_cameras = [
    FoVPerspectiveCameras(
        device=device, R=R[None, i, ...], T=T[None, i, ...], fov=IMG_FOV_RAD, degrees=False
    )
    for i in range(num_views)
]

In [None]:
fig = plot_mesh(mesh)
i = 100
fig = plot_poses(cam_poses[i : i + 1], fig=fig)
fig.update_layout(width=1200, height=700, scene=dict(aspectmode="data"))
fig.show()

In [None]:
plt.imshow(target_rgb[i].detach().cpu().numpy())

In [None]:
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=IMG_FOV_RAD, degrees=False)
camera = FoVPerspectiveCameras(
    device=device, R=R[None, i, ...], T=T[None, i, ...], fov=IMG_FOV_RAD, degrees=False
)

In [None]:
# Place a point light in front of the object. As mentioned above, the front of
# the cow is facing the -z direction.
# lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
# lights = PointLights(
#     device=device,
#     diffuse_color=[[1.0, 1.0, 1.0]],
#     ambient_color=[[0.5, 0.5, 0.5]],
#     specular_color=[[0.0, 0.0, 0.0]],
#     location=[[0.0, 0.0, -3.0]],
# )
lights = DirectionalLights(
    device=device,
    diffuse_color=[[2.0, 2.0, 2.0]],
    ambient_color=[[0.0, 0.0, 0.0]],
    specular_color=[[0.0, 0.0, 0.0]],
    direction=[[0.0, 1.0, -0.05]],
)

# Rasterization settings for differentiable rendering, where the blur_radius
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable
# Renderer for Image-based 3D Reasoning', ICCV 2019
sigma = 1e-4
raster_settings_soft = RasterizationSettings(
    image_size=(360, 640),
    blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
    faces_per_pixel=50,
    perspective_correct=False,
)

# Differentiable soft renderer using per vertex RGB colors for texture
blend_params = BlendParams(background_color=(0.0, 0.0, 0.0))
renderer_textured = MeshRenderer(
    rasterizer=MeshRasterizer(cameras=camera, raster_settings=raster_settings_soft),
    shader=SoftPhongShader(device=device, cameras=camera, lights=lights, blend_params=blend_params),
)

In [None]:
verts_shape = mesh.verts_packed().shape
terrain_verts_rgb = torch.full([1, verts_shape[0], 3], 1.0, device=device)
textured_mesh = mesh.clone()
textured_mesh.textures = TexturesVertex(terrain_verts_rgb)

rendered_image = renderer_textured(textured_mesh, cameras=camera, lights=lights)
rendered_rgb = rendered_image[0, ..., :3].cpu().numpy()
plt.imshow(rendered_rgb)

In [None]:
from pytorch3d.loss import (
    mesh_edge_loss,
    mesh_laplacian_smoothing,
    mesh_normal_consistency,
)

In [None]:
# Number of views to optimize over in each SGD iteration
num_views_per_iteration = 2
# Number of optimization steps
Niter = 2000
# Plot period for the losses
plot_period = 250

%matplotlib inline

# Optimize using rendered RGB image loss, rendered silhouette image loss, mesh
# edge loss, mesh normal consistency, and mesh laplacian smoothing
losses = {
    "rgb": {"weight": 1.0, "values": []},
    "silhouette": {"weight": 1.0, "values": []},
    "edge": {"weight": 1.0, "values": []},
    "normal": {"weight": 0.01, "values": []},
    "laplacian": {"weight": 1.0, "values": []},
}


# Losses to smooth / regularize the mesh shape
def update_mesh_shape_prior_losses(mesh, loss):
    # and (b) the edge length of the predicted mesh
    loss["edge"] = mesh_edge_loss(mesh)
    # mesh normal consistency
    loss["normal"] = mesh_normal_consistency(mesh)
    # mesh laplacian smoothing
    loss["laplacian"] = mesh_laplacian_smoothing(mesh, method="uniform")


# Deform the mesh
verts_shape = mesh.verts_packed().shape
# deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)
deform_zs = torch.zeros(verts_shape[0], device=device, requires_grad=True)

# Learn per vertex colors that define texture of the mesh
verts_shape = mesh.verts_packed().shape
terrain_verts_rgb = torch.full([1, verts_shape[0], 3], 1.0, device=device, requires_grad=True)

# The optimizer
optimizer = torch.optim.Adam([deform_zs, terrain_verts_rgb], lr=1e-2)

In [None]:
# Show a visualization comparing the rendered predicted mesh to the ground truth
# mesh
def visualize_prediction(
    predicted_mesh,
    renderer=renderer_textured,
    target_image=target_rgb[1],
    title="",
    silhouette=False,
):
    inds = 3 if silhouette else range(3)
    with torch.no_grad():
        predicted_images = renderer(predicted_mesh)
    plt.figure(figsize=(20, 10))
    plt.subplot(1, 2, 1)
    plt.imshow(predicted_images[0, ..., inds].cpu().detach().numpy())

    plt.subplot(1, 2, 2)
    plt.imshow(target_image.cpu().detach().numpy())
    plt.title(title)
    plt.axis("off")

In [None]:
loop = tqdm(range(Niter))

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()

    # # Deform the mesh
    deform_verts = torch.zeros(verts_shape, device=device)
    deform_verts[:, 2] = deform_zs
    new_mesh = mesh.offset_verts(deform_verts)
    # new_mesh = mesh.clone()

    # Add per vertex colors to texture the mesh
    new_mesh.textures = TexturesVertex(verts_features=terrain_verts_rgb)

    # Losses to smooth /regularize the mesh shape
    loss = {k: torch.tensor(0.0, device=device) for k in losses}
    update_mesh_shape_prior_losses(new_mesh, loss)

    # Randomly select two views to optimize over in this iteration.  Compared
    # to using just one view, this helps resolve ambiguities between updating
    # mesh shape vs. updating mesh texture
    for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:
        images_predicted = renderer_textured(new_mesh, cameras=target_cameras[j], lights=lights)

        # Squared L2 distance between the predicted RGB image and the target
        # image from our dataset
        predicted_rgb = images_predicted[..., :3]
        loss_rgb = ((predicted_rgb - target_rgb[j]) ** 2).mean()
        loss["rgb"] += loss_rgb / num_views_per_iteration

    # Weighted sum of the losses
    sum_loss = torch.tensor(0.0, device=device)
    for k, l in loss.items():
        sum_loss += l * losses[k]["weight"]
        losses[k]["values"].append(float(l.detach().cpu()))

    # Print the losses
    loop.set_description("total_loss = %.6f" % sum_loss)

    # Plot mesh
    if i % plot_period == 0:
        visualize_prediction(
            new_mesh, renderer=renderer_textured, title="iter: %d" % i, silhouette=False
        )

    # Optimization step
    sum_loss.backward()
    optimizer.step()

In [None]:
final_mesh = mesh.offset_verts(deform_verts)
final_mesh.textures = TexturesVertex(verts_features=terrain_verts_rgb)

In [None]:
final_mesh.verts_packed()

In [None]:
fig = plot_mesh(final_mesh, show_edges=False, textured=True)
fig.update_layout(width=1200, height=700, scene=dict(aspectmode="data"))
fig.show()

In [None]:
# Plot losses as a function of optimization iteration
def plot_losses(losses):
    fig = plt.figure(figsize=(13, 5))
    ax = fig.gca()
    for k, l in losses.items():
        ax.plot(l["values"], label=k + " loss")
    ax.legend(fontsize="16")
    ax.set_xlabel("Iteration", fontsize="16")
    ax.set_ylabel("Loss", fontsize="16")
    ax.set_title("Loss vs iterations", fontsize="16")

In [None]:
visualize_prediction(new_mesh, renderer=renderer_textured, silhouette=False)
plot_losses(losses)