In [2]:
import torch

torch.finfo(torch.bfloat16).min

-3.3895313892515355e+38

In [4]:
"""
Row-self-attention equivalence test:
OpenFold vs a “flatten-and-permute” PyTorch port
"""

import math, torch, numpy as np
import torch.nn as nn
from typing import Optional, List, Tuple

torch.manual_seed(0)
np.random.seed(0)

# ----------------------------------------------------------------------
#  OpenFold helper layers (trimmed to what the test needs)
# ----------------------------------------------------------------------
def _prod(nums):          # utility
    out = 1
    for n in nums: out *= n
    return out

def _fan(shape, fan="fan_in"):
    fan_out, fan_in = shape
    if fan == "fan_in":  return fan_in
    if fan == "fan_out": return fan_out
    return (fan_in + fan_out) / 2

def lecun_normal_(w):
    from scipy.stats import truncnorm
    f = _fan(w.shape, "fan_in")
    scale = 1. / max(1, f)
    std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
    vals = truncnorm.rvs(-2, 2, loc=0, scale=std, size=_prod(w.shape))
    with torch.no_grad(): w.copy_(torch.tensor(vals.reshape(w.shape)))

class OFLinear(nn.Linear):
    def __init__(self, inp, outp, bias=False):
        super().__init__(inp, outp, bias=bias)
        lecun_normal_(self.weight)
        if bias: nn.init.zeros_(self.bias)

class OFLayerNorm(nn.Module):
    def __init__(self, d, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d))
        self.bias   = nn.Parameter(torch.zeros(d))
        self.eps    = eps
    def forward(self,x): return nn.functional.layer_norm(
        x, x.shape[-1:], self.weight, self.bias, self.eps)

# ----------------------------------------------------------------------
#  OpenFold Attention (row, no pair-bias)
# ----------------------------------------------------------------------
class OFAttention(nn.Module):
    def __init__(self, d, heads):
        super().__init__()
        dh = d // heads
        self.h = heads
        self.q = OFLinear(d, d, bias=False)
        self.k = OFLinear(d, d, bias=False)
        self.v = OFLinear(d, d, bias=False)
        self.o = OFLinear(d, d, bias=True)     # Wᵒ
        self.g = OFLinear(d, d, bias=True)     # Wᵍ
        self.sig = nn.Sigmoid()
        self.dh = dh
    def _split(self,t):        # [..., L, H*dh] -> [..., H, L, dh]
        t = t.view(*t.shape[:-1], self.h, self.dh)
        return t.transpose(-3,-2)
    def forward(self, x, mask=None):
        q,k,v = map(self._split,(self.q(x), self.k(x), self.v(x)))
        q = q / math.sqrt(self.dh)
        a = torch.matmul(q, k.transpose(-2,-1))   # [..., H, L, L]
        if mask is not None: a = a + mask         # additive (-inf) bias
        a = torch.softmax(a, dim=-1)
        o = torch.matmul(a, v)                    # [..., H, L, dh]
        o = o.transpose(-3,-2).reshape_as(x)
        o = self.o(o) * self.sig(self.g(x))
        return o

class OFRowAttention(nn.Module):
    def __init__(self, d, heads):
        super().__init__()
        self.norm = OFLayerNorm(d)
        self.att  = OFAttention(d, heads)
    def forward(self, m, mask=None):
        m_norm = self.norm(m)
        update = self.att(m_norm, mask)           # no pair-bias -> mask maybe None
        return m + update                         # residual inside

# ----------------------------------------------------------------------
#  Your (flatten-and-permute) implementation
# ----------------------------------------------------------------------
class RowwiseDropout(nn.Module):
    def __init__(self,p): super().__init__(); self.p=p
    def forward(self,x):
        if (not self.training) or self.p==0: return x
        B,N,*rest = x.shape
        mask = (torch.rand(B,N,1,1, device=x.device) > self.p).float()
        return x * mask / (1-self.p)

class IdentityLinear(nn.Linear):
    """acts like nn.Identity but still has .weight/.bias attrs"""
    def __init__(self, d):
        super().__init__(d,d, bias=False)
        nn.init.eye_(self.weight)
        for p in self.parameters(): p.requires_grad=False
    def forward(self,x): return x

