In [4]:
"""
Row-self-attention equivalence test:
OpenFold vs a “flatten-and-permute” PyTorch port
"""

import math, torch, numpy as np
import torch.nn as nn
from typing import Optional, List, Tuple

torch.manual_seed(0)
np.random.seed(0)

# ----------------------------------------------------------------------
#  OpenFold helper layers (trimmed to what the test needs)
# ----------------------------------------------------------------------
def _prod(nums):          # utility
    out = 1
    for n in nums: out *= n
    return out

def _fan(shape, fan="fan_in"):
    fan_out, fan_in = shape
    if fan == "fan_in":  return fan_in
    if fan == "fan_out": return fan_out
    return (fan_in + fan_out) / 2

def lecun_normal_(w):
    from scipy.stats import truncnorm
    f = _fan(w.shape, "fan_in")
    scale = 1. / max(1, f)
    std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
    vals = truncnorm.rvs(-2, 2, loc=0, scale=std, size=_prod(w.shape))
    with torch.no_grad(): w.copy_(torch.tensor(vals.reshape(w.shape)))

class OFLinear(nn.Linear):
    def __init__(self, inp, outp, bias=False):
        super().__init__(inp, outp, bias=bias)
        lecun_normal_(self.weight)
        if bias: nn.init.zeros_(self.bias)

class OFLayerNorm(nn.Module):
    def __init__(self, d, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d))
        self.bias   = nn.Parameter(torch.zeros(d))
        self.eps    = eps
    def forward(self,x): return nn.functional.layer_norm(
        x, x.shape[-1:], self.weight, self.bias, self.eps)

# ----------------------------------------------------------------------
#  OpenFold Attention (row, no pair-bias)
# ----------------------------------------------------------------------
class OFAttention(nn.Module):
    def __init__(self, d, heads):
        super().__init__()
        dh = d // heads
        self.h = heads
        self.q = OFLinear(d, d, bias=False)
        self.k = OFLinear(d, d, bias=False)
        self.v = OFLinear(d, d, bias=False)
        self.o = OFLinear(d, d, bias=True)     # Wᵒ
        self.g = OFLinear(d, d, bias=True)     # Wᵍ
        self.sig = nn.Sigmoid()
        self.dh = dh
    def _split(self,t):        # [..., L, H*dh] -> [..., H, L, dh]
        t = t.view(*t.shape[:-1], self.h, self.dh)
        return t.transpose(-3,-2)
    def forward(self, x, mask=None):
        q,k,v = map(self._split,(self.q(x), self.k(x), self.v(x)))
        q = q / math.sqrt(self.dh)
        a = torch.matmul(q, k.transpose(-2,-1))   # [..., H, L, L]
        if mask is not None: a = a + mask         # additive (-inf) bias
        a = torch.softmax(a, dim=-1)
        o = torch.matmul(a, v)                    # [..., H, L, dh]
        o = o.transpose(-3,-2).reshape_as(x)
        o = self.o(o) * self.sig(self.g(x))
        return o

class OFRowAttention(nn.Module):
    def __init__(self, d, heads):
        super().__init__()
        self.norm = OFLayerNorm(d)
        self.att  = OFAttention(d, heads)
    def forward(self, m, mask=None):
        m_norm = self.norm(m)
        update = self.att(m_norm, mask)           # no pair-bias -> mask maybe None
        return m + update                         # residual inside

# ----------------------------------------------------------------------
#  Your (flatten-and-permute) implementation
# ----------------------------------------------------------------------
class RowwiseDropout(nn.Module):
    def __init__(self,p): super().__init__(); self.p=p
    def forward(self,x):
        if (not self.training) or self.p==0: return x
        B,N,*rest = x.shape
        mask = (torch.rand(B,N,1,1, device=x.device) > self.p).float()
        return x * mask / (1-self.p)

class IdentityLinear(nn.Linear):
    """acts like nn.Identity but still has .weight/.bias attrs"""
    def __init__(self, d):
        super().__init__(d,d, bias=False)
        nn.init.eye_(self.weight)
        for p in self.parameters(): p.requires_grad=False
    def forward(self,x): return x

class FlatRowAttention(nn.Module):
    def __init__(self, d, heads, p_drop=0.):
        super().__init__()
        self.norm = nn.LayerNorm(d)
        self.mha  = nn.MultiheadAttention(
            d, heads, dropout=p_drop, batch_first=True, bias=False)
        self.mha.out_proj = IdentityLinear(d)     # <-   fixed!
        self.Wo   = nn.Linear(d,d, bias=True)
        self.Wg   = nn.Linear(d,d, bias=True)
        nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias)
        nn.init.zeros_(self.Wg.weight); nn.init.ones_(self.Wg.bias)
        self.drop = RowwiseDropout(p_drop)
    def forward(self,x):                          # x:[B,N,L,D]
        B,N,L,D = x.shape
        x_n = self.norm(x)
        qkv = x_n.permute(0,2,1,3).reshape(B*L, N, D)
        y,_ = self.mha(qkv,qkv,qkv, need_weights=False)
        y = y.reshape(B,L,N,D).permute(0,2,1,3)
        y = self.Wo(y) * torch.sigmoid(self.Wg(x_n))
        return x + self.drop(y)                   # residual inside

# ----------------------------------------------------------------------
#  Helpers
# ----------------------------------------------------------------------
def make_data(B=2,N=16,L=32,D=128):
    msa  = torch.randn(B,N,L,D)
    mask = torch.ones(B,N,L)                      # could randomise later
    return msa, mask

def align_weights(of, flat):
    """copy weights so the two nets start identical"""
    flat.norm.weight.data.copy_(of.norm.weight)
    flat.norm.bias .data.copy_(of.norm.bias)
    # QKV
    qkv = flat.mha.in_proj_weight.chunk(3,0)
    of.att.q.weight.data.copy_(qkv[0])
    of.att.k.weight.data.copy_(qkv[1])
    of.att.v.weight.data.copy_(qkv[2])
    # Wo / Wg
    of.att.o.weight.data.copy_(flat.Wo.weight)
    of.att.o.bias .data.copy_(flat.Wo.bias)
    of.att.g.weight.data.copy_(flat.Wg.weight)
    of.att.g.bias .data.copy_(flat.Wg.bias)

def run_pair(B,N,L,D,H):
    msa,_ = make_data(B,N,L,D)
    of    = OFRowAttention(D,H).eval()
    flat  = FlatRowAttention(D,H).eval()
    align_weights(of, flat)
    with torch.no_grad():
        out_ref  = of(msa)             # [B,N,L,D]
        out_flat = flat(msa)           # same
    diff = (out_ref - out_flat).abs()
    print(f"B={B:2d} N={N:3d} L={L:3d} D={D:3d}  maxΔ={diff.max():.6g}  meanΔ={diff.mean():.6g}")
    return diff.max().item() < 1e-5

# ----------------------------------------------------------------------
#  Tests
# ----------------------------------------------------------------------
print("Testing equivalence …")
passed = True
for cfg in [(2,16,32,128,8),
            (1, 8,16, 64,4),
            (3,32,64,256,8)]:          # enlarge cautiously if RAM limited
    passed &= run_pair(*cfg)
print("\nFINAL RESULT :", "✅ SUCCESS" if passed else "❌ MISMATCH")


Testing equivalence …
B= 2 N= 16 L= 32 D=128  maxΔ=0  meanΔ=0
B= 1 N=  8 L= 16 D= 64  maxΔ=0  meanΔ=0
B= 3 N= 32 L= 64 D=256  maxΔ=0  meanΔ=0

FINAL RESULT : ✅ SUCCESS
