In [1]:
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from rdkit import Chem
import pickle
from rdkit.Chem.rdchem import HybridizationType, ChiralType
from torch_geometric.utils import from_networkx

[12:30:10] Enabling RDKit 2019.09.3 jupyter extensions


In [2]:
filepath = '../Dataset/merged.sdf'
mols = Chem.SDMolSupplier(filepath)

In [3]:
def mol2y(mol):
    _y = []
    som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
           'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
           'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
          ]
    result = []
    for k in som:
        try:
            _res = mol.GetProp(k)
            if ' ' in _res:
                res = _res.split(' ')
                for s in res:
                    result.append(int(s))
                # res = [int(temp) for temp in res]
            else:
                # res = [int(_res)]
                result.append(int(_res))
        except:
            pass

    for data in result:
        _y.append(data)
    _y = list(set(_y))

    y = np.zeros(len(mol.GetAtoms()))
    for i in _y:
        y[i-1] = 1
    return y

In [4]:
def mol2graph(mol):
    target = mol2y(mol)
    g = nx.Graph()
    identity = {
        'C':[1,0,0,0,0,0,0,0,0,0],
        'N':[0,1,0,0,0,0,0,0,0,0],
        'O':[0,0,1,0,0,0,0,0,0,0],
        'F':[0,0,0,1,0,0,0,0,0,0],
        'P':[0,0,0,0,1,0,0,0,0,0],
        'S':[0,0,0,0,0,1,0,0,0,0],
        'Cl':[0,0,0,0,0,0,1,0,0,0],
        'Br':[0,0,0,0,0,0,0,1,0,0],
        'I':[0,0,0,0,0,0,0,0,1,0],
        'other':[0,0,0,0,0,0,0,0,0,1],
    }
    for atom in mol.GetAtoms():
        node_feats = []
        # atom number
        idx = atom.GetIdx()
        # atom type one-hot 10
        node_feats.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
        # implicit valence
        node_feats.append(atom.GetImplicitValence())
        # formal charge
        node_feats.append(atom.GetFormalCharge())
        # radical electrons
        node_feats.append(atom.GetNumRadicalElectrons())
            
        # aromatic 0 or 1
        if atom.GetIsAromatic():
            node_feats.append(1)
        else:
            node_feats.append(0)

        # chirality
        chirality = atom.GetChiralTag()
        if chirality == ChiralType.CHI_TETRAHEDRAL_CCW: temp = [1, 0, 0, 0]
        if chirality == ChiralType.CHI_TETRAHEDRAL_CW: temp = [0, 1, 0, 0]
        if chirality == ChiralType.CHI_OTHER: temp = [0, 0, 1, 0]
        if chirality == ChiralType.CHI_UNSPECIFIED: temp = [0, 0, 0, 1]
        node_feats.extend(temp)
        # hybridization
        hybridization = atom.GetHybridization()
        if hybridization == HybridizationType.S: tmp = [1, 0, 0, 0, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP: tmp = [0, 1, 0, 0, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP2: tmp = [0, 0, 1, 0, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP3: tmp = [0, 0, 0, 1, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP3D: tmp = [0, 0, 0, 0, 1, 0, 0, 0]
        if hybridization == HybridizationType.SP3D2: tmp = [0, 0, 0, 0, 0, 1, 0, 0]
        if hybridization == HybridizationType.OTHER: tmp = [0, 0, 0, 0, 0, 0, 1, 0]
        if hybridization == HybridizationType.UNSPECIFIED: tmp = [0, 0, 0, 0, 0, 0, 0, 1]
        node_feats.extend(tmp)
        node_feats = np.asarray(node_feats)
        g.add_node(idx, x=node_feats, y=int(target[idx]))

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            edge_feats = np.asarray(edge_feats)
            g.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), edge_attr = edge_feats)

    return g

In [5]:
def get_neighbors_aslist(g, node, depth=3):
    output = {}
    output[0] = [node]
    layers = dict(nx.bfs_successors(g, source=node, depth_limit=depth))
    nodes = [node]
    for i in range(1, depth+1):
        output[i] = []
        for x in nodes:
            output[i].extend(layers.get(x, []))
        nodes = output[i]
    res = []
    for _, v in output.items():
        res.extend(v)
    return res

In [6]:
dataset = []
for mol in mols:
    g = mol2graph(mol)
    dataset.append(g)

In [7]:
# split training set， test set
import random
random.seed(42)
random.shuffle(dataset)

In [8]:
training_set = dataset[:int(len(dataset) * 0.8)]
test_set = dataset[int(len(dataset) * 0.8):]

In [11]:
_tr_set = []
for g in training_set:
    for node in g.nodes(data=True):
        out = get_neighbors_aslist(g, node[0], depth=3)
        # subgraph
        subgraph = g.subgraph(out)
        # generate new y
        y = 0
        for n in subgraph.nodes(data=True):
            if n[-1]['y']:
                y = 1
                break
        _tr_set.append((subgraph, np.array(y,dtype=np.int64)))

In [12]:
len(_tr_set)

12010

In [13]:
_test_set = []
for g in test_set:
    for node in g.nodes(data=True):
        out = get_neighbors_aslist(g, node[0], depth=3)
        # subgraph
        subgraph = g.subgraph(out)
        # generate new y
        y = 0
        for n in subgraph.nodes(data=True):
            if n[-1]['y']:
                y = 1
                break
        _test_set.append((subgraph, np.array(y,dtype=np.int64)))

In [14]:
len(_test_set)

3259

In [15]:
import pickle
pickle.dump(_tr_set, open('./subgraphdataset/raw/train.pkl', 'wb'))
pickle.dump(_test_set, open('./subgraphdataset/raw/test.pkl', 'wb'))

In [24]:
class SubGraph(Dataset):

    def __init__(self, root, filename, test=False,transform=None, pre_transform=None, pre_filter=None):
        self.filename = filename
        self.test = test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return self.filename

    @property
    def processed_file_names(self):
        self.raws = pickle.load(open(self.raw_paths[0], 'rb'))
        if self.test:
            return [f'data_test_{i}' for i in range(len(self.raws))]
        else:
            return [f'data_{i}.pt' for i in range(len(self.raws))]

    def download(self):
        pass

    def process(self):
        self.raws = pickle.load(open(self.raw_paths[0], 'rb'))
        for idx, mol in enumerate(self.raws):
            subgraph, label = mol
            # create data object
            data = from_networkx(subgraph)
            label = torch.tensor(label, dtype=torch.int64)
            data['target'] = label
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_test_{idx}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_{idx}.pt'))
        
    def len(self):
        return len(self.raws)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [25]:
train_dataset = SubGraph('./subgraphdataset/', 'train.pkl')
test_dataset = SubGraph('./subgraphdataset/', 'test.pkl', test=True)

Processing...
Done!
Processing...
Done!
