In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.nn as gnn
from torch_geometric.data import Data 


import numpy as np
import pandas as pd
import os

from rdkit import Chem

from Bio.PDB.PDBParser import PDBParser
import scipy.linalg as linalg
import dgl

p = PDBParser(PERMISSIVE=1)

## Load SMILES


ref:  https://github.com/jacquesboitreaud/interpretGCN/blob/master/dataloading/rdkit_to_nx.py

atoms: https://blog.csdn.net/dreadlesss/article/details/106306472

In [2]:
BOUND_TYPE_dict = dict(zip(Chem.rdchem.BondType.values.values(),Chem.rdchem.BondType.values.keys()))

def one_hot(emb_len, attr_list):
    return np.eye(emb_len)[attr_list]


def Node_feature(atom):
    feature_dict = {}
    feature_dict["id"] = atom.GetIdx()
    feature_dict["atomic"] = atom.GetAtomicNum()
    feature_dict["symbol"] = atom.GetSymbol()
    feature_dict["aromatic"] = 1 if atom.GetIsAromatic() else 0
    feature_dict["valence"] = atom.GetTotalValence()
    return feature_dict

def Smiles_to_mtx(Smiles):
    mol = Chem.MolFromSmiles(Smiles)
    mol = Chem.RemoveHs(mol)
    num_atom = mol.GetNumAtoms()
    mtx = np.zeros((num_atom,num_atom))   # mtx = np.diag(np.ones((num_atom,)))
    nodes = [Node_feature(atom) for atom in mol.GetAtoms()]  # nodes = {atom.GetIdx():Node_feature(atom) for atom in mol.GetAtoms()}
    for bond in mol.GetBonds():       ## C-C bounds
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        mtx[i,j] = mtx[j,i] = BOUND_TYPE_dict[bond.GetBondType()]
    return mtx,nodes# pd.DataFrame(mtx,index = symbol_list, columns = symbol_list)



def Smiles_to_data(Smiles):  ## Each molecule as a graph: edge_index=[2, E], edge_attr=[E, embE], x=[N, embN], y=[?]
    mol = Chem.MolFromSmiles(Smiles)
    mol = Chem.RemoveHs(mol)
    from_, to_, attrE = zip(*[(bond.GetBeginAtomIdx(),bond.GetEndAtomIdx(),BOUND_TYPE_dict[bond.GetBondType()]) for bond in mol.GetBonds()])
    edge_index = torch.tensor([list(from_),list(to_)],dtype=torch.int64)
    edge_attr = torch.tensor(one_hot(len(BOUND_TYPE_dict), list(attrE)),dtype=torch.int64)    ## embE: one-hot BOUND_TYPE
    nodes = [Node_feature(atom) for atom in mol.GetAtoms()]
    x = torch.tensor([[n['atomic'],n['aromatic'],n['valence']] for n in nodes],dtype=torch.float32)   ## embN
    return Data(edge_index=edge_index, edge_attr=edge_attr, x=x)

In [3]:
Smiles = 'C[C@H](O)c1ccccc1'
Smiles_to_data(Smiles)

Data(x=[9, 3], edge_index=[2, 9], edge_attr=[9, 22])

In [4]:
Smiles = 'C[C@H](O)c1ccccc1'
Smiles_to_mtx(Smiles)

