In [1]:
%cd /ibex/user/slimhy/PADS/code
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
import torch
from models.partvae import PartAwareVAE, PartAwareAE

BATCH_SIZE = 1
LATENT_DIM = 512
N_PARTS = 24
device = torch.device("cuda")

def sample_mask():
    # Set first n_to_mask parts to be masked in a mask for each sample
    n_to_mask = torch.randint(0, N_PARTS, (BATCH_SIZE,))
    mask = torch.zeros(BATCH_SIZE, N_PARTS).bool()
    for i in range(BATCH_SIZE):
        mask[i, :n_to_mask[i]] = True
    return mask


def init_model(use_vae):
    if use_vae:
        model_class = PartAwareVAE
    else:
        model_class = PartAwareAE
    pvae = model_class(
        dim=512,
        latent_dim=128,
        heads=8,
        dim_head=64,
        depth=2,
    ).to(device)
    pvae = pvae.to(device)
    pvae = pvae.eval()
    return pvae


def create_sample_data():
    # Create sample data
    l = torch.randn(BATCH_SIZE, LATENT_DIM, 8)
    bb = torch.randn(BATCH_SIZE, N_PARTS, 4, 3)
    bb_l = torch.randint(0, N_PARTS, (BATCH_SIZE, N_PARTS)).long()
    
    return l, bb, bb_l, sample_mask()


def clone_tuple(sample_data):
    return tuple([x.clone() for x in sample_data])


@torch.no_grad()
def test_deterministic(pvae_model, sample_data):
    l, bb, bb_l, mask = sample_data
    
    # Forward pass with original data
    logits1, kl1, part_latents1 = pvae_model(latents=l, part_bbs=bb, part_labels=bb_l, batch_mask=mask, deterministic=True)

    logits2, kl2, part_latents2 = pvae_model(latents=l, part_bbs=bb, part_labels=bb_l, batch_mask=mask, deterministic=True)
    
    print("\nDeterministic inference test")
    print("=="*20)
    print(f"- Logits are equal: {torch.allclose(logits1, logits2)}")
    

@torch.no_grad()
def test_masked_entries(pvae_model, sample_data):
    l, bb, bb_l, mask = sample_data
    
    # Forward pass with original data
    logits1, kl1, part_latents1 = pvae_model(latents=l, part_bbs=bb, part_labels=bb_l, batch_mask=mask, deterministic=True)
    
    # Resample the mask
    new_mask = sample_mask().to(device)
    logits2, kl2, part_latents2 = pvae_model(latents=l, part_bbs=bb, part_labels=bb_l, batch_mask=new_mask, deterministic=True)
    
    # Change masked values and verify that the output does NOT change
    bb[~new_mask] = 100.
    logits3, kl3, part_latents3 = pvae_model(latents=l, part_bbs=bb, part_labels=bb_l, batch_mask=new_mask, deterministic=True)
    
    # Change masked values and verify that the output does NOT change
    bb[new_mask] = 100.
    logits4, kl4, part_latents4 = pvae_model(latents=l, part_bbs=bb, part_labels=bb_l, batch_mask=new_mask, deterministic=True)

    print("\nMasked entries test")
    print("=="*20)
    print(f"- Resampled mask gives different output: {not torch.allclose(logits1, logits2)}")
    print(f"- Changing masked entries gives same output: {torch.max(logits2 - logits3)}")
    print(f"- Changing unmasked entries gives diff. output: {torch.max(logits2 - logits4)}")


def test_permutation_invariance(pvae_model, sample_data):
    l, bb, bb_l, _ = [t.to(device) for t in create_sample_data()]
    mask = ~((torch.ones(BATCH_SIZE, N_PARTS))*0).bool().to(device)
    ctxt_short = torch.randn((BATCH_SIZE, LATENT_DIM, 8)).to(device)

    def encode_stuff(pvae, bb, bb_l, mask, ctxt):
        _, part_latents, _ = pvae.encode(latents=ctxt, part_bbs=bb, part_labels=bb_l, batch_mask=mask, deterministic=True)
        return part_latents

    # Forward pass with original data
    part_latents_1 = encode_stuff(pvae_model, bb, bb_l, mask, ctxt_short)

    rnd_perm = torch.randperm(N_PARTS)
    bb_perm = bb[:, rnd_perm, :, :]
    bb_l_perm = bb_l[:, rnd_perm]

    # Forward pass with permuted data
    part_latents_2 = encode_stuff(pvae_model, bb_perm, bb_l_perm, mask, ctxt_short)

    # Check if the outputs are the same (up to permutation)
    print("\nPermutation invariance test:")
    print("=="*20)
    print(f"- Part latents are invariant: {torch.allclose(part_latents_1, part_latents_2)}")
    print(f"- Part latents are equivariant: {torch.allclose(part_latents_1[:, rnd_perm, :], part_latents_2)}")

/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
from util.misc import set_all_seeds
set_all_seeds(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
torch.set_default_dtype(torch.float64)

# Create sample data
sample_data = [t.to(device) for t in create_sample_data()]

# Run tests
pvae_model = init_model(use_vae=True)
test_deterministic(pvae_model, sample_data)
test_masked_entries(pvae_model, sample_data)
test_permutation_invariance(pvae_model, sample_data)

Set seed to 0

Deterministic inference test
- Logits are equal: True

Masked entries test
- Resampled mask gives different output: True
- Changing masked entries gives same output: 0.0003500815745000885
- Changing unmasked entries gives diff. output: 2.0631369963220867

Permutation invariance test:
- Part latents are invariant: False
- Part latents are equivariant: True
