# RFDiffusion Breakdown
Forward add noising: add 3D Gaussian noise to Ca coordinates, use Brownian motion to add noise to manifold of rotation matrix representing N-CA-C orientation   
Backward denoising: using RosettaFold as denoising engine (take coorindates with noise as input into RosettaFold and output the 'right' structure )  



## Key Modules
Embedding module:
1. positionalEmbedding: relative postion encoding [B, L, L, d]
2. class MSA_emb: input MSA [B,N,L,d_init] Seq [B,L]          
output: msa_embedding :[B,N,L,d_msa] pair_embedding: [B,L,L,d_pair]
3. extra_emb: embedding for extra msa
4. TemplatePairStack: use PairStr2Pair module to embed template structure [L, L, d]
5. TemplateTorsionStack: use MLP+ AttionwithBias to embed torsion
6. Templ_emb: embed features including
    2d: 37 distogram bin (rbf) +6 orientation  
        mask (missing/unaligned) 1
    1d: tiled AA (20 +GAP); confidence: 1; contacting or not (1)

7. recycleing: take MSA Pair state as input and update them with a engine that do not update parameteres





Attention module

1. FeedForwardLayer
2. Attention (conventional)
3. AttentionwithBias: adding bias and gate on top of 2
4. SequenceWeight
5. MSARowAttentionwithBias
6. MSAColAttention
7. MSAGlobalAttention: share key value accoress all attention heads
8. BaisedAxialAttention

Track_module (similar to Evoformer)
#1. MSA -> MSA update (biased attention. bias from pair & structure)  
#2. Pair -> Pair update (biased attention. bias from structure)  
#3. MSA -> Pair update (extract coevolution signal)  
#4. Str -> Str update (node from MSA, edge from Pair)  

Class MSAPairStr2MSA:  
        '''  
        Inputs:  
            - msa: MSA feature (B, N, L, d_msa)  
            - pair: Pair feature (B, L, L, d_pair)  
            - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)  
            - xyz: xyz coordinates (B, L, n_atom, 3)  
            - state: updated node features after SE(3)-Transformer layer (B, L, d_state)  
        Output:  
            - msa: Updated MSA feature (B, N, L, d_msa)  

Class PairStr2Pair:
        inputs: pair [B, L, L, d_pair] 
                rbf_feat [B, L, L, d]

        operation:   row/ colum/ff, using rbf_feat as bias
        outputs: updated pair 

Class MSA2Pair: (alphafold2 algorithm 10)
        usign outer proecut mean to generate pair represenation from MSA
        Key step to flow MSA info into Pair representation

Class SCPred: (unique in RosttaFold)
        '''
        Predict side-chain torsion angles along with backbone torsions
        Inputs:
            - seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
            - state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
        engine: Many layer of residual connected MLP
        Outputs:
            - si: predicted torsion angles (phi, psi, omega, chi1~4 with cos/sin, Cb bend, Cb twist, CG) (B, L, 10, 2)
        '''
Class Str2Str: we breaddown with more details below.

Class IterBlock: conbine msa2msa msa2pair pair2pair str2str to build a basic iteration block

Class IterativeSimulator: use IterBlock and initial condition/values to update msa, pair, ratation, translation paramter, alpha and state


In [None]:
class Str2Str(nn.Module):
    def __init__(self, d_msa=256, d_pair=128, d_state=16, 
            SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
        super(Str2Str, self).__init__()
        
        # initial node & pair feature process
        self.norm_msa = nn.LayerNorm(d_msa) 
        self.norm_pair = nn.LayerNorm(d_pair)
        self.norm_state = nn.LayerNorm(d_state)
    
        self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
        self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
        self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
        
        self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
        self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features'])
        self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features'])
        
        self.se3 = SE3TransformerWrapper(**SE3_param)
        self.sc_predictor = SCPred(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
                                   p_drop=p_drop)
        
        self.reset_parameter()

    def reset_parameter(self):
        # initialize weights to normal distribution
        self.embed_x = init_lecun_normal(self.embed_x)
        self.embed_e1 = init_lecun_normal(self.embed_e1)
        self.embed_e2 = init_lecun_normal(self.embed_e2)

        # initialize bias to zeros
        nn.init.zeros_(self.embed_x.bias)
        nn.init.zeros_(self.embed_e1.bias)
        nn.init.zeros_(self.embed_e2.bias)
    
    @torch.cuda.amp.autocast(enabled=False)
    def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5):
        B, N, L = msa.shape[:3]

        if motif_mask is None:
            motif_mask = torch.zeros(L).bool()
        
        # process msa & pair features
        node = self.norm_msa(msa[:,0]) # use only query seq represenation
        pair = self.norm_pair(pair)
        state = self.norm_state(state)
       
        node = torch.cat((node, state), dim=-1) # add state and seq representation
        node = self.norm_node(self.embed_x(node)) # update node embedding
        pair = self.norm_edge1(self.embed_e1(pair)) # undate edge embedding
        
        neighbor = get_seqsep(idx) # get neighhours seq representation
        rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1])) #generate RBF feature based Euclidean distance 
        pair = torch.cat((pair, rbf_feat, neighbor), dim=-1) # use oriingal pair represenation, rbf feature, and neighours' feature to update edge 
        pair = self.norm_edge2(self.embed_e2(pair))
        
        # define graph
        if top_k != 0:
            G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
        else:
            G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
        l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
        l1_feats = l1_feats.reshape(B*L, -1, 3)
        
        # apply SE(3) Transformer & update coordinates
        shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)

        state = shift['0'].reshape(B, L, -1) # (B, L, C)
        
        offset = shift['1'].reshape(B, L, 2, 3)
        offset[:,motif_mask,...] = 0            # NOTE: motif mask is all zeros if not freeezing the motif 

        delTi = offset[:,:,0,:] / 10.0 # translation
        R = offset[:,:,1,:] / 100.0 # rotation
        
        Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
        qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm

        delRi = torch.zeros((B,L,3,3), device=xyz.device)
        delRi[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
        delRi[:,:,0,1] = 2*qB*qC - 2*qA*qD
        delRi[:,:,0,2] = 2*qB*qD + 2*qA*qC
        delRi[:,:,1,0] = 2*qB*qC + 2*qA*qD
        delRi[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
        delRi[:,:,1,2] = 2*qC*qD - 2*qA*qB
        delRi[:,:,2,0] = 2*qB*qD - 2*qA*qC
        delRi[:,:,2,1] = 2*qC*qD + 2*qA*qB
        delRi[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD

        Ri = einsum('bnij,bnjk->bnik', delRi, R_in)
        Ti = delTi + T_in #einsum('bnij,bnj->bni', delRi, T_in) + delTi
            
        alpha = self.sc_predictor(msa[:,0], state)
        return Ri, Ti, state, alpha