class FlatRowAttention(nn.Module):
    def __init__(self, d, heads, p_drop=0.):
        super().__init__()
        self.norm = nn.LayerNorm(d)
        self.mha  = nn.MultiheadAttention(
            d, heads, dropout=p_drop, batch_first=True, bias=False)
        self.mha.out_proj = IdentityLinear(d)     # <-   fixed!
        self.Wo   = nn.Linear(d,d, bias=True)
        self.Wg   = nn.Linear(d,d, bias=True)
        nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias)
        nn.init.zeros_(self.Wg.weight); nn.init.ones_(self.Wg.bias)
        self.drop = RowwiseDropout(p_drop)
    def forward(self,x):                          # x:[B,N,L,D]
        B,N,L,D = x.shape
        x_n = self.norm(x)
        qkv = x_n.permute(0,2,1,3).reshape(B*L, N, D)
        y,_ = self.mha(qkv,qkv,qkv, need_weights=False)
        y = y.reshape(B,L,N,D).permute(0,2,1,3)
        y = self.Wo(y) * torch.sigmoid(self.Wg(x_n))
        return x + self.drop(y)                   # residual inside

# ----------------------------------------------------------------------
#  Helpers
# ----------------------------------------------------------------------
def make_data(B=2,N=16,L=32,D=128):
    msa  = torch.randn(B,N,L,D)
    mask = torch.ones(B,N,L)                      # could randomise later
    return msa, mask

def align_weights(of, flat):
    """copy weights so the two nets start identical"""
    flat.norm.weight.data.copy_(of.norm.weight)
    flat.norm.bias .data.copy_(of.norm.bias)
    # QKV
    qkv = flat.mha.in_proj_weight.chunk(3,0)
    of.att.q.weight.data.copy_(qkv[0])
    of.att.k.weight.data.copy_(qkv[1])
    of.att.v.weight.data.copy_(qkv[2])
    # Wo / Wg
    of.att.o.weight.data.copy_(flat.Wo.weight)
    of.att.o.bias .data.copy_(flat.Wo.bias)
    of.att.g.weight.data.copy_(flat.Wg.weight)
    of.att.g.bias .data.copy_(flat.Wg.bias)

def run_pair(B,N,L,D,H):
    msa,_ = make_data(B,N,L,D)
    of    = OFRowAttention(D,H).eval()
    flat  = FlatRowAttention(D,H).eval()
    align_weights(of, flat)
    with torch.no_grad():
        out_ref  = of(msa)             # [B,N,L,D]
        out_flat = flat(msa)           # same
    diff = (out_ref - out_flat).abs()
    print(f"B={B:2d} N={N:3d} L={L:3d} D={D:3d}  maxΔ={diff.max():.6g}  meanΔ={diff.mean():.6g}")
    return diff.max().item() < 1e-5

# ----------------------------------------------------------------------
#  Tests
# ----------------------------------------------------------------------
print("Testing equivalence …")
passed = True
for cfg in [(2,16,32,128,8),
            (1, 8,16, 64,4),
            (3,32,64,256,8)]:          # enlarge cautiously if RAM limited
    passed &= run_pair(*cfg)
print("\nFINAL RESULT :", "✅ SUCCESS" if passed else "❌ MISMATCH")


Testing equivalence …
B= 2 N= 16 L= 32 D=128  maxΔ=0  meanΔ=0
B= 1 N=  8 L= 16 D= 64  maxΔ=0  meanΔ=0
B= 3 N= 32 L= 64 D=256  maxΔ=0  meanΔ=0

FINAL RESULT : ✅ SUCCESS


In [3]:
import os, pathlib, subprocess, sys, importlib

# ── 1.  Clone Cutlass once if it isn't there ───────────────────────────────
cutlass_dir = pathlib.Path.home() / "cutlass"
if not cutlass_dir.exists():
    subprocess.check_call(["git", "clone", "--depth", "1",
                           "https://github.com/NVIDIA/cutlass", str(cutlass_dir)])

# ── 2.  Export CUTLASS_PATH for *this* Python process ──────────────────────
os.environ["CUTLASS_PATH"] = str(cutlass_dir)

# ── 3.  (Optional) ensure our change is permanent for future kernels ───────
home_bashrc = pathlib.Path.home() / ".bashrc"
line = f'export CUTLASS_PATH="{cutlass_dir}"\n'
if line not in home_bashrc.read_text():
    home_bashrc.open("a").write(line)


