Skip to content

cuequivariance_torch.attention_pair_bias returns None when return_z_proj is False #227

@eneskelestemur

Description

@eneskelestemur

Not sure if this is an intended behavior or a bug, but cuequivariance_torch.attention_pair_bias returns a placeholder NoneType even when return_z_proj=False. This behavior doesn't match the example in the docs.

From the docs: https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.attention_pair_bias.html

import torch
import cuequivariance
from cuequivariance_torch import attention_pair_bias

# Print torch and cuda version
print("Torch version:", torch.__version__)
print("cuEquivariance version:", cuequivariance.__version__)

if torch.cuda.is_available():
    device = torch.device("cuda")
    batch_size, seq_len, num_heads, heads_dim, hidden_dim = 1, 32, 2, 32, 64
    query_len, key_len, z_dim = 32, 32, 16
    # Create input tensors on GPU
    s = torch.randn(batch_size, seq_len, hidden_dim,
                    device=device, dtype=torch.bfloat16)
    q = torch.randn(batch_size, num_heads, query_len, heads_dim,
                    device=device, dtype=torch.bfloat16)
    k = torch.randn(batch_size, num_heads, key_len, heads_dim,
                    device=device, dtype=torch.bfloat16)
    v = torch.randn(batch_size, num_heads, key_len, heads_dim,
                    device=device, dtype=torch.bfloat16)
    z = torch.randn(batch_size, query_len, key_len, z_dim,
                    device=device, dtype=torch.bfloat16)
    mask = torch.rand(batch_size, key_len,
                      device=device) < 0.5
    w_proj_z = torch.randn(num_heads, z_dim,
                    device=device, dtype=torch.bfloat16)
    w_proj_g = torch.randn(hidden_dim, hidden_dim,
                    device=device, dtype=torch.bfloat16)
    w_proj_o = torch.randn(hidden_dim, hidden_dim,
                    device=device, dtype=torch.bfloat16)
    w_ln_z = torch.randn(z_dim,
                    device=device, dtype=torch.bfloat16)
    b_ln_z = torch.randn(z_dim,
                    device=device, dtype=torch.bfloat16)
    # Perform operation
    output = attention_pair_bias(
        s=s,
        q=q,
        k=k,
        v=v,
        z=z,
        mask=mask,
        num_heads=num_heads,
        w_proj_z=w_proj_z,
        w_proj_g=w_proj_g,
        w_proj_o=w_proj_o,
        w_ln_z=w_ln_z,
        b_ln_z=b_ln_z,
        return_z_proj=False,
    )
    print(output[0].shape, output[1])  # Not from docs
    print(output.shape)  # torch.Size([1, 32, 64])

outputs:

Torch version: 2.8.0+cu129
cuEquivariance version: 0.8.0
torch.Size([1, 32, 64]) None
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], [line 53](vscode-notebook-cell:?execution_count=5&line=53)
     37 output = attention_pair_bias(
     38     s=s,
     39     q=q,
   (...)     50     return_z_proj=False,
     51 )
     52 print(output[0].shape, output[1])
---> [53](vscode-notebook-cell:?execution_count=5&line=53) print(output.shape)  # torch.Size([1, 32, 64])

AttributeError: 'tuple' object has no attribute 'shape'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions