# Deterministic Flash Attention - Overnight Build & Test

**This notebook can run unattended overnight.**

It builds, verifies, and saves everything automatically without requiring runtime restart.

## Step 1: Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Extract to LOCAL filesystem
!unzip -q /content/drive/MyDrive/flash-attention-deterministic.zip -d /content/
%cd /content/flash-attention-deterministic

# Install dependencies
!pip install -q ninja packaging

print("\n✓ Setup complete")

## Step 2: Build (15-20 minutes)

In [None]:
import time
start_time = time.time()

!MAX_JOBS=16 pip install -e . --no-build-isolation -v 2>&1 | tee build.log

build_time = (time.time() - start_time) / 60
print(f"\n{'='*80}")
print(f"BUILD COMPLETE - took {build_time:.1f} minutes")
print(f"{'='*80}")

## Step 3: Save Build Artifacts

In [None]:
# Create save directory
!mkdir -p /content/drive/MyDrive/flash_attn_FINAL

# Save build log
!cp build.log /content/drive/MyDrive/flash_attn_FINAL/

# Save all compiled .so files
!find . -name "*.so" -exec cp {} /content/drive/MyDrive/flash_attn_FINAL/ \;

# Save build directory
!cp -r build /content/drive/MyDrive/flash_attn_FINAL/build_backup 2>/dev/null || echo "No build dir"

print("\n✓ Build artifacts saved to Drive")

# Verify compilation
import subprocess
flag_count = subprocess.run(['grep', '-c', 'DFLASH_ATTENTION_DETERMINISTIC', 'build.log'], 
                           capture_output=True, text=True)
count = int(flag_count.stdout.strip() if flag_count.returncode == 0 else '0')
print(f"\n-DFLASH_ATTENTION_DETERMINISTIC found: {count} times")

if count > 50:
    print("✅ Compilation looks good!")
else:
    print("⚠️  WARNING: Flag count is low")

## Step 4: Verification Tests

**NOTE:** These tests run in the same Python session as the build.
The library should work correctly without restart since we're not replacing an existing installation.

In [None]:
import torch
import gc
import time

# Force reimport after build
import sys
if 'flash_attn' in sys.modules:
    del sys.modules['flash_attn']
    
from flash_attn import flash_attn_func, set_deterministic_mode

print("="*80)
print("VERIFICATION TESTS")
print("="*80)

### Test 1: Memory Allocation

In [None]:
def measure_memory(deterministic):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    set_deterministic_mode(enabled=deterministic, split_size=512)
    
    # CORRECT layout: (batch, seqlen, heads, headdim)
    q = torch.randn(4, 2048, 32, 64, dtype=torch.float16, device='cuda')
    k = torch.randn(4, 2048, 32, 64, dtype=torch.float16, device='cuda')
    v = torch.randn(4, 2048, 32, 64, dtype=torch.float16, device='cuda')
    
    torch.cuda.synchronize()
    before = torch.cuda.memory_allocated() / (1024**2)
    out = flash_attn_func(q, k, v, causal=False)
    torch.cuda.synchronize()
    peak = torch.cuda.max_memory_allocated() / (1024**2)
    
    del q, k, v, out
    torch.cuda.empty_cache()
    return peak - before

print("\nMEMORY TEST (B=4, L=2048, H=32, D=64):")
mem_std = measure_memory(False)
mem_det = measure_memory(True)

print(f"  Standard:      {mem_std:.2f} MB")
print(f"  Deterministic: {mem_det:.2f} MB")
print(f"  Increase:      {mem_det - mem_std:.2f} MB")

if mem_det - mem_std > 100:
    print("  ✅ PASS: Significant memory increase detected")
    memory_pass = True
else:
    print("  ❌ FAIL: Memory increase too small")
    memory_pass = False

### Test 2: Performance Overhead

In [None]:
print("\nPERFORMANCE TEST (B=8, L=4096, H=32, D=64):")

# CORRECT layout: (batch, seqlen, heads, headdim)
batch_size, seqlen, num_heads, head_dim = 8, 4096, 32, 64
q = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')

# Warmup
set_deterministic_mode(enabled=False)
for _ in range(10):
    _ = flash_attn_func(q, k, v, causal=False)
set_deterministic_mode(enabled=True, split_size=512)
for _ in range(10):
    _ = flash_attn_func(q, k, v, causal=False)

# Standard mode
set_deterministic_mode(enabled=False)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(50):
    _ = flash_attn_func(q, k, v, causal=False)
torch.cuda.synchronize()
time_std = (time.perf_counter() - t0) / 50 * 1000

# Deterministic mode
set_deterministic_mode(enabled=True, split_size=512)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(50):
    _ = flash_attn_func(q, k, v, causal=False)
