In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU"
        " (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient "
        "(sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --data_path /ibex/project/c2273/PADS/3DCoMPaT \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True
# --------------------

# Initialize and load autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = torch.compile(ae.eval().to(device), mode="max-autotune")


/ibex/user/slimhy/PADS/code
Set seed to 0


In [2]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader

class PairType():
    NO_ROT_PAIR = "rand_no_rot,rand_no_rot"
    PART_DROP = "part_drop,orig"

# Create your datasets
dataset_train = ShapeLatentDataset(args.data_path, split="train", shuffle_parts=True, filter_n_ids=2)
dataset_val = ShapeLatentDataset(args.data_path, split="test", shuffle_parts=False, filter_n_ids=2)

# Create the DataLoader using the sampler
data_loader_train = ComposedPairedShapesLoader(
    dataset_train,
    batch_size=4,
    pair_types_list=[PairType.NO_ROT_PAIR],
    num_workers=0,
    shuffle=True,
    use_distributed=False
) 

In [3]:
import torch
from torch import nn
from models.modules import (
    StackedAttentionBlocks,
)

class PartAwareEncoder(nn.Module):
    """
    Generating a set of part-aware latents from a set of part bounding boxes and part labels: "part queries".
    """

    def __init__(
        self,
        dim=512,
        latent_dim=128,
        max_parts=24,
        heads=8,
        in_heads=1,
        dim_head=64,
        depth=2,
        weight_tie_layers=False,
        use_attention_masking=False,
    ):
        super().__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        self.max_parts = max_parts
        self.use_attention_masking = use_attention_masking

        # Part Embeddings
        self.part_embed = PartEmbed(dim)

        # Input Cross-Attention Block
        self.in_encode = PreNorm(
            dim, Attention(dim, dim, heads=in_heads, dim_head=dim), context_dim=dim
        )
        self.in_proj = nn.Linear(8, 24)

        # Stacked Attention Layers
        self.encoder_layers = StackedAttentionBlocks(
            dim, depth, heads, dim_head, weight_tie_layers
        )

        # Compress latents
        self.compress_latents = nn.Sequential(
            nn.Linear(dim, dim),
            GEGLU(),
            nn.Linear(dim // 2, latent_dim),
        )

    def forward(self, latents, part_bbs, part_labels, batch_mask):
        # Embed part labels and bounding boxes
        part_embeds, labels_embed, bb_embeds = self.part_embed(
            part_bbs, part_labels, batch_mask
        )
        latents_in = self.in_proj(latents).transpose(1, 2)

        # Encode part embeddings
        mask = batch_mask if self.use_attention_masking else None
        x = self.in_encode(part_embeds, context=latents_in, mask=mask)
        x = self.encoder_layers(x)
        part_latents = self.compress_latents(x)  # B x 128 x 24

        return part_latents, part_embeds


In [8]:
import torch.nn as nn
from torch.nn import functional as F
from losses.partvae import KLRecLoss, RecLoss, ScaleInvariantLoss, PartDropLoss
from schedulefree import AdamWScheduleFree


def get_losses():
    """
    Instantiate the losses.
    """
    return (
        RecLoss(),
        KLRecLoss(),
        ScaleInvariantLoss(),
        PartDropLoss(),
    )


def forward_pass(
    pvae,
    data_tuple,
    rec_loss,
    kl_rec_loss,
    scale_inv_loss,
    part_drop_loss,
    pair_types,
):
    """
    Compute a single forward pass of the model.
    """
    # Unpack the data tuple
    pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) = data_tuple
    device = pvae.device

    # Compute the mask from batch labels
    mask_a = (bb_l_a != -1).to(device)  # B x 24
    mask_b = (bb_l_b != -1).to(device)  # B x 24

    l_a, l_b = l_a.to(device), l_b.to(device)  # B x 8 x 512
    # l_a = l_a.transpose(1, 2)  # B x 512 x 8
    # l_b = l_b.transpose(1, 2)  # B x 512 x 8
    bb_a, bb_b = bb_a.to(device), bb_b.to(device)  # B x 24 x 4 x 3
    bb_l_a, bb_l_b = bb_l_a.to(device), bb_l_b.to(device)  # B x 24

    # Optionally compute the KL Reg loss
    if pvae.is_vae:
        logits_a, kl_a, part_latents_a = pvae(
            latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a
        )
        logits_b, kl_b, part_latents_b = pvae(
            latents=l_b, part_bbs=bb_b, part_labels=bb_l_b, batch_mask=mask_b
        )

        # KL Reg loss
        kl_reg = kl_rec_loss(kl_a, mask=mask_a) + kl_rec_loss(kl_b, mask=mask_b)
        kl_reg /= 2.0
    else:
        logits_a, part_latents_a = pvae(
            latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a
        )
        logits_b, part_latents_b = pvae(
            latents=l_b, part_bbs=bb_b, part_labels=bb_l_b, batch_mask=mask_b
        )
        kl_reg = torch.tensor(0.0).to(device)

    # L2 loss
    # torch.Size([2, 512, 8]) torch.Size([2, 512, 8])
    # rec_loss = F.mse_loss(logits_a, l_a)#  + F.mse_loss(logits_b, l_b)
    rec_loss = rec_loss(logits_a, l_a, transpose=True) + rec_loss(logits_b, l_b, transpose=True)
    rec_loss /= 2.0

    # if pair_types == PairType.NO_ROT_PAIR:
    #     inv_loss = scale_inv_loss(part_latents_a, part_latents_b, mask_a)
    # elif pair_types == PairType.PART_DROP:
    #     inv_loss = part_drop_loss(
    #         part_latents_a, part_latents_b, bb_a, bb_b, mask_a, mask_b
    #     )
    #     
    # # Add part latents magnitude loss
    # inv_loss += F.mse_loss(part_latents_b, torch.zeros_like(part_latents_b))
    # inv_loss += F.mse_loss(part_latents_a, torch.zeros_like(part_latents_a))
    inv_loss = torch.tensor(0.0).to(device)

    return {
        "kl_reg": kl_reg,
        "rec_loss": rec_loss,
        "inv_loss": inv_loss,
        "last_sample": (logits_a, l_a, bb_a, bb_l_a, mask_a),
    }


