# Learning proteins pockets shape

Under development. 

To allow the notebook to read the cavities data, project's file structure needs to look as the following:

```
Project
├───scPDB
│   ├───1a2b_1
│   │    └─ cavity6.mol2
│   │    │  ...   
│   └───1a2f_1
│   │   ...
└───ThisNotebook.ipynb
```

In [2]:
import os
import time
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

#torch_geometric 
from torch_cluster import radius_graph
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_scatter import scatter

#e3nn
import e3nn
from e3nn import o3
from e3nn.nn import FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.math import soft_one_hot_linspace
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork

In [3]:
# First position contains the scPDB directory name
paths = [x[0] for x in tqdm(os.walk(os.getcwd()+'/scPDB'))][1:]

17595it [00:13, 1321.62it/s]


In [25]:
# Access every cavity6.mol2 file contained in each folder. Folder == Unique Protein-Cavity

# Function retrieves all the data inside cavity6.mol2 file and returns it as a list. 
def getCavity(path):
    prot_path = path.split('/')
    file = open(prot_path[-2] + '/' + prot_path[-1] + "/cavity6.mol2", "r")
    data = [line for line in file]
#     Get PDB-ID. 
    pdb_id = prot_path[-1].split('_')[0]
    print('Protein:', pdb_id)
    return data


# Get coordinates of each point of the cavity
def getCoordinates(line):
    x = float(line.split()[2])
    y = float(line.split()[3])
    z = float(line.split()[4])

    return (x, y, z)

# Retrieve pharmacophoric features from the binding site 
def getFeatures(line):
    atom_type = line.split()[1]
    
    return atom_type
    
# Extract cavity geometrical data from file
def getCavityInfo(data):
    # Files follow the same format, starting from 'ATOM' and finishing at 'BOND'
    coords = [getCoordinates(atom) for atom in data[data.index('@<TRIPOS>ATOM\n')+1
                                                     :data.index('@<TRIPOS>BOND\n')]]
    
    features = [getFeatures(atom) for atom in data[data.index('@<TRIPOS>ATOM\n')+1
                                                     :data.index('@<TRIPOS>BOND\n')]]
    
    return torch.tensor(coords, dtype = torch.get_default_dtype()), features


# Create labels for each of the cavities. (ONLY DRUGGABLE ATM)
def getLabels(cavities) -> torch.Tensor:
    return torch.ones(len(cavities), 1)
#     return torch.zeros((len(cavities), len(cavities))).fill_diagonal_(1)
    

# Returns centroid from a given point cloud. Unsqueeze (increase dimensionality) required to compute distance later. 
def getCentroid(point_cloud) -> torch.Tensor:
    return torch.mean(point_cloud, dim = 0).unsqueeze(0)


#Compute euclidean distance between two tensors centroids
def calculateDistance(t1, t2):
    return torch.cdist(getCentroid(t1), getCentroid(t2), p=2).item()


# Receives a list of lists. Returns a dict where each key corresponds to an atom type and value is a unique integer.
def featuresToDict(feat_list):
    # Flat the list of lists 
    flat_features = [features for sublist in feat_list for features in sublist]

    # Reduce to unique features. 
    unique_features = set(flat_features)

    # Generate dictionary for posterior mapping/encoding of features
    dict_features = {element:idx for idx, element in enumerate(unique_features)}
    
    return dict_features 

In [26]:
# Number of cavities to be extracted can be set (for now). 
NUM_PROTEINS = 1

cavities, cavities_features = [], []
for protein in paths[:NUM_PROTEINS]:
    coords, feats = getCavityInfo(getCavity(protein))
    cavities.append(coords)
    cavities_features.append(feats)

Protein: 4j22


### Fpocket toy example: 4j22

We will conduct a toy example with the first protein obtained from scPDB, 4j22. The goal is to demonstrate that our neural network is able to distinguish between druggable and non-druggable cavities in the protein. Since scPDB provides a single cavity, we will search for more through **fpocket**. All files returned by the algorithm are inside the 'pockets' folder, sorted by their druggability score.

In [46]:
pockets_data = [pocket for pocket in [i[2] for i in os.walk(os.getcwd()+'/pockets')][0] if '.pdb' in pocket]

# Reads 'pockets' folder and returns a list of coordinates and atom types for each cavity.
def getFpocketCoords(pocket_id):   
    coords, features = [], []
    # Focus on the lines that contain the coordinates 
    data = [line for line in open(os.getcwd()+'/pockets/'+pocket_id)][20:-2]
    # For each of the lines, extract the coordinates and atom type. 
    for atom in data:
        atom_line_split = atom.split()
        # Some files are B''9XX instead of B1XX (hence an additional split between the number and character)
        if atom_line_split[5].isdigit() == True:
            coords.append(atom_line_split[6:9])
        else:
            coords.append(atom_line_split[5:8])
        features.append(atom_line_split[2])
    
    return (torch.Tensor(np.array(coords, dtype='float32')), features)

In [47]:
pockets, feats = [], []
for pid in pockets_data:
    p, f = getFpocketCoords(pid)
    # There might be (not sure) some errors with the file format. 
    # Some atom types seem to be mixed with the aminoacid, e.g., CAACYS (CA CYS). Hardcoded solution. Just to make the example work.  
    for idx, atom_type in enumerate(f):
        if len(atom_type) > 5:
            f[idx] = atom_type[:2]
    pockets.append(p)
    feats.append(f)

# Compute euclidean distance between each identified pocket and the one given by scPDB
p_distances = [calculateDistance(cavities[0], p) for p in pockets]

# All labels are set to 0 except for the one that has the closest distance to the one obtained from scPDB.
p_labels = torch.zeros(len(pockets), 1)
p_labels[p_distances.index(min(p_distances))] = 1

# Generate dictionary where each key is a different atom type and value is a unique integer. 
dict_features = featuresToDict(feats)

print(p_distances)

[12.776483535766602, 27.861412048339844, 32.577369689941406, 33.63336944580078, 33.588951110839844, 50.967750549316406, 49.08852767944336, 22.45990753173828, 43.91595458984375, 18.20598793029785, 28.26254653930664, 13.513494491577148, 37.019020080566406, 21.109922409057617, 18.126548767089844, 39.27044677734375, 20.355783462524414, 25.32646942138672, 16.542572021484375, 36.99082946777344, 28.47964096069336, 18.026079177856445, 29.520065307617188, 7.529668807983398, 2.0994231700897217, 13.186159133911133]


As it can be seen, the smallest distance is 2, which compared to the others is quite good. However, fpocket's drug score for that specific cavity ---which should be druggable due to probably corresponding to the one given by scPDB--- is almost 0, hence not considering a potential druggable pocket (doubt). Nevertheless, in order to complete the example, we will consider that cavity as druggable.

In [48]:
print(p_labels)

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.]])


All labels are 0 (undruggable) except for the one that matches scPDB cavity (druggable)

### Dataset and batch generation

In [1]:
# In order to rotate we create a 3x3 rotation matrix and perform a matrix multiplication on the tensor 
def applyRotation(cavities) -> list:
    # Obtain a random 3x3 matrix
    rand_matrix = o3.rand_matrix(1)
    print('Random matrix:', rand_matrix)
    # Multiply cavity coordinates by the random matrix
    rotated_set = [torch.matmul(cav, rand_matrix).squeeze(0) for cav in cavities]
    
    return rotated_set


# x: Node feature matrix. [Num_node, num_node_features] = [Num_points_in_cavity, 1]
# pos : Node position matrix. [num_nodes, num_dimensions] = [Num_points_in_cavity, 3] 
def buildDataset(cavities):
    dataset = [Data(
                    pos = cav, 
                    x = torch.tensor([dict_features[k] for k in feats[idx]],
                                     dtype=torch.get_default_dtype()).unsqueeze(1)) 
               for idx,cav in enumerate(cavities)]
    return dataset


