In [2]:
import random, time, statistics, sys
from typing import List
import torch
from torch import autocast  # Changed from torch.cuda.amp import autocast
# ==============================
# CONFIG — adjust as needed
# ==============================
from pathlib import Path

DATA_ROOT = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH/train_pdbch")
A3M_NAME  = "final_filtered_256_stripped.a3m"
N_SAMPLES = 16          # number of random MSAs to time
SEED      = 42          # reproducible sampling
DEVICE_GPU = "cuda"     # change to "cuda:1" etc. if needed

# --------------------------------------------------
# Helper functions
# --------------------------------------------------
def find_a3m_files(root: Path, name: str) -> List[Path]:
    return list(root.rglob(name))

def parse_a3m(path: Path) -> List[str]:
    """Return list of sequences (query first) from an A3M file."""
    seqs, seq = [], []
    with open(path, "r", encoding="utf-8") as fh:
        for line in fh:
            line = line.rstrip()
            if not line:
                continue
            if line.startswith(">"):
                if seq:
                    seqs.append("".join(seq))
                seq = []
            else:
                seq.append(line)
        if seq:
            seqs.append("".join(seq))
    if not (1 <= len(seqs) <= 256):
        raise ValueError(f"{path}: expected 1–256 sequences, got {len(seqs)}")
    return seqs

def timed_forward(model, batch_converter, seqs, device, use_autocast=False):
    """Return wall-clock seconds for one forward pass (sync'ed)."""
    msa = [(f"seq{i}", s) for i, s in enumerate(seqs)]
    _, _, tok = batch_converter([msa])
    tok = tok.to(device, non_blocking=True)

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    t0 = time.perf_counter()

    with torch.no_grad():
        ctx = autocast(
            device_type=device.type,
            enabled=use_autocast,
            dtype=torch.bfloat16,
        )
        with ctx:
            _ = model(tok, repr_layers=[12])["representations"][12]

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    return time.perf_counter() - t0

def print_stats(label, times):
    mean, med = statistics.mean(times), statistics.median(times)
    sd = statistics.stdev(times) if len(times) > 1 else 0.0
    print(f"\n{label}  —  summary (sec)")
    print("-" * 40)
    print(f"  mean   : {mean:8.4f}")
    print(f"  median : {med:8.4f}")
    print(f"  stddev : {sd:8.4f}")
    print(f"  min    : {min(times):8.4f}")
    print(f"  max    : {max(times):8.4f}")

def run_variant(label, device, compile_model=False, use_autocast=False):
    import esm  # local import keeps notebook cell light
    print(f"\n===== {label} =====")
    # -------- load / compile (excluded from timing) --------
    model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
    model.eval().to(device)
    if compile_model:
        print("Compiling model …")
        model = torch.compile(model, mode="default", dynamic=True)
        # 1 dummy pass to finish compile graph (not timed)
        _ = model(torch.randint(0, 20, (1, 1, 16), device=device))
    batch_converter = alphabet.get_batch_converter()

    # -------- per-file timing ------------------------------
    times = []
    for p in SAMPLE_FILES:
        seqs = parse_a3m(p)
        dt = timed_forward(model, batch_converter, seqs, device, use_autocast)
        times.append(dt)
        print(f"{p.name:<35} {dt:8.4f} s")

    print_stats(label, times)

# --------------------------------------------------
# Pick files & run
# --------------------------------------------------
random.seed(SEED)
all_a3m = find_a3m_files(DATA_ROOT, A3M_NAME)
if len(all_a3m) < N_SAMPLES:
    raise RuntimeError(f"Found only {len(all_a3m)} A3M files (need {N_SAMPLES}).")
SAMPLE_FILES = random.sample(all_a3m, N_SAMPLES)
print(f"Selected {N_SAMPLES} random A3M files (seed={SEED}).")

