In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Extracting features into HDF5 files for each split.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.points.pointbert_utils import fps
from models.points.encoders import pointbert_g512_d12_compat
from datasets.metadata import N_COMPAT_CLASSES, N_COMPAT_FINE_PARTS

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


def gen_mock_data(
    batch_size=4,
    num_parts=24,
    num_latents=512,
    num_classes=250,
    get_part_points=False,
    num_points=4096,
    device="cuda",
):
    latents = torch.rand(batch_size, num_latents, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(
        0, num_classes, (batch_size, num_parts), dtype=torch.long
    )
    shape_cls = torch.randint(0, 10, (batch_size, 1), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()

    if get_part_points:
        part_points = torch.rand(batch_size, num_parts, num_points, 3)
        data_tup = latents, part_bbs, part_labels, part_points, shape_cls, batch_mask
    else:
        data_tup = latents, part_bbs, part_labels, shape_cls, batch_mask
    return tuple(map(lambda x: x.to(device), data_tup))


def is_close(tensor1, tensor2, rtol=1e-5, atol=1e-5):
    return torch.allclose(
        tensor1, tensor2, rtol=rtol, atol=atol
    )
    

latents, part_bbs, part_labels, part_points, shape_cls, batch_mask = gen_mock_data(get_part_points=True)


/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
"""
Defining a part-neural asset model.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.points.pointbert_utils import fps


class WeightedAggregation(nn.Module):
    def __init__(self, feature_dim=128):
        super().__init__()
        self.weight_layer = nn.Linear(feature_dim, 1)

    def forward(self, x):
        # Compute weights for each feature vector
        weights = self.weight_layer(x)  # B, N_p, 512, 1

        # Apply softmax to get a probability distribution
        weights = F.softmax(weights, dim=-1)  # B, N_p, 512, 1

        # Apply weights to input vectors and sum
        weighted_sum = torch.mean(weights * x, dim=2)  # B, N_p, 128

        return weighted_sum


class BoundingBoxTokenizer(nn.Module):
    def __init__(
        self, bb_input_dim=12, mlp_hidden_dim=64, mlp_output_dim=32, mlp_depth=3
    ):
        super().__init__()
        self.mlp = self._build_mlp(
            bb_input_dim, mlp_hidden_dim, mlp_output_dim, mlp_depth
        )
        self.output_dim = mlp_output_dim

    def _build_mlp(self, input_dim, hidden_dim, output_dim, depth):
        layers = []

        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())

        # Hidden layers
        for _ in range(depth - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())

        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))

        return nn.Sequential(*layers)

    def forward(self, part_bbs):
        B, P, _, _ = part_bbs.shape  # B, P, 4, 3

        # Flatten the bounding boxes
        flattened_bbs = part_bbs.view(B, P, -1)  # B, P, 12

        # Process through MLP
        pose_tokens = self.mlp(flattened_bbs)  # B, P, mlp_output_dim

        return pose_tokens


class PartTokenizer(nn.Module):
    def __init__(
        self,
        pc_encoder,
        bb_input_dim=12,
        bb_hidden_dim=64,
        bb_output_dim=32,
        bb_mlp_depth=3,
        visual_feature_dim=128,
        out_dim=512,
    ):
        super().__init__()
        self.pc_encoder = pc_encoder
        self.bb_tokenizer = BoundingBoxTokenizer(
            bb_input_dim=bb_input_dim,
            mlp_hidden_dim=bb_hidden_dim,
            mlp_output_dim=bb_output_dim,
            mlp_depth=bb_mlp_depth,
        )
        self.visual_aggregator = WeightedAggregation(feature_dim=visual_feature_dim)
        self.output_dim = bb_output_dim + visual_feature_dim
        self.out_proj = nn.Linear(self.output_dim, out_dim)

    def forward(
        self,
        part_bbs,
        part_points,
        batch_mask,  # 1 if part is invalid, 0 otherwise
        shape_cls=None,
        num_samples=512,
        deterministic=False,
    ):
        batch_mask = (
            ~batch_mask
        )  # Invert mask to make 0 if part is invalid, 1 otherwise
        B, P = batch_mask.shape

        # Apply mask to bounding boxes and get tokens
        bb_tokens = self.bb_tokenizer(part_bbs)  # B, P, bb_output_dim
        bb_tokens = bb_tokens * batch_mask.unsqueeze(-1)  # Zero out invalid parts

        # Process point clouds
        N_s = num_samples
        resampled = self.subsample_parts(
            part_points, num_samples=N_s, deterministic=deterministic
        )
        _, top_fts = self.pc_encoder(resampled, cls_label=shape_cls)
        part_fts = top_fts.reshape(B, P, N_s, -1)

        # Zero out invalid parts in point features
        print(top_fts.shape)
        part_fts = part_fts * batch_mask.unsqueeze(-1).unsqueeze(-1)
        print(part_fts.shape)

        # Apply weighted aggregation to get visual tokens
        visual_tokens = self.visual_aggregator(part_fts)  # B, P, visual_feature_dim
        visual_tokens = visual_tokens * batch_mask.unsqueeze(
            -1
        )  # Zero out invalid parts

        # Concatenate bounding box tokens and visual tokens
        combined_tokens = torch.cat(
            [bb_tokens, visual_tokens], dim=-1
        )  # B, P, bb_output_dim + visual_feature_dim

        return self.out_proj(combined_tokens), (bb_tokens, visual_tokens)

    def load_encoder_checkpoint(self, ckpt_path):
        ckpt = torch.load(ckpt_path)
        self.pc_encoder.load_state_dict(ckpt)

    @staticmethod
    @torch.inference_mode()
    def subsample_parts(part_points, num_samples=512, deterministic=False):
        N_s = num_samples
        B, P, N, C = part_points.shape
        resampled = torch.zeros((B, P, N_s, C)).to(part_points.device)
        if deterministic:
            for i, p in enumerate(part_points):
                resampled[i] = p[:, :N_s]
        else:
            for i, p in enumerate(part_points):
                resampled[i] = fps(p, N_s)
        return resampled.view(B, P * N_s, C)


In [3]:
pc_encoder = pointbert_g512_d12_compat()
part_tokenizer = PartTokenizer(
    pc_encoder=pc_encoder,
    bb_input_dim=12,
    bb_hidden_dim=64,
    bb_output_dim=32,
    bb_mlp_depth=3,
    visual_feature_dim=128,
    out_dim=512,
).cuda()

ckpt_path = "/ibex/user/slimhy/PADS/code/ckpt/pointbert.pth"
part_tokenizer.load_encoder_checkpoint(ckpt_path)

part_tokens, _ = part_tokenizer(part_bbs, part_points, batch_mask=batch_mask, shape_cls=shape_cls)
print(part_tokens.shape)  # B, P, 160 (32 + 128)

torch.Size([4, 12288, 128])
torch.Size([4, 24, 512, 128])
torch.Size([4, 24, 512])


In [5]:
import models.diffusion as dm


def test_pc_encoder_permutation_equivariance():
    # Initialize encoder
    pc_encoder = dm.kl_d512_m512_l8_d24_passets().pqe.pc_encoder.cuda()
    
    # Create single shape point cloud
    points = torch.randn(1, 1024, 3).cuda()
    shape_cls = torch.zeros(1).long().cuda()
    
    # Get original output
    local_feat, global_feat = pc_encoder(points, cls_label=shape_cls)
    
    # Shuffle points with single permutation
    perm = torch.randperm(1024)
    points_shuffled = points[:, perm]
    
    # Get output with shuffled points
    local_feat_shuffled, global_feat_shuffled = pc_encoder(points_shuffled, cls_label=shape_cls)
    
    # Check local features are permuted the same way
    if not is_close(local_feat[:, perm], local_feat_shuffled):
        print("Local features are not equivariant")
    
    # Check global features don't change
    if not is_close(global_feat, global_feat_shuffled):
        print("Global features are not invariant")


def test_bb_tokenizer_permutation_equivariance():
    # Initialize the BoundingBoxTokenizer
    bb_tokenizer = BoundingBoxTokenizer(
        bb_input_dim=12,
        mlp_hidden_dim=64,
        mlp_output_dim=32,
        mlp_depth=3
    ).cuda()

    # Create mock data
    batch_size, num_parts = 2, 24
    part_bbs = torch.randn(batch_size, num_parts, 4, 3).cuda()

    # Get original output
    original_output = bb_tokenizer(part_bbs)

    # Shuffle parts
    perm = torch.randperm(num_parts)
    shuffled_part_bbs = part_bbs[:, perm]

    # Get output with shuffled parts
    shuffled_output = bb_tokenizer(shuffled_part_bbs)

    # Check if outputs are equivariant (i.e., shuffled in the same way)
    assert is_close(original_output[:, perm], shuffled_output), "Outputs are not equivariant to parts shuffling."
    print("BoundingBoxTokenizer is permutation equivariant w.r.t. parts shuffling.")


def test_part_tokenizer_permutation_equivariance():
    # Initialize the PartTokenizer
    part_tokenizer = dm.kl_d512_m512_l8_d24_passets().pqe.cuda()

    # Create mock data
    latents, part_bbs, part_labels, part_points, shape_cls, batch_mask = gen_mock_data(get_part_points=True)
    num_parts = part_bbs.shape[1]

    # Get original output
    original_output, _ = part_tokenizer(part_bbs, part_points, batch_mask=batch_mask, shape_cls=shape_cls)

    # Shuffle parts
    perm = torch.randperm(num_parts)
    shuffled_part_bbs = part_bbs[:, perm]
    shuffled_part_points = part_points[:, perm]

    # Get output with shuffled parts
    shuffled_output, _ = part_tokenizer(shuffled_part_bbs, shuffled_part_points, batch_mask=batch_mask, shape_cls=shape_cls)

    # Check if outputs are equivariant (i.e., shuffled in the same way)
    if not is_close(original_output[:, perm], shuffled_output):
        print("Outputs are not equivariant to parts shuffling.")
    else:
        print("PartTokenizer is permutation equivariant w.r.t. parts shuffling.")


test_pc_encoder_permutation_equivariance()
test_bb_tokenizer_permutation_equivariance()
test_part_tokenizer_permutation_equivariance()

Local features are not equivariant
Global features are not invariant
BoundingBoxTokenizer is permutation equivariant w.r.t. parts shuffling.
torch.Size([4, 24, 512, 128])
torch.Size([4, 24, 512, 128])
torch.Size([4, 24, 512, 128])
torch.Size([4, 24, 512, 128])
PartTokenizer is permutation equivariant w.r.t. parts shuffling.
