In [None]:

def print_all_gradients(model, epoch=None):
    """打印模型中所有可训练参数的梯度信息"""
    if epoch is not None:
        print(f"\n=== Epoch {epoch} 梯度信息 ===")
    else:
        print(f"\n=== 当前梯度信息 ===")
    
    # 收集所有梯度信息
    grad_data = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad.data
            grad_info = {
                "name": name,
                "shape": tuple(grad.shape),
                "dtype": str(grad.dtype),
                "norm": grad.norm().item(),
                "min": grad.min().item(),
                "max": grad.max().item(),
                "mean": grad.mean().item(),
                "nan": torch.isnan(grad).sum().item(),
                "inf": torch.isinf(grad).sum().item(),
                "zero": (grad == 0).sum().item()
            }
            grad_data.append(grad_info)
    
    # 按梯度范数排序（从大到小）
    grad_data.sort(key=lambda x: x["norm"], reverse=True)
    
    # 打印表格头
    print(f"{'参数名称':<40} | {'形状':<20} | {'范数':>10} | {'NaN':>5} | {'Inf':>5} | {'零值%':>6} | {'范围'}")
    print("-" * 120)
    
    # 打印每个参数的梯度信息
    for info in grad_data:
        zero_percent = info["zero"] / torch.numel(torch.zeros(info['shape'])) * 100
        range_str = f"[{info['min']:.3e}, {info['max']:.3e}]"
        
        # 高亮异常梯度
        if info["nan"] > 0 or info["inf"] > 0:
            highlight = "\033[91m"  # 红色
            reset = "\033[0m"
        else:
            highlight = reset = ""
        
        print(f"{highlight}{info['name']:<40} | {str(info['shape']):<20} | "
            f"{info['norm']:>10.3e} | "
            f"{info['nan']:>5} | "
            f"{info['inf']:>5} | "
            f"{zero_percent:>5.1f}% | "
            f"{range_str}{reset}")
    
    # 打印统计摘要
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    grad_params = sum(info['shape'][0] for info in grad_data)
    nan_params = sum(info['nan'] for info in grad_data)
    inf_params = sum(info['inf'] for info in grad_data)
    
    print("\n摘要:")
    print(f"• 总可训练参数: {total_params:,}")
    print(f"• 有梯度的参数: {grad_params} ({grad_params/total_params:.1%})")
    print(f"• NaN 梯度值总数: {nan_params}")
    print(f"• Inf 梯度值总数: {inf_params}")

In [3]:
# test_attention_sdpa.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from contextlib import nullcontext

# -------------------- Config --------------------
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("PyTorch:", torch.__version__)
_HAS_SDPA = hasattr(F, "scaled_dot_product_attention")
print("scaled_dot_product_attention available:", _HAS_SDPA)

# -------------------- Helpers --------------------
def _mask_to_bool(mask):
    """Convert mask (None or 0/1 int tensor or bool) -> boolean mask where True = invalid (padding)."""
    if mask is None:
        return None
    return (mask == 0) if mask.dtype != torch.bool else mask

def _stable_softmax_on_last_dim(logits, dim, attn_mask_bool=None):
    """Perform softmax in float32 for stability, then cast back to logits.dtype."""
    logits_f32 = logits.float()
    if attn_mask_bool is not None:
        logits_f32 = logits_f32.masked_fill(attn_mask_bool, -1e9)
    probs_f32 = F.softmax(logits_f32, dim=dim)
    return probs_f32.to(logits.dtype)

def _sdpa_forward(q, k, v, mask_bool=None, bias_add=None, dropout_p=0.0, is_causal=False):
    """
    q: (B, Q, H, D)
    k: (B, K, H, D)
    v: (B, K, H, D)
    mask_bool: either None, (B, L) or (B, Q, K) boolean mask with True=invalid
    bias_add: optional additive bias to logits (broadcastable)
    returns: (B, Q, H, D)
    """
    q_sd = q.permute(0, 2, 1, 3).contiguous()  # (B, H, Q, D)
    k_sd = k.permute(0, 2, 1, 3).contiguous()  # (B, H, K, D)
    v_sd = v.permute(0, 2, 1, 3).contiguous()  # (B, H, K, D)
    B, H, Q, D = q_sd.shape
    _, _, K, _ = k_sd.shape

    attn_mask = None
    if mask_bool is not None:
        if mask_bool.dim() == 2:
            # mask_bool: (B, L) -> build (B, Q, K)
            mask_q = mask_bool[:, :Q]
            mask_k = mask_bool[:, :K]
            mask_2d = mask_q.unsqueeze(2) | mask_k.unsqueeze(1)  # (B, Q, K)
            mask_float = mask_2d.to(q_sd.dtype) * (-1e9)
        elif mask_bool.dim() == 3:
            # already (B, Q, K)
            mask_float = mask_bool.to(q_sd.dtype) * (-1e9)
        else:
            raise RuntimeError("Unsupported mask_bool dim in SDPA wrapper")
        attn_mask = mask_float.unsqueeze(1)  # (B,1,Q,K) to broadcast to heads

    if bias_add is not None:
        b = bias_add
        # attempt common rearrangements to (B, H, Q, K)
        if b.dim() == 4 and b.shape[-1] == H:
            b = b.permute(0, 3, 1, 2)  # (B,H,Q,K)
        elif b.dim() == 3 and b.shape[-1] == H and b.shape[1] == K:
            b = b.permute(0, 2, 1).unsqueeze(2)  # (B,H,1,K)
        b = b.to(q_sd.dtype)
        if attn_mask is None:
            attn_mask = b
        else:
            attn_mask = attn_mask + b

    # call PyTorch SDPA
    out_sd = F.scaled_dot_product_attention(q_sd, k_sd, v_sd, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
    out = out_sd.permute(0, 2, 1, 3).contiguous()  # (B, Q, H, D)
    return out

# -------------------- Attention class (use_sdpa optional) --------------------
class Attention(nn.Module):
    def __init__(self, d_query, d_key, n_head, d_hidden, d_out, use_sdpa=False):
        super().__init__()
        self.h = n_head
        self.dim = d_hidden
        self.to_q = nn.Linear(d_query, n_head * d_hidden, bias=False)
        self.to_k = nn.Linear(d_key, n_head * d_hidden, bias=False)
        self.to_v = nn.Linear(d_key, n_head * d_hidden, bias=False)
        self.to_out = nn.Linear(n_head * d_hidden, d_out)
        self.scaling = 1.0 / math.sqrt(d_hidden)
        self.use_sdpa = use_sdpa and _HAS_SDPA

        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_k.weight)
        nn.init.xavier_uniform_(self.to_v.weight)

    def forward(self, query, key, value, mask=None):
        """
        mask: None or (B, L) with 1=valid, 0=invalid (padding)
        output: (B, Q, d_out)
        """
        mask_bool = _mask_to_bool(mask)  # True = invalid
        B, Q = query.shape[:2]
        _, K = key.shape[:2]

        # project
        q_full = self.to_q(query).reshape(B, Q, self.h, self.dim)
        k = self.to_k(key).reshape(B, K, self.h, self.dim)
        v = self.to_v(value).reshape(B, K, self.h, self.dim)

        if self.use_sdpa:
            # DO NOT apply manual scaling to q before SDPA (SDPA handles scaling)
            # build mask_2d if mask_bool is (B,L)
            if mask_bool is not None and mask_bool.dim() == 2:
                mask_2d = mask_bool[:, :Q].unsqueeze(2) | mask_bool[:, :K].unsqueeze(1)  # (B,Q,K)
            else:
                mask_2d = mask_bool
            out = _sdpa_forward(q_full, k, v, mask_bool=mask_2d, bias_add=None)
            out = out.reshape(B, Q, self.h * self.dim)
            out = self.to_out(out)
            if mask_bool is not None:
                out = out.masked_fill(mask_bool[:, :Q].unsqueeze(-1), 0.0)
            return out

        # fallback: scale q manually and use stable softmax
        q = q_full * self.scaling
        attn_logits = einsum('bqhd,bkhd->bhqk', q, k)  # (B, H, Q, K)

        if mask_bool is not None:
            mask_q = mask_bool[:, :Q]
            mask_k = mask_bool[:, :K]
            mask_2d = mask_q.unsqueeze(2) | mask_k.unsqueeze(1)  # (B, Q, K)
            attn_for_mask = mask_2d.unsqueeze(1)  # (B,1,Q,K)
        else:
            attn_for_mask = None

        attn = _stable_softmax_on_last_dim(attn_logits, dim=-1, attn_mask_bool=attn_for_mask)
        out = einsum('bhqk,bkhd->bqhd', attn, v)
        out = out.reshape(B, Q, self.h * self.dim)
        out = self.to_out(out)
        if mask_bool is not None:
            out = out.masked_fill(mask_bool[:, :Q].unsqueeze(-1), 0.0)
        return out