# Create batches. Decide whether applying rotation or not. 
def makeBatch(cavities, rotation = True):
    # If false, dont apply rotation to the set the cavities. 
    if rotation == False:
        dataset = buildDataset(cavities)
    else:
        dataset = buildDataset(applyRotation(cavities))
    
    # Return batch. Shuffle is set to False (default). 
    return next(iter(DataLoader(dataset, batch_size = len(dataset))))

### Building the network: Equivariant convolution

The main operation of the convolution is the Fully Connected Tensor Product

In [36]:
class Convolution(torch.nn.Module):
    def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors) -> None:
        super().__init__()
        
        self.num_neighbors = num_neighbors
        
        # Required to know how many weights are required in the Multi-Layer Perceptron (MLP)
        tp = FullyConnectedTensorProduct(
            irreps_in1 = irreps_in,
            irreps_in2 = irreps_sh,
            irreps_out = irreps_out,
            internal_weights = False,
            shared_weights = False,
        )
        
        # MLP: [Input, internal, and output dimensions], activation function
        self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu)
        # Tensor product
        self.tp = tp
        # Visualize TP
#         print('Fully Connected Tensor Product:', self.tp.visualize());
#         plt.show()
        
        self.irreps_out = self.tp.irreps_out

    def forward(self, node_features, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor:
        # To map the relative distances to the weights of the tensor product we will embed the distances
        # using a basis function and then feed this embedding (edge_scalars) to a neural network. 
        weight = self.fc(edge_scalars)
        # To compute this quantity per edges, so we will need to “lift” the input feature to the edges.
        # For that we use edge_src that contains, for each edge, the index of the source node.
        edge_features = self.tp(node_features[edge_src], edge_attr, weight)
        # Sum over the neighbors. Get final output
        node_features = scatter(edge_features, edge_dst, dim=0).div(self.num_neighbors**0.5)
        
        return node_features

### Building the network

Now that convolution layer has been defined, we can fully construct our equivariant neural network for point clouds.

In [37]:
class Network(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        # Number of neighbors hyperparameter
        self.num_neighbors = 3.8
        
        # Set the spherical harmonics 
        self.irreps_sh = o3.Irreps.spherical_harmonics(3)
        irreps = self.irreps_sh

        # First layer with gate
        gate = Gate(
            "16x0e + 16x0o", [torch.relu, torch.abs],  # scalar
            "8x0e + 8x0o + 8x0e + 8x0o", [torch.relu, torch.tanh, torch.relu, torch.tanh],  # gates (scalars)
            "16x1o + 16x1e"  # gated tensors, num_irreps has to match with gates
        )
        # Convolutional layer. Irreps_sh, irreps_sh, gate.irreps_in, num_neighbors
        self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in, self.num_neighbors)
        # Gate and its output
        self.gate = gate
        # irreps_out = irreps_scalars + (ElementWiseTensorProduct(irreps_gates, irreps_gated))
        irreps = self.gate.irreps_out

        # Final layer. gate ouput, irreps_sh, output specified, num_neighbors. 
        self.final = Convolution(irreps, self.irreps_sh, "1x0e", self.num_neighbors)
        # Final output
        self.irreps_out = self.final.irreps_out

    def forward(self, data, nnodes, mradius) -> torch.Tensor:
        # Set the number of nodes and max radius. 
        num_nodes = nnodes
        max_radius = mradius
        
        # Generate graph using the node positions and creating the edges when the relative distance 
        # between a pair of nodes is smaller than max_radius (r).
        edge_src, edge_dst = radius_graph(x = data.pos, r= max_radius, batch=data.batch)
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        
        # Computing the sh
        # Normalize=True ensure that x is divided by |x| prior computation
        edge_attr = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=edge_vec,
            normalize=True,
            normalization='component'
        )
        
        # Embed the distances then feed this embedding to the MLP (Convolutional class)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_vec.norm(dim=1),
            start=0.5,
            end=2.5,
            number=3,
            basis='smooth_finite',
            cutoff=True
        ) * 3**0.5
        
        x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5)
        
        # Network architecture:
        
        #CONV
        x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
