In [2]:
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit import RDPaths
from rdkit.Chem import rdmolops, rdmolfiles

import dgl
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dgl import model_zoo

from dgl.data.chem.utils import mol_to_complete_graph, mol_to_bigraph
# from dgllife.utils import mol_to_bigraph

from dgl.data.chem.utils import atom_type_one_hot
from dgl.data.chem.utils import atom_degree_one_hot
from dgl.data.chem.utils import atom_formal_charge
from dgl.data.chem.utils import atom_num_radical_electrons
from dgl.data.chem.utils import atom_hybridization_one_hot
from dgl.data.chem.utils import atom_total_num_H_one_hot
from dgl.data.chem.utils import one_hot_encoding
from dgl.data.chem import CanonicalAtomFeaturizer
from dgl.data.chem import CanonicalBondFeaturizer
from dgl.data.chem import ConcatFeaturizer
from dgl.data.chem import BaseAtomFeaturizer
from dgl.data.chem import BaseBondFeaturizer

from dgl.data.chem import one_hot_encoding
from dgl.data.utils import split_dataset

from functools import partial
from sklearn.metrics import roc_auc_score
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error as mse

import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
seed = 1
def data_train_test(data, n_epochs = 100):
    def chirality(atom):
        try:
            return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
                   [atom.HasProp('_ChiralityPossible')]
        except:
            return [False, False] + [atom.HasProp('_ChiralityPossible')]


    def collate_molgraphs(data):
        assert len(data[0]) in [3, 4], \
            'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
        if len(data[0]) == 3:
            smiles, graphs, labels = map(list, zip(*data))
            masks = None
        else:
            smiles, graphs, labels, masks = map(list, zip(*data))

        bg = dgl.batch(graphs)
        bg.set_n_initializer(dgl.init.zero_initializer)
        bg.set_e_initializer(dgl.init.zero_initializer)
        labels = torch.stack(labels, dim=0)

        if masks is None:
            masks = torch.ones(labels.shape)
        else:
            masks = torch.stack(masks, dim=0)
        return smiles, bg, labels, masks


    def run_a_train_epoch(n_epochs, epoch, model, data_loader, loss_criterion, optimizer):
        model.train()
        total_loss = 0
        losses = []
        pred = []
        best_r2 = 0
        for batch_id, batch_data in enumerate(data_loader):
            batch_data
            smiles, bg, labels, masks = batch_data
            bg.to(torch.device('cpu'))
            labels = labels.to('cpu')
            masks = masks.to('cpu')

            prediction = model(bg, bg.ndata['hv'], bg.edata['he'])
            pred.append(prediction.detach().numpy())
            loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.data.item())

        model.eval()
        all_pred = []
        for test_data in test_loader:
            smi_lst, bg, labels, masks = test_data
            bg.to(torch.device('cpu'))
            labels = labels.to('cpu')
            masks = masks.to('cpu')
            pred = model(bg, bg.ndata['hv'], bg.edata['he'])
            all_pred.append(pred.data.cpu().numpy())

        res = np.vstack(all_pred)
        r2 = r2_score(test_sol, np.vstack(res))
        Mse = mse(test_sol, np.vstack(res))
        r2_all.append(r2)
        mse_all.append(Mse)
        if best_r2 < r2:
            best_r2 = r2
            print()
#             print(res)
            pd.DataFrame([res.ravel(), np.array(test_sol).ravel()]).T.to_csv("./results/{}/DL/{}/GNN/test/test_pred.csv".format(data, seed), index=0)
        total_score = np.mean(losses)
        print('epoch {:d}/{:d}, training mse: {:.4f} test r_2: {:.4f} mse: {:.4f}'.format(epoch + 1, n_epochs, total_score, r2, Mse), end="")

        return total_score

    atom_featurizer = BaseAtomFeaturizer(
                     {'hv': ConcatFeaturizer([
                      partial(atom_type_one_hot, allowable_set=[
                              'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
                        encode_unknown=True),
                      partial(atom_degree_one_hot, allowable_set=list(range(6))),
                      atom_formal_charge, atom_num_radical_electrons,
                      partial(atom_hybridization_one_hot, encode_unknown=True),
                      lambda atom: [0], # A placeholder for aromatic information,
                        atom_total_num_H_one_hot, chirality
                     ],
                    )})
    bond_featurizer = BaseBondFeaturizer({
                                         'he': lambda bond: [0 for _ in range(10)]
        })

    os.makedirs("./results/{}/DL/{}/GNN/test/".format(data, seed), exist_ok=True)
    os.makedirs("./results/{}/DL/{}/GNN/train/".format(data, seed), exist_ok=True)

    train = os.path.join("./data/{}/{}_training.sdf".format(data, data))
    test = os.path.join("./data/{}/{}_prediction.sdf".format(data, data))

    train_mols = Chem.SDMolSupplier(train)
    train_smi = [Chem.MolToSmiles(m) for m in train_mols]
    train_sol = torch.tensor([float(mol.GetProp('Tox')) for mol in train_mols]).reshape(-1, 1)

    test_mols = Chem.SDMolSupplier(test)
    test_smi = [Chem.MolToSmiles(m) for m in test_mols]
    test_sol = torch.tensor([float(mol.GetProp('Tox')) for mol in test_mols]).reshape(-1, 1)

    train_graph = [mol_to_bigraph(mol,
                                  node_featurizer=atom_featurizer,
                                  edge_featurizer=bond_featurizer) for mol in train_mols]

    test_graph = [mol_to_bigraph(mol,
                                 node_featurizer=atom_featurizer,
                                 edge_featurizer=bond_featurizer) for mol in test_mols]

    model = model_zoo.chem.AttentiveFP(node_feat_size=39,
                                       edge_feat_size=10,
                                       num_layers=2,
                                       num_timesteps=2,
                                       graph_feat_size=200,
                                       output_size=1,
                                       dropout=0.2)
    model = model.to('cpu')

    train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=16,
                              collate_fn=collate_molgraphs)
    test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=16,
                             collate_fn=collate_molgraphs)

    loss_fn = nn.MSELoss(reduction='none')
    optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
    epochs = []
    scores = []
    r2_all = []
    mse_all = []
    for e in range(n_epochs):
        score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)
        epochs.append(e)
        scores.append(score)

    plt.plot(epochs, scores)
    # plt.savefig('score.png')

    '''
    Validate the model
    '''
    model.eval()
    all_pred = []
    for test_data in test_loader:
        smi_lst, bg, labels, masks = test_data
        bg.to(torch.device('cpu'))
        labels = labels.to('cpu')
        masks = masks.to('cpu')
        pred = model(bg, bg.ndata['hv'], bg.edata['he'])
        all_pred.append(pred.data.cpu().numpy())
    res = np.vstack(all_pred)

    print(r2_score(test_sol, res))

    plt.figure()
    plt.plot(range(len(r2_all)), r2_all)
    plt.plot(range(len(r2_all)), mse_all)
    plt.show
    print(max(r2_all), min(mse_all))
    print(r2_all.index(max(r2_all)), mse_all.index(min(mse_all)))
    
data_train_test("IGC50", n_epochs = 100)

In [1]:
import pandas as pd
def data_train_kf(data, set_cv = 5, n_epochs = 100):
    def chirality(atom):
        try:
            return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
                   [atom.HasProp('_ChiralityPossible')]
        except:
            return [False, False] + [atom.HasProp('_ChiralityPossible')]


    def collate_molgraphs(data):
        assert len(data[0]) in [3, 4], \
            'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
        if len(data[0]) == 3:
            smiles, graphs, labels = map(list, zip(*data))
            masks = None
        else:
            smiles, graphs, labels, masks = map(list, zip(*data))

        bg = dgl.batch(graphs)
        bg.set_n_initializer(dgl.init.zero_initializer)
        bg.set_e_initializer(dgl.init.zero_initializer)
        labels = torch.stack(labels, dim=0)

        if masks is None:
            masks = torch.ones(labels.shape)
        else:
            masks = torch.stack(masks, dim=0)
        return smiles, bg, labels, masks


    def run_a_train_epoch(n_epochs, epoch, model, data_loader, loss_criterion, optimizer):
        model.train()
        total_loss = 0
        losses = []
        pred = []
        best_r2 = 0
        for batch_id, batch_data in enumerate(data_loader):
            batch_data
            smiles, bg, labels, masks = batch_data
            bg.to(torch.device('cpu'))
            labels = labels.to('cpu')
            masks = masks.to('cpu')

            prediction = model(bg, bg.ndata['hv'], bg.edata['he'])
            pred.append(prediction.detach().numpy())
            loss = (loss_criterion(prediction, labels) * (masks != 0).float()).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.data.item())

        model.eval()
        all_pred = []
        all_y = []
        for test_data in test_loader:
            smi_lst, bg, labels, masks = test_data
            bg.to(torch.device('cpu'))
            labels = labels.to('cpu')
            masks = masks.to('cpu')
            pred = model(bg, bg.ndata['hv'], bg.edata['he'])
            all_pred.append(pred.data.cpu().numpy())
            all_y.append(labels)

        res = np.vstack(all_pred)
        r2 = r2_score(np.vstack(all_y), np.vstack(res))
        if r2 > best_r2:
            best_r2 = r2
            pred_kflod.iloc[index_val, 0] = np.ravel(res)
            pred_kflod.iloc[index_val, 1] = np.ravel(np.vstack(all_y))

        Mse = mse(np.vstack(all_y), np.vstack(res))
        r2_all.append(r2)
        mse_all.append(Mse)
        total_score = np.mean(losses)
        print('epoch {:d}/{:d}, cv: {} train mse: {:.4f} test r_2: {:.4f} mse: {:.4f}'.format(epoch + 1, n_epochs, cv, total_score, r2, Mse), end="")
        return total_score

    '''
    Atom and bond featurizer
    '''
    atom_featurizer = BaseAtomFeaturizer(
                     {'hv': ConcatFeaturizer([
                      partial(atom_type_one_hot, allowable_set=[
                              'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
                        encode_unknown=True),
                      partial(atom_degree_one_hot, allowable_set=list(range(6))),
                      atom_formal_charge, atom_num_radical_electrons,
                      partial(atom_hybridization_one_hot, encode_unknown=True),
                      lambda atom: [0],
                        atom_total_num_H_one_hot, chirality
                     ],
                    )})
    bond_featurizer = BaseBondFeaturizer({
                                         'he': lambda bond: [0 for _ in range(10)]
        })

    '''
    Convert molecules to graph
    '''
    os.makedirs("./results/{}/DL/{}/GNN/test/".format(data, seed), exist_ok=True)
    os.makedirs("./results/{}/DL/{}/GNN/train/".format(data, seed), exist_ok=True)

    train = os.path.join("./data/{}/{}_training.sdf".format(data, data))
    test = os.path.join("./data/{}/{}_prediction.sdf".format(data, data))

    train_mols = Chem.SDMolSupplier(train)
    train_smi = [Chem.MolToSmiles(m) for m in train_mols]
    train_sol = torch.tensor([float(mol.GetProp('Tox')) for mol in train_mols]).reshape(-1, 1)

    test_mols = Chem.SDMolSupplier(test)
    test_smi = [Chem.MolToSmiles(m) for m in test_mols]
    test_sol = torch.tensor([float(mol.GetProp('Tox')) for mol in test_mols]).reshape(-1, 1)

    train_graph = [mol_to_bigraph(mol,
                                  node_featurizer=atom_featurizer,
                                  edge_featurizer=bond_featurizer) for mol in train_mols]

    test_graph = [mol_to_bigraph(mol,
                                 node_featurizer=atom_featurizer,
                                 edge_featurizer=bond_featurizer) for mol in test_mols]


    model = model_zoo.chem.AttentiveFP(node_feat_size=39,
                                       edge_feat_size=10,
                                       num_layers=2,
                                       num_timesteps=2,
                                       graph_feat_size=200,
                                       output_size=1,
                                       dropout=0.2)
    model = model.to('cpu')

    seed = 1
    base_indices = np.arange(0,len(train_sol))
    np.random.seed(seed)
    np.random.shuffle(base_indices)
    np.random.seed(seed)
    np.random.shuffle(base_indices)
    step = int(len(train_sol)/set_cv)
    pred_kflod = pd.DataFrame(index=range(len(train_sol)), columns=["cv1~5", "True"])
    for cv in range(set_cv):
        print("*"*20, "Kflod", cv ,"*"*20)
        index = base_indices
        if cv < set_cv-1:
            index_train = np.concatenate([index[:cv*step],index[(cv+1)*step:]], axis=0)
            index_val = index[cv*step:(cv+1)*step]
        else: 
            index_train = index[0:cv*step]
            index_val = index[cv*step:]

        train_loader = DataLoader(dataset=list(zip(np.array(train_smi)[index_train], np.array(train_graph)[index_train], train_sol[index_train])), batch_size=122,
                                  collate_fn=collate_molgraphs)
        test_loader = DataLoader(dataset=list(zip(np.array(train_smi)[index_val], np.array(train_graph)[index_val], train_sol[index_val])), batch_size=122,
                                 collate_fn=collate_molgraphs)

        model = model_zoo.chem.AttentiveFP(node_feat_size=39,
                                       edge_feat_size=10,
                                       num_layers=2,
                                       num_timesteps=2,
                                       graph_feat_size=200,
                                       output_size=1,
                                       dropout=0.2)
        model = model.to('cpu')
        loss_fn = nn.MSELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
        
        epochs = []
        scores = []
        r2_all = []
        mse_all = []
        for e in range(n_epochs):
            score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)
            epochs.append(e)
            scores.append(score)

        plt.plot(epochs, scores)
        plt.savefig('{}.png'.format(cv))
    print(r2_score(pred_kflod.iloc[:,0], pred_kflod.iloc[:,1]))
    pred_kflod.to_csv("./results/{}/DL/{}/GNN/train/kf_pred_all.csv".format(data, seed), index=0)
    
data_train_kf("IGC50", set_cv = 5, n_epochs = 100)