# -------------------- Test routine --------------------
import time
def run_test():
    # small random test settings
    B = 2
    Q = 5
    K = 6
    d_query = 160
    d_key = 160
    n_head = 40
    d_hidden = 80
    d_out = 200

    # generate random inputs
    query = torch.randn(B, Q, d_query, device=device)
    key   = torch.randn(B, K, d_key, device=device)
    value = torch.randn(B, K, d_key, device=device)
    mask = torch.ones(B, max(Q, K), dtype=torch.int8, device=device)
    mask[0, 3:] = 0
    mask[1, 4:] = 0

    # create model and save state
    model = Attention(d_query, d_key, n_head, d_hidden, d_out, use_sdpa=False).to(device)
    state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

    def load_state(m, s):
        mdl_state = {k: v.clone() for k, v in s.items()}
        device_ = next(m.parameters()).device if any(True for _ in m.parameters()) else device
        for k in mdl_state:
            mdl_state[k] = mdl_state[k].to(device_)
        m.load_state_dict(mdl_state)

    results = {}
    names = []

    # determine autocast context for current device
    if device.type == "cuda":
        autocast_ctx = torch.cuda.amp.autocast
    else:
        # prefer torch.cpu.amp.autocast if available, else nullcontext
        autocast_ctx = getattr(torch.cpu.amp, "autocast", nullcontext)

    # 1) FP32 baseline (no autocast), use_sdpa=False
    load_state(model, state)
    model.use_sdpa = False
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        out_fp32 = model(query, key, value, mask=mask)
        print(f"FP32 forward time: {time.time() - start_time:.6f} seconds")
    results['fp32'] = out_fp32.detach(); names.append('fp32')
    print("fp32: shape", out_fp32.shape, "dtype", out_fp32.dtype, "has_nan", torch.isnan(out_fp32).any().item())

    # 2) AMP (autocast) with fallback softmax (use_sdpa=False)
    load_state(model, state)
    model.use_sdpa = False
    model.eval()
    with torch.no_grad(), autocast_ctx():
        start_time = time.time()
        out_amp = model(query, key, value, mask=mask)
        print(f"AMP forward time: {time.time() - start_time:.6f} seconds")  
    results['amp'] = out_amp.detach(); names.append('amp')
    print("amp:  shape", out_amp.shape, "dtype", out_amp.dtype, "has_nan", torch.isnan(out_amp).any().item())

    # 3) SDPA path (if available) with autocast
    if _HAS_SDPA:
        load_state(model, state)
        model.use_sdpa = True
        model.eval()
        with torch.no_grad(), autocast_ctx():
            start_time = time.time()
            out_sdpa = model(query, key, value, mask=mask)
            print(f"SDPA forward time: {time.time() - start_time:.6f} seconds")
        results['sdpa'] = out_sdpa.detach(); names.append('sdpa')
        print("sdpa: shape", out_sdpa.shape, "dtype", out_sdpa.dtype, "has_nan", torch.isnan(out_sdpa).any().item())
    else:
        print("SDPA unavailable; skipping sdpa run")

    # compute pairwise max absolute differences
    def max_abs_diff(a, b):
        return (a.to(torch.float32) - b.to(torch.float32)).abs().max().item()

    for i in range(len(names)):
        for j in range(i + 1, len(names)):
            print(f"max_abs_diff {names[i]} vs {names[j]} = {max_abs_diff(results[names[i]], results[names[j]]):.6e}")

    # print small slice of outputs for inspection
    for k in names:
        t = results[k].cpu()
        print(f"\n--- sample values for {k} (slice [0, :3, :6]) dtype={t.dtype} ---")
        print(t[0, :3, :6])

if __name__ == "__main__":
    run_test()


Device: cuda
PyTorch: 2.6.0+cu124
scaled_dot_product_attention available: True
FP32 forward time: 0.391782 seconds
fp32: shape torch.Size([2, 5, 200]) dtype torch.float32 has_nan False
AMP forward time: 0.149990 seconds
amp:  shape torch.Size([2, 5, 200]) dtype torch.float16 has_nan False
SDPA forward time: 0.000487 seconds
sdpa: shape torch.Size([2, 5, 200]) dtype torch.float16 has_nan False
max_abs_diff fp32 vs amp = 2.413094e-04
max_abs_diff fp32 vs sdpa = 2.335757e-04
max_abs_diff amp vs sdpa = 2.441406e-04

--- sample values for fp32 (slice [0, :3, :6]) dtype=torch.float32 ---
tensor([[ 0.0273,  0.0148, -0.0659, -0.0573,  0.2050,  0.0022],
        [ 0.0096,  0.0128, -0.0640, -0.0453,  0.2135,  0.0026],
        [ 0.0100,  0.0167, -0.0386, -0.0476,  0.1880, -0.0046]])

