In [1]:
"""
Flash Attention Compatibility Test for RTX 4000 Ada
Run this BEFORE Experiment 6
"""

import torch
import sys

print("="*70)
print("FLASH ATTENTION COMPATIBILITY TEST")
print("="*70)

# Test 1: Basic GPU Info
print("\n[TEST 1] GPU Information:")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    compute_cap = torch.cuda.get_device_capability(0)
    print(f"Compute capability: {compute_cap[0]}.{compute_cap[1]}")
else:
    print("❌ CUDA not available! Check your environment.")
    sys.exit(1)

# Test 2: Check if scaled_dot_product_attention exists (PyTorch 2.0+)
print("\n[TEST 2] PyTorch Native Flash Attention:")
try:
    from torch.nn.functional import scaled_dot_product_attention
    print("✅ scaled_dot_product_attention found (PyTorch 2.0+)")
except ImportError:
    print("❌ scaled_dot_product_attention NOT found")
    print("   You need PyTorch 2.0+")
    sys.exit(1)

# Test 3: Test basic Flash Attention call
print("\n[TEST 3] Basic Flash Attention Test:")
try:
    Q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
    K = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
    V = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
    
    # Test with is_causal (what you need for language modeling)
    out = torch.nn.functional.scaled_dot_product_attention(
        Q, K, V, 
        is_causal=True,
        dropout_p=0.0
    )
    
    print(f"✅ Flash Attention works!")
    print(f"   Input shape: {Q.shape}")
    print(f"   Output shape: {out.shape}")
    print(f"   Memory used: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    
except Exception as e:
    print(f"❌ Flash Attention failed!")
    print(f"   Error: {str(e)}")
    print("\n   Possible causes:")
    print("   - GPU compute capability not supported")
    print("   - CUDA/PyTorch version mismatch")
    print("   - Driver issue")

# Test 4: Check which backends are available
print("\n[TEST 4] Available SDPA Backends:")
try:
    from torch.backends.cuda import sdp_kernel, SDPBackend
    
    # Create dummy tensors
    Q = torch.randn(1, 1, 128, 64, device='cuda', dtype=torch.float16)
    K = torch.randn(1, 1, 128, 64, device='cuda', dtype=torch.float16)
    V = torch.randn(1, 1, 128, 64, device='cuda', dtype=torch.float16)
    
    backends = {
        "FLASH_ATTENTION": SDPBackend.FLASH_ATTENTION,
        "EFFICIENT_ATTENTION": SDPBackend.EFFICIENT_ATTENTION,
        "MATH": SDPBackend.MATH
    }
    
    for name, backend in backends.items():
        try:
            with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
                # Enable only this backend
                if backend == SDPBackend.FLASH_ATTENTION:
                    with sdp_kernel(enable_flash=True):
                        torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
                elif backend == SDPBackend.EFFICIENT_ATTENTION:
                    with sdp_kernel(enable_mem_efficient=True):
                        torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
                else:
                    with sdp_kernel(enable_math=True):
                        torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
                print(f"✅ {name} available")
        except Exception as e:
            print(f"❌ {name} not available: {str(e)}")
            
except ImportError:
    print("⚠️  Cannot check backends (need PyTorch 2.0+)")

# Test 5: Memory comparison test
print("\n[TEST 5] Memory Comparison Test:")
print("Testing standard attention vs Flash Attention memory usage...")

batch_size = 4
seq_len = 512
d_model = 256
num_heads = 8
head_dim = d_model // num_heads

def standard_attention(Q, K, V):
    """Your current attention implementation"""
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)
    
    # Causal mask
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, device='cuda'))
    scores = scores.masked_fill(causal_mask == 0, float('-inf'))
    
    attn = torch.softmax(scores, dim=-1)
    out = torch.matmul(attn, V)
    return out

# Standard attention memory
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)

try:
    out_standard = standard_attention(Q, K, V)
    mem_standard = torch.cuda.max_memory_allocated() / 1e6
    print(f"Standard Attention: {mem_standard:.2f} MB")
except Exception as e:
    print(f"Standard Attention failed: {e}")
    mem_standard = 0

# Flash attention memory
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)

try:
    out_flash = torch.nn.functional.scaled_dot_product_attention(
        Q, K, V, is_causal=True
    )
    mem_flash = torch.cuda.max_memory_allocated() / 1e6
    print(f"Flash Attention: {mem_flash:.2f} MB")
    
    if mem_standard > 0:
        savings = ((mem_standard - mem_flash) / mem_standard) * 100
        print(f"Memory savings: {savings:.1f}%")
except Exception as e:
    print(f"Flash Attention failed: {e}")

# Test 6: Speed comparison
print("\n[TEST 6] Speed Comparison Test:")
import time

def benchmark(fn, Q, K, V, num_runs=100):
    # Warmup
    for _ in range(10):
        fn(Q, K, V)
    
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        fn(Q, K, V)
    torch.cuda.synchronize()
    end = time.time()
    
    return (end - start) / num_runs * 1000  # ms per run

Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda', dtype=torch.float32)

try:
    time_standard = benchmark(standard_attention, Q, K, V)
    print(f"Standard Attention: {time_standard:.2f} ms/iteration")
except:
    print("Standard Attention: Failed")
    time_standard = 0

try:
    def flash_fn(Q, K, V):
        return torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
    
    time_flash = benchmark(flash_fn, Q, K, V)
    print(f"Flash Attention: {time_flash:.2f} ms/iteration")
    
    if time_standard > 0:
        speedup = time_standard / time_flash
        print(f"Speedup: {speedup:.2f}x faster")
except:
    print("Flash Attention: Failed")

print("\n" + "="*70)
print("TEST COMPLETE")
print("="*70)

# Final verdict
print("\n[VERDICT]")
if torch.cuda.is_available() and hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
    try:
        Q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
        K = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
        V = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
        torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
        print("✅ Flash Attention WILL WORK on this GPU!")
        print("   You can proceed with Experiment 6.")
    except:
        print("❌ Flash Attention WILL NOT WORK on this GPU!")
        print("   Consider alternative GPUs or standard attention.")
else:
    print("❌ Requirements not met. Check PyTorch version and CUDA.")

FLASH ATTENTION COMPATIBILITY TEST

[TEST 1] GPU Information:
PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA version: 12.8
GPU name: NVIDIA RTX 4000 Ada Generation
GPU memory: 20.99 GB
Compute capability: 8.9

[TEST 2] PyTorch Native Flash Attention:
✅ scaled_dot_product_attention found (PyTorch 2.0+)

[TEST 3] Basic Flash Attention Test:
✅ Flash Attention works!
   Input shape: torch.Size([2, 8, 512, 64])
   Output shape: torch.Size([2, 8, 512, 64])
   Memory used: 0.00 GB

[TEST 4] Available SDPA Backends:
✅ FLASH_ATTENTION available
✅ EFFICIENT_ATTENTION available
✅ MATH available

[TEST 5] Memory Comparison Test:
Testing standard attention vs Flash Attention memory usage...


  self.gen = func(*args, **kwds)


Standard Attention: 86.11 MB
Flash Attention: 20.05 MB
Memory savings: 76.7%

[TEST 6] Speed Comparison Test:
Standard Attention: 1.01 ms/iteration
Flash Attention: 0.15 ms/iteration
Speedup: 6.63x faster

TEST COMPLETE

[VERDICT]
✅ Flash Attention WILL WORK on this GPU!
   You can proceed with Experiment 6.
