In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU"
        " (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient "
        "(sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --data_path /ibex/project/c2273/PADS/3DCoMPaT \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)
misc.set_all_seeds(args.seed)
torch.backends.cudnn.benchmark = True
# --------------------

# Initialize and load autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = torch.compile(ae.eval().to(device), mode="max-autotune")


/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8
Set seed to 0


In [2]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader

class PairType():
    NO_ROT_PAIR = "rand_no_rot,rand_no_rot"
    PART_DROP = "part_drop,orig"

# Create your datasets
dataset_train = ShapeLatentDataset(args.data_path, split="train", shuffle_parts=True, filter_n_ids=2)
dataset_val = ShapeLatentDataset(args.data_path, split="test", shuffle_parts=False, filter_n_ids=2)

# Create the DataLoader using the sampler
data_loader_train = ComposedPairedShapesLoader(
    dataset_train,
    batch_size=4,
    pair_types_list=[PairType.NO_ROT_PAIR],
    num_workers=0,
    shuffle=True,
    use_distributed=False
) 

In [3]:
import torch
from torch import nn
from models.modules import (
    Attention,
    PreNorm,
)
from models.partqueries import PartEmbed


class PartQueriesGenerator(nn.Module):
    """
    Generating a set of part-aware latents
    from a set of part bounding boxes and part labels: "part queries".
    """

    def __init__(
        self,
        dim=512,
        latent_dim=128,
        max_parts=24,
        heads=8,
        in_heads=1,
        dim_head=64,
        depth=2,
        weight_tie_layers=False,
        use_attention_masking=True,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.max_parts = max_parts
        self.use_attention_masking = use_attention_masking

        # Part Embeddings
        self.part_embed = PartEmbed(dim)
        self.embed_proj = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

        # Input Cross-Attention Block
        self.in_encode = PreNorm(
            dim, Attention(dim, dim, heads=in_heads, dim_head=dim), context_dim=dim
        )

        # Compress latents to latent dimension
        self.compress_latents = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, latent_dim),
        )

    def forward(self, latents, part_bbs, part_labels, batch_mask):
        """
        :param latents:     B x 512 x 8
        :param part_bbs:    B x 24 x 4 x 3
        :param part_labels: B x 24
        :param batch_mask:  B x 24
        """
        # Embed part labels and bounding boxes
        part_embeds, labels_embed, bb_embeds = self.part_embed(
            part_bbs, part_labels, batch_mask
        )
        
        latents_kv = latents.transpose(1, 2).repeat(1, 3, 1) # B x 24 x 512
        
        # Concatenate part embeddings with latents
        part_embeds = self.embed_proj(part_embeds) # B x 24 x 1024

        # Encode part embeddings
        mask = batch_mask if self.use_attention_masking else None
        x = self.in_encode(part_embeds, context=latents_kv, mask=mask) # B x 24 x 1024
        part_latents = self.compress_latents(x)  # B x 24 x 128
        return part_latents, part_embeds

    @property
    def device(self):
        return next(self.parameters()).device


In [4]:
import torch.nn as nn


# Initialize the model
pvae = PartQueriesGenerator(
    dim=512,
    latent_dim=512,
    heads=8,
    dim_head=64,
    depth=4,
).to(device)
pvae = pvae.to(device)
pvae.train(True)


data_seen = False
for data_tuple in data_loader_train:
    data_seen = True
    
    """
    Compute a single forward pass of the model.
    """
    # Unpack the data tuple
    pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) = data_tuple
    device = pvae.device

    # Compute the mask from batch labels
    mask_a = (bb_l_a != -1).to(device)  # B x 24
    mask_b = (bb_l_b != -1).to(device)  # B x 24

    l_a, l_b = l_a.to(device), l_b.to(device)  # B x 8 x 512
    bb_a, bb_b = bb_a.to(device), bb_b.to(device)  # B x 24 x 4 x 3
    bb_l_a, bb_l_b = bb_l_a.to(device), bb_l_b.to(device)  # B x 24

    part_latents, part_embeds = pvae(
        latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a
    )
assert data_seen, "No data seen in the training loop."

In [5]:
import torch
from torch import nn

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


def assert_close(tensor1, tensor2, rtol=1e-5, atol=1e-5):
    assert torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol), \
        f"Tensors are not close: \n{tensor1}\n{tensor2}"


def test_latent_permutation_invariance(model):
    batch_size, num_parts = 1,  24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    original_output, _ = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(8)
    
    part_latents, _ = model(latents[:, :, perm], part_bbs, part_labels, batch_mask)
    
    try:
        assert_close(original_output, part_latents)
        print("Part latents invariant to latents permutation: PASSED")
    except AssertionError as e:
        print("Part latents invariant to latents permutation: FAILED")
        print(str(e))
        
        
def test_part_embeddings_equivariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(num_parts)
    permuted_part_bbs = part_bbs[:, perm, :, :]
    permuted_part_labels = part_labels[:, perm]
    
    permuted_part_latents, permuted_part_embeds = model(
        latents,
        permuted_part_bbs,
        permuted_part_labels,
        batch_mask
    )

    try:
        assert_close(part_embeds[:, perm, :], permuted_part_embeds)
        print("Part embeddings equivariant to parts permutation: PASSED")
    except AssertionError as e:
        print("Part embeddings equivariant to parts permutation: FAILED")
        print(str(e))


def test_part_latents_equivariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(num_parts)
    permuted_part_bbs = part_bbs[:, perm, :, :]
    permuted_part_labels = part_labels[:, perm]
    
    permuted_part_latents, permuted_part_embeds = model(
        latents,
        permuted_part_bbs,
        permuted_part_labels,
        batch_mask
    )

    try:
        assert_close(part_latents[:, perm, :], permuted_part_latents)
        print("Part latents equivariant to parts permutation: PASSED")
    except AssertionError as e:
        print("Part latents equivariant to parts permutation: FAILED")
        print(str(e))


# Initialize your model
model = PartQueriesGenerator(
    dim=512,
    latent_dim=128,
    max_parts=24,
    heads=8,
    in_heads=1,
    dim_head=64,
    depth=2
)

# Run the tests
test_latent_permutation_invariance(model)
print()
test_part_embeddings_equivariance(model)
print()
test_part_latents_equivariance(model)

Part latents invariant to latents permutation: PASSED

Part embeddings equivariant to parts permutation: PASSED

Part latents equivariant to parts permutation: PASSED
