# Building GNN_1
## feature representation 
Inputs: X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all

node: CA atoms <br>
edge E_idx  : topk shortest distance based on CA-CA distance <br>
node_feature: h_V <br>
edge_feature: h_E, containing physical distance(rbf_size * # of pairs of [N, CA, C, O, <br>CB] ) and relative residue number <br>
masking: two type of masking for padding (based on chain_M), masking inter or intra <br>chains(based on chain_encoding_list)<br>

outputs: E_idx, h_E<br>
Next step:
h_V (initialize with zero) and h_E is used in GNN and through message passing to update h_E and h_V<br>


In [17]:
# we used output from featurize function
#example data
import numpy as np
from MPNN_featurize import featurize

batch = [
    {
        'seq_chain_A': 'MKLVFLVLLVFVQGF',
        'coords_chain_A': {'N_chain_A': np.random.rand(15, 3), 'CA_chain_A': np.random.rand(15, 3), 'C_chain_A': np.random.rand(15, 3), 'O_chain_A': np.random.rand(15, 3)},
        'seq_chain_B': 'MSVKVEEVG',
        'coords_chain_B': {'N_chain_B': np.random.rand(9, 3), 'CA_chain_B': np.random.rand(9, 3), 'C_chain_B': np.random.rand(9, 3), 'O_chain_B': np.random.rand(9, 3)},
        'seq_chain_C': 'ATCGATCGATCGATCG',
        'coords_chain_C': {'N_chain_C': np.random.rand(16, 3), 'CA_chain_C': np.random.rand(16, 3), 'C_chain_C': np.random.rand(16, 3), 'O_chain_C': np.random.rand(16, 3)},
        'masked_list': ['A', 'B'],
        'visible_list': ['C'],
        'num_of_chains': 3,
        'seq': 'MKLVFLVLLVFVQGF'+ 'MSVKVEEVG' + 'ATCGATCGATCGATCG'
    },
      {
        'seq_chain_X': 'ACDEFGHIKLMNPQRSTVWY',
        'coords_chain_X': {'N_chain_X': np.random.rand(20, 3), 'CA_chain_X': np.random.rand(20, 3), 'C_chain_X': np.random.rand(20, 3), 'O_chain_X': np.random.rand(20, 3)},
        'seq_chain_Y': 'ACCDEFGHILKLM',
        'coords_chain_Y': {'N_chain_Y': np.random.rand(13, 3), 'CA_chain_Y': np.random.rand(13, 3), 'C_chain_Y': np.random.rand(13, 3), 'O_chain_Y': np.random.rand(13, 3)},
        'seq_chain_Z': 'LKLMNRPQRST',
        'coords_chain_Z': {'N_chain_Z': np.random.rand(11, 3), 'CA_chain_Z': np.random.rand(11, 3), 'C_chain_Z': np.random.rand(11, 3), 'O_chain_Z': np.random.rand(11, 3)},
        'masked_list': ['X', 'Y'],
        'visible_list': ['Z'],
        'num_of_chains': 3,
        'seq': 'ACDEFGHIKLMNPQRSTVWY'+'ACCDEFGHILKLM'+'LKLMNRPQRST'

    }
]
device='cuda'
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)

In [24]:
%%writefile build_GNN_1.py
# orignail code
import torch
import torch.nn as nn
import numpy as np

class PositionalEncodings(nn.Module):
    def __init__(self, num_embeddings, max_relative_feature=32):
        super(PositionalEncodings, self).__init__()
        self.num_embeddings = num_embeddings
        self.max_relative_feature = max_relative_feature
        self.linear = nn.Linear(2*max_relative_feature+1+1, num_embeddings)

    def forward(self, offset, mask):
        d = torch.clip(offset + self.max_relative_feature, 0, 2*self.max_relative_feature)*mask + (1-mask)*(2*self.max_relative_feature+1)
        d_onehot = torch.nn.functional.one_hot(d, 2*self.max_relative_feature+1+1)
        E = self.linear(d_onehot.float())
        return E
    