# Initialize the model
pvae = PartAwareAE(
    dim=512,
    latent_dim=512,
    heads=8,
    dim_head=64,
    depth=4,
).to(device)
pvae = pvae.to(device)
# pvae = torch.compile(pvae, mode="max-autotune")
pvae.train(True)

# Initialize the optimizer
optimizer = AdamWScheduleFree(
    pvae.parameters(), lr=1e-3, weight_decay=1e-5
)
optimizer.zero_grad()

rec_loss, kl_loss, scale_inv_loss, part_drop_loss = get_losses()

metric_logger = misc.MetricLogger(delimiter="  ")

for epoch in metric_logger.log_every(
    list(range(1000)), 100
):
    data_seen = False
    for data_tuple in data_loader_train:
        data_seen = True
        # Computing loss
        loss = forward_pass(
            pvae=pvae,
            data_tuple=data_tuple,
            rec_loss=rec_loss,
            kl_rec_loss=kl_loss,
            scale_inv_loss=scale_inv_loss,
            part_drop_loss=part_drop_loss,
            pair_types=PairType.PART_DROP,
        )
        total_loss = (
            0.005 * loss["kl_reg"]
            + 1. * loss["rec_loss"]
            + 0.2 * loss["inv_loss"]
        )
        data_seen = True
        
        accum_iter = 1

        # Backward pass
        total_loss /= accum_iter
        total_loss.backward()
        
        # Clip the gradients
        nn.utils.clip_grad_norm_(pvae.parameters(), 5.)

        optimizer.step()
        optimizer.zero_grad()

        # Log the losses
        loss_update = {
            "train_loss": float(total_loss.item()),
            "kl_reg": float(loss["kl_reg"].item()),
            "rec_loss": float(loss["rec_loss"].item()),
            "inv_loss": float(loss["inv_loss"].item()),
        }
        metric_logger.update(**loss_update)
    assert data_seen, "No data seen in the training loop."

  [   0/1000]  eta: 0:00:47  train_loss: 0.8609 (0.8609)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.8609 (0.8609)  inv_loss: 0.0000 (0.0000)  time: 0.0475  data: 0.0000  max mem: 4823
  [ 100/1000]  eta: 0:00:35  train_loss: 0.0051 (0.0923)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.0051 (0.0923)  inv_loss: 0.0000 (0.0000)  time: 0.0394  data: 0.0000  max mem: 4823
  [ 200/1000]  eta: 0:00:31  train_loss: 0.0024 (0.0484)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.0024 (0.0484)  inv_loss: 0.0000 (0.0000)  time: 0.0396  data: 0.0000  max mem: 4823
  [ 300/1000]  eta: 0:00:27  train_loss: 0.0017 (0.0331)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.0017 (0.0331)  inv_loss: 0.0000 (0.0000)  time: 0.0401  data: 0.0000  max mem: 4823
  [ 400/1000]  eta: 0:00:23  train_loss: 0.0010 (0.0252)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.0010 (0.0252)  inv_loss: 0.0000 (0.0000)  time: 0.0401  data: 0.0000  max mem: 4823
  [ 500/1000]  eta: 0:00:19  train_loss: 0.0005 (0.0203)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.0005 (0.02

In [5]:
import numpy as np
import trimesh
import k3d
from util import s2vs


def get_args_parser():
    parser = argparse.ArgumentParser("Autoencoder Visualization", add_help=False)
    parser.add_argument("--ae_pth", required=True, help="Autoencoder checkpoint")
    parser.add_argument("--ae", type=str, default="kl_d512_m512_l8", help="Name of autoencoder")
    parser.add_argument("--ae-latent-dim", type=int, default=4096, help="AE latent dimension")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size")
    parser.add_argument("--num_workers", default=8, type=int, help="Number of workers for data loading")
    parser.add_argument("--device", default="cuda", help="Device to use for computation")
    parser.add_argument("--seed", default=0, type=int, help="Random seed")
    parser.add_argument("--pin_mem", action="store_true", help="Pin CPU memory in DataLoader")
    return parser

def obb_to_corners(box_array):
    center, right, up, forward = [np.array(box_array[i]) for i in range(4)]
    corners = np.array([
        [-1, -1, -1], [ 1, -1, -1], [ 1,  1, -1], [-1,  1, -1],
        [-1, -1,  1], [ 1, -1,  1], [ 1,  1,  1], [-1,  1,  1]
    ])
    transform = np.column_stack((right, up, forward))
    return np.dot(corners, transform.T) + center

def create_trimesh_boxes(bounding_boxes):
    return [
        trimesh.util.concatenate([
            trimesh.primitives.Sphere(radius=0.01).apply_translation(corner)
            for corner in obb_to_corners(box)
        ])
        for box in bounding_boxes if not np.all(box == 0)
    ]

def load_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    return model

def visualize_decoded_mesh(decoded_mesh, bounding_boxes):
    trimesh_boxes = create_trimesh_boxes(bounding_boxes)
    col_map = k3d.helpers.map_colors(np.arange(len(trimesh_boxes)), k3d.colormaps.basic_color_maps.Rainbow)
    
    plot = k3d.plot()
    plot += k3d.mesh(decoded_mesh.vertices, decoded_mesh.faces, color=0x0000ff, opacity=0.5)
    for k, bb_mesh in enumerate(trimesh_boxes):
        plot += k3d.mesh(bb_mesh.vertices, bb_mesh.faces, color=int(col_map[k]), opacity=0.5)
    return plot

def show_mesh(ae, logits_a, bb_a):
    assert logits_a.shape[0] == 1, "Can only visualize a single shape at a time"
    mesh = s2vs.decode_latents(ae, logits_a.to(ae.device), grid_density=128, batch_size=64**3)
    bounding_boxes = np.array(bb_a[0].cpu())
    plot = visualize_decoded_mesh(mesh.trimesh_mesh, bounding_boxes)
    return plot

In [6]:
with torch.inference_mode():
    logits_a, l_a, bb_a, bb_l_a, mask_a = loss["last_sample"]
    logits_a, _ = pvae(
        latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a
    )

In [7]:
show_mesh(ae, logits_a[0].unsqueeze(0), bb_a)



Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper=1.0, axes_helper_colors=[16711680, 65280, 255], background…