--- sample values for amp (slice [0, :3, :6]) dtype=torch.float16 ---
tensor([[ 0.0272,  0.0148, -0.0658, -0.0573,  0.2050,  0.0022],
        [ 0.0096,  0.0129, -0.0641, -0.0453,  0.2136,  0.0026],
        [ 0.0099,  

  with torch.no_grad(), autocast_ctx():
  with torch.no_grad(), autocast_ctx():


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
from einops import rearrange
from opt_einsum import contract as einsum

# =============================================================================
# 0. 辅助函数 (模拟外部依赖)
# =============================================================================
def init_lecun_normal(module):
    """模拟 rfdiffusion.util_module.init_lecun_normal"""
    nn.init.xavier_normal_(module.weight)
    return module

# =============================================================================
# 1. 定义模型 - Attention
# =============================================================================

class AttentionOriginal(nn.Module):
    def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
        super().__init__()
        self.h, self.dim = n_head, d_hidden
        self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
        self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
        self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
        self.to_out = nn.Linear(n_head*d_hidden, d_out)
        self.scaling = 1/math.sqrt(d_hidden)
        self.reset_parameter()
    def reset_parameter(self):
        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_k.weight)
        nn.init.xavier_uniform_(self.to_v.weight)
    def forward(self, query, key, value, mask=None):
        bool_mask = (mask == 0)
        B, Q, _, K = *query.shape[:2], *key.shape[:2]
        q = self.to_q(query).reshape(B, Q, self.h, self.dim)
        k = self.to_k(key).reshape(B, K, self.h, self.dim)
        v = self.to_v(value).reshape(B, K, self.h, self.dim)
        q = q * self.scaling
        attn = einsum('bqhd,bkhd->bhqk', q, k)
        mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
        attn = attn.masked_fill(mask_2d.unsqueeze(1), -1e9)
        attn = F.softmax(attn, dim=-1)
        out = einsum('bhqk,bkhd->bqhd', attn, v)
        out = out.reshape(B, Q, self.h*self.dim)
        out = self.to_out(out)
        if mask is not None: out = out.masked_fill(bool_mask.unsqueeze(-1), 0.0)
        return out

class AttentionOptimized(nn.Module):
    def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
        super().__init__()
        self.h, self.d_hidden = n_head, d_hidden
        self.to_q = nn.Linear(d_query, n_head * d_hidden, bias=False)
        self.to_kv = nn.Linear(d_key, 2 * n_head * d_hidden, bias=False)
        self.to_out = nn.Linear(n_head * d_hidden, d_out)
        self.reset_parameter()
    def reset_parameter(self):
        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_kv.weight)
    def forward(self, query, key, value, mask=None):
        q = self.to_q(query)
        k, v = self.to_kv(key).chunk(2, dim=-1)
        
        q = rearrange(q, 'b q (h d) -> b h q d', h=self.h)
        k = rearrange(k, 'b k (h d) -> b h k d', h=self.h)
        v = rearrange(v, 'b k (h d) -> b h k d', h=self.h)
        
        attn_mask = None
        if mask is not None:
            bool_mask = (mask == 0)
            mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
            attn_mask = mask_2d.unsqueeze(1)
            
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
        out = rearrange(out, 'b h q d -> b q (h d)')
        out = self.to_out(out)
        if mask is not None: out = out.masked_fill((mask == 0).unsqueeze(-1), 0.0)
        return out

# =============================================================================
# 2. 定义模型 - AttentionWithBias
# =============================================================================

class AttentionWithBiasOriginal(nn.Module):
    def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
        super().__init__()
        self.norm_in, self.norm_bias = nn.LayerNorm(d_in), nn.LayerNorm(d_bias)
        self.to_q, self.to_k, self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False), nn.Linear(d_in, n_head*d_hidden, bias=False), nn.Linear(d_in, n_head*d_hidden, bias=False)
        self.to_b, self.to_g, self.to_out = nn.Linear(d_bias, n_head, bias=False), nn.Linear(d_in, n_head*d_hidden), nn.Linear(n_head*d_hidden, d_in)
        self.scaling, self.h, self.dim = 1 / math.sqrt(d_hidden), n_head, d_hidden
        self.reset_parameter()
    def reset_parameter(self):
        for w in [self.to_q.weight, self.to_k.weight, self.to_v.weight]: nn.init.xavier_uniform_(w)
        self.to_b = init_lecun_normal(self.to_b)
        if hasattr(self.to_g, 'bias') and self.to_g.bias is not None:
            nn.init.ones_(self.to_g.bias)
    def forward(self, x, bias, mask=None):
        B,L = x.shape[:2]
        bool_mask = (mask == 0)
        x_norm, bias_norm = self.norm_in(torch.nan_to_num(x)), self.norm_bias(torch.nan_to_num(bias))
        query, key, value = self.to_q(x_norm).reshape(B, L, self.h, self.dim), self.to_k(x_norm).reshape(B, L, self.h, self.dim), self.to_v(x_norm).reshape(B, L, self.h, self.dim)
        bias_h, gate = self.to_b(bias_norm), torch.sigmoid(self.to_g(x_norm))
        key = key * self.scaling
        attn = einsum('bqhd,bkhd->bhqk', query, key)
        attn = rearrange(attn, 'b h q k -> b q k h') + bias_h
        if bool_mask.any():
            mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
            attn.masked_fill_(mask_2d.unsqueeze(-1), -1e9)
        attn = F.softmax(attn, dim=2)
        attn = rearrange(attn, 'b q k h -> b h q k')
        out = einsum('bhqk,bkhd->bqhd', attn, value).reshape(B, L, -1)
        out = gate * out
        out = self.to_out(out)
        if bool_mask.any(): out = out.masked_fill(bool_mask.unsqueeze(-1), 0.0)
        return out

class AttentionWithBiasOptimized(nn.Module):
    def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
        super().__init__()
        self.norm_in, self.norm_bias = nn.LayerNorm(d_in), nn.LayerNorm(d_bias)
        self.to_qkv = nn.Linear(d_in, 3 * n_head * d_hidden, bias=False)
        self.to_b, self.to_g, self.to_out = nn.Linear(d_bias, n_head, bias=False), nn.Linear(d_in, n_head*d_hidden), nn.Linear(n_head*d_hidden, d_in)
        self.h = n_head
        self.reset_parameter()
    def reset_parameter(self):
        nn.init.xavier_uniform_(self.to_qkv.weight)
        self.to_b = init_lecun_normal(self.to_b)
        if hasattr(self.to_g, 'bias') and self.to_g.bias is not None:
            nn.init.ones_(self.to_g.bias)
    def forward(self, x, bias, mask=None):
        x_norm, bias_norm = self.norm_in(torch.nan_to_num(x)), self.norm_bias(torch.nan_to_num(bias))
        q, k, v = self.to_qkv(x_norm).chunk(3, dim=-1)
        q, k, v = [rearrange(t, 'b l (h d) -> b h l d', h=self.h) for t in (q, k, v)]
        
        float_mask = self.to_b(bias_norm)
        float_mask = rearrange(float_mask, 'b q k h -> b h q k')

        if mask is not None:
            padding_mask = (mask == 0)
            mask_2d = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
            float_mask = float_mask.masked_fill(mask_2d.unsqueeze(1), -torch.inf)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=float_mask)
        out = rearrange(out, 'b h l d -> b l (h d)')
        
        gate = torch.sigmoid(self.to_g(x_norm))
        out = gate * out
        out = self.to_out(out)
        if mask is not None: out = out.masked_fill((mask == 0).unsqueeze(-1), 0.0)
        return out

# =============================================================================
# 3. 定义模型 - MSARowAttentionWithBias
# =============================================================================

class MSARowAttentionWithBiasOriginal(nn.Module):
    def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
        super().__init__()
        self.norm_msa, self.norm_pair = nn.LayerNorm(d_msa), nn.LayerNorm(d_pair)
        self.to_q, self.to_k, self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False), nn.Linear(d_msa, n_head*d_hidden, bias=False), nn.Linear(d_msa, n_head*d_hidden, bias=False)
        self.to_b, self.to_g, self.to_out = nn.Linear(d_pair, n_head, bias=False), nn.Linear(d_msa, n_head*d_hidden), nn.Linear(n_head*d_hidden, d_msa)
        self.scaling, self.h, self.dim = 1/math.sqrt(d_hidden), n_head, d_hidden
        self.reset_parameter()
    def reset_parameter(self):
        for w in [self.to_q.weight, self.to_k.weight, self.to_v.weight]: nn.init.xavier_uniform_(w)
        self.to_b = init_lecun_normal(self.to_b)
        if hasattr(self.to_g, 'bias') and self.to_g.bias is not None:
            nn.init.ones_(self.to_g.bias)
    def forward(self, msa, pair, mask=None):
        B, N, L, _ = msa.shape
        bool_mask = (mask == 0)
        msa_norm, pair_norm = self.norm_msa(torch.nan_to_num(msa)), self.norm_pair(torch.nan_to_num(pair))
        query, key, value = [t.reshape(B,N,L,self.h,self.dim) for t in (self.to_q(msa_norm), self.to_k(msa_norm), self.to_v(msa_norm))]
        bias, gate = self.to_b(pair_norm), torch.sigmoid(self.to_g(msa_norm))
        m_view = bool_mask.view(B, 1, L, 1, 1)
        query, value, key = query.masked_fill(m_view, 0.0), value.masked_fill(m_view, 0.0), key.masked_fill(m_view, 0.0)
        attn = einsum('bnqhd,bnkhd->bqkh', query*self.scaling, key) + bias
        mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
        if mask_2d.any(): attn = attn.masked_fill(mask_2d.view(B, L, L, 1), -1e9)
        attn = F.softmax(attn, dim=2)
        out = einsum('bqkh,bnkhd->bnqhd', attn, value).reshape(B, N, L, -1)
        out = gate * out
        out = self.to_out(out)
        if bool_mask.any(): out = out.masked_fill(bool_mask.view(B, 1, L, 1), 0.0)
        return out

