# Demo: encoding apollo scape

## Step 1: Loading all meshes from Apollo Scape

In [None]:
import glob
import torch
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.vis.plotly_vis import plot_scene
from pprint import pprint

DEVICE = torch.device("cuda:0")
TSDF_UNIT = torch.Tensor([0.1]*3).to(DEVICE)
APOLLO_SCAPE_PATH = "../assets/apollo_scape/*.obj"

obj_glob = glob.glob(APOLLO_SCAPE_PATH)
pprint(obj_glob)

apollo_scape_meshes = None

verts_list, faces_list = list(), list()
for obj_path in obj_glob:
    verts, faces, _ = load_obj(obj_path, load_textures=False)
    verts_list.append(verts)
    faces_list.append(faces.verts_idx)

apollo_scape_meshes = Meshes(verts_list, faces_list).to(DEVICE)
print(f"Meshes size: {apollo_scape_meshes.__len__()}")

## Step 2: Preparing all the meshes

In [None]:
bbox = apollo_scape_meshes.verts_packed().amax(
    0) - apollo_scape_meshes.verts_packed().amin(0)

antisotropic_res = (bbox / TSDF_UNIT).ceil().int()
quantified_bbox = TSDF_UNIT * antisotropic_res

print(f"bbox: {bbox}")
print(f"resolution: {antisotropic_res}")
print(f"quantified bbox: {quantified_bbox}")


# Move all the vehicle to the center

apollo_scape_meshes.offset_verts_(torch.Tensor([0.0, -bbox[1].item()/2, 0.0]).to(DEVICE))

## Step 3: Spliting training set and testing set

In [None]:
dataset_size = apollo_scape_meshes.__len__()

train_ratio = 0.8
test_ratio = 1 - train_ratio

indices = torch.randperm(dataset_size)

train_size = int(train_ratio * dataset_size)

train_indices = indices[:train_size]
test_indices = indices[train_size:]

train_meshes: Meshes = apollo_scape_meshes[train_indices]
test_meshes: Meshes = apollo_scape_meshes[test_indices]

## Step 4: Instantiating TSDF object & run

In [None]:
import sys
import os

sys.path.append("..")

from voxeltorch import TSDF, tsdf2meshes


tsdf = TSDF(resolution=antisotropic_res + 1, sampling_count=4096,
            downsampling_count=2048, bbox=quantified_bbox, isotropic=True)

In [None]:
train_tsdf_grid = tsdf.tsdf(train_meshes)
print(train_tsdf_grid.size())

test_tsdf_grid = tsdf.tsdf(test_meshes)
print(test_tsdf_grid.size())

## Step 5: Visualizing TSDF Meshes

In [None]:
def visualize_meshes(meshes: Meshes, title: str = "Mesh"):
    mesh_dict = {}
    for idx in range(meshes.__len__()):
        mesh_dict[title + f" {idx}"] = {
            "mesh": meshes[idx]
        }
    fig = plot_scene(mesh_dict, ncols=2)
    fig.update_layout(height=400, width=400 * meshes.__len__())
    return fig.show()

In [None]:
selected_idx = torch.randint(0, train_meshes.__len__() - 1, (2, ))
print("Visualizing original meshes...")
visualize_meshes(train_meshes[selected_idx], "Original Meshes")

print("Visualizing voxel meshes from TSDF...")
tsdf_meshes = tsdf2meshes(
    train_tsdf_grid[selected_idx], TSDF_UNIT)
visualize_meshes(tsdf_meshes, "TSDF Meshes")

## Step 6: Building PCA Encoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PCA_encoder(nn.Module):
    def __init__(self, Q: int = 5):
        # Parameters
        self.mean = None
        self.U = None  # [B, q]
        self.S = None  # [q]
        self.V = None  # [n, q]

        # Target rank compress to
        self.Q = Q

    def fit(self, batch_X: torch.Tensor):
        """
        Args:
            batch_X: [B, N]
        """
        self.mean = batch_X.mean(dim=0)
        self.U, self.S, self.V = torch.pca_lowrank(
            batch_X - self.mean, q=self.Q, center=True)

    def encode(self, batch_X: torch.Tensor):
        """
        Args:
            batch_X: [B, N]
        Returns:
            latent: [B, Q]
        """
        latent = (batch_X - self.mean) @ self.V @ self.S.diag().inverse()
        return latent

    def decode(self, latent: torch.Tensor):
        """
        Args:
            latent: [B, Q]
        Returns:
            reconstructed_X: [B, N]
        """
        reconstructed_X = latent @ self.S.diag() @ self.V.T + self.mean
        return reconstructed_X


tsdf_encoder = PCA_encoder(Q=5)

In [None]:
tsdf_encoder.fit(train_tsdf_grid.view(train_tsdf_grid.size(0), -1))

# Train l2 norm
latent = tsdf_encoder.encode(train_tsdf_grid.view(train_tsdf_grid.size(0), -1))
train_reconstructed_tsdf_grid = tsdf_encoder.decode(
    latent).view(-1, *(antisotropic_res + 1))
print(
    f"Train L2 norm: {(train_tsdf_grid - train_reconstructed_tsdf_grid).pow(2).sum(dim=(1, 2, 3)).sqrt().mean()}")

# Test l2 norm
latent = tsdf_encoder.encode(test_tsdf_grid.view(test_tsdf_grid.size(0), -1))
test_reconstructed_tsdf_grid = tsdf_encoder.decode(
    latent).view(-1, *(antisotropic_res + 1))
print(
    f"Test L2 norm: {(test_tsdf_grid - test_reconstructed_tsdf_grid).pow(2).sum(dim=(1, 2, 3)).sqrt().mean()}")

## Step 7: Visualizing Reconstructed TSDF Meshes

In [None]:
selected_idx = torch.randint(0, test_meshes.__len__() - 1, (2, ))

print(f"Randomly selected index: [{selected_idx}]")
print("Visualizing voxel meshes from TSDF...")
tsdf_meshes = tsdf2meshes(
    test_tsdf_grid[selected_idx], TSDF_UNIT)
visualize_meshes(tsdf_meshes, "Original TSDF Meshes")

print("Visualizing voxel meshes from reconstructed TSDF...")
tsdf_meshes = tsdf2meshes(
    test_reconstructed_tsdf_grid[selected_idx], TSDF_UNIT)
visualize_meshes(tsdf_meshes, "Reconstructed TSDF Meshes")