# Deterministic Flash Attention - Overnight Build & Test

**This notebook can run unattended overnight.**

It builds, verifies, and saves everything automatically without requiring runtime restart.

## Step 1: Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Clone latest version from GitHub (includes bug fix)
!cd /content && rm -rf flash-attention-deterministic
!git clone https://github.com/ProbioticFarmer/flash-attention-deterministic.git /content/flash-attention-deterministic
%cd /content/flash-attention-deterministic

# Install dependencies
!pip install -q ninja packaging

print("\n✓ Setup complete")
print(f"✓ Using latest code with long sequence bug fix")

## Step 2: Build (15-20 minutes)

In [None]:
import time
start_time = time.time()

!MAX_JOBS=16 pip install -e . --no-build-isolation -v 2>&1 | tee build.log

build_time = (time.time() - start_time) / 60
print(f"\n{'='*80}")
print(f"BUILD COMPLETE - took {build_time:.1f} minutes")
print(f"{'='*80}")

## Step 3: Save Build Artifacts

In [None]:
# Create save directory
!mkdir -p /content/drive/MyDrive/flash_attn_FINAL

# Save build log
!cp build.log /content/drive/MyDrive/flash_attn_FINAL/

# Save all compiled .so files
!find . -name "*.so" -exec cp {} /content/drive/MyDrive/flash_attn_FINAL/ \;

# Save build directory
!cp -r build /content/drive/MyDrive/flash_attn_FINAL/build_backup 2>/dev/null || echo "No build dir"

print("\n✓ Build artifacts saved to Drive")

# Verify compilation
import subprocess
flag_count = subprocess.run(['grep', '-c', 'DFLASH_ATTENTION_DETERMINISTIC', 'build.log'], 
                           capture_output=True, text=True)
count = int(flag_count.stdout.strip() if flag_count.returncode == 0 else '0')
print(f"\n-DFLASH_ATTENTION_DETERMINISTIC found: {count} times")

if count > 50:
    print("✅ Compilation looks good!")
else:
    print("⚠️  WARNING: Flag count is low")

## Step 4: Verification Tests

**NOTE:** These tests run in the same Python session as the build.
The library should work correctly without restart since we're not replacing an existing installation.

In [None]:
import torch
import gc
import time

# Force reimport after build
import sys
if 'flash_attn' in sys.modules:
    del sys.modules['flash_attn']
    
from flash_attn import flash_attn_func, set_deterministic_mode

print("="*80)
print("VERIFICATION TESTS")
print("="*80)

### Test 1: Memory Allocation

In [None]:
def measure_memory(deterministic):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    set_deterministic_mode(enabled=deterministic, split_size=512)
    
    # CORRECT layout: (batch, seqlen, heads, headdim)
    q = torch.randn(4, 2048, 32, 64, dtype=torch.float16, device='cuda')
    k = torch.randn(4, 2048, 32, 64, dtype=torch.float16, device='cuda')
    v = torch.randn(4, 2048, 32, 64, dtype=torch.float16, device='cuda')
    
    torch.cuda.synchronize()
    before = torch.cuda.memory_allocated() / (1024**2)
    out = flash_attn_func(q, k, v, causal=False)
    torch.cuda.synchronize()
    peak = torch.cuda.max_memory_allocated() / (1024**2)
    
    del q, k, v, out
    torch.cuda.empty_cache()
    return peak - before

print("\nMEMORY TEST (B=4, L=2048, H=32, D=64):")
mem_std = measure_memory(False)
mem_det = measure_memory(True)

print(f"  Standard:      {mem_std:.2f} MB")
print(f"  Deterministic: {mem_det:.2f} MB")
print(f"  Increase:      {mem_det - mem_std:.2f} MB")

if mem_det - mem_std > 100:
    print("  ✅ PASS: Significant memory increase detected")
    memory_pass = True
else:
    print("  ❌ FAIL: Memory increase too small")
    memory_pass = False

### Test 2: Performance Overhead

In [None]:
print("\nPERFORMANCE TEST (B=8, L=4096, H=32, D=64):")

# CORRECT layout: (batch, seqlen, heads, headdim)
batch_size, seqlen, num_heads, head_dim = 8, 4096, 32, 64
q = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')

# Warmup
set_deterministic_mode(enabled=False)
for _ in range(10):
    _ = flash_attn_func(q, k, v, causal=False)
set_deterministic_mode(enabled=True, split_size=512)
for _ in range(10):
    _ = flash_attn_func(q, k, v, causal=False)

# Standard mode
set_deterministic_mode(enabled=False)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(50):
    _ = flash_attn_func(q, k, v, causal=False)
torch.cuda.synchronize()
time_std = (time.perf_counter() - t0) / 50 * 1000

# Deterministic mode
set_deterministic_mode(enabled=True, split_size=512)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(50):
    _ = flash_attn_func(q, k, v, causal=False)
torch.cuda.synchronize()
time_det = (time.perf_counter() - t0) / 50 * 1000

overhead = ((time_det / time_std) - 1) * 100

print(f"  Standard:      {time_std:.3f} ms")
print(f"  Deterministic: {time_det:.3f} ms")
print(f"  Overhead:      {overhead:+.1f}%")

if 5 <= overhead <= 50:
    print("  ✅ PASS: Reasonable overhead")
    perf_pass = True
elif overhead > 50:
    print("  ⚠️  WARNING: High overhead")
    perf_pass = True
else:
    print("  ⚠️  WARNING: Overhead suspiciously low")
    perf_pass = False

### Test 3: Batch Invariance

In [None]:
print("\nBATCH INVARIANCE TEST (B=8 vs B=4+4):")

set_deterministic_mode(enabled=True, split_size=512)

# CORRECT layout: (batch, seqlen, heads, headdim)
batch_size, seqlen, num_heads, head_dim = 8, 2048, 32, 64
q_full = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
k_full = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')
v_full = torch.randn(batch_size, seqlen, num_heads, head_dim, dtype=torch.float16, device='cuda')

# Full batch
out_full = flash_attn_func(q_full, k_full, v_full, causal=False)

# Split batch
out_split1 = flash_attn_func(q_full[:4], k_full[:4], v_full[:4], causal=False)
out_split2 = flash_attn_func(q_full[4:], k_full[4:], v_full[4:], causal=False)
out_split = torch.cat([out_split1, out_split2], dim=0)

# Check
batch_invariant = torch.equal(out_full, out_split)
max_diff = (out_full - out_split).abs().max().item()

print(f"  Max difference: {max_diff:.2e}")

if batch_invariant:
    print("  ✅ PASS: Bit-exact batch invariance")
    batch_pass = True
else:
    print(f"  ❌ FAIL: Results differ by {max_diff}")
    batch_pass = False

## Step 5: Memory Scaling Analysis

In [None]:
print("\nMEMORY SCALING ANALYSIS:")
print("Testing memory overhead at increasing scales (10x, 20x, 30x...)")
print("Test will continue until CUDA OOM or error occurs\n")

# Baseline parameters from Test 1
baseline_b, baseline_l = 4, 2048
baseline_mem_std = mem_std
baseline_mem_det = mem_det

# Store results
scaling_results = []

# Test at multiples of baseline
for scale in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
    # Scale batch size while keeping sequence length manageable
    test_b = baseline_b * scale
    test_l = baseline_l
    
    print(f"Scale {scale}x (B={test_b}, L={test_l}):")
    
    try:
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        set_deterministic_mode(enabled=False)
        q = torch.randn(test_b, test_l, 32, 64, dtype=torch.float16, device='cuda')
        k = torch.randn(test_b, test_l, 32, 64, dtype=torch.float16, device='cuda')
        v = torch.randn(test_b, test_l, 32, 64, dtype=torch.float16, device='cuda')
        
        torch.cuda.synchronize()
        out = flash_attn_func(q, k, v, causal=False)
        torch.cuda.synchronize()
        mem_std_scaled = torch.cuda.max_memory_allocated() / (1024**2)
        
        del q, k, v, out
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        set_deterministic_mode(enabled=True, split_size=512)
        q = torch.randn(test_b, test_l, 32, 64, dtype=torch.float16, device='cuda')
        k = torch.randn(test_b, test_l, 32, 64, dtype=torch.float16, device='cuda')
        v = torch.randn(test_b, test_l, 32, 64, dtype=torch.float16, device='cuda')
        
        torch.cuda.synchronize()
        out = flash_attn_func(q, k, v, causal=False)
        torch.cuda.synchronize()
        mem_det_scaled = torch.cuda.max_memory_allocated() / (1024**2)
        
        overhead = mem_det_scaled - mem_std_scaled
        overhead_pct = ((mem_det_scaled / mem_std_scaled) - 1) * 100
        
        scaling_results.append({
            'scale': scale,
            'batch': test_b,
            'seqlen': test_l,
            'mem_std': mem_std_scaled,
            'mem_det': mem_det_scaled,
            'overhead_mb': overhead,
            'overhead_pct': overhead_pct,
            'success': True
        })
        
        print(f"  ✅ Standard: {mem_std_scaled:.1f} MB, Deterministic: {mem_det_scaled:.1f} MB")
        print(f"     Overhead: +{overhead:.1f} MB ({overhead_pct:.1f}%)\n")
        
        del q, k, v, out
        
    except RuntimeError as e:
        error_msg = str(e)
        scaling_results.append({
            'scale': scale,
            'batch': test_b,
            'seqlen': test_l,
            'mem_std': None,
            'mem_det': None,
            'overhead_mb': None,
            'overhead_pct': None,
            'success': False,
            'error': error_msg
        })
        
        print(f"  ❌ FAILED: {error_msg[:80]}...")
        print(f"     Max scale reached: {scale-10}x\n")
        break
    
    gc.collect()
    torch.cuda.empty_cache()