# Devices
gpu = torch.device(DEVICE_GPU if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")

# Run the four variants
run_variant("GPU | eager FP32",          gpu, compile_model=False, use_autocast=False)
run_variant("GPU | eager BF16 (autocast)", gpu, compile_model=False, use_autocast=True)
run_variant("GPU | compiled + BF16",     gpu, compile_model=True,  use_autocast=True)
run_variant("CPU | eager FP32",          cpu, compile_model=False, use_autocast=False)

Selected 16 random A3M files (seed=42).

===== GPU | eager FP32 =====
final_filtered_256_stripped.a3m       0.6882 s
final_filtered_256_stripped.a3m       0.1456 s
final_filtered_256_stripped.a3m       0.2433 s
final_filtered_256_stripped.a3m       0.1363 s
final_filtered_256_stripped.a3m       0.2594 s
final_filtered_256_stripped.a3m       0.1704 s
final_filtered_256_stripped.a3m       0.1921 s
final_filtered_256_stripped.a3m       0.5438 s
final_filtered_256_stripped.a3m       0.4511 s
final_filtered_256_stripped.a3m       0.2002 s
final_filtered_256_stripped.a3m       1.2911 s
final_filtered_256_stripped.a3m       0.6022 s
final_filtered_256_stripped.a3m       0.3167 s
final_filtered_256_stripped.a3m       0.8935 s
final_filtered_256_stripped.a3m       0.4789 s
final_filtered_256_stripped.a3m       0.4189 s

GPU | eager FP32  —  summary (sec)
----------------------------------------
  mean   :   0.4395
  median :   0.3678
  stddev :   0.3150
  min    :   0.1363
  max    :   1.2911





final_filtered_256_stripped.a3m      38.6286 s


KeyboardInterrupt: 

In [2]:
import random, time, statistics
import torch
from torch import autocast
import numpy as np
from pathlib import Path
from typing import List, Tuple
import torch.nn.functional as F

# ==============================
# CONFIG
# ==============================
DATA_ROOT = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH/train_pdbch")
A3M_NAME = "final_filtered_256_stripped.a3m"
N_TEST_FILES = 16        # number of files for accuracy/speed testing
SEED = 42
DEVICE_GPU = "cuda"
MAX_SEQ_LIMIT = 32      # test with this many sequences vs full MSA

# --------------------------------------------------
# Helper functions
# --------------------------------------------------
def find_a3m_files(root: Path, name: str) -> List[Path]:
    return list(root.rglob(name))

def parse_a3m(path: Path, max_sequences: int = None) -> List[str]:
    """Return list of sequences from A3M file, optionally limited."""
    seqs, seq = [], []
    with open(path, "r", encoding="utf-8") as fh:
        for line in fh:
            line = line.rstrip()
            if not line:
                continue
            if line.startswith(">"):
                if seq:
                    seqs.append("".join(seq))
                    if max_sequences and len(seqs) >= max_sequences:
                        break
                seq = []
            else:
                seq.append(line)
        if seq and (not max_sequences or len(seqs) < max_sequences):
            seqs.append("".join(seq))
    
    if not seqs:
        raise ValueError(f"{path}: no sequences found")
    
    return seqs

def get_model_output(model, batch_converter, seqs, device, use_autocast=False):
    """Get model representations and timing."""
    msa = [(f"seq{i}", s) for i, s in enumerate(seqs)]
    _, _, tok = batch_converter([msa])
    tok = tok.to(device, non_blocking=True)

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    t0 = time.perf_counter()

    with torch.no_grad():
        ctx = autocast(
            device_type=device.type,
            enabled=use_autocast,
            dtype=torch.bfloat16,
        )
        with ctx:
            output = model(tok, repr_layers=[12])["representations"][12]

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    
    elapsed_time = time.perf_counter() - t0
    return output, elapsed_time

def calculate_similarity_metrics(tensor1, tensor2):
    """Calculate various similarity metrics between two tensors."""
    # Flatten tensors for easier comparison
    flat1 = tensor1.flatten().float()
    flat2 = tensor2.flatten().float()
    
    # Cosine similarity
    cos_sim = F.cosine_similarity(flat1.unsqueeze(0), flat2.unsqueeze(0)).item()
    
    # Mean squared error
    mse = F.mse_loss(flat1, flat2).item()
    
    # Mean absolute error
    mae = F.l1_loss(flat1, flat2).item()
    
    # Relative error (%)
    rel_error = (torch.abs(flat1 - flat2) / (torch.abs(flat1) + 1e-8)).mean().item() * 100
    
    # Pearson correlation
    corr = torch.corrcoef(torch.stack([flat1, flat2]))[0, 1].item()
    
    return {
        'cosine_similarity': cos_sim,
        'mse': mse,
        'mae': mae,
        'relative_error_pct': rel_error,
        'pearson_correlation': corr
    }

def test_accuracy_fp32_vs_bf16():
    """Test accuracy difference between FP32 and BF16."""
    print("=" * 60)
    print("ACCURACY TEST: FP32 vs BF16")
    print("=" * 60)
    
    import esm
    model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
    model.eval().to(device)
    batch_converter = alphabet.get_batch_converter()
    
    all_metrics = []
    
    for i, file_path in enumerate(test_files):
        print(f"\nFile {i+1}/{len(test_files)}: {file_path.name}")
        seqs = parse_a3m(file_path)
        print(f"  Sequences in MSA: {len(seqs)}")
        print(f"  Query length: {len(seqs[0])}")
        
        # FP32 run
        fp32_output, fp32_time = get_model_output(model, batch_converter, seqs, device, use_autocast=False)
        
        # BF16 run  
        bf16_output, bf16_time = get_model_output(model, batch_converter, seqs, device, use_autocast=True)
        
        # Calculate similarity metrics
        metrics = calculate_similarity_metrics(fp32_output.cpu(), bf16_output.cpu())
        metrics['speedup'] = fp32_time / bf16_time
        metrics['fp32_time'] = fp32_time
        metrics['bf16_time'] = bf16_time
        
        all_metrics.append(metrics)
        
        print(f"  FP32 time: {fp32_time:.4f}s")
        print(f"  BF16 time: {bf16_time:.4f}s") 
        print(f"  Speedup: {metrics['speedup']:.2f}x")
        print(f"  Cosine similarity: {metrics['cosine_similarity']:.6f}")
        print(f"  Pearson correlation: {metrics['pearson_correlation']:.6f}")
        print(f"  Relative error: {metrics['relative_error_pct']:.4f}%")
        print(f"  MSE: {metrics['mse']:.2e}")
    
    # Summary statistics
    print("\n" + "=" * 40)
    print("ACCURACY SUMMARY")
    print("=" * 40)
    avg_cosine = np.mean([m['cosine_similarity'] for m in all_metrics])
    avg_corr = np.mean([m['pearson_correlation'] for m in all_metrics])
    avg_rel_err = np.mean([m['relative_error_pct'] for m in all_metrics])
    avg_speedup = np.mean([m['speedup'] for m in all_metrics])
    
    print(f"Average cosine similarity: {avg_cosine:.6f}")
    print(f"Average Pearson correlation: {avg_corr:.6f}") 
    print(f"Average relative error: {avg_rel_err:.4f}%")
    print(f"Average speedup: {avg_speedup:.2f}x")
    
    return all_metrics

def test_sequence_limit_speed():
    """Test speed difference with limited vs full MSA sequences."""
    print("\n" + "=" * 60)
    print(f"SPEED TEST: Full MSA vs {MAX_SEQ_LIMIT} sequences (BF16)")
    print("=" * 60)
    
    import esm
    model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
    model.eval().to(device)
    batch_converter = alphabet.get_batch_converter()
    
    full_times = []
    limited_times = []
    
    for i, file_path in enumerate(test_files):
        print(f"\nFile {i+1}/{len(test_files)}: {file_path.name}")
        
        # Full MSA
        seqs_full = parse_a3m(file_path)
        _, time_full = get_model_output(model, batch_converter, seqs_full, device, use_autocast=True)
        
        # Limited MSA
        seqs_limited = parse_a3m(file_path, max_sequences=MAX_SEQ_LIMIT)
        _, time_limited = get_model_output(model, batch_converter, seqs_limited, device, use_autocast=True)
        
        speedup = time_full / time_limited
        
        full_times.append(time_full)
        limited_times.append(time_limited)
        
        print(f"  Full MSA ({len(seqs_full)} seqs): {time_full:.4f}s")
        print(f"  Limited MSA ({len(seqs_limited)} seqs): {time_limited:.4f}s")
        print(f"  Speedup: {speedup:.2f}x")
    
    # Summary
    print("\n" + "=" * 40)
    print("SEQUENCE LIMIT SUMMARY")
    print("=" * 40)
    avg_full = np.mean(full_times)
    avg_limited = np.mean(limited_times)
    avg_speedup = avg_full / avg_limited
    
    print(f"Average full MSA time: {avg_full:.4f}s")
    print(f"Average limited MSA time: {avg_limited:.4f}s")
    print(f"Average speedup: {avg_speedup:.2f}x")
    print(f"Memory/compute reduction: ~{(1 - MAX_SEQ_LIMIT/256)*100:.1f}%")

# --------------------------------------------------
# Main execution
# --------------------------------------------------
random.seed(SEED)
device = torch.device(DEVICE_GPU if torch.cuda.is_available() else "cpu")

# Find test files
all_a3m = find_a3m_files(DATA_ROOT, A3M_NAME)
if len(all_a3m) < N_TEST_FILES:
    raise RuntimeError(f"Found only {len(all_a3m)} A3M files (need {N_TEST_FILES}).")

test_files = random.sample(all_a3m, N_TEST_FILES)
print(f"Selected {N_TEST_FILES} random A3M files for testing (seed={SEED})")
print(f"Device: {device}")

# Run tests
accuracy_results = test_accuracy_fp32_vs_bf16()
test_sequence_limit_speed()

print("\n" + "=" * 60)
print("TESTING COMPLETE")
print("=" * 60)

Selected 16 random A3M files for testing (seed=42)
Device: cuda
ACCURACY TEST: FP32 vs BF16

File 1/16: final_filtered_256_stripped.a3m
  Sequences in MSA: 256
  Query length: 183
  FP32 time: 0.6552s
  BF16 time: 0.1555s
  Speedup: 4.21x
  Cosine similarity: 1.006351
  Pearson correlation: 0.999922
  Relative error: 14.5984%
  MSE: 1.12e-04

File 2/16: final_filtered_256_stripped.a3m
  Sequences in MSA: 256
  Query length: 82
  FP32 time: 0.1456s
  BF16 time: 0.0561s
  Speedup: 2.60x
  Cosine similarity: 1.001910
  Pearson correlation: 0.999928
  Relative error: 15.8799%
  MSE: 1.01e-04

File 3/16: final_filtered_256_stripped.a3m
  Sequences in MSA: 230
  Query length: 159
  FP32 time: 0.2435s
  BF16 time: 0.0895s
  Speedup: 2.72x
  Cosine similarity: 1.004292
  Pearson correlation: 0.999965
  Relative error: 11.7620%
  MSE: 5.37e-05

File 4/16: final_filtered_256_stripped.a3m
  Sequences in MSA: 256
  Query length: 75
  FP32 time: 0.1364s
  BF16 time: 0.0543s
  Speedup: 2.51x
  Cosin

In [7]:
# ================================================================
# MEM-EFFICIENT-SDPA SPEED TEST
#   – BF16 autocast in both cases
#   – compares PyTorch defaults vs explicit mem_efficient_sdp
# ================================================================
import os, random, time, statistics, torch, numpy as np
from pathlib import Path
from typing import List, Tuple
import random, time, statistics
import torch
from torch import autocast
import numpy as np
from pathlib import Path
from typing import List, Tuple
import torch.nn.functional as F

# ------------------------------
#  CONFIGURATION
# ------------------------------
DATA_ROOT        = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH/train_pdbch")
A3M_NAME         = "final_filtered_256_stripped.a3m"
N_TEST_FILES     = 16        # ≥ number of *.a3m to draw
SEED             = 42
DEVICE           = torch.device("cuda" if torch.cuda.is_available() else "cpu")

assert DEVICE.type == "cuda", "Need GPU for this test"
assert torch.cuda.get_device_capability(DEVICE)[0] >= 8, "Need Ampere+ SM 80"

def find_a3m_files(root: Path, name: str) -> List[Path]:
    return list(root.rglob(name))

def parse_a3m(path: Path, max_sequences: int = None) -> List[str]:
    """Return list of sequences from A3M file, optionally limited."""
    seqs, seq = [], []
    with open(path, "r", encoding="utf-8") as fh:
        for line in fh:
            line = line.rstrip()
            if not line:
                continue
            if line.startswith(">"):
                if seq:
                    seqs.append("".join(seq))
                    if max_sequences and len(seqs) >= max_sequences:
                        break
                seq = []
            else:
                seq.append(line)
        if seq and (not max_sequences or len(seqs) < max_sequences):
            seqs.append("".join(seq))
    
    if not seqs:
        raise ValueError(f"{path}: no sequences found")
    
    return seqs

def get_model_output(model, batch_converter, seqs, device, use_autocast=False):
    """Get model representations and timing."""
    msa = [(f"seq{i}", s) for i, s in enumerate(seqs)]
    _, _, tok = batch_converter([msa])
    tok = tok.to(device, non_blocking=True)

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    t0 = time.perf_counter()

    with torch.no_grad():
        ctx = autocast(
            device_type=device.type,
            enabled=use_autocast,
            dtype=torch.bfloat16,
        )
        with ctx:
            output = model(tok, repr_layers=[12])["representations"][12]

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    
    elapsed_time = time.perf_counter() - t0
    return output, elapsed_time

# ------------------------------
#  SDPA BACKEND CONTROL
# ------------------------------
def set_pytorch_defaults():
    """Leave PyTorch SDPA settings at their defaults (do nothing)."""
    # Don't touch any backends - let PyTorch decide
    print("    Using PyTorch default SDPA settings")

def set_mem_efficient_explicit():
    """Explicitly enable mem_efficient_sdp and disable math_sdp."""
    # Don't touch flash_sdp - leave it as is
    torch.backends.cuda.enable_mem_efficient_sdp(True)
    torch.backends.cuda.enable_math_sdp(False)
    print("    Explicitly enabled mem_efficient_sdp, disabled math_sdp")

def print_current_backends():
    """Print current state of SDPA backends."""
    flash_enabled = torch.backends.cuda.flash_sdp_enabled()
    mem_enabled = torch.backends.cuda.mem_efficient_sdp_enabled()
    math_enabled = torch.backends.cuda.math_sdp_enabled()
    print(f"    Current backends - Flash: {flash_enabled}, MemEff: {mem_enabled}, Math: {math_enabled}")

# ------------------------------
#  BENCHMARK DRIVER
# ------------------------------
def benchmark_sdpa_backends(test_files: List[Path]):
    import esm
    model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
    model.eval().to(DEVICE)
    batch_converter = alphabet.get_batch_converter()

    times_default, times_mem_efficient = [], []

    for i, fp in enumerate(test_files, 1):
        seqs = parse_a3m(fp)                 # full MSA
        print(f"[{i:02}/{len(test_files)}] {fp.name:45s}  {len(seqs):3d} seqs  len={len(seqs[0])}")

        # ---------- PyTorch Defaults ----------
        set_pytorch_defaults()
        print_current_backends()
        _, t_default = get_model_output(model, batch_converter, seqs, DEVICE, use_autocast=True)
        times_default.append(t_default)

        # ---------- Explicit Mem Efficient ----------
        set_mem_efficient_explicit()
        print_current_backends()
        _, t_mem_eff = get_model_output(model, batch_converter, seqs, DEVICE, use_autocast=True)
        times_mem_efficient.append(t_mem_eff)

        print(f"    PyTorch Default    : {t_default:7.3f}s")
        print(f"    Mem Efficient Expl : {t_mem_eff:7.3f}s   ratio {t_default/t_mem_eff:4.2f}×")
        print()

    # ---------- SUMMARY ----------
    avg_default = statistics.mean(times_default)
    avg_mem_eff = statistics.mean(times_mem_efficient)
    print("\n" + "-"*60)
    print(f"Average PyTorch Default    : {avg_default:7.3f}s")
    print(f"Average Mem Efficient Expl : {avg_mem_eff:7.3f}s")
    print(f"Overall ratio              : {avg_default/avg_mem_eff:4.2f}×  "
          f"(median {statistics.median(times_default)/statistics.median(times_mem_efficient):4.2f}×)")
    print("-"*60)

    # Print final backend state
    print("\nFinal backend state:")
    print_current_backends()

# ------------------------------
#  MAIN EXECUTION
# ------------------------------
random.seed(SEED)
all_a3m = find_a3m_files(DATA_ROOT, A3M_NAME)
if len(all_a3m) < N_TEST_FILES:
    raise RuntimeError(f"Need ≥{N_TEST_FILES} files, found {len(all_a3m)}")

test_files = random.sample(all_a3m, N_TEST_FILES)
print(f"Selected {N_TEST_FILES} random A3M files (seed={SEED})")
print(f"Device: {DEVICE}")

# Show initial backend state
print("\nInitial backend state:")
print_current_backends()
print()

benchmark_sdpa_backends(test_files)

Selected 16 random A3M files (seed=42)
Device: cuda

Initial backend state:
    Current backends - Flash: True, MemEff: False, Math: False

[01/16] final_filtered_256_stripped.a3m                256 seqs  len=183
    Using PyTorch default SDPA settings
    Current backends - Flash: True, MemEff: False, Math: False
    Explicitly enabled mem_efficient_sdp, disabled math_sdp
    Current backends - Flash: True, MemEff: True, Math: False
    PyTorch Default    :   0.090s
    Mem Efficient Expl :   0.090s   ratio 1.00×

[02/16] final_filtered_256_stripped.a3m                256 seqs  len=82
    Using PyTorch default SDPA settings
    Current backends - Flash: True, MemEff: True, Math: False
    Explicitly enabled mem_efficient_sdp, disabled math_sdp
    Current backends - Flash: True, MemEff: True, Math: False
    PyTorch Default    :   0.048s
    Mem Efficient Expl :   0.048s   ratio 1.00×

[03/16] final_filtered_256_stripped.a3m                230 seqs  len=159
    Using PyTorch default S