In [None]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Extracting features into HDF5 files for each split.
"""
import torch

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


def assert_close(tensor1, tensor2, rtol=1e-6, atol=1e-6):
    assert torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol), \
        f"Tensors are not close: \n{tensor1}\n{tensor2}"


def test_latent_permutation_invariance(model):
    batch_size, num_parts = 1,  24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    original_output, _ = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(8)
    
    part_latents, _ = model(latents[:, :, perm], part_bbs, part_labels, batch_mask)
    
    try:
        assert_close(original_output, part_latents)
        print("Part latents invariant to latents permutation: PASSED")
    except AssertionError as e:
        print("Part latents invariant to latents permutation: FAILED")
        print(str(e))
        
        
def test_part_embeddings_equivariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(num_parts)
    permuted_part_bbs = part_bbs[:, perm, :, :]
    permuted_part_labels = part_labels[:, perm]
    
    permuted_part_latents, permuted_part_embeds = model(
        latents,
        permuted_part_bbs,
        permuted_part_labels,
        batch_mask
    )

    try:
        assert_close(part_embeds[:, perm, :], permuted_part_embeds)
        print("Part embeddings equivariant to parts permutation: PASSED")
    except AssertionError as e:
        print("Part embeddings equivariant to parts permutation: FAILED")
        print(str(e))


def test_part_latents_equivariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(part_latents.shape[1])
    permuted_part_bbs = part_bbs[:, perm, :, :]
    permuted_part_labels = part_labels[:, perm]
    
    permuted_part_latents, permuted_part_embeds = model(
        latents,
        permuted_part_bbs,
        permuted_part_labels,
        batch_mask
    )

    try:
        assert_close(part_latents[:, perm, :], permuted_part_latents)
        print("Part latents equivariant to parts permutation: PASSED")
    except AssertionError as e:
        print("Part latents equivariant to parts permutation: FAILED")
        print(str(e))
        

def print_return_shapes(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    print("Part latents shape:", part_latents.shape)
    print("Part embeddings shape:", part_embeds.shape)


def test_masked_elements_invariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    
    # Create a random mask
    num_masked = torch.randint(1, num_parts, (1,)).item()  # Random number of masked elements
    batch_mask = torch.ones(batch_size, num_parts, dtype=torch.bool)
    masked_indices = torch.randperm(num_parts)[:num_masked]
    batch_mask[:, masked_indices] = False
    
    # Get original output
    original_part_latents, original_part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    # Modify masked elements in part_bbs and part_labels
    modified_part_bbs = part_bbs.clone()
    modified_part_labels = part_labels.clone()
    modified_part_bbs[:, masked_indices, :, :] = torch.rand_like(modified_part_bbs[:, masked_indices, :, :])
    modified_part_labels[:, masked_indices] = torch.randint(0, 10, (batch_size, num_masked), dtype=torch.long)
    
    # Get output with modified masked elements
    modified_part_latents, modified_part_embeds = model(latents, modified_part_bbs, modified_part_labels, batch_mask)
    
    try:
        assert_close(original_part_latents, modified_part_latents)
        assert_close(original_part_embeds, modified_part_embeds)
        print(f"Masked elements invariance (with {num_masked} masked elements): PASSED")
    except AssertionError as e:
        print(f"Masked elements invariance (with {num_masked} masked elements): FAILED")
        print(str(e))

In [7]:
"""
Generating a set of part-aware latents from a set of part bounding boxes and part labels: "part queries".
"""
import torch
from torch import nn
from datasets.metadata import N_COMPAT_FINE_PARTS
from models.modules import (
    Attention,
    PointEmbed,
    PreNorm,
)


class PartEmbed(nn.Module):
    """
    Part-aware embeddings for part labels and bounding boxes.
    """

    def __init__(
        self,
        dim,
        max_parts=24,
        single_learnable_query=False,
    ):
        super().__init__()

        self.embed_dim = dim
        self.max_parts = max_parts
        self.single_learnable_query = single_learnable_query

        # Embedding layers
        self.centroid_embed = PointEmbed(dim=dim // 2)
        self.vector_embed = PointEmbed(dim=dim // 2)
        self.part_label_embed = nn.Embedding(N_COMPAT_FINE_PARTS, dim // 2)

        # Projections
        self.bb_embeds_proj = nn.Sequential(
            nn.Linear(4 * dim // 2, dim // 2), nn.ReLU(), nn.Linear(dim // 2, dim // 2)
        )
        self.final_embeds_proj = nn.Linear(dim, dim)

        # Learnable empty object queries
        if single_learnable_query:
            self.empty_object_query = nn.Parameter(torch.randn(dim))
        else:
            self.empty_queries = nn.Parameter(torch.randn(max_parts, dim))

    def forward(self, part_bbs, part_labels, batch_mask):
        B, N, _, _ = part_bbs.shape
        D = self.embed_dim // 2

        # Embed centroids and vectors coordinates
        # ================================================
        bb_centroids = part_bbs[:, :, 0, :]
        bb_centroid_embeds = self.centroid_embed(bb_centroids)  # B x 24 x D

        bb_vectors = part_bbs[:, :, 1:, :].reshape(-1, N * 3, 3)
        bb_vector_embeds = self.vector_embed(bb_vectors)  # B x 72 x D

        # Interleave the embeddings
        bb_vector_embeds_reshaped = bb_vector_embeds.reshape(
            bb_centroid_embeds.shape[0], N, 3, -1
        )
        bb_embeds = torch.empty(
            (bb_centroid_embeds.shape[0], N * 4, bb_centroid_embeds.shape[2]),
            device=bb_centroid_embeds.device,
        )  # B x 96 x D
        bb_embeds[:, 0::4, :] = bb_centroid_embeds
        bb_embeds[:, 1::4, :] = bb_vector_embeds_reshaped[:, :, 0, :]
        bb_embeds[:, 2::4, :] = bb_vector_embeds_reshaped[:, :, 1, :]
        bb_embeds[:, 3::4, :] = bb_vector_embeds_reshaped[:, :, 2, :]

        # Project the embeddings (vectors + centroids) to the same dimension
        bb_embeds = bb_embeds.view(B, N, 4 * D)  # B x 24 x (D * 4)
        bb_embeds = self.bb_embeds_proj(bb_embeds)  # B x 24 x D
        # ================================================

        # Embed part labels
        # ================================================
        labels_embed = self.part_label_embed(part_labels * batch_mask)  # B x 24 x D
        # ================================================

        # Final embeddings
        # ================================================
        part_embeds = torch.cat([labels_embed, bb_embeds], dim=-1)
        part_embeds = self.final_embeds_proj(part_embeds)  # B x 24 x D

        # Replace masked (empty) objects with the corresponding learnable empty object query
        empty_mask = ~batch_mask.bool()
        if self.single_learnable_query:
            part_embeds[empty_mask] = self.empty_object_query.expand(
                empty_mask.sum(), -1
            )
        else:
            for i in range(self.max_parts):
                part_embeds[:, i, :] = torch.where(
                    empty_mask[:, i].unsqueeze(1),
                    self.empty_queries[i].unsqueeze(0).expand(B, -1),
                    part_embeds[:, i, :],
                )
        # ================================================

        return part_embeds, labels_embed, bb_embeds


class PQM(nn.Module):
    """
    Generating a set of part-aware latents
    from a set of part bounding boxes and part labels: "part queries".
    """

    def __init__(
        self,
        dim=512,
        max_parts=24,
        heads=8,
        in_heads=8,
        dim_head=64,
    ):
        super().__init__()

        self.max_parts = max_parts

        # Part Embeddings
        self.part_embed = PartEmbed(dim, max_parts)
        self.embed_proj = nn.Linear(dim, dim)

        # Input Cross-Attention Block
        self.in_encode = PreNorm(
            dim, Attention(dim, dim, heads=in_heads, dim_head=dim), context_dim=dim
        )

    def forward(self, latents, part_bbs, part_labels, batch_mask):
        """
        :param latents:     B x 512 x 8
        :param part_bbs:    B x 24 x 4 x 3
        :param part_labels: B x 24
        :param batch_mask:  B x 24
        """
        # Embed part labels and bounding boxes
        part_embeds, labels_embed, bb_embeds = self.part_embed(
            part_bbs, part_labels, batch_mask
        )

        # Apply learnable projection instead of repeat
        latents_kv = latents.transpose(1, 2) # B x 8 x 512

        # Concatenate part embeddings with latents
        part_embeds = self.embed_proj(part_embeds)  # B x 24 x 512

        # Encode part embeddings
        part_queries = self.in_encode(part_embeds, context=latents_kv)  # B x 24 x 512
        return part_queries, part_embeds

    @property
    def device(self):
        return next(self.parameters()).device


In [None]:
import torch

# Initialize your model
part_latents_dim = 128
pqm = PQM(
    dim=512,
    heads=8,
    dim_head=64,
)

# Run the tests
print_return_shapes(pqm)
print()
test_latent_permutation_invariance(pqm)
print()
test_part_embeddings_equivariance(pqm)
print()
test_part_latents_equivariance(pqm)
print()
test_masked_elements_invariance(pqm)

## Testing

In [None]:
import models.diffusion as dm
model = dm.kl_d512_m512_l8_d24_pq()

part_queries, part_embeds = pqm(
    torch.rand(1, 512, 8),
    torch.rand(1, 24, 4, 3),
    torch.randint(0, 10, (1, 24), dtype=torch.long),
    torch.ones(1, 24).bool(),
)

print(part_queries.shape)

in_noised_latent = torch.zeros(1, 512, 8)
in_timestep = torch.Tensor([5]).long()
model.model(x=in_noised_latent, t=in_timestep, cond=part_queries, cond_mask=~torch.ones(1, 24).bool())

In [10]:
import models.diffusion as dm
model = dm.kl_d512_m512_l8_d24_pq()