class MSARowAttentionWithBiasOptimized(nn.Module):
    def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
        super().__init__()
        self.norm_msa, self.norm_pair = nn.LayerNorm(d_msa), nn.LayerNorm(d_pair)
        self.to_qkv = nn.Linear(d_msa, 3 * n_head * d_hidden, bias=False)
        self.to_b, self.to_g, self.to_out = nn.Linear(d_pair, n_head, bias=False), nn.Linear(d_msa, n_head*d_hidden), nn.Linear(n_head*d_hidden, d_msa)
        self.scaling, self.h, self.dim = 1/math.sqrt(d_hidden), n_head, d_hidden
        self.reset_parameter()
    def reset_parameter(self):
        nn.init.xavier_uniform_(self.to_qkv.weight)
        self.to_b = init_lecun_normal(self.to_b)
        if hasattr(self.to_g, 'bias') and self.to_g.bias is not None:
            nn.init.ones_(self.to_g.bias)
    def forward(self, msa, pair, mask=None):
        B, N, L, _ = msa.shape
        bool_mask = (mask == 0)
        msa_norm, pair_norm = self.norm_msa(torch.nan_to_num(msa)), self.norm_pair(torch.nan_to_num(pair))
        
        q, k, v = self.to_qkv(msa_norm).chunk(3, dim=-1)
        query, key, value = [t.reshape(B,N,L,self.h,self.dim) for t in (q, k, v)]

        bias, gate = self.to_b(pair_norm), torch.sigmoid(self.to_g(msa_norm))
        
        m_view = bool_mask.view(B, 1, L, 1, 1)
        query, value, key = query.masked_fill(m_view, 0.0), value.masked_fill(m_view, 0.0), key.masked_fill(m_view, 0.0)
        
        attn = einsum('bnqhd,bnkhd->bqkh', query*self.scaling, key) + bias
        mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
        if mask_2d.any(): attn = attn.masked_fill(mask_2d.view(B, L, L, 1), -1e9)
        
        attn = F.softmax(attn, dim=2)
        
        out = einsum('bqkh,bnkhd->bnqhd', attn, value).reshape(B, N, L, -1)
        out = gate * out
        out = self.to_out(out)
        if bool_mask.any(): out = out.masked_fill(bool_mask.view(B, 1, L, 1), 0.0)
        return out

# =============================================================================
# 4. 定义模型 - BiasedAxialAttention
# =============================================================================
class BiasedAxialAttentionOriginal(nn.Module):
    def __init__(self, d_pair, d_bias, n_head, d_hidden, is_row=True):
        super().__init__()
        self.is_row, self.norm_pair, self.norm_bias = is_row, nn.LayerNorm(d_pair), nn.LayerNorm(d_bias)
        self.to_q, self.to_k, self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False), nn.Linear(d_pair, n_head*d_hidden, bias=False), nn.Linear(d_pair, n_head*d_hidden, bias=False)
        self.to_b, self.to_g, self.to_out = nn.Linear(d_bias, n_head, bias=False), nn.Linear(d_pair, n_head*d_hidden), nn.Linear(n_head*d_hidden, d_pair)
        self.scaling, self.h, self.dim = 1/math.sqrt(d_hidden), n_head, d_hidden
        self.reset_parameter()
    def reset_parameter(self):
        for w in [self.to_q.weight, self.to_k.weight, self.to_v.weight]: nn.init.xavier_uniform_(w)
        self.to_b = init_lecun_normal(self.to_b)
        if hasattr(self.to_g, 'bias') and self.to_g.bias is not None:
            nn.init.ones_(self.to_g.bias)
    def forward(self, pair, bias, mask=None):
        B, L, _, _ = pair.shape
        bool_mask = (mask == 0)
        mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
        safe_pair, safe_bias = torch.nan_to_num(pair), torch.nan_to_num(bias)
        if self.is_row:
            safe_pair, safe_bias = safe_pair.permute(0, 2, 1, 3), safe_bias.permute(0, 2, 1, 3)
            mask_2d = mask_2d.permute(0, 2, 1)
        pair_norm, bias_norm = self.norm_pair(safe_pair), self.norm_bias(safe_bias)
        query, key, value = [t.reshape(B,L,L,self.h,self.dim) for t in (self.to_q(pair_norm), self.to_k(pair_norm), self.to_v(pair_norm))]
        bias_h, gate = self.to_b(bias_norm), torch.sigmoid(self.to_g(pair_norm))
        query, key = query * self.scaling, key / math.sqrt(L)
        attn = einsum('bnihd,bnjhd->bijh', query, key) + bias_h
        if mask_2d.any(): attn = attn.masked_fill(mask_2d.unsqueeze(-1), -1e9)
        attn = F.softmax(attn, dim=2)
        out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
        out = gate * out
        out = self.to_out(out)
        if self.is_row: out = out.permute(0, 2, 1, 3)
        if mask_2d.any(): out = out.masked_fill(mask_2d.unsqueeze(-1), 0.0)
        return out

class BiasedAxialAttentionOptimized(nn.Module):
    def __init__(self, d_pair, d_bias, n_head, d_hidden, is_row=True):
        super().__init__()
        self.is_row, self.norm_pair, self.norm_bias = is_row, nn.LayerNorm(d_pair), nn.LayerNorm(d_bias)
        self.to_qkv = nn.Linear(d_pair, 3 * n_head * d_hidden, bias=False)
        self.to_b, self.to_g, self.to_out = nn.Linear(d_bias, n_head, bias=False), nn.Linear(d_pair, n_head*d_hidden), nn.Linear(n_head*d_hidden, d_pair)
        self.scaling, self.h, self.dim = 1/math.sqrt(d_hidden), n_head, d_hidden
        self.reset_parameter()
    def reset_parameter(self):
        nn.init.xavier_uniform_(self.to_qkv.weight)
        self.to_b = init_lecun_normal(self.to_b)
        if hasattr(self.to_g, 'bias') and self.to_g.bias is not None:
            nn.init.ones_(self.to_g.bias)
    def forward(self, pair, bias, mask=None):
        B, L, _, _ = pair.shape
        bool_mask = (mask == 0)
        mask_2d = bool_mask.unsqueeze(1) | bool_mask.unsqueeze(2)
        safe_pair, safe_bias = torch.nan_to_num(pair), torch.nan_to_num(bias)
        if self.is_row:
            safe_pair, safe_bias = safe_pair.permute(0, 2, 1, 3), safe_bias.permute(0, 2, 1, 3)
            mask_2d = mask_2d.permute(0, 2, 1)
        
        pair_norm, bias_norm = self.norm_pair(safe_pair), self.norm_bias(safe_bias)
        
        q, k, v = self.to_qkv(pair_norm).chunk(3, dim=-1)
        query, key, value = [t.reshape(B,L,L,self.h,self.dim) for t in (q, k, v)]

        bias_h, gate = self.to_b(bias_norm), torch.sigmoid(self.to_g(pair_norm))
        query, key = query * self.scaling, key / math.sqrt(L)
        
        attn = einsum('bnihd,bnjhd->bijh', query, key) + bias_h
        if mask_2d.any(): attn = attn.masked_fill(mask_2d.unsqueeze(-1), -1e9)
        attn = F.softmax(attn, dim=2)
        out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
        out = gate * out
        out = self.to_out(out)
        if self.is_row: out = out.permute(0, 2, 1, 3)
        if mask_2d.any(): out = out.masked_fill(mask_2d.unsqueeze(-1), 0.0)
        return out

# =============================================================================
# 5. 测试函数
# =============================================================================

def copy_weights(model_original, model_optimized):
    """
    将原始模型的权重精确地复制到优化模型中。
    这个新版本更健壮，能正确处理所有情况。
    """
    # 获取两个模型的 state_dict
    orig_sd = model_original.state_dict()
    opt_sd = model_optimized.state_dict()
    
    new_opt_sd = {}
    
    # 遍历优化模型的每一个参数
    for name, param in opt_sd.items():
        if name == 'to_kv.weight':
            # 从原始模型拼接 k, v 权重
            k_w = orig_sd['to_k.weight']
            v_w = orig_sd['to_v.weight']
            new_opt_sd[name] = torch.cat([k_w, v_w], dim=0)
        elif name == 'to_qkv.weight':
            # 从原始模型拼接 q, k, v 权重
            q_w = orig_sd['to_q.weight']
            k_w = orig_sd['to_k.weight']
            v_w = orig_sd['to_v.weight']
            new_opt_sd[name] = torch.cat([q_w, k_w, v_w], dim=0)
        elif name in orig_sd:
            # 如果在原始模型中存在同名参数，直接复制
            new_opt_sd[name] = orig_sd[name]
        else:
            # 如果没有，保留优化模型自己的参数（这种情况不应发生）
            new_opt_sd[name] = param
            
    # 加载构建好的新 state_dict
    model_optimized.load_state_dict(new_opt_sd)