In [8]:
import torch
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention as evo

B, N_seq, N_res, H, D = 1, 32, 128, 4, 32
Q = torch.randn(B, N_seq, N_res, H, D, dtype=torch.float16, device="cuda")
K = torch.randn_like(Q);  V = torch.randn_like(Q)
# build a -1e9 additive bias *in the same dtype as Q*
pad_bias = torch.full((B, N_seq, 1, 1, N_res),
                      fill_value=-1e4,
                      dtype=Q.dtype,          # float16 or bfloat16
                      device="cuda")
pair_bias = torch.zeros(B, 1, H, N_res, N_res, dtype=torch.float16, device="cuda")

out = evo(Q, K, V, [pad_bias, pair_bias])
print("Success →", out.shape)


Success → torch.Size([1, 32, 128, 4, 32])


In [4]:
import torch, deepspeed
print(torch.__version__)        # 2.1.2+cu121
import deepspeed.ops.deepspeed4science.evoformer_attn as ds_evo
print(ds_evo.__file__)          # confirm you’re importing the patched file


2.1.2+cu121
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/deepspeed/ops/deepspeed4science/evoformer_attn.py


In [1]:
import importlib, os, subprocess, sys, re, json, torch

import deepspeed
from deepspeed.ops.deepspeed4science import evoformer_attn

print(f"torch  ⟶ {torch.__version__}")
print(f"deepspeed ⟶ {deepspeed.__version__}")
print(f"python path ⟶ {deepspeed.__file__}")

so_path = evoformer_attn.kernel_.__file__ if getattr(evoformer_attn, "kernel_", None) else "NOT LOADED"
print(f"evoformer_attn CUDA so ⟶ {so_path}")

if os.path.exists(so_path):
    try:
        # read the first 2k bytes to sniff for view/reshape strings
        with open(so_path, "rb") as f: header = f.read(2048)
        has_reshape = b".reshape(" in header or b"reshape(" in header
        print("patched reshape present?  ", has_reshape)
    except Exception as e:  
        print("could not inspect binary:", e)


[2025-07-12 01:50:59,708] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cpu (auto detect)


/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2025-07-12 01:51:06,574] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
torch  ⟶ 2.1.2+cu121
deepspeed ⟶ 0.17.2
python path ⟶ /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/deepspeed/__init__.py
evoformer_attn CUDA so ⟶ NOT LOADED


In [4]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128


Looking in indexes: https://download.pytorch.org/whl/cu128


In [2]:
import math
import torch, torch.nn as nn, torch.nn.functional as F