torch.cuda.synchronize()
time_det = (time.perf_counter() - t0) / 50 * 1000

overhead = ((time_det / time_std) - 1) * 100

print(f"  Standard:      {time_std:.3f} ms")
print(f"  Deterministic: {time_det:.3f} ms")
print(f"  Overhead:      {overhead:+.1f}%")

if 5 <= overhead <= 50:
    print("  ✅ PASS: Reasonable overhead")
    perf_pass = True
elif overhead > 50:
    print("  ⚠️  WARNING: High overhead")
    perf_pass = True
else:
    print("  ⚠️  WARNING: Overhead suspiciously low")
    perf_pass = False

### Test 3: Batch Invariance

In [None]:
print("\nBATCH INVARIANCE TEST (B=8 vs B=4+4):")

set_deterministic_mode(enabled=True, split_size=512)

# CORRECT layout: (batch, seqlen, heads, headdim)
batch_size, seqlen, num_heads, head_dim = 8, 2048, 32, 64
q_full = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
k_full = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
v_full = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')

# Full batch
out_full = flash_attn_func(q_full, k_full, v_full, causal=False)

# Split batch
out_split1 = flash_attn_func(q_full[:4], k_full[:4], v_full[:4], causal=False)
out_split2 = flash_attn_func(q_full[4:], k_full[4:], v_full[4:], causal=False)
out_split = torch.cat([out_split1, out_split2], dim=0)

# Check
batch_invariant = torch.equal(out_full, out_split)
max_diff = (out_full - out_split).abs().max().item()

print(f"  Max difference: {max_diff:.2e}")

if batch_invariant:
    print("  ✅ PASS: Bit-exact batch invariance")
    batch_pass = True
else:
    print(f"  ❌ FAIL: Results differ by {max_diff}")
    batch_pass = False

## Step 5: Summary and Save Results

In [None]:
print("\n" + "="*80)
print("FINAL RESULTS")
print("="*80)

print(f"\n  Memory Test:      {'✅ PASS' if memory_pass else '❌ FAIL'}")
print(f"  Performance Test: {'✅ PASS' if perf_pass else '❌ FAIL'}")
print(f"  Batch Invariance: {'✅ PASS' if batch_pass else '❌ FAIL'}")

all_pass = memory_pass and perf_pass and batch_pass

if all_pass:
    print("\n" + "="*80)
    print("✅ ALL TESTS PASSED - DETERMINISTIC MODE WORKING!")
    print("="*80)
else:
    print("\n" + "="*80)
    print("❌ SOME TESTS FAILED")
    print("="*80)

# Save detailed results
results = f"""Deterministic Flash Attention - Build & Verification Results
{'='*70}

BUILD INFO:
  Build time: {build_time:.1f} minutes
  GPU: {torch.cuda.get_device_name(0)}
  CUDA: {torch.version.cuda}
  PyTorch: {torch.__version__}

MEMORY TEST (B=4, L=2048, H=32, D=64):
  Standard:      {mem_std:.2f} MB
  Deterministic: {mem_det:.2f} MB
  Increase:      {mem_det - mem_std:.2f} MB ({((mem_det/mem_std - 1) * 100):.1f}%)
  Status: {'PASS' if memory_pass else 'FAIL'}

PERFORMANCE TEST (B=8, L=4096, H=32, D=64):
  Standard:      {time_std:.3f} ms
  Deterministic: {time_det:.3f} ms
  Overhead:      {overhead:+.1f}%
  Status: {'PASS' if perf_pass else 'FAIL'}

BATCH INVARIANCE TEST (B=8 vs B=4+4):
  Max difference: {max_diff:.2e}
  Bit-exact: {'YES' if batch_invariant else 'NO'}
  Status: {'PASS' if batch_pass else 'FAIL'}

OVERALL: {'✅ ALL TESTS PASSED' if all_pass else '❌ SOME TESTS FAILED'}
"""

print("\n" + results)

# Save to Drive
with open('/content/drive/MyDrive/flash_attn_FINAL/verification_results.txt', 'w') as f:
    f.write(results)

print("\n✓ Results saved to /content/drive/MyDrive/flash_attn_FINAL/verification_results.txt")

## Step 6: Create Wheel (Optional)

In [None]:
if all_pass:
    print("Creating wheel package...")
    !python setup.py bdist_wheel
    !cp dist/*.whl /content/drive/MyDrive/flash_attn_FINAL/
    print("\n✓ Wheel saved to Drive")
    print("\nInstall in future sessions with:")
    print("  !pip install /content/drive/MyDrive/flash_attn_FINAL/*.whl")
else:
    print("⚠️  Skipping wheel creation - tests failed")