def check_consistency(model_original, model_optimized, inputs, device, name=""):
    """检查原始模型和优化模型的输出是否一致"""
    print(f"--- 正在检查 {name} 的一致性 ---")
    
    # 将模型和数据移到设备
    model_original.to(device)
    model_optimized.to(device)
    
    # **关键步骤：在比较前复制权重**
    copy_weights(model_original, model_optimized)
    
    inputs_dev = [t.to(device) for t in inputs if t is not None]

    # 设置为评估模式
    model_original.eval()
    model_optimized.eval()

    try:
        with torch.no_grad():
            output_original = model_original(*inputs_dev)
            output_optimized = model_optimized(*inputs_dev)

        # 比较输出
        are_close = torch.allclose(output_original, output_optimized, atol=1e-5, rtol=1e-4)
        max_diff = (output_original - output_optimized).abs().max().item()

        if are_close:
            print(f"✅ PASS: 输出一致。最大差异: {max_diff:.6e}")
        else:
            print(f"❌ FAIL: 输出不一致。最大差异: {max_diff:.6e}")

    except Exception as e:
        print(f"❌ ERROR: 在检查一致性时发生错误: {e}")
    
    print("-" * (20 + len(name)))


def benchmark(model, inputs, device, iterations=50):
    model.to(device)
    model.train() # 确保模型处于训练模式以进行反向传播
    inputs = [t.to(device) for t in inputs if t is not None]
    
    # Warmup
    for _ in range(5):
        output = model(*inputs)
        loss = output.sum()
        loss.backward()
        model.zero_grad()
    if device.type == 'cuda': torch.cuda.synchronize()
    
    # Benchmark
    if device.type == 'cuda': torch.cuda.reset_peak_memory_stats(device)
    start_time = time.perf_counter()
    for _ in range(iterations):
        output = model(*inputs)
        loss = output.sum()
        loss.backward()
        model.zero_grad()
    if device.type == 'cuda': torch.cuda.synchronize()
    end_time = time.perf_counter()
    
    avg_time_ms = (end_time - start_time) / iterations * 1000
    peak_memory_mb = torch.cuda.max_memory_allocated(device) / (1024 * 1024) if device.type == 'cuda' else 0
    return avg_time_ms, peak_memory_mb

