In [1]:
import os

import numpy as np
import scipy.sparse as sp

from spektral.data import Dataset, Graph
from rdkit import Chem
from rdkit.Chem import AllChem

In [2]:
def get_nodes(mol):
    AllChem.ComputeGasteigerCharges(mol)
    nodes = np.concatenate((
        np.array([(
            atom.GetAtomicNum(), 
            atom.GetDoubleProp("_GasteigerCharge")) 
        for atom in mol.GetAtoms()]),
        mol.GetConformer().GetPositions()[:,:2]),
        axis=1
    )
    return nodes

def get_edges(mol):
    return np.array([
        bond.GetBondTypeAsDouble() for bond in mol.GetBonds()])

def isfloat(s):
    try:
        float(s)
        return True
    
    except ValueError:
        pass
 
    try:
        import unicodedata
        unicodedata.numeric(s)
        return True
    
    except (TypeError, ValueError):
        pass
 
    return False

def get_labels(mol, key='IC50 (nM)'):
    """Generate label data for each molecule
    
    "rank" indicates precense or absence of angle brackets,
    which are reported for concentrations beyond detection limits.
    rank = 1 when "<", 2 when ">", and 3 when none
    
    "conc" containts the reported concentration values
    angle brackets are removed and boundary values are saved.
    when conc value is 0, it means metric was not reported.
    
    """
    # read potency metric
    sample = mol.GetPropsAsDict()[key]
    # remove leading and trailing whitespaces
    sample = sample.strip()
        
    # below exp. range
    if "<" in sample: 
        
        rank = 1
        conc = sample.replace('<', '')

    # outside exp. range
    elif ">" in sample:
        
        rank = 2
        conc = sample.replace('>', '')

    # inside exp. range
    elif isfloat(sample):
        
        rank = 3
        conc = sample

    # no data provided
    else:
        rank = 3
        conc = 0.0
    
    return np.array([rank, float(conc)])

# create instance of sdf reader
suppl = Chem.SDMolSupplier('datasets/estrogen_receptor_alpha.sdf', sanitize=True, strictParsing=True)

# read all molecules besides ones with errors into a list
mols = [mol for mol in suppl if mol is not None]

# Get nodes
x = [get_nodes(mol) for mol in mols]
    
# Adjacency matrices
a = [Chem.rdmolops.GetAdjacencyMatrix(mol) for mol in mols]

# Edge features: bond types
e = [get_edges(mol) for mol in mols]

# Labels: (rank, IC50s)
# this metric is less reliable than e.g. Kd as 
# it depends on the of the substrates used in 
# the essay and it is cell type dependent.
y = [get_labels(mol) for mol in mols]

In [3]:
class EstrogenDB(Dataset):
    """Dataset from BindingDB
    """
    def __init__(self, n_samples, nodes, edges, adjcs, feats, dpath, **kwargs):
        self.n_samples = n_samples
        self.nodes = nodes
        self.edges = edges
        self.adjcs = adjcs
        self.feats = feats
        # dataset directory
        self.dpath = dpath
        
        super().__init__(**kwargs)
        
    def read(self):
        # create Graph objects
        data = np.load(os.path.join(
            self.dpath, f'EstrogenDB.npz'), 
                       allow_pickle=True)
        
        return [
            self.make_graph(
                node=data['x'][i],
                adjc=data['a'][i], 
                edge=data['e'][i],
                feat=data['y'][i])
            for i in range(self.n_samples)
            if data['y'][i][1] > 0
        ]
    
    def download(self):
        # save graph arrays into directory
        filename = os.path.join(self.dpath, f'EstrogenDB')
        
        np.savez_compressed(
            filename, 
            x=self.nodes, 
            a=self.adjcs, 
            e=self.edges, 
            y=self.feats)
    
    @staticmethod
    def make_graph(node, adjc, edge, feat):
        # The node features
        x = node.astype(float)
        
        # The adjacency matrix
        # convert to scipy.sparse matrix
        a = adjc.astype(int)
        a = sp.csr_matrix(a)
        # check shape (n_nodes, n_nodes)
        assert len(node) == a.shape[0]
        assert len(node) == a.shape[1]
        
        # The labels
        y = feat.astype(float)
        
        # The edge features 
        e = edge.astype(float)
        
        return Graph(x=x, a=a, e=e, y=y)

In [4]:
dataset = EstrogenDB(
    n_samples=1000, 
    nodes=x, edges=e, 
    adjcs=a, feats=y, 
    dpath='/Users/TiNoel/AnacondaProjects/ChemGraphs/datasets')

  return array(a, dtype, copy=False, order=order, subok=True)


In [5]:
for graph in dataset[:3]:
    print(graph)

Graph(n_nodes=34, n_node_features=4, n_edge_features=38, n_labels=2)
Graph(n_nodes=34, n_node_features=4, n_edge_features=38, n_labels=2)
Graph(n_nodes=41, n_node_features=4, n_edge_features=44, n_labels=2)
