In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Extracting features into HDF5 files for each split.
"""
import torch

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


def assert_close(tensor1, tensor2, rtol=1e-5, atol=1e-5):
    assert torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol), \
        f"Tensors are not close: \n{tensor1}\n{tensor2}"


def test_latent_permutation_invariance(model):
    batch_size, num_parts = 1,  24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    original_output, _ = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(8)
    
    part_latents, _ = model(latents[:, :, perm], part_bbs, part_labels, batch_mask)
    
    try:
        assert_close(original_output, part_latents)
        print("Part latents invariant to latents permutation: PASSED")
    except AssertionError as e:
        print("Part latents invariant to latents permutation: FAILED")
        print(str(e))
        
        
def test_part_embeddings_equivariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(num_parts)
    permuted_part_bbs = part_bbs[:, perm, :, :]
    permuted_part_labels = part_labels[:, perm]
    
    permuted_part_latents, permuted_part_embeds = model(
        latents,
        permuted_part_bbs,
        permuted_part_labels,
        batch_mask
    )

    try:
        assert_close(part_embeds[:, perm, :], permuted_part_embeds)
        print("Part embeddings equivariant to parts permutation: PASSED")
    except AssertionError as e:
        print("Part embeddings equivariant to parts permutation: FAILED")
        print(str(e))


def test_part_latents_equivariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    perm = torch.randperm(part_latents.shape[1])
    permuted_part_bbs = part_bbs[:, perm, :, :]
    permuted_part_labels = part_labels[:, perm]
    
    permuted_part_latents, permuted_part_embeds = model(
        latents,
        permuted_part_bbs,
        permuted_part_labels,
        batch_mask
    )

    try:
        assert_close(part_latents[:, perm, :], permuted_part_latents)
        print("Part latents equivariant to parts permutation: PASSED")
    except AssertionError as e:
        print("Part latents equivariant to parts permutation: FAILED")
        print(str(e))
        

def print_return_shapes(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    batch_mask = torch.ones(batch_size, num_parts).bool()
    
    part_latents, part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    print("Part latents shape:", part_latents.shape)
    print("Part embeddings shape:", part_embeds.shape)

def test_masked_elements_invariance(model):
    batch_size, num_parts = 1, 24
    
    latents = torch.rand(batch_size, 512, 8)
    part_bbs = torch.rand(batch_size, num_parts, 4, 3)
    part_labels = torch.randint(0, 10, (batch_size, num_parts), dtype=torch.long)
    
    # Create a random mask
    num_masked = torch.randint(1, num_parts, (1,)).item()  # Random number of masked elements
    batch_mask = torch.ones(batch_size, num_parts, dtype=torch.bool)
    masked_indices = torch.randperm(num_parts)[:num_masked]
    batch_mask[:, masked_indices] = False
    
    # Get original output
    original_part_latents, original_part_embeds = model(latents, part_bbs, part_labels, batch_mask)
    
    # Modify masked elements in part_bbs and part_labels
    modified_part_bbs = part_bbs.clone()
    modified_part_labels = part_labels.clone()
    modified_part_bbs[:, masked_indices, :, :] = torch.rand_like(modified_part_bbs[:, masked_indices, :, :])
    modified_part_labels[:, masked_indices] = torch.randint(0, 10, (batch_size, num_masked), dtype=torch.long)
    
    # Get output with modified masked elements
    modified_part_latents, modified_part_embeds = model(latents, modified_part_bbs, modified_part_labels, batch_mask)
    
    try:
        assert_close(original_part_latents, modified_part_latents)
        assert_close(original_part_embeds, modified_part_embeds)
        print(f"Masked elements invariance (with {num_masked} masked elements): PASSED")
    except AssertionError as e:
        print(f"Masked elements invariance (with {num_masked} masked elements): FAILED")
        print(str(e))

/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [3]:
import torch
from models.part_queries import PQM, PQMShallow, PartQueriesEncoder


# Initialize your model
part_latents_dim = 128
pqm = PQMShallow(
    dim=512,
    latent_dim=part_latents_dim,
    heads=8,
    dim_head=64,
    use_attention_masking=False,
)
model = PartQueriesEncoder(
    pqm=pqm,
    dim=part_latents_dim,
    input_length=24,
    output_length=8,
)

# Run the tests
print_return_shapes(model)
print()
test_latent_permutation_invariance(model)
print()
test_part_embeddings_equivariance(model)
print()
# test_part_latents_equivariance(model)
print()
test_masked_elements_invariance(model)

Part latents shape: torch.Size([1, 8, 128])
Part embeddings shape: torch.Size([1, 24, 128])

Part latents invariant to latents permutation: PASSED

Part embeddings equivariant to parts permutation: PASSED


Masked elements invariance (with 19 masked elements): PASSED
