In [None]:
%cd /ibex/user/slimhy/PADS/code/
"""
Invert a set of input shapes.
"""
import argparse
import torch
import trimesh
import os
import numpy as np

import util.misc as misc
import util.s2vs as s2vs

from datasets.sampling import sample_surface_tpp
from eval.metrics import chamfer_distance, f_score


def get_args_parser():
    """
    Parsing input arguments.
    """
    parser = argparse.ArgumentParser("Extracting Latents", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        type=int,
        help="Batch size per GPU (this is the grid dimension, to be cubed)",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae_latent_dim",
        type=int,
        default=512 * 8,
        help="AE latent dimension",
    )
    parser.add_argument("--ae_pth", required=True, help="Autoencoder checkpoint")

    # CUDA parameters
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    # Datasplit parameters
    parser.add_argument("--obj_dir", type=str, help="Path to object directory")
    parser.add_argument(
        "--opt_point_ratio",
        type=int,
        help="Ratio of points to sample from the object surface (2^N)",
        default=1,
    )

    # Hyperparameters
    parser.add_argument(
        "--lr",
        type=float,
        help="Learning rate",
    )
    parser.add_argument(
        "--max_iter",
        type=int,
        help="Maximum number of iterations for each stage",
    )

    # Logging
    parser.add_argument(
        "--log_dir",
        type=str,
        help="Logging directory",
    )
    parser.add_argument(
        "--config_name",
        type=str,
        help="Name of the active configuration",
    )

    return parser


def initialize_latents(args, ae, root_names):
    """
    Get the initial latents for the optimization.
    Also load the points and occs for the input shapes.
    """
    device = ae.device
    
    def load_batch(root_name, suffix):
        """
        Load a batch of points and occs for a root name and suffix.
        """
        stacked_data = []
        for root_name in root_names:
            path = os.path.join(args.obj_dir, root_name + suffix + ".npy")
            stacked_data += [np.load(path)]
        return torch.tensor(np.array(stacked_data), device=device).squeeze(1)

    # Make a batch from the points in CPU
    points = load_batch(root_names, "_surface_points")
    near_points = load_batch(root_names, "_near_surface_points")
    occs = load_batch(root_names, "_occs")

    # Encode the points
    init_latents = s2vs.encode_pc(ae, points, fps_sampling=True).detach()

    return init_latents, points, near_points, occs


def get_metrics(ae, pc_original, latents, batch_size):
    """
    Get the metrics for an input latent.
    """
    rec_mesh = s2vs.decode_latents(
        ae=ae,
        latent=misc.d_GPU(latents),
        grid_density=512,
        batch_size=batch_size,
    )
    pc_rec = sample_surface_tpp(rec_mesh, pc_original.shape[0])
    chamfer = chamfer_distance(pc_original, pc_rec, backend="kaolin")
    f_sc = f_score(gt=pc_original, pred=pc_rec)
    return rec_mesh, chamfer, f_sc