(array([[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0., 12.,  0.,  0.,  0., 12.],
        [ 0.,  0.,  0., 12.,  0., 12.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0., 12.,  0., 12.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0., 12.,  0., 12.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0., 12.,  0., 12.],
        [ 0.,  0.,  0., 12.,  0.,  0.,  0., 12.,  0.]]),
 [{'id': 0, 'atomic': 6, 'symbol': 'C', 'aromatic': 0, 'valence': 4},
  {'id': 1, 'atomic': 6, 'symbol': 'C', 'aromatic': 0, 'valence': 4},
  {'id': 2, 'atomic': 8, 'symbol': 'O', 'aromatic': 0, 'valence': 2},
  {'id': 3, 'atomic': 6, 'symbol': 'C', 'aromatic': 1, 'valence': 4},
  {'id': 4, 'atomic': 6, 'symbol': 'C', 'aromatic': 1, 'valence': 4},
  {'id': 5, 'atomic': 6, 'symbol': 'C', 'aromatic': 1, 'valence': 4},
  {'id': 6, 'atomic': 6, 'symbol': 'C', 'aromatic': 1, 'valence': 4},
  {'id': 7, '

## Load PDB

- edge: [C-C bound]
- node: [coord,aa]

or refer to: https://blog.csdn.net/C20180602_csq/article/details/138327140

Todo: Let the graph fully connected, and add features to edge: [L2_distance_of_ATOM_C > ??, C-C bound]

In [5]:
SHORTEN_dict= {'VAL':'V', 'ILE':'I', 'LEU':'L', 'GLU':'E', 'GLN':'Q',
            'ASP':'D', 'ASN':'N', 'HIS':'H', 'TRP':'W', 'PHE':'F', 
            'TYR':'Y', 'ARG':'R', 'LYS':'K', 'SER':'S', 'THR':'T', 
            'MET':'M', 'ALA':'A', 'GLY':'G', 'PRO':'P', 'CYS':'C'}   ## BZJX*

def load_embd_dict(AA_selc = list(''.join(SHORTEN_dict.values()) + '*'), file = "BLOSUM62.txt"):
    df = pd.read_csv("BLOSUM62.txt", sep="\\s+", index_col=0)
    df = df[AA_selc].T[AA_selc].T
    df = df.sort_index().loc[:, df.columns.sort_values()]       ## Sort cols and rows by name
    w, v = linalg.eigh(np.exp2(df))
    v = v * np.sign(v[0])              ## v[:,i] is i-th eigen vector
    v = v @ np.diag(w**0.5)            ## scale v by eigen value
    return dict(zip(df.columns.values,v.T))

EMBD_dict = load_embd_dict()

def embedAA(aa):        ## embedAA('A') --> (21,)
    if aa in SHORTEN_dict.keys():
        return EMBD_dict[SHORTEN_dict[aa]]
    else:
        return EMBD_dict['*']

In [6]:
def load_pdb(p,id,file):
    model = p.get_structure(id, file).get_models().__next__()  ## Get the first model from the structure
    chains = {}
    for chain in model.get_chains():           ## Usually only one chain as: 'COMPND   3 CHAIN: A'
        id = chain.get_id()
        chains[id] = {
            'aa': [],
            'coord': []
        }
        for residue in chain.get_residues(): 
            aa = residue.get_resname()
            coord = residue['C'].get_coord()
            chains[id]['aa'].append(aa)
            chains[id]['coord'].append(coord)
        chains[id]['seqlen'] = len(chains[id]['aa'])
    return chains


def chains_to_dgl(chains):
    g = dgl.graph([])
    pos = 0
    for id in chains.keys():
        chain = chains[id]
        start_of_chain = True
        for (aa,coord) in zip(chain['aa'],chain['coord']):
            g.add_nodes(1, {
                'coord': torch.tensor(coord,dtype=torch.float32).unsqueeze(0), 
                'aa': torch.tensor(embedAA(aa),dtype=torch.float32).unsqueeze(0)
            })
            if start_of_chain == False:
                g.add_edges(torch.tensor([pos-1]), torch.tensor([pos]), {
                    'C-C': torch.tensor([1])         ## others can be S-S bonds / H bounds, closely located items, etc.  This time we don't calcuate them
                })
            pos += 1
            start_of_chain = False
    return g


def chains_to_data_CCbound(chains):  ## Each molecule as a graph: edge_index=[2, E], edge_attr=[E, embE], x_=[N, embN], y=[?]
    edge_index = [[],[]]
    x_aa = []
    x_coord = []
    start_pos = 0
    for id in chains.keys():
        chain = chains[id]
        x_aa += [embedAA(aa) for aa in chain['aa']]
        x_coord += chain['coord']
        edge_index[0] += [i+start_pos for i in range(chain['seqlen']-1)]
        edge_index[1] += [i+start_pos+1 for i in range(chain['seqlen']-1)]
        start_pos += chain['seqlen']
    x_aa = torch.tensor(np.array(x_aa),dtype=torch.float32)
    x_coord = torch.tensor(np.array(x_coord),dtype=torch.float32)
    edge_index= torch.tensor(np.array(edge_index),dtype=torch.int64)    ## C-C bound
    return Data(edge_index=edge_index, x=x_aa, x_coord=x_coord)


def chains_to_data_L2(chains):
    dataCC = chains_to_data_CCbound(chains)
    x_aa = dataCC.x
    x_coord = dataCC.x_coord
    aa_num = dataCC.x_coord.size(0)
    edge_index = [[],[]]
    edge_attr = []
    for i in range(0,aa_num-1):
        for j in range(i+1,aa_num):
            L2 = torch.norm(x_coord[i] - x_coord[j],p=2)
            edge_index[0].append(i)
            edge_index[1].append(j)
            edge_attr.append(L2)
    edge_index = torch.tensor(np.array(edge_index),dtype=torch.int64)
    edge_attr = torch.tensor(np.array(edge_attr),dtype=torch.float32)
    select = edge_attr < edge_attr[0] * 1.5             ## < 1.5 * Distance of the 1st neighboring C-C pairs
    return Data(edge_index=edge_index[:,select], edge_attr=edge_attr[select], x=x_aa)


In [7]:
files = os.listdir('pdb')
ids = [f.split('-model_')[0] for f in files]
files = [os.path.join('pdb',f) for f in files]

file = files[0]
id = ids[0]

chains = load_pdb(p,id,file)
chains_to_data_CCbound(chains), chains_to_data_L2(chains), chains_to_dgl(chains)

(Data(x=[128, 21], edge_index=[2, 127], x_coord=[128, 3]),
 Data(x=[128, 21], edge_index=[2, 159], edge_attr=[159]),
 Graph(num_nodes=128, num_edges=127,
       ndata_schemes={'coord': Scheme(shape=(3,), dtype=torch.float32), 'aa': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={'C-C': Scheme(shape=(), dtype=torch.int64)}))

In [8]:
## Try Conv
from dgl.nn import GraphConv

g = chains_to_dgl(chains)


convLayer = GraphConv(in_feats = 21, out_feats = 2, norm='both', weight=True, bias=True, allow_zero_in_degree=True)
feat = g.ndata['aa']

feat.size(), convLayer(g, feat).size()

(torch.Size([128, 21]), torch.Size([128, 2]))