class AFNOMix2D(nn.Module):
    """
    Adaptive Fourier Neural Operator token mixer (2-D) with padding support.

    Args
    ----
    d_model : int   # channel dimension D
    num_blocks : int = 8      # block-diagonal chunks along D
    sparsity_thresh : float = 1e-2   # λ for soft-shrink
    hard_thresh_frac : float = 1.0   # keep lowest-freq fraction (≤ 1.0)
    hidden_factor : int = 1          # expansion in inner GELU
    dropout : float = 0.1
    """
    def __init__(
        self,
        d_model: int,
        num_blocks: int = 8,
        sparsity_thresh: float = 1e-2,
        hard_thresh_frac: float = 1.0,
        hidden_factor: int = 1,
        dropout: float = 0.1,
    ):
        super().__init__()
        assert d_model % num_blocks == 0, "d_model must divide num_blocks"
        blk = d_model // num_blocks

        self.norm   = nn.LayerNorm(d_model)
        self.drop   = nn.Dropout(dropout)
        self.gamma  = nn.Parameter(torch.full((d_model,), 0.1))

        scale = 0.02
        # real&imag weights packed in dim-0 (2, …)
        self.w1 = nn.Parameter(scale * torch.randn(2, num_blocks, blk,
                                                   blk * hidden_factor))
        self.b1 = nn.Parameter(scale * torch.randn(2, num_blocks,
                                                   blk * hidden_factor))
        self.w2 = nn.Parameter(scale * torch.randn(2, num_blocks,
                                                   blk * hidden_factor, blk))
        self.b2 = nn.Parameter(scale * torch.randn(2, num_blocks, blk))

        self.num_blocks      = num_blocks
        self.block_size      = blk
        self.sparsity_thresh = sparsity_thresh
        self.hard_frac       = hard_thresh_frac
        self.hidden_factor   = hidden_factor

    # ----------------------------------------------------------
    # helpers
    # ----------------------------------------------------------
    @staticmethod
    def _soft_shrink_complex(z: torch.Tensor, lambd: float):
        """Complex soft-shrink with safe div."""
        mag = z.abs()
        return z * F.relu(mag - lambd) / (mag + 1e-9)

    # ----------------------------------------------------------
    # forward
    # ----------------------------------------------------------
    def forward(self, x: torch.Tensor, pad: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x   : [B, S, L, D]  – real tensor
        pad : [B, S, L]     – True ⇒ padded token
        """
        B, S, L, D = x.shape
        assert D == self.block_size * self.num_blocks, "dim mismatch"

        # ------------------------------------------------------
        # 0) mask pads & layer-norm
        # ------------------------------------------------------
        x_masked = x.masked_fill(pad.unsqueeze(-1), 0.)
        x_n      = self.norm(x_masked)

        # ------------------------------------------------------
        # 1) 2-D real FFT  →  [B, S_hat, L_hat, D] (complex)
        # ------------------------------------------------------
        z = torch.fft.rfft2(x_n, dim=(-3, -2), norm="ortho")   # complex
        S_hat, L_hat = z.shape[-3:-1]          # S stays full, L halves

        # ------------------------------------------------------
        # 2) optional hard truncation of high-freq modes
        # ------------------------------------------------------
        if self.hard_frac < 1.0:
            k_S = int(S_hat * self.hard_frac)
            k_L = int(L_hat * self.hard_frac)
            z[..., k_S:, :, :] = 0
            z[..., :, k_L:, :] = 0

        # ------------------------------------------------------
        # 3) block-diagonal mixing in frequency domain
        # ------------------------------------------------------
        # reshape channel dim into (B, Ŝ, L̂, nb, blk)
        z = z.view(B, S_hat, L_hat, self.num_blocks, self.block_size)

        # split real/imag weights like the official code
        w1_r, w1_i = self.w1[0], self.w1[1]
        b1_r, b1_i = self.b1[0], self.b1[1]
        w2_r, w2_i = self.w2[0], self.w2[1]
        b2_r, b2_i = self.b2[0], self.b2[1]

        # first linear + GELU in complex domain
        o1_r = (torch.einsum('...bi,bij->...bj',  z.real, w1_r) -
                torch.einsum('...bi,bij->...bj',  z.imag, w1_i) + b1_r)
        o1_i = (torch.einsum('...bi,bij->...bj',  z.imag, w1_r) +
                torch.einsum('...bi,bij->...bj',  z.real, w1_i) + b1_i)
        o1   = torch.complex(o1_r, o1_i)
        o1   = torch.complex(F.gelu(o1.real), F.gelu(o1.imag))

        # second linear
        o2_r = (torch.einsum('...bi,bij->...bj',  o1.real, w2_r) -
                torch.einsum('...bi,bij->...bj',  o1.imag, w2_i) + b2_r)
        o2_i = (torch.einsum('...bi,bij->...bj',  o1.imag, w2_r) +
                torch.einsum('...bi,bij->...bj',  o1.real, w2_i) + b2_i)
        z    = torch.complex(o2_r, o2_i)

        # ------------------------------------------------------
        # 4) soft-shrink sparsity, reshape back
        # ------------------------------------------------------
        z = self._soft_shrink_complex(z, self.sparsity_thresh)
        z = z.view(B, S_hat, L_hat, D)

        # ------------------------------------------------------
        # 5) inverse FFT → residual + dropout, re-apply pad
        # ------------------------------------------------------
        y = torch.fft.irfft2(z, s=(S, L), dim=(-3, -2), norm="ortho")
        y = y.masked_fill(pad.unsqueeze(-1), 0.)

        return x + self.drop(self.gamma * y)


In [3]:
import torch
x   = torch.randn(2, 256, 1024, 768)
pad = torch.zeros (2, 256, 1024, dtype=torch.bool)
mix = AFNOMix2D(768)
y   = mix(x, pad)
assert y.shape == x.shape and y.dtype == x.dtype


: 