In [None]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae_latent_dim 4096 \
    --batch_size 16 \
    --num_workers 8 \
    --device cuda
    --obj_dir /ibex/project/c2273/3DCoMPaT/packaged \
    --lr 1e-2 \
    --max_iter 100
    """

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

In [None]:
from datasets.metadata import class_to_hex


def get_root_name(f, no_rot):
    """
    Get the root name of a file.
    """
    f = f.replace("_near_surface", "")
    f = f.replace("_surface", "")
    r = f.split("_")[-1]
    if no_rot:
        if "_all_aug_" in f or "_rand_rot_" in f:
            return None
    return f.replace("_" + r, "")


def get_all_files(path, class_name=None, no_rot=False):
    """
    List and filter augmented shapes from the input directory.
    """
    root_names = os.listdir(path)
    root_names = [get_root_name(f, no_rot=no_rot) for f in root_names]
    root_names = [f for f in root_names if f is not None]
    
    if class_name is not None:
        class_hex = class_to_hex(class_name) + "_"
        root_names = [f for f in root_names if f.startswith(class_hex)]

    # Batch the root names
    return [root_names[i : i + args.batch_size] for i in range(0, len(root_names), args.batch_size)]


def show_orig_mesh(root_name):
    """
    Show the original mesh.
    """
    from datasets.metadata import flip_front_to_back
    flip_transform_4x4 = np.eye(4)
    flip_transform_4x4[:3, :3] = flip_front_to_back
    path = os.path.join("/ibex/project/c2273/3DCoMPaT/obj_manifold", root_name[:6] + ".obj")
    mesh = trimesh.load(path)
    return mesh.apply_transform(flip_transform_4x4)
    

In [None]:
from schedulefree import AdamWScheduleFree
import torch.nn.functional as F
from models.s2vs import fps_subsample



def resample_points(points, occs, opt_point_ratio):
    """
    Iterator to resample points and occs to have n_points at every iteration.
    """
    ratio = 1. / opt_point_ratio

    while True:
        B, N, D = points.shape
        points, idx = fps_subsample(points, ratio, return_idx=True)
        occs = occs.flatten()
        occs = occs[idx].view(B, -1)
        yield points, occs


def optimize_latents(
    ae,
    near_surface_points,
    occs,
    init_latents,
    *,
    opt_point_ratio=1,
    accumulation_steps=1,
    max_iter=100,
    optimizer=AdamWScheduleFree,
    lr=1e-3,
):
    """
    Optimize input latent codes w.r.t. a single object with optional gradient accumulation.
    """
    latents = init_latents.clone().detach().to(ae.device).requires_grad_(True)
    optimizer = optimizer([latents], lr=lr)

    # Main optimization loop
    iter_count = 0
    
    # Initialize the shape iterator
    shape_it = resample_points(points, occs, opt_point_ratio)

    while iter_count < max_iter:
        optimizer.zero_grad()
        accumulated_loss = 0

        for k in range(accumulation_steps):
            near_surface_points, occs = next(shape_it)
            near_surface_points = near_surface_points.to(ae.device)
            logits = s2vs.query_latents(
                ae, latents, near_surface_points, use_graph=True
            ).squeeze(-1)
            occs = occs.float().to(ae.device)

            loss = F.binary_cross_entropy_with_logits(logits, occs).mean()
            accumulated_loss += loss.item()

            # Accumulate gradients without stepping the optimizer
            loss.backward()

        # Step the optimizer
        optimizer.step()

        iter_count += 1

    return latents.detach().cpu()


# Set device and seed
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True

# Instantiate autoencoder
ae = s2vs.load_model(args.ae, args.ae_pth, device, torch_compile=True)
ae = ae.eval()

all_metrics = {}
root_names = get_all_files(args.obj_dir, class_name="table")
for root_batch in root_names:
    # Initialize the latents, define the ground truth mesh
    init_latents, points, near_points, occs = initialize_latents(args, ae, root_batch)
    
    # Save the initial latents
    init_latents_save = init_latents.clone().detach()

    # Optimize the latents
    optimized_latents = optimize_latents(
        ae=ae,
        near_surface_points=near_points,
        occs=occs,
        init_latents=init_latents,
        opt_point_ratio=4,
        accumulation_steps=1,
        max_iter=args.max_iter,
        optimizer=AdamWScheduleFree,
        lr=1e-3,
    )

    break

In [None]:
optimized_metrics = []
for k in range(optimized_latents.shape[0]):
    optimized_metrics.append(get_metrics(
        ae=ae,
        pc_original=points[k].unsqueeze(0),
        latents=optimized_latents[k].unsqueeze(0),
        batch_size=128**3
    ))

In [None]:
init_metrics = []
for k in range(optimized_latents.shape[0]):
    init_metrics.append(get_metrics(
        ae=ae,
        pc_original=points[k].unsqueeze(0),
        latents=init_latents[k].unsqueeze(0),
        batch_size=128**3
    ))

In [None]:
from util.mesh import show_side_by_side

In [None]:
def show_all(k):
    return show_side_by_side(optimized_metrics[k][0], init_metrics[k][0], show_orig_mesh(root_batch[k]))

In [None]:
show_all(2)

In [None]:
show_all(4)