# =============================================================================
# 6. 主执行逻辑
# =============================================================================
if __name__ == "__main__":
    # --- 参数 (调整为更能体现优化效果的真实负载) ---
    B, N, L, D_MSA, D_PAIR = 4, 64, 256, 64, 32
    N_HEAD, D_HIDDEN = 4, 16
    ITERATIONS = 50

    # --- 环境 ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"正在使用设备: {device}")
    if device.type == 'cuda':
        print(f"设备名称: {torch.cuda.get_device_name(device)}")
        if not hasattr(F, 'scaled_dot_product_attention'):
            print("\n警告: PyTorch 版本过低，不支持 SDPA，Attention 模块的优化效果将不明显。")
            AttentionOptimized = AttentionOriginal
    
    print("-" * 50)
    print("测试参数 (较大负载):")
    print(f"  Batch Size: {B}, MSA Depth: {N}, Seq Len: {L}")
    print(f"  D_MSA: {D_MSA}, D_PAIR: {D_PAIR}")
    print(f"  Heads: {N_HEAD}, Hidden Dim: {D_HIDDEN}")
    print("-" * 50)

    # --- 输入数据 ---
    query = torch.randn(B, L, D_MSA)
    key = torch.randn(B, L, D_MSA)
    msa = torch.randn(B, N, L, D_MSA)
    pair = torch.randn(B, L, L, D_PAIR)
    bias_for_attn = torch.randn(B, L, L, N_HEAD)
    bias_for_axial = torch.randn(B, L, L, N_HEAD)
    mask = torch.ones(B, L)
    mask[:, -L//4:] = 0 # Mask last quarter

    # --- 模型和测试配置 ---
    test_configs = {
        "Attention": {
            "models": (AttentionOriginal(D_MSA, D_MSA, N_HEAD, D_HIDDEN, D_MSA), 
                       AttentionOptimized(D_MSA, D_MSA, N_HEAD, D_HIDDEN, D_MSA)),
            "inputs": (query, key, key, mask)
        },
        "AttentionWithBias": {
            "models": (AttentionWithBiasOriginal(D_MSA, N_HEAD, N_HEAD, D_HIDDEN), 
                       AttentionWithBiasOptimized(D_MSA, N_HEAD, N_HEAD, D_HIDDEN)),
            "inputs": (query, bias_for_attn, mask)
        },
        "MSARowAttention": {
            "models": (MSARowAttentionWithBiasOriginal(D_MSA, D_PAIR, N_HEAD, D_HIDDEN),
                       MSARowAttentionWithBiasOptimized(D_MSA, D_PAIR, N_HEAD, D_HIDDEN)),
            "inputs": (msa, pair, mask)
        },
        "BiasedAxialAttention": {
            "models": (BiasedAxialAttentionOriginal(D_PAIR, N_HEAD, N_HEAD, D_HIDDEN),
                       BiasedAxialAttentionOptimized(D_PAIR, N_HEAD, N_HEAD, D_HIDDEN)),
            "inputs": (pair, bias_for_axial, mask)
        }
    }
    
    results = {}
    for name, config in test_configs.items():
        original_model, optimized_model = config["models"]
        inputs = config["inputs"]
        
        # 1. 检查一致性
        check_consistency(original_model, optimized_model, inputs, device, name)
        
        # 2. 运行性能基准测试
        print(f"--- 正在进行 {name} 的性能基准测试 ---")
        print("  - 原始版本...")
        orig_t, orig_m = benchmark(original_model, inputs, device, ITERATIONS)
        
        print("  - 优化版本...")
        opt_t, opt_m = benchmark(optimized_model, inputs, device, ITERATIONS)
        
        results[name] = {
            "orig_t": orig_t, "orig_m": orig_m,
            "opt_t": opt_t, "opt_m": opt_m,
            "time_gain": (orig_t - opt_t) / orig_t * 100,
            "mem_gain": (orig_m - opt_m) / orig_m * 100 if orig_m > 0 else 0
        }

    # --- 打印结果 ---
    print("\n" + "="*80)
    print("基准测试最终结果")
    print("="*80)
    header = f"| {'模块名':<22} | {'版本':<8} | {'时间 (ms/iter)':<16} | {'内存 (MB)':<12} | {'提升':<10} |"
    print(header)
    print(f"|{'-'*24}|{'-'*10}|{'-'*18}|{'-'*14}|{'-'*12}|")
    
    for name, res in results.items():
        print(f"| {name:<22} | {'原始':<8} | {res['orig_t']:<16.3f} | {res['orig_m']:<12.2f} | {'-':<10} |")
        print(f"| {'':<22} | {'优化':<8} | {res['opt_t']:<16.3f} | {res['opt_m']:<12.2f} | {res['time_gain']:>6.2f}% (T) |")
        if device.type == 'cuda':
            print(f"| {'':<22} | {'':<8} | {'':<16} | {'':<12} | {res['mem_gain']:>6.2f}% (M) |")
        print(f"|{'-'*24}|{'-'*10}|{'-'*18}|{'-'*14}|{'-'*12}|")
    print("="*80)

正在使用设备: cuda
设备名称: NVIDIA A800-SXM4-80GB
--------------------------------------------------
测试参数 (较大负载):
  Batch Size: 4, MSA Depth: 64, Seq Len: 256
  D_MSA: 64, D_PAIR: 32
  Heads: 4, Hidden Dim: 16
--------------------------------------------------
--- 正在检查 Attention 的一致性 ---
❌ FAIL: 输出不一致。最大差异: 7.560725e-01
-----------------------------
--- 正在进行 Attention 的性能基准测试 ---
  - 原始版本...
  - 优化版本...
--- 正在检查 AttentionWithBias 的一致性 ---
✅ PASS: 输出一致。最大差异: 4.768372e-07
-------------------------------------
--- 正在进行 AttentionWithBias 的性能基准测试 ---
  - 原始版本...
  - 优化版本...
--- 正在检查 MSARowAttention 的一致性 ---
✅ PASS: 输出一致。最大差异: 0.000000e+00
-----------------------------------
--- 正在进行 MSARowAttention 的性能基准测试 ---
  - 原始版本...
  - 优化版本...
--- 正在检查 BiasedAxialAttention 的一致性 ---
✅ PASS: 输出一致。最大差异: 0.000000e+00
----------------------------------------
--- 正在进行 BiasedAxialAttention 的性能基准测试 ---
  - 原始版本...
  - 优化版本...

基准测试最终结果
| 模块名                    | 版本       | 时间 (ms/iter)     | 内存 (MB)      | 提升        

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from opt_einsum import contract as einsum
from rfdiffusion.util_module import init_lecun_normal
from mytest.Attention_module import *

def print_all_gradients(model, epoch=None):
    """打印模型中所有可训练参数的梯度信息"""
    if epoch is not None:
        print(f"\n=== Epoch {epoch} 梯度信息 ===")
    else:
        print(f"\n=== 当前梯度信息 ===")
    
    # 收集所有梯度信息
    grad_data = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad.data
            grad_info = {
                "name": name,
                "shape": tuple(grad.shape),
                "dtype": str(grad.dtype),
                "norm": grad.norm().item(),
                "min": grad.min().item(),
                "max": grad.max().item(),
                "mean": grad.mean().item(),
                "nan": torch.isnan(grad).sum().item(),
                "inf": torch.isinf(grad).sum().item(),
                "zero": (grad == 0).sum().item()
            }
            grad_data.append(grad_info)
    
    # 按梯度范数排序（从大到小）
    grad_data.sort(key=lambda x: x["norm"], reverse=True)
    
    # 打印表格头
    print(f"{'参数名称':<40} | {'形状':<20} | {'范数':>10} | {'NaN':>5} | {'Inf':>5} | {'零值%':>6} | {'范围'}")
    print("-" * 120)
    
    # 打印每个参数的梯度信息
    for info in grad_data:
        if len(grad_data) > 0:
            zero_percent = info["zero"] / (info['shape'][0] if len(info['shape']) > 0 else 1) * 100
        else:
            zero_percent = 0
        range_str = f"[{info['min']:.3e}, {info['max']:.3e}]"
        
        # 高亮异常梯度
        if info["nan"] > 0 or info["inf"] > 0:
            highlight = "\033[91m"  # 红色
            reset = "\033[0m"
        else:
            highlight = reset = ""
        
        print(f"{highlight}{info['name']:<40} | {str(info['shape']):<20} | "
            f"{info['norm']:>10.3e} | "
            f"{info['nan']:>5} | "
            f"{info['inf']:>5} | "
            f"{zero_percent:>5.1f}% | "
            f"{range_str}{reset}")

def test_all_attention_modules():
    torch.manual_seed(42)
    B, L = 2, 5
    d_model, d_pair, d_bias = 64, 1, 32
    n_head, d_hidden = 4, 16
    
    print("="*80)
    print("测试所有 Attention 模块的梯度")
    print("="*80)

    # 1. 测试 FeedForwardLayer
    print("\n" + "="*50)
    print("1. 测试 FeedForwardLayer")
    print("="*50)
    
    data = torch.randn(B, L, d_model)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    data[0, 3:] = float('nan')
    data[1, :2] = float('nan')
    
    ffn = FeedForwardLayer(d_model, 4)
    optimizer = torch.optim.Adam(ffn.parameters(), lr=0.001)
    target = torch.randn_like(data)
    
    for i in range(3):
        optimizer.zero_grad()
        out = ffn(data)
        loss = F.mse_loss(out[mask], target[mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(ffn, epoch=i+1)
        optimizer.step()

    # 2. 测试 Attention
    print("\n" + "="*50)
    print("2. 测试 Attention")
    print("="*50)
    
    data = torch.randn(B, L, d_model)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    data[0, 3:] = float('nan')
    data[1, :2] = float('nan')
    
    attn = Attention(d_model, d_model, n_head, d_hidden, d_model)
    optimizer = torch.optim.Adam(attn.parameters(), lr=0.001)
    target = torch.randn_like(data)
    
    for i in range(3):
        optimizer.zero_grad()
        out = attn(data, data, data)
        loss = F.mse_loss(out[mask], target[mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(attn, epoch=i+1)
        optimizer.step()

    # 3. 测试 AttentionWithBias
    print("\n" + "="*50)
    print("3. 测试 AttentionWithBias")
    print("="*50)
    
    data = torch.randn(B, L, d_model)
    bias = torch.randn(B, L, L, d_bias)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    data[0, 3:] = float('nan')
    data[1, :2] = float('nan')
    
    attn_bias = AttentionWithBias(d_model, d_bias, n_head, d_hidden)
    optimizer = torch.optim.Adam(attn_bias.parameters(), lr=0.001)
    target = torch.randn_like(data)
    
    for i in range(3):
        optimizer.zero_grad()
        out = attn_bias(data, bias)
        loss = F.mse_loss(out[mask], target[mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(attn_bias, epoch=i+1)
        optimizer.step()

    # 4. 测试 SequenceWeight
    print("\n" + "="*60)
    print("4. 测试 SequenceWeight")
    print("="*60)

    msa = torch.randn(B, N, L, d_model)
    # 在MSA的pad位置设置为0
    msa_mask = mask.unsqueeze(1).expand(B, N, L)  # [B, N, L]
    msa[~msa_mask] = 0.0

    seq_weight = SequenceWeight(d_model, n_head, d_hidden)
    optimizer = torch.optim.Adam(seq_weight.parameters(), lr=0.001)
    target = torch.randn(B, N, L, n_head, 1)
    target_mask = mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1).expand_as(target)
    target[~target_mask] = 0.0

    for i in range(3):
        optimizer.zero_grad()
        out = seq_weight(msa, mask=mask)
        
        # 克隆输出以避免就地修改
        out_cloned = out.clone()
        print(out_cloned)
        
        loss = F.mse_loss(out_cloned[target_mask], target[target_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(seq_weight, epoch=i+1)
        optimizer.step()

    # 5. 测试 MSARowAttentionWithBias
    print("\n" + "="*50)
    print("5. 测试 MSARowAttentionWithBias")
    print("="*50)
    
    B, N, L = 2, 3, 5
    msa = torch.randn(B, N, L, d_model)
    pair = torch.randn(B, L, L, d_pair)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    msa[0, :, 3:] = float('nan')
    msa[1, :, :2] = float('nan')
    
    msa_row_attn = MSARowAttentionWithBias(d_model, d_pair, n_head, d_hidden)
    optimizer = torch.optim.Adam(msa_row_attn.parameters(), lr=0.001)
    target = torch.randn_like(msa)
    
    for i in range(3):
        optimizer.zero_grad()
        out = msa_row_attn(msa, pair)
        # 创建目标mask
        target_mask = mask.unsqueeze(1).unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[target_mask], target[target_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(msa_row_attn, epoch=i+1)
        optimizer.step()

    # 6. 测试 MSAColAttention
    print("\n" + "="*50)
    print("6. 测试 MSAColAttention")
    print("="*50)
    
    B, N, L = 2, 3, 5
    msa = torch.randn(B, N, L, d_model)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    msa[0, :, 3:] = float('nan')
    msa[1, :, :2] = float('nan')
    
    msa_col_attn = MSAColAttention(d_model, n_head, d_hidden)
    optimizer = torch.optim.Adam(msa_col_attn.parameters(), lr=0.001)
    target = torch.randn_like(msa)
    
    for i in range(3):
        optimizer.zero_grad()
        out = msa_col_attn(msa)
        # 创建目标mask
        target_mask = mask.unsqueeze(1).unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[target_mask], target[target_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(msa_col_attn, epoch=i+1)
        optimizer.step()

    # 7. 测试 MSAColGlobalAttention
    print("\n" + "="*50)
    print("7. 测试 MSAColGlobalAttention")
    print("="*50)
    
    B, N, L = 2, 3, 5
    msa = torch.randn(B, N, L, d_model)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    msa[0, :, 3:] = float('nan')
    msa[1, :, :2] = float('nan')
    
    msa_global_attn = MSAColGlobalAttention(d_model, n_head, d_hidden)
    optimizer = torch.optim.Adam(msa_global_attn.parameters(), lr=0.001)
    target = torch.randn_like(msa)
    
    for i in range(3):
        optimizer.zero_grad()
        out = msa_global_attn(msa)
        # 创建目标mask
        target_mask = mask.unsqueeze(1).unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[target_mask], target[target_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(msa_global_attn, epoch=i+1)
        optimizer.step()

    # 8. 测试 BiasedAxialAttention
    print("\n" + "="*50)
    print("8. 测试 BiasedAxialAttention")
    print("="*50)
    
    B, L = 2, 5
    pair = torch.randn(B, L, L, d_pair)
    bias = torch.randn(B, L, L, d_bias)
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 3:] = 0
    mask[1, :2] = 0
    
    # 在无效位置添加NaN
    pair[0, 3:, :] = float('nan')
    pair[0, :, 3:] = float('nan')
    pair[1, :2, :] = float('nan')
    pair[1, :, :2] = float('nan')
    
    axial_attn = BiasedAxialAttention(d_pair, d_bias, n_head, d_hidden, is_row=True)
    optimizer = torch.optim.Adam(axial_attn.parameters(), lr=0.001)
    target = torch.randn_like(pair)
    
    # 创建布尔掩码
    loss_mask = torch.zeros(B, L, L, dtype=torch.bool)
    loss_mask[0, :3, :3] = True
    loss_mask[1, 2:, 2:] = True
    loss_mask = loss_mask.unsqueeze(-1).expand_as(pair)
    
    for i in range(3):
        optimizer.zero_grad()
        out = axial_attn(pair, bias)
        loss = F.mse_loss(out[loss_mask], target[loss_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(axial_attn, epoch=i+1)
        optimizer.step()

    print("\n" + "="*80)
    print("所有模块测试完成!")
    print("="*80)

if __name__ == "__main__":
    test_all_attention_modules()

测试所有 Attention 模块的梯度

1. 测试 FeedForwardLayer
Epoch 1: Loss = 1.198163

=== Epoch 1 梯度信息 ===
参数名称                                     | 形状                   |         范数 |   NaN |   Inf |    零值% | 范围
------------------------------------------------------------------------------------------------------------------------
linear2.weight                           | (64, 256)            |  7.948e-01 |     0 |     0 | 1000.0% | [-4.731e-02, 4.004e-02]
linear1.weight                           | (256, 64)            |  3.954e-01 |     0 |     0 | 250.0% | [-2.223e-02, 1.733e-02]
linear2.bias                             | (64,)                |  1.308e-01 |     0 |     0 |   0.0% | [-5.037e-02, 3.718e-02]
linear1.bias                             | (256,)               |  5.741e-02 |     0 |     0 |   3.9% | [-9.596e-03, 1.318e-02]
norm.bias                                | (64,)                |  3.470e-02 |     0 |     0 |   0.0% | [-9.698e-03, 9.952e-03]
norm.weight                            

AttributeError: 'NoneType' object has no attribute 'to'

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from opt_einsum import contract as einsum
from rfdiffusion.util_module import init_lecun_normal
from mytest.Attention_module import *

def print_all_gradients(model, epoch=None):
    """打印模型中所有可训练参数的梯度信息"""
    if epoch is not None:
        print(f"\n=== Epoch {epoch} 梯度信息 ===")
    else:
        print(f"\n=== 当前梯度信息 ===")
    
    # 收集所有梯度信息
    grad_data = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad.data
            grad_info = {
                "name": name,
                "shape": tuple(grad.shape),
                "dtype": str(grad.dtype),
                "norm": grad.norm().item(),
                "min": grad.min().item(),
                "max": grad.max().item(),
                "mean": grad.mean().item(),
                "nan": torch.isnan(grad).sum().item(),
                "inf": torch.isinf(grad).sum().item(),
                "zero": (grad == 0).sum().item()
            }
            grad_data.append(grad_info)
    
    # 按梯度范数排序（从大到小）
    grad_data.sort(key=lambda x: x["norm"], reverse=True)
    
    # 打印表格头
    print(f"{'参数名称':<40} | {'形状':<20} | {'范数':>10} | {'NaN':>5} | {'Inf':>5} | {'零值%':>6} | {'范围'}")
    print("-" * 120)
    
    # 打印每个参数的梯度信息
    for info in grad_data:
        total_elements = 1
        for dim in info['shape']:
            total_elements *= dim
        zero_percent = info["zero"] / total_elements * 100 if total_elements > 0 else 0
        range_str = f"[{info['min']:.3e}, {info['max']:.3e}]"
        
        # 高亮异常梯度
        if info["nan"] > 0 or info["inf"] > 0:
            highlight = "\033[91m"  # 红色
            reset = "\033[0m"
        else:
            highlight = reset = ""
        
        print(f"{highlight}{info['name']:<40} | {str(info['shape']):<20} | "
            f"{info['norm']:>10.3e} | "
            f"{info['nan']:>5} | "
            f"{info['inf']:>5} | "
            f"{zero_percent:>5.1f}% | "
            f"{range_str}{reset}")

def test_all_attention_modules_with_padding():
    """测试所有注意力模块，确保pad为0"""
    torch.manual_seed(42)
    
    # 设置基本参数
    B, L = 2, 8
    N = 1  # MSA序列数
    d_model, d_pair, d_bias = 16,24,32
    n_head, d_hidden = 4, 4
    
    print("="*100)
    print("测试所有 Attention 模块 - 使用 padding=0")
    print("="*100)

    # 创建mask - pad位置为0，有效位置为1
    mask = torch.ones(B, L, dtype=torch.bool)
    mask[0, 5:] = 0  # 第一个batch后3位为pad
    mask[1, :3] = 0  # 第二个batch前3位为pad
    mask[1, 6:] = 0  # 第二个batch后2位为pad
    pad = float('nan')  # 使用NaN表示pad位置
    print(f"Mask shape: {mask.shape}")
    print(f"Valid positions - Batch 0: {mask[0].sum()}, Batch 1: {mask[1].sum()}")

    # 1. 测试 FeedForwardLayer
    print("\n" + "="*60)
    print("1. 测试 FeedForwardLayer")
    print("="*60)
    
    data = torch.randn(B, L, d_model)
    # 在pad位置设置为0而不是NaN
    data[~mask] = pad
    
    ffn = FeedForwardLayer(d_model, 4)
    optimizer = torch.optim.Adam(ffn.parameters(), lr=0.001)
    target = torch.randn_like(data)
    target[~mask] = 0.0
    
    for i in range(3):
        optimizer.zero_grad()
        out = ffn(data)
        # 只在有效位置计算loss
        loss = F.mse_loss(out[mask], target[mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(ffn, epoch=i+1)
        optimizer.step()

    # 2. 测试 Attention
    print("\n" + "="*60)
    print("2. 测试 Attention")
    print("="*60)
    
    data = torch.randn(B, L, d_model)
    data[~mask] = pad
    
    attn = Attention(d_model, d_model, n_head, d_hidden, d_model)
    optimizer = torch.optim.Adam(attn.parameters(), lr=0.001)
    target = torch.randn_like(data)
    target[~mask] = 0.0
    
    for i in range(3):
        optimizer.zero_grad()
        out = attn(data, data, data, mask=mask)
        print(out)
        loss = F.mse_loss(out[mask], target[mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(attn, epoch=i+1)
        optimizer.step()

    # 3. 测试 AttentionWithBias
    print("\n" + "="*60)
    print("3. 测试 AttentionWithBias")
    print("="*60)
    
    data = torch.randn(B, L, d_model)
    bias = torch.randn(B, L, L, d_bias)
    data[~mask] = pad
    # 修正bias的pad处理 - 需要正确的维度扩展
    bias_mask = mask.unsqueeze(1).unsqueeze(-1) & mask.unsqueeze(2).unsqueeze(-1)  # [B, L, L, 1]
    bias_mask = bias_mask.expand_as(bias)  # [B, L, L, d_bias]
    bias[~bias_mask] = pad
    
    attn_bias = AttentionWithBias(d_model, d_bias, n_head, d_hidden)
    optimizer = torch.optim.Adam(attn_bias.parameters(), lr=0.001)
    target = torch.randn_like(data)
    target[~mask] = pad
    
    for i in range(3):
        optimizer.zero_grad()
        out = attn_bias(data, bias, mask=mask)
        loss = F.mse_loss(out[mask], target[mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(attn_bias, epoch=i+1)
        optimizer.step()

    # 4. 测试 SequenceWeight
    print("\n" + "="*60)
    print("4. 测试 SequenceWeight")
    print("="*60)
    
    msa = torch.randn(B, N, L, d_model)
    # 在MSA的pad位置设置为0
    msa_mask = mask.unsqueeze(1).expand(B, N, L)  # [B, N, L]
    msa[~msa_mask] = pad
    
    seq_weight = SequenceWeight(d_model, n_head, d_hidden)
    optimizer = torch.optim.Adam(seq_weight.parameters(), lr=0.001)
    target = torch.randn(B, N, L, n_head, 1)
    target_mask = mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1).expand_as(target)
    target[~target_mask] = pad
    
    for i in range(3):
        optimizer.zero_grad()
        out = seq_weight(msa, mask=mask)
        loss = F.mse_loss(out[target_mask], target[target_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(seq_weight, epoch=i+1)
        optimizer.step()

    # 5. 测试 MSARowAttentionWithBias
    print("\n" + "="*60)
    print("5. 测试 MSARowAttentionWithBias")
    print("="*60)
    
    msa = torch.randn(B, N, L, d_model)
    pair = torch.randn(B, L, L, d_pair)
    msa[~msa_mask] = 0.0
    # pair的pad处理 - [B, L, L, d_pair]
    pair_mask = mask.unsqueeze(1) & mask.unsqueeze(2)  # [B, L, L]
    pair[~pair_mask.unsqueeze(-1).expand_as(pair)] = 0.0
    
    msa_row_attn = MSARowAttentionWithBias(d_model, d_pair, n_head, d_hidden)
    optimizer = torch.optim.Adam(msa_row_attn.parameters(), lr=0.001)
    target = torch.randn_like(msa)
    target[~msa_mask.unsqueeze(-1).expand_as(target)] = 0.0
    msa.requires_grad_(True)  # 确保MSA的梯度被计算
    for i in range(3):
        optimizer.zero_grad()
        out = msa_row_attn(msa, pair, mask=mask)
        print(out.shape)
        target_mask_expanded = msa_mask.unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[target_mask_expanded], target[target_mask_expanded])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print(msa.grad)
        print_all_gradients(msa_row_attn, epoch=i+1)
        optimizer.step()

    # 6. 测试 MSAColAttention
    print("\n" + "="*60)
    print("6. 测试 MSAColAttention")
    print("="*60)
    
    msa = torch.randn(B, N, L, d_model)
    msa[~msa_mask] = pad
    
    msa_col_attn = MSAColAttention(d_model, n_head, d_hidden)
    optimizer = torch.optim.Adam(msa_col_attn.parameters(), lr=0.001)
    target = torch.randn_like(msa)
    target[~msa_mask.unsqueeze(-1).expand_as(target)] = pad
    
    for i in range(3):
        optimizer.zero_grad()
        out = msa_col_attn(msa, mask=mask)
        target_mask_expanded = msa_mask.unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[target_mask_expanded], target[target_mask_expanded])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(msa_col_attn, epoch=i+1)
        optimizer.step()

    # 7. 测试 MSAColGlobalAttention
    print("\n" + "="*60)
    print("7. 测试 MSAColGlobalAttention")
    print("="*60)
    
    msa = torch.randn(B, N, L, d_model)
    msa[~msa_mask] = pad
    
    msa_global_attn = MSAColGlobalAttention(d_model, n_head, d_hidden)
    optimizer = torch.optim.Adam(msa_global_attn.parameters(), lr=0.001)
    target = torch.randn_like(msa)
    target[~msa_mask.unsqueeze(-1).expand_as(target)] = pad
    
    for i in range(3):
        optimizer.zero_grad()
        out = msa_global_attn(msa, mask=mask)
        target_mask_expanded = msa_mask.unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[target_mask_expanded], target[target_mask_expanded])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(msa_global_attn, epoch=i+1)
        optimizer.step()

    # 8. 测试 BiasedAxialAttention (行方向)
    print("\n" + "="*60)
    print("8. 测试 BiasedAxialAttention (行方向)")
    print("="*60)
    
    pair = torch.randn(B, L, L, d_pair)
    bias = torch.randn(B, L, L, d_bias)
    pair[~pair_mask.unsqueeze(-1).expand_as(pair)] = pad
    bias[~pair_mask.unsqueeze(-1).expand_as(bias)] = pad
    
    axial_attn_row = BiasedAxialAttention(d_pair, d_bias, n_head, d_hidden, is_row=True)
    optimizer = torch.optim.Adam(axial_attn_row.parameters(), lr=0.001)
    target = torch.randn_like(pair)
    target[~pair_mask.unsqueeze(-1).expand_as(target)] = pad
    
    for i in range(3):
        optimizer.zero_grad()
        out = axial_attn_row(pair, bias, mask=mask)
        loss_mask = pair_mask.unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[loss_mask], target[loss_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(axial_attn_row, epoch=i+1)
        optimizer.step()

    # 9. 测试 BiasedAxialAttention (列方向)
    print("\n" + "="*60)
    print("9. 测试 BiasedAxialAttention (列方向)")
    print("="*60)
    
    pair = torch.randn(B, L, L, d_pair)
    bias = torch.randn(B, L, L, d_bias)
    pair[~pair_mask.unsqueeze(-1).expand_as(pair)] = pad
    bias[~pair_mask.unsqueeze(-1).expand_as(bias)] = pad
    
    axial_attn_col = BiasedAxialAttention(d_pair, d_bias, n_head, d_hidden, is_row=False)
    optimizer = torch.optim.Adam(axial_attn_col.parameters(), lr=0.001)
    target = torch.randn_like(pair)
    target[~pair_mask.unsqueeze(-1).expand_as(target)] = pad
    
    for i in range(3):
        optimizer.zero_grad()
        out = axial_attn_col(pair, bias, mask=mask)
        loss_mask = pair_mask.unsqueeze(-1).expand_as(out)
        loss = F.mse_loss(out[loss_mask], target[loss_mask])
        print(f"Epoch {i+1}: Loss = {loss.item():.6f}")
        loss.backward()
        print_all_gradients(axial_attn_col, epoch=i+1)
        optimizer.step()

    print("\n" + "="*100)
    print("所有 Attention 模块测试完成！")
    print("="*100)
    
    # 测试总结
    print("\n测试总结:")
    print(f"- 批次大小: {B}")
    print(f"- 序列长度: {L}")
    print(f"- MSA序列数: {N}")
    print(f"- 模型维度: {d_model}")
    print(f"- 对维度: {d_pair}")
    print(f"- 偏置维度: {d_bias}")
    print(f"- 注意力头数: {n_head}")
    print(f"- 隐藏维度: {d_hidden}")
    print(f"- Padding策略: 无效位置设置为0")
    print(f"- 有效位置数量: Batch 0: {mask[0].sum()}, Batch 1: {mask[1].sum()}")

if __name__ == "__main__":
    test_all_attention_modules_with_padding()

测试所有 Attention 模块 - 使用 padding=0
Mask shape: torch.Size([2, 8])
Valid positions - Batch 0: 5, Batch 1: 3

1. 测试 FeedForwardLayer
Epoch 1: Loss = 1.179109

=== Epoch 1 梯度信息 ===
参数名称                                     | 形状                   |         范数 |   NaN |   Inf |    零值% | 范围
------------------------------------------------------------------------------------------------------------------------
[91mnorm.weight                              | (16,)                |        nan |    16 |     0 |   0.0% | [nan, nan][0m
norm.bias                                | (16,)                |  6.382e-02 |     0 |     0 |   0.0% | [-2.606e-02, 2.799e-02]
[91mlinear1.weight                           | (64, 16)             |        nan |  1024 |     0 |   0.0% | [nan, nan][0m
linear1.bias                             | (64,)                |  9.206e-02 |     0 |     0 |   3.1% | [-1.496e-02, 3.071e-02]
[91mlinear2.weight                           | (16, 64)             |        nan |  1024 | 

In [None]:
import torch.nn  as nn 
 
# 创建 LayerNorm 层 
layer = nn.LayerNorm(64)
print(layer)
# 打印所有参数及其名称 
for name, param in layer.named_parameters(): 
    print(f"参数名称: {name}, 形状: {param.shape},  数据类型: {param.dtype}") 

LayerNorm((64,), eps=1e-05, elementwise_affine=True)
参数名称: weight, 形状: torch.Size([64]),  数据类型: torch.float32
参数名称: bias, 形状: torch.Size([64]),  数据类型: torch.float32