#         print('Conv output:', x.shape)
#         print('Conv irreps_out:', self.conv.irreps_out, '\n')
        
        #GATE
        x = self.gate(x)
#         print('Gate output:', x.shape)
#         print('Gate irreps_in:', self.gate.irreps_in)
#         print('Gate irreps_out:', self.gate.irreps_out, '\n')
        
        #FINAL (CONV)
        x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
#         print('Final(conv) output:', x.shape)
#         print('Final irreps_out:', self.final.irreps_out)
        
        
        return scatter(x, data.batch, dim=0).div(num_nodes**0.5)

### Training the model

In [38]:
def main(cavities, labels, num_nodes, max_radius):
    #Set the number of epochs
    NUM_EPOCHS = 1000
    EPOCH_BATCH_SIZE = 250
    accs = []
    
    # Train and test batches have suffered different rotations. Labels remain equal. 
    x, y = cavities, labels
    train_x, train_y = makeBatch(x), labels
    test_x, test_y = makeBatch(x), labels
    print('Batch created:', train_x)
    
    # Build model
    print("Model built.")
    net = Network()
    
    #Set the optimizer
    optim = torch.optim.Adam(net.parameters(), lr=1e-3)
    
    for step in tqdm(range(NUM_EPOCHS)):
        pred = net(train_x, num_nodes, max_radius)
        loss = (pred - train_y).pow(2).sum()
        
        # Update network parameters
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        if step % EPOCH_BATCH_SIZE == 0:
            accuracy = net(test_x, num_nodes, max_radius).round().eq(test_y).all(dim=1).double().mean(dim=0).item()
            print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy")  
            accs.append(accuracy) 
    
    return net

In [39]:
net = main(pockets, p_labels, 4, 1.5)

Batch created: DataBatch(x=[629, 1], pos=[629, 3], batch=[629], ptr=[27])
Model built.


  0%|          | 1/1000 [00:00<10:36,  1.57it/s]

epoch     0 | loss 33.8       |  65.4% accuracy


 26%|██▌       | 260/1000 [00:05<00:15, 46.98it/s]

epoch   250 | loss 0.3        | 100.0% accuracy


 51%|█████     | 510/1000 [00:10<00:09, 49.90it/s]

epoch   500 | loss 0.1        | 100.0% accuracy


 76%|███████▌  | 760/1000 [00:15<00:04, 49.20it/s]

epoch   750 | loss 0.1        | 100.0% accuracy


100%|██████████| 1000/1000 [00:19<00:00, 50.13it/s]


The network manages to identify rotated non druggable/druggable cavities.

## Testing

Model will be tested against rotated and non-rotated sets of pockets.

In [59]:
prediction = net(makeBatch(pockets, False), 4, 1.5)
print(prediction.round())

tensor([[-0.],
        [-0.],
        [-0.],
        [-0.],
        [0.],
        [0.],
        [-0.],
        [0.],
        [-0.],
        [-0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [-0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [-0.],
        [1.],
        [-0.]], grad_fn=<RoundBackward0>)


When the network is tested against a set of **non-rotated cavities**, we obtain a **100% accuracy**

In [68]:
prediction = net(makeBatch(pockets), 4, 1.5)
print(prediction.round())

Random matrix: tensor([[[ 0.5833, -0.8113, -0.0400],
         [ 0.7226,  0.5407, -0.4307],
         [ 0.3711,  0.2224,  0.9016]]])
tensor([[-0.],
        [-0.],
        [-0.],
        [-0.],
        [0.],
        [0.],
        [-0.],
        [0.],
        [-0.],
        [-0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [-0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [-0.],
        [1.],
        [-0.]], grad_fn=<RoundBackward0>)


When the network is tested against a set of **rotated cavities**, we obtain a **100% accuracy**. Each time *makeBatch()* is called, a new rotation is applied to the original set of coordinates.