def gather_edges(edges, neighbor_idx):
    # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
    neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
    edge_features = torch.gather(edges, 2, neighbors)
    return edge_features

class ProteinFeatures(nn.Module):
    def __init__(self, edge_features, node_features, num_positional_embeddings=16,
        num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
        """ Extract protein features """
        super(ProteinFeatures, self).__init__()
        self.edge_features = edge_features
        self.node_features = node_features
        self.top_k = top_k
        self.augment_eps = augment_eps 
        self.num_rbf = num_rbf
        self.num_positional_embeddings = num_positional_embeddings

        self.embeddings = PositionalEncodings(num_positional_embeddings)
        node_in, edge_in = 6, num_positional_embeddings + num_rbf*25
        self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
        self.norm_edges = nn.LayerNorm(edge_features)

    def _dist(self, X, mask, eps=1E-6):
        mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
        dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
        D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
        D_max, _ = torch.max(D, -1, keepdim=True)
        D_adjust = D + (1. - mask_2D) * D_max
        sampled_top_k = self.top_k
        D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
        return D_neighbors, E_idx

    def _rbf(self, D):
        device = D.device
        D_min, D_max, D_count = 2., 22., self.num_rbf
        D_mu = torch.linspace(D_min, D_max, D_count, device=device)
        D_mu = D_mu.view([1,1,1,-1])
        D_sigma = (D_max - D_min) / D_count
        D_expand = torch.unsqueeze(D, -1)
        RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
        return RBF

    def _get_rbf(self, A, B, E_idx):
        D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
        D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
        RBF_A_B = self._rbf(D_A_B_neighbors)
        return RBF_A_B

    def forward(self, X, mask, residue_idx, chain_labels):
        if self.training and self.augment_eps > 0:
            X = X + self.augment_eps * torch.randn_like(X)
        
        b = X[:,:,1,:] - X[:,:,0,:]
        c = X[:,:,2,:] - X[:,:,1,:]
        a = torch.cross(b, c, dim=-1)
        Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:]
        Ca = X[:,:,1,:]
        N = X[:,:,0,:]
        C = X[:,:,2,:]
        O = X[:,:,3,:]
 
        D_neighbors, E_idx = self._dist(Ca, mask)

        RBF_all = []
        RBF_all.append(self._rbf(D_neighbors)) #Ca-Ca
        RBF_all.append(self._get_rbf(N, N, E_idx)) #N-N
        RBF_all.append(self._get_rbf(C, C, E_idx)) #C-C
        RBF_all.append(self._get_rbf(O, O, E_idx)) #O-O
        RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) #Cb-Cb
        RBF_all.append(self._get_rbf(Ca, N, E_idx)) #Ca-N
        RBF_all.append(self._get_rbf(Ca, C, E_idx)) #Ca-C
        RBF_all.append(self._get_rbf(Ca, O, E_idx)) #Ca-O
        RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) #Ca-Cb
        RBF_all.append(self._get_rbf(N, C, E_idx)) #N-C
        RBF_all.append(self._get_rbf(N, O, E_idx)) #N-O
        RBF_all.append(self._get_rbf(N, Cb, E_idx)) #N-Cb
        RBF_all.append(self._get_rbf(Cb, C, E_idx)) #Cb-C
        RBF_all.append(self._get_rbf(Cb, O, E_idx)) #Cb-O
        RBF_all.append(self._get_rbf(O, C, E_idx)) #O-C
        RBF_all.append(self._get_rbf(N, Ca, E_idx)) #N-Ca
        RBF_all.append(self._get_rbf(C, Ca, E_idx)) #C-Ca
        RBF_all.append(self._get_rbf(O, Ca, E_idx)) #O-Ca
        RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) #Cb-Ca
        RBF_all.append(self._get_rbf(C, N, E_idx)) #C-N
        RBF_all.append(self._get_rbf(O, N, E_idx)) #O-N
        RBF_all.append(self._get_rbf(Cb, N, E_idx)) #Cb-N
        RBF_all.append(self._get_rbf(C, Cb, E_idx)) #C-Cb
        RBF_all.append(self._get_rbf(O, Cb, E_idx)) #O-Cb
        RBF_all.append(self._get_rbf(C, O, E_idx)) #C-O
        RBF_all = torch.cat(tuple(RBF_all), dim=-1)

        offset = residue_idx[:,:,None]-residue_idx[:,None,:]
        offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]

        d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long() #find self vs non-self interaction
        E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
        E_positional = self.embeddings(offset.long(), E_chains)
        E = torch.cat((E_positional, RBF_all), -1)
        E = self.edge_embedding(E)
        E = self.norm_edges(E)
        return E, E_idx

Overwriting build_GNN_1.py


# breaddown of functions
Class ProteinFeatures(nn.Module):
    - Initialize with parameters:
        edge_features, node_features
        num_positional_embeddings (default: 16)
        num_rbf (default: 16)
        top_k (default: 30)
        augment_eps (default: 0.0)
        num_chain_embeddings (default: 16)

    - Call the parent class initializer (super).
    - Define instance variables for input parameters.
    - Initialize:
        - PositionalEncodings for positional embeddings.
        - Linear transformation for edge features.
        - Layer normalization for edges.

Methods:

1. `_dist(X, mask, eps=1e-6)`:
    - Create a 2D mask for pairwise distance computation.
    - Compute pairwise distances (D) between coordinates in X using the mask.
    - Adjust distances for masked regions.
    - Identify the top-k nearest neighbors for each point.
    - Return the distances and indices of the top-k neighbors.

2. `_rbf(D)`:
    - Define radial basis function (RBF) parameters (min, max, count).
    - Compute RBF expansion of distances D.
    - Return the RBF representation.

3. `_get_rbf(A, B, E_idx)`:
    - Compute pairwise distances between points in A and B.
    - Gather distances for top-k neighbors using E_idx.
    - Convert distances to RBF representation.
    - Return the RBF representation.

4. `forward(X, mask, residue_idx, chain_labels)`:
    - If in training mode and `augment_eps` > 0:
        - Add random noise to the input coordinates (X).
    - Extract atom positions (N, Ca, C, O, Cb) for the protein backbone.
    - Compute pairwise distances and top-k indices using `_dist`.
    - Generate RBF features for:
        - Intra-atom pairs (Ca-Ca, N-N, etc.).
        - Inter-atom pairs (Ca-N, Ca-C, etc.).
    - Concatenate all RBF features.
    - Compute positional offsets using residue indices and chain labels.
    - Identify self vs. non-self interactions based on chain labels.
    - Create edge features by combining positional embeddings and RBF features.
    - Apply edge embedding and normalization.
    - Return edge features and indices.

In [None]:
# breakdwon 
# How X is used
b = X[:,:,1,:] - X[:,:,0,:] # B L CA-N  vector   B L 3
c = X[:,:,2,:] - X[:,:,1,:] # C-CA VECTOR       B L 3
a = torch.cross(b, c, dim=-1) # orthorgnal to N-CA-C plane
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:] # calculate ghost CB
Ca = X[:,:,1,:] # cooridnates for backbone
N = X[:,:,0,:]
C = X[:,:,2,:]
O = X[:,:,3,:]

demo=ProteinFeatures(edge_features=16, node_features=16, num_positional_embeddings=16,
        num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16)
# use CA as node get neighbor info
D_neighbors, E_idx = demo._dist(Ca, mask) # mask to tell whether it is padding
# def _dist(self, X, mask, eps=1E-6):
#     mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) # generate a 2D mask for padddign
#     dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) # calculate interresidue vector
#     D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) # calculate interresidue distance and filter with padding mask
#     D_max, _ = torch.max(D, -1, keepdim=True) # get the largest distance
#     D_adjust = D + (1. - mask_2D) * D_max # for between padding position using the larget distance
#     sampled_top_k = self.top_k
#     D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False) # select top_k shortest distance, and extract id
#     return D_neighbors, E_idx  # return value and indices
print(D_neighbors[0,0,:])
print(E_idx[0,0,:])  # now we know neighours 

tensor([0.0010, 0.1319, 0.1947, 0.2920, 0.3160, 0.3528, 0.3663, 0.3675, 0.3915,
        0.4363, 0.4774, 0.4821, 0.5453, 0.5860, 0.5929, 0.6301, 0.6304, 0.6589,
        0.6782, 0.7061, 0.7180, 0.7350, 0.7607, 0.7616, 0.7640, 0.7864, 0.7878,
        0.7886, 0.7932, 0.7976], device='cuda:0')
tensor([ 0, 17,  1, 37, 23,  5, 28, 13,  7, 11, 36, 12, 14, 29, 18,  2, 15, 34,
        24, 35, 19, 22, 20, 33, 26, 27, 21,  9, 31,  3], device='cuda:0')


In [20]:
# all distance are coded with RBF kernel
# CA-CA is easy as we know values D_neighbours
# how to calcualte other atom-atom pair distance, following take N N as examaple
# We know N cooridates : We could calculate all N N distance, then pick out neighours based on E_idx, or pick eighbour first calculate
# the first approach is used, coded in _get_rbf
A=N # B, L, 3
B=N # B, L, 3
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
D_A_B=D_A_B[:,:,:,None] # B L L 1
print(D_A_B.shape)
neighbors = E_idx.unsqueeze(-1).expand(-1, -1, -1, D_A_B.size(-1)) # B, L, K, 1
print(neighbors.shape)
edge_features = torch.gather(D_A_B, 2, neighbors)[:,:,:,0]# B, L, K
print(edge_features.shape) # values of N-N distance of neighbour nodes are generated

torch.Size([2, 44, 44, 1])
torch.Size([2, 44, 30, 1])
torch.Size([2, 44, 30])


In [21]:
# rbf_kernel
#[B,L,K]  to [B,L, K, num_rbf]
NN_features=demo._rbf(edge_features) # distance to rbf faetures
print(NN_features.shape)

# all backbone atom (including cb) pairs rbf featurs are calculated and concat
#RBF_all = torch.cat(tuple(RBF_all), dim=-1) # tuplue convert list to tuple to fix sequence

torch.Size([2, 44, 30, 16])


In [22]:
# Now we focus on residue index part in edge_feature
offset = residue_idx[:,:,None]-residue_idx[:,None,:] # Broadcast to 2D # B, L, L
offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K] get interresidue idex differnt for neighbors
print(offset.shape)
d_chains = ((chain_encoding_all[:, :, None] - chain_encoding_all[:,None,:])==0).long() #find self vs non-self interaction
E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]  #get residue pairs are intra (1) or inter (0) used as mask
print(E_chains.shape)
# encode inter residue index distance using one-hot (from Alphafold)
#define maxmimum index distance, anything beyond treated as maximum 
# then onehot vector is used as input for a layer of MLP, output as residue index distance feature
# inmpletment in positionEncoding
max_relative_feature = 20 # only look left and right 20 
# inter chain is set to 2*maxium +1
# warning: only d = maxium  means close, both 0 and 2*max_relative_feature are far
# E_chains used as mask for intra or inter chain
d = torch.clip(offset + max_relative_feature, 0, 2*max_relative_feature)*E_chains + (1-E_chains)*(2*max_relative_feature+1) # B, L, K
d_onehot = torch.nn.functional.one_hot(d, 2*max_relative_feature+1+1) # B, L, K, 2*maxium +1
print(d_onehot.shape)
# then d_onehot is used for MLP


torch.Size([2, 44, 30])
torch.Size([2, 44, 30])
torch.Size([2, 44, 30, 42])


## finally concat physical distance and residue index distance feature
 used for MLP and then norm feature
E = torch.cat((E_positional, RBF_all), -1)
E = self.edge_embedding(E)
E = self.norm_edges(E)

## Node feature

### E is the h_E features [B, L, K, 2*maxium +1]
### initalize h_V with zero, match dimension with h_E
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) 
### h_V and h_E is used in GNN for message pass to update h_V an h_E