# Print summary table
print("\n" + "="*80)
print("MEMORY SCALING SUMMARY")
print("="*80)
print(f"{'Scale':<8} {'Batch':<8} {'Std (MB)':<12} {'Det (MB)':<12} {'Overhead':<15} {'Status':<10}")
print("-"*80)

for r in scaling_results:
    if r['success']:
        print(f"{r['scale']:>4}x    {r['batch']:<8} {r['mem_std']:>10.1f}  {r['mem_det']:>10.1f}  +{r['overhead_mb']:>7.1f} ({r['overhead_pct']:>5.1f}%)  {'✅ OK':<10}")
    else:
        print(f"{r['scale']:>4}x    {r['batch']:<8} {'N/A':<10}  {'N/A':<10}  {'N/A':<15}  {'❌ FAIL':<10}")

print("="*80)

# Generate plot if matplotlib available
try:
    import matplotlib.pyplot as plt
    
    successful = [r for r in scaling_results if r['success']]
    
    if len(successful) > 1:
        scales = [r['scale'] for r in successful]
        overheads_mb = [r['overhead_mb'] for r in successful]
        overheads_pct = [r['overhead_pct'] for r in successful]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        # Absolute overhead
        ax1.plot(scales, overheads_mb, 'o-', linewidth=2, markersize=8)
        ax1.set_xlabel('Scale Factor (×baseline)', fontsize=12)
        ax1.set_ylabel('Memory Overhead (MB)', fontsize=12)
        ax1.set_title('Deterministic Mode: Absolute Memory Overhead', fontsize=14)
        ax1.grid(True, alpha=0.3)
        
        # Percentage overhead
        ax2.plot(scales, overheads_pct, 's-', linewidth=2, markersize=8, color='orange')
        ax2.set_xlabel('Scale Factor (×baseline)', fontsize=12)
        ax2.set_ylabel('Memory Overhead (%)', fontsize=12)
        ax2.set_title('Deterministic Mode: Relative Memory Overhead', fontsize=14)
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('/content/drive/MyDrive/flash_attn_FINAL/memory_scaling.png', dpi=150, bbox_inches='tight')
        print("\n✓ Memory scaling plot saved to Drive")
        plt.show()
        
except ImportError:
    print("\n(Matplotlib not available - skipping plot generation)")

# Save scaling data
import json
with open('/content/drive/MyDrive/flash_attn_FINAL/memory_scaling_data.json', 'w') as f:
    json.dump(scaling_results, f, indent=2)
print("✓ Memory scaling data saved to Drive")

# Determine if scaling test passed
max_successful_scale = max([r['scale'] for r in scaling_results if r['success']], default=0)
scaling_pass = max_successful_scale >= 20  # Pass if we can handle at least 20x scale

print(f"\nMaximum successful scale: {max_successful_scale}x")
if scaling_pass:
    print("✅ PASS: Memory scaling test successful")
else:
    print("⚠️  WARNING: Low maximum scale")

## Step 6: Long Sequence Bug Fix Verification

## Step 7: Final Results

## Step 8: Create Wheel (Optional)

## Step 6: Create Wheel (Optional)

In [None]:
if all_pass:
    print("Creating wheel package...")
    !python setup.py bdist_wheel
    !cp dist/*.whl /content/drive/MyDrive/flash_attn_FINAL/
    print("\n✓ Wheel saved to Drive")
    print("\nInstall in future sessions with:")
    print("  !pip install /content/drive/MyDrive/flash_attn_FINAL/*.whl")
else:
    print("⚠️  Skipping wheel creation - tests failed")