Skip to content

torch triangle_attention failed when sequence length large than 3000 #136

@lupengk

Description

@lupengk

When I pass data with a length over 3000 to triangle attention, torch triangle_attention fails to run, but it can still output the shape.

OUTPUT:

torch.Size([1, 3000, 4, 3000, 32])
Traceback (most recent call last):
  File "test_triattn.py", line 14, in <module>
    print(torch.sum(output))
          ^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

CODE:

import torch
import math
from cuequivariance_torch import triangle_attention


if torch.cuda.is_available():
    q_x=torch.randn([1, 3000,4, 3000, 32],dtype=torch.bfloat16,device=torch.device('cuda'),requires_grad=True)
    kv_x=torch.randn([1, 3000,4, 3000, 32],dtype=torch.bfloat16,device=torch.device('cuda'),requires_grad=True)
    triangle_bias=torch.randn([1, 1, 4, 3000, 3000],dtype=torch.bfloat16,device=torch.device('cuda'),requires_grad=True)
    mask_bias=torch.ones([1, 3000, 1, 1, 3000],dtype=torch.bfloat16,device=torch.device('cuda'))>0
    scale = 1 / math.sqrt(32)
    output = triangle_attention(q_x,kv_x,kv_x,triangle_bias,mask_bias,scale)
    print(output.shape)
    print(torch.sum(output))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions