In [2]:
from rdkit import Chem
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
import os
import pandas as pd
import pickle
from rdkit.Chem.rdchem import HybridizationType, ChiralType

[12:23:19] Enabling RDKit 2019.09.3 jupyter extensions


In [3]:
filepath = '../Dataset/Supplementary Material-2.xls'
token_list = ['2E1', '2D6', '3A4', '2A6', '2C19', '2C9', '2B6', '1A2', '2C8']

In [4]:
# data = {'name':[smiles, {'token':target}]}
train_data = {}
for idx, token in enumerate(token_list):
    df = pd.read_excel(io=filepath,sheet_name=idx)
    for i in range(len(df.index.values)):
        name, smiles, target = df.loc[i, :].values
        if name in train_data:
            train_data[name][-1][token] = target
        else:
            train_data[name] = [smiles, {token: target}]
print(len(train_data))

1811


In [5]:
# find error, if some cyp is true, some cyp is false, that is error
raw_data = []
for key, value in train_data.items():
    if len(value[-1]) > 1:
        if True in value[-1].values() and False in value[-1].values():
            raw_data.append([value[0], True])
        elif True in value[-1].values():
            raw_data.append([value[0], True])
        else:
            raw_data.append([value[0], False])
    else:
        if True in value[-1].values():
            raw_data.append([value[0], True])
        else:
            raw_data.append([value[0], False])
print(len(raw_data))

1811


In [6]:
true_count = 0
false_count = 0
for data in raw_data:
    if data[-1]:
        true_count += 1
    else:
        false_count += 1
print(f'truecount is {true_count} false count is {false_count}')

truecount is 700 false count is 1111


In [7]:
import random
random.seed(42)
random.shuffle(raw_data)

In [8]:
train_data = raw_data[:int(len(raw_data) * 0.8)]
test_data = raw_data[int(len(raw_data) * 0.8):]
pickle.dump(train_data, open('../Dataset/cypstrate_all/raw/train.pkl', 'wb'))
pickle.dump(test_data, open('../Dataset/cypstrate_all/raw/test.pkl', 'wb'))

In [9]:
class Cypstrateall(Dataset):

    def __init__(self, root, filename, test=False,transform=None, pre_transform=None, pre_filter=None):
        self.filename = filename
        self.test = test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return self.filename

    @property
    def processed_file_names(self):
        self.raws = pickle.load(open(self.raw_paths[0], 'rb'))
        if self.test:
            return [f'data_test_{i}' for i in range(len(self.raws))]
        else:
            return [f'data_{i}.pt' for i in range(len(self.raws))]

    def download(self):
        pass

    def process(self):
        self.raws = pickle.load(open(self.raw_paths[0], 'rb'))
        for idx, data in enumerate(self.raws):
            mol = Chem.MolFromSmiles(data[0])
            # Get node features
            node_feats = self._get_node_features(mol)
            # Get edge features
            edge_feats = self._get_edge_features(mol)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol)
            # Get labels info
            label = self._get_labels(data[1])
            # create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        )
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_test_{idx}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_{idx}.pt'))
        
    def _get_node_features(self, mol):
        all_node_feats = []

        identity = {
            'C':[1,0,0,0,0,0,0,0,0,0],
            'N':[0,1,0,0,0,0,0,0,0,0],
            'O':[0,0,1,0,0,0,0,0,0,0],
            'F':[0,0,0,1,0,0,0,0,0,0],
            'P':[0,0,0,0,1,0,0,0,0,0],
            'S':[0,0,0,0,0,1,0,0,0,0],
            'Cl':[0,0,0,0,0,0,1,0,0,0],
            'Br':[0,0,0,0,0,0,0,1,0,0],
            'I':[0,0,0,0,0,0,0,0,1,0],
            'other':[0,0,0,0,0,0,0,0,0,1],
        }
        for atom in mol.GetAtoms():
            node_feats = []
            # atom number
            idx = atom.GetIdx()
            # atom type one-hot 10
            node_feats.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
            # implicit valence
            node_feats.append(atom.GetImplicitValence())
            # formal charge
            node_feats.append(atom.GetFormalCharge())
            # radical electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            
            # aromatic 0 or 1
            if atom.GetIsAromatic():
                node_feats.append(1)
            else:
                node_feats.append(0)

            # chirality
            chirality = atom.GetChiralTag()
            if chirality == ChiralType.CHI_TETRAHEDRAL_CCW: temp = [1, 0, 0, 0]
            if chirality == ChiralType.CHI_TETRAHEDRAL_CW: temp = [0, 1, 0, 0]
            if chirality == ChiralType.CHI_OTHER: temp = [0, 0, 1, 0]
            if chirality == ChiralType.CHI_UNSPECIFIED: temp = [0, 0, 0, 1]
            node_feats.extend(temp)
            # hybridization
            hybridization = atom.GetHybridization()
            if hybridization == HybridizationType.S: tmp = [1, 0, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP: tmp = [0, 1, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP2: tmp = [0, 0, 1, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3: tmp = [0, 0, 0, 1, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3D: tmp = [0, 0, 0, 0, 1, 0, 0, 0]
            if hybridization == HybridizationType.SP3D2: tmp = [0, 0, 0, 0, 0, 1, 0, 0]
            if hybridization == HybridizationType.OTHER: tmp = [0, 0, 0, 0, 0, 0, 1, 0]
            if hybridization == HybridizationType.UNSPECIFIED: tmp = [0, 0, 0, 0, 0, 0, 0, 1]
            node_feats.extend(tmp)
            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, target):
        if target:
            y = 1
        else:
            y = 0
        return torch.tensor(y, dtype=torch.int64)

    def len(self):
        return len(self.raws)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [10]:
train_dataset = Cypstrateall('../Dataset/cypstrate_all/', 'train.pkl')
test_dataset = Cypstrateall('../Dataset/cypstrate_all/', 'test.pkl', test=True)

Processing...
Done!
