In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Extracting features into HDF5 files for each split.
"""
import torch
import torch.nn as nn
from einops import rearrange
from util.misc import default




class MaskableCrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads

        if context_dim is None:
            context_dim = query_dim

        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
        )

    def forward(self, x, context=None, context_mask=None):
        """
        Args:
            x: Input tensor of shape (B, N, D) where N is the query sequence length
            context: Context tensor of shape (B, L, D) where L is the context sequence length
            context_mask: Boolean mask of shape (B, L) where True means the value is masked

        Returns:
            Tensor of shape (B, N, D)

        Where:
            B: batch size
            N: sequence length of query
            L: sequence length of context
            D: dimension of input features
        """
        h = self.heads
        q = self.to_q(x)

        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        # where d = dim_head
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))

        # sim shape: (B*h, N, L)
        sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

        if context_mask is not None:
            # Expand mask for the heads dimension
            # context_mask shape: (B, L) -> (B, 1, L) -> (B*h, 1, L)
            mask = context_mask.unsqueeze(1).expand(-1, 1, -1)
            mask = mask.repeat_interleave(h, dim=0)

            # Expand mask for the query dimension
            # mask shape: (B*h, 1, L) -> (B*h, N, L)
            mask = mask.expand(-1, x.size(1), -1)

            # Create a mask of -inf where context_mask is True
            mask_value = -torch.finfo(sim.dtype).max
            sim = sim.masked_fill(mask, mask_value)

        # attn shape: (B*h, N, L)
        attn = sim.softmax(dim=-1)

        # out shape: (B*h, N, d) -> (B, N, h*d)
        out = torch.einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)

        return self.to_out(out)



/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
import torch
from torch import randperm


def assert_close(tensor1, tensor2, rtol=1e-7, atol=1e-7, revert=False):
    cond = torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol)
    if revert:
        assert not cond, f"Tensors are close: \n{tensor1}\n{tensor2}"
    else:
        assert cond, f"Tensors are not close: \n{tensor1}\n{tensor2}"

def test_query_permutation_equivariance():
    # Initialize model
    model = MaskableCrossAttention(query_dim=64, context_dim=64, heads=4)
    
    # Create random input
    batch_size, seq_len, dim = 2, 8, 64
    x = torch.randn(batch_size, seq_len, dim)
    context = torch.randn(batch_size, seq_len, dim)
    
    # Get output for original input
    out1 = model(x, context)
    
    # Create random permutation for queries
    perm = randperm(seq_len)
    
    # Permute input queries
    x_perm = x[:, perm, :]
    
    # Get output for permuted input
    out2 = model(x_perm, context)
    
    # The output should follow the same permutation
    assert_close(out1[:, perm, :], out2)
    print("Query permutation equivariance test passed!")

def test_context_permutation_invariance():
    # Initialize model
    model = MaskableCrossAttention(query_dim=64, context_dim=64, heads=4)
    
    # Create random input
    batch_size, seq_len, dim = 2, 8, 64
    x = torch.randn(batch_size, seq_len, dim)
    context = torch.randn(batch_size, seq_len, dim)
    
    # Get output for original input
    out1 = model(x, context)
    
    # Create random permutation for context
    perm = randperm(seq_len)
    
    # Permute context
    context_perm = context[:, perm, :]
    
    # Get output for permuted context
    out2 = model(x, context_perm)
    
    # The output should be the same regardless of context permutation
    assert_close(out1, out2)
    print("Context permutation invariance test passed!")

def test_mask_invariance():
    # Initialize model
    model = MaskableCrossAttention(query_dim=64, context_dim=64, heads=4)
    
    # Create random input
    batch_size, seq_len, dim = 2, 8, 64
    x = torch.randn(batch_size, seq_len, dim)
    context = torch.randn(batch_size, seq_len, dim)
    
    # Create random mask (True means masked)
    mask = torch.rand(batch_size, seq_len) > 0.5
    
    # Get output with original context
    out1 = model(x, context, mask)
    
    # Modify masked values in context
    modified_context = context.clone()
    modified_context[mask] = 100.0  # Set masked values to large number
    
    # Get output with modified context
    out2 = model(x, modified_context, mask)
    
    # The output should be the same regardless of masked values
    assert_close(out1, out2)
    print("Mask invariance test passed!")
    
    # Now modify non-masked values in context
    modified_context = context.clone()
    modified_context[~mask] = 100.0  # Set non-masked values to large number
    
    # Get output with modified context
    out3 = model(x, modified_context, mask)
    
    # The output should be different when non-masked values are modified
    assert_close(out1, out3, revert=True)
    print("Modifying non-masked values in context changes the output as expected!")
    
    # Now modify the query tensor and keep the context tensor the same
    modified_x = x.clone()
    
    # Modify masked values in query
    modified_x[mask] = 100.0  # Set masked values to large number
    
    # Get output with modified query
    out4 = model(modified_x, context, mask)
    
    # The output should be different when masked values in query are modified
    assert_close(out1, out4, revert=True)
    print("Modifying values in query changes the output as expected!")


torch.manual_seed(0) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


# Run all tests
test_query_permutation_equivariance()
test_context_permutation_invariance()
test_mask_invariance()

Query permutation equivariance test passed!
Context permutation invariance test passed!
Mask invariance test passed!
Modifying non-masked values in context changes the output as expected!
Modifying values in query changes the output as expected!
