## Flash Attention Sanity Check

Run this notebook from the project root after installing `flash-attn>=2.0.0` on a CUDA-enabled runtime. The cells exercise the updated import paths for both the standalone UNet attention wrapper and the latent diffusion cross-attention module.


In [None]:
import torch

try:
    from src.cm.unet import QKVFlashAttention
except ImportError as err:
    raise RuntimeError("Make sure this notebook runs from the cd4mt repo root.") from err

device = 'cuda' if torch.cuda.is_available() else None
if device is None:
    raise RuntimeError('Flash attention 2.x requires a CUDA runtime with compatible drivers.')

embed_dim = 128
num_heads = 4
seq_len = 64
batch_size = 2

attn = QKVFlashAttention(
    embed_dim=embed_dim,
    num_heads=num_heads,
    attention_dropout=0.0,
    causal=False,
    dtype=torch.float16,
)
attn = attn.to(device)

qkv = torch.randn(batch_size, 3 * embed_dim, seq_len, device=device, dtype=torch.float16)
with torch.no_grad():
    out = attn(qkv)

out.shape

torch.Size([2, 128, 64])

In [2]:
from ldm.modules.modules import attention as ldm_attn

ldm_attn.CrossAttention.use_flash_attention = True

cross = ldm_attn.CrossAttention(
    query_dim=embed_dim,
    heads=num_heads,
    dim_head=embed_dim // num_heads,
    dropout=0.0,
)
cross = cross.to(device)

x = torch.randn(batch_size, seq_len, embed_dim, device=device)

with torch.no_grad():
    cross_out = cross(x)

cross_out.shape


ModuleNotFoundError: No module named 'latent_diffusion'