In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mytest.RoseTTAFoldModel import RoseTTAFoldModule  # 假设模型保存在model.py中

def test_forward_backward():
    # 配置参数
    config = {
        'n_extra_block': 2,
        'n_main_block': 4,
        'n_ref_block': 4,
        'd_msa': 64,
        'd_msa_full': 64,
        'd_pair': 128,
        'd_templ': 64,
        'n_head_msa': 8,
        'n_head_pair': 4,
        'n_head_templ': 4,
        'd_hidden': 128,
        'd_hidden_templ': 64,
        'p_drop': 0.1,
        'd_t1d': 22,
        'd_t2d': 44,
        'T': 100,
        'use_motif_timestep': False,
        'freeze_track_motif': False,
        'SE3_param_full': {'l0_in_features': 32, 'l0_out_features': 16, 'num_edge_features': 32},
        'SE3_param_topk': {'l0_in_features': 32, 'l0_out_features': 16, 'num_edge_features': 32},
        'input_seq_onehot': True  # 确保使用one-hot输入
    }

    # 初始化模型
    model = RoseTTAFoldModule(**config)
    model.train()  # 设置为训练模式

    # 测试用例1: 不带模板/回收的基本前向传播
    def test_case1():
        B, L = 2, 50  # batch_size=2, seq_len=50
        N_latent, N_full = 5, 10  # MSA序列数
        
        # 生成虚拟输入数据 - 转换为one-hot编码
        msa_latent = F.one_hot(torch.randint(0, 21, (B, N_latent, L)), num_classes=21).float()
        msa_full = F.one_hot(torch.randint(0, 21, (B, N_full, L)), num_classes=21).float()
        seq = F.one_hot(torch.randint(0, 21, (B, L)), num_classes=21).float()
        xyz = torch.randn(B, L, 3, 3)  # 初始坐标 [N, CA, C]
        bond_mat = F.one_hot(torch.randint(0, 6, (B, L, L)), num_classes=6).float()  # 键类型 [0-5]
        idx = torch.arange(L).unsqueeze(0).repeat(B, 1)  # 残基索引
        t = torch.randint(0, 100, (B,))  # 扩散时间步
        motif_mask = torch.randint(0, 2, (B, L)).bool()  # 随机motif掩码
        
        # 前向传播
        outputs = model(
            msa_latent, msa_full, seq, xyz, bond_mat, idx, t,
            t1d=None, t2d=None, xyz_t=None, alpha_t=None,
            msa_prev=None, pair_prev=None, state_prev=None,
            return_raw=False, return_infer=False,
            motif_mask=motif_mask
        )
        
        # 验证输出完整性
        assert len(outputs) == 7
        logits, logits_aa, logits_exp, xyz_out, alpha_s, bond_matrix, lddt = outputs
        
        # 检查形状
        assert logits.shape == (B, L, L, 37)  # distogram+orientations
        assert logits_aa.shape == (B, L, 20)   # 氨基酸预测
        assert logits_exp.shape == (B, L, 2)    # 实验解析预测
        assert xyz_out.shape == (4, B, L, 3, 3)  # 4个中间结构
        assert alpha_s.shape == (4, B, L, 3)    # 旋转角
        assert bond_matrix.shape == (B, L, L, 6) # 键预测
        assert lddt.shape == (B, 50, L)          # lDDT分箱预测
        
        return outputs

    # 测试用例2: 带模板/回收的完整流程
    def test_case2():
        B, L = 2, 30
        N_latent, N_full, N_templ = 4, 8, 3
        
        # 生成输入数据 - 转换为one-hot编码
        msa_latent = F.one_hot(torch.randint(0, 21, (B, N_latent, L)), num_classes=21).float()
        msa_full = F.one_hot(torch.randint(0, 21, (B, N_full, L)), num_classes=21).float()
        seq = F.one_hot(torch.randint(0, 21, (B, L)), num_classes=21).float()
        xyz = torch.randn(B, L, 3, 3)
        bond_mat = F.one_hot(torch.randint(0, 6, (B, L, L)), num_classes=6).float()
        idx = torch.arange(L).unsqueeze(0).repeat(B, 1)
        t = torch.randint(0, 100, (B,))
        motif_mask = torch.ones(B, L).bool()  # 所有残基均为motif
        
        # 模板数据
        t1d = torch.randn(B, N_templ, L, config['d_t1d'])
        t2d = torch.randn(B, N_templ, L, L, config['d_t2d'])
        xyz_t = torch.randn(B, N_templ, L, 3, 3)
        alpha_t = torch.randn(B, N_templ, L, 3)
        
        # 回收数据
        msa_prev = torch.randn(B, L, config['d_msa'])
        pair_prev = torch.randn(B, L, L, config['d_pair'])
        state_prev = torch.randn(B, L, config['SE3_param_topk']['l0_out_features'])
        
        # 前向传播 (返回推理格式)
        outputs = model(
            msa_latent, msa_full, seq, xyz, bond_mat, idx, t,
            t1d=t1d, t2d=t2d, xyz_t=xyz_t, alpha_t=alpha_t,
            msa_prev=msa_prev, pair_prev=pair_prev, state_prev=state_prev,
            return_infer=True,
            motif_mask=motif_mask
        )
        
        # 验证输出
        assert len(outputs) == 8
        msa_s, pair_s, xyz, state, alpha, logits_aa, bond_matrix, pred_lddt = outputs
        
        assert msa_s.shape == (B, L, config['d_msa'])
        assert pair_s.shape == (B, L, L, config['d_pair'])
        assert xyz.shape == (B, L, 3, 3)
        assert pred_lddt.shape == (B, L)  # 标量pLDDT
        
        return outputs

    # 测试反向传播
    def test_backward():
        outputs = test_case1()
        logits, logits_aa, logits_exp, _, _, bond_matrix, lddt = outputs
        
        # 创建虚拟标签
        dist_label = torch.randint(0, 37, (B, L, L))
        aa_label = torch.randint(0, 20, (B, L))
        exp_label = torch.randint(0, 2, (B, L))
        bond_label = torch.randint(0, 6, (B, L, L))
        
        # 计算多任务损失
        loss_fn = nn.CrossEntropyLoss()
        loss = (
            loss_fn(logits.permute(0,3,1,2), dist_label) + 
            loss_fn(logits_aa.permute(0,2,1), aa_label) + 
            loss_fn(logits_exp.permute(0,2,1), exp_label) + 
            loss_fn(bond_matrix.permute(0,3,1,2), bond_label) + 
            lddt.mean()  # 简化处理
        )
        
        # 反向传播
        loss.backward()
        
        # 检查梯度是否存在
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert param.grad is not None
                assert not torch.isnan(param.grad).any()
        
        print("Backward pass completed without NaNs")

    # 执行测试
    print("Running Test Case 1...")
    test_case1()
    print("Test Case 1 passed!\n")
    
    print("Running Test Case 2...")
    test_case2()
    print("Test Case 2 passed!\n")
    
    print("Testing Backward Pass...")
    test_backward()
    print("Backward test passed!")

if __name__ == "__main__":
    test_forward_backward()

Running Test Case 1...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (500x21 and 48x64)