In [1]:
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from rdkit import Chem
import pickle
from rdkit.Chem.rdchem import HybridizationType, ChiralType
from torch_geometric.utils import from_networkx

[09:06:50] Enabling RDKit 2019.09.3 jupyter extensions


In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
filepath = '../Dataset/merged.sdf'
mols = Chem.SDMolSupplier(filepath)

In [4]:
def mol2y(mol):
    _y = []
    som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
           'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
           'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
          ]
    result = []
    for k in som:
        try:
            _res = mol.GetProp(k)
            if ' ' in _res:
                res = _res.split(' ')
                for s in res:
                    result.append(int(s))
                # res = [int(temp) for temp in res]
            else:
                # res = [int(_res)]
                result.append(int(_res))
        except:
            pass

    for data in result:
        _y.append(data)
    _y = list(set(_y))

    y = np.zeros(len(mol.GetAtoms()))
    for i in _y:
        y[i-1] = 1
    return y

In [5]:
def mol2graph(mol):
    target = mol2y(mol)
    g = nx.Graph()
    identity = {
        'C':[1,0,0,0,0,0,0,0,0,0],
        'N':[0,1,0,0,0,0,0,0,0,0],
        'O':[0,0,1,0,0,0,0,0,0,0],
        'F':[0,0,0,1,0,0,0,0,0,0],
        'P':[0,0,0,0,1,0,0,0,0,0],
        'S':[0,0,0,0,0,1,0,0,0,0],
        'Cl':[0,0,0,0,0,0,1,0,0,0],
        'Br':[0,0,0,0,0,0,0,1,0,0],
        'I':[0,0,0,0,0,0,0,0,1,0],
        'other':[0,0,0,0,0,0,0,0,0,1],
    }
    for atom in mol.GetAtoms():
        node_feats = []
        # atom number
        idx = atom.GetIdx()
        # atom type one-hot 10
        node_feats.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
        # implicit valence
        node_feats.append(atom.GetImplicitValence())
        # formal charge
        node_feats.append(atom.GetFormalCharge())
        # radical electrons
        node_feats.append(atom.GetNumRadicalElectrons())
            
        # aromatic 0 or 1
        if atom.GetIsAromatic():
            node_feats.append(1)
        else:
            node_feats.append(0)

        # chirality
        chirality = atom.GetChiralTag()
        if chirality == ChiralType.CHI_TETRAHEDRAL_CCW: temp = [1, 0, 0, 0]
        if chirality == ChiralType.CHI_TETRAHEDRAL_CW: temp = [0, 1, 0, 0]
        if chirality == ChiralType.CHI_OTHER: temp = [0, 0, 1, 0]
        if chirality == ChiralType.CHI_UNSPECIFIED: temp = [0, 0, 0, 1]
        node_feats.extend(temp)
        # hybridization
        hybridization = atom.GetHybridization()
        if hybridization == HybridizationType.S: tmp = [1, 0, 0, 0, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP: tmp = [0, 1, 0, 0, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP2: tmp = [0, 0, 1, 0, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP3: tmp = [0, 0, 0, 1, 0, 0, 0, 0]
        if hybridization == HybridizationType.SP3D: tmp = [0, 0, 0, 0, 1, 0, 0, 0]
        if hybridization == HybridizationType.SP3D2: tmp = [0, 0, 0, 0, 0, 1, 0, 0]
        if hybridization == HybridizationType.OTHER: tmp = [0, 0, 0, 0, 0, 0, 1, 0]
        if hybridization == HybridizationType.UNSPECIFIED: tmp = [0, 0, 0, 0, 0, 0, 0, 1]
        node_feats.extend(tmp)
        node_feats = np.asarray(node_feats)
        g.add_node(idx, x=node_feats, y=int(target[idx]))

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            edge_feats = np.asarray(edge_feats)
            g.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), edge_attr = edge_feats)

    return g

In [6]:
def get_neighbors_aslist(g, node, depth=1):
    output = {}
    output[0] = [node]
    layers = dict(nx.bfs_successors(g, source=node, depth_limit=depth))
    nodes = [node]
    for i in range(1, depth+1):
        output[i] = []
        for x in nodes:
            output[i].extend(layers.get(x, []))
        nodes = output[i]
    res = []
    for _, v in output.items():
        res.extend(v)
    return res

In [7]:
dataset = []
for mol in mols:
    g = mol2graph(mol)
    dataset.append(g)

In [8]:
len(dataset)

680

In [9]:
# split training set， test set， validation set
import random
random.seed(42)
random.shuffle(dataset)

In [10]:
training_set = dataset[:int(len(dataset) * 0.8)]
test_set = dataset[int(len(dataset) * 0.8):]

In [11]:
tr_set = training_set[:int(len(training_set) * 0.8)]
val_set = training_set[int(len(training_set) * 0.8):]

In [12]:
len(tr_set), len(val_set), len(test_set)

(435, 109, 136)

In [13]:
_tr_set = []
for g in tr_set:
    tmp = []
    for node in g.nodes(data=True):
        out = get_neighbors_aslist(g, node[0], depth=2)
        # subgraph
        subgraph = g.subgraph(out)
        # generate new y
        y = []
        for n in subgraph.nodes(data=True):
            y.append(n[-1]['y'])
        tmp.append((subgraph, np.array(y)))
    _tr_set.append(tmp)

In [14]:
_val_set = []
for g in val_set:
    tmp = []
    for node in g.nodes(data=True):
        out = get_neighbors_aslist(g, node[0], depth=2)
        # subgraph
        subgraph = g.subgraph(out)
        # generate new y
        y = []
        for n in subgraph.nodes(data=True):
            y.append(n[-1]['y'])
        tmp.append((subgraph, np.array(y)))
    _val_set.append(tmp)

In [15]:
_test_set = []
for g in test_set:
    tmp = []
    for node in g.nodes(data=True):
        out = get_neighbors_aslist(g, node[0], depth=2)
        # subgraph
        subgraph = g.subgraph(out)
        # generate new y
        y = []
        for n in subgraph.nodes(data=True):
            y.append(n[-1]['y'])
        tmp.append((subgraph, np.array(y)))
    _test_set.append(tmp)

In [16]:
len(_tr_set), len(_val_set), len(_test_set)

(435, 109, 136)

In [17]:
import torch.nn as nn
from torch_geometric.nn import SAGEConv
from torch.utils.data import random_split
from torch_geometric.utils import from_networkx
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, precision_score, f1_score, recall_score, jaccard_score, matthews_corrcoef

In [18]:
from torch.utils.tensorboard import SummaryWriter
tr_writer = SummaryWriter("./tensorboard/subgraph2/train")
val_writer = SummaryWriter("./tensorboard/subgraph2/val")

In [19]:
class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        num_classses = 2

        conv_hidden = args['conv_hidden']
        cls_hidden = args['cls_hidden']
        self.n_layers = args['n_layers']
        # cls_drop = ['cls_drop']

        self.conv_layers = nn.ModuleList([])

        self.conv1 = SAGEConv(26, conv_hidden)

        for i in range(self.n_layers):
            self.conv_layers.append(
                SAGEConv(conv_hidden, conv_hidden)
            )

        self.linear1 = nn.Linear(conv_hidden, cls_hidden)
        self.linear2 = nn.Linear(cls_hidden, num_classses)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(p=0.5)

    
    def forward(self, mol):

        res = self.conv1(mol.x, mol.edge_index)
        for i in range(self.n_layers):
            res = self.conv_layers[i](res, mol.edge_index)

        res = self.linear1(res)
        res = self.relu(res)
        res = self.drop1(res)
        res = self.linear2(res)

        return res

In [20]:
import random
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
def seed_torch(seed=42):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化，使得实验可复现
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

In [21]:
# evaluation
def top2(output, label):
    sf = nn.Softmax(dim=1)
    preds = sf(output)
    preds = preds[:, 1]
    _, indices = torch.topk(preds, 2)
    pos_index = []
    for i in range(label.shape[0]):
        if label[i] == 1:
            pos_index.append(i)  
    for li in pos_index:
        if li in indices:
            return True
    return False

def MCC(output, label):
    print(output, label)
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    return up / down

def metrics(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    mcc = up / down
    selectivity = tn / (tn + fp)
    recall = tp / (tp + fn)
    g_mean = (selectivity * recall) ** 0.5
    balancedAccuracy = (recall + selectivity) / 2
    return mcc, selectivity, recall, g_mean, balancedAccuracy

In [22]:
def train(args, model, device, training_set, optimizer, criterion, epoch):
    model.train()
    sf = nn.Softmax(dim=1)
    total_loss = 0
    all_acc = []
    all_auc = []
    all_mcc = []
    subgraph_num = 0
    for mol in training_set:
        sub_mcc = []
        sub_auc = []
        sub_acc = []
        for sub_mol, target in mol:
            subgraph_num += 1
            sub_mol = from_networkx(sub_mol)
            sub_mol = sub_mol.to(device)
            sub_mol.x = sub_mol.x.to(torch.float32)
            target = torch.tensor(target, dtype=torch.int64).to(device)
            optimizer.zero_grad()
            output= model(sub_mol)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # tracking
            sub_acc.append(accuracy_score(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            try:
                # 验证集会出现全是0，这样子算不了auc，所以这种情况，直接跳过
                sub_auc.append(roc_auc_score(target.cpu().detach().numpy(), sf(output)[:, 1].cpu().detach().numpy()))
            except ValueError:
                pass
            sub_mcc.append(matthews_corrcoef(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))

        all_acc.append(np.mean(sub_acc))
        all_auc.append(np.mean(sub_auc))
        all_mcc.append(np.mean(sub_mcc))

    tr_writer.add_scalar('Ave Loss', total_loss / subgraph_num, epoch)
    tr_writer.add_scalar('ACC', np.mean(all_acc), epoch)
    tr_writer.add_scalar('AUC', np.mean(all_auc), epoch)
    tr_writer.add_scalar('MCC', np.mean(all_mcc), epoch)
    print(f'Train Epoch: {epoch}, Ave Loss: {total_loss / subgraph_num} ACC: {np.mean(all_acc)}  AUC: {np.mean(all_auc)} MCC: {np.mean(all_mcc)}')

In [23]:
def val(args, model, device, val_set, optimizer, criterion, epoch):
    model.eval()
    sf = nn.Softmax(dim=1)
    total_loss = 0
    all_acc = []
    all_auc = []
    all_mcc = []
    subgraph_num = 0
    for mol in val_set:
        sub_mcc = []
        sub_auc = []
        sub_acc = []
        for sub_mol, target in mol:
            subgraph_num += 1
            sub_mol = from_networkx(sub_mol)
            sub_mol = sub_mol.to(device)
            sub_mol.x = sub_mol.x.to(torch.float32)
            target = torch.tensor(target, dtype=torch.int64).to(device)
            optimizer.zero_grad()
            output= model(sub_mol)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # tracking
            sub_acc.append(accuracy_score(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            try:
                # 验证集会出现全是0，这样子算不了auc，所以这种情况，直接跳过
                sub_auc.append(roc_auc_score(target.cpu().detach().numpy(), sf(output)[:, 1].cpu().detach().numpy()))
            except ValueError:
                pass
            sub_mcc.append(matthews_corrcoef(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            
        all_acc.append(np.mean(sub_acc))
        all_auc.append(np.mean(sub_auc))
        all_mcc.append(np.mean(sub_mcc))

    val_writer.add_scalar('Ave Loss', total_loss / subgraph_num, epoch)
    val_writer.add_scalar('ACC', np.mean(all_acc), epoch)
    val_writer.add_scalar('AUC', np.mean(all_auc), epoch)
    val_writer.add_scalar('MCC', np.mean(all_mcc), epoch)
    print(f'validation Epoch: {epoch}, Ave Loss: {total_loss / subgraph_num} ACC: {np.mean(all_acc)}  AUC: {np.mean(all_auc)} MCC: {np.mean(all_mcc)}')
    return np.mean(all_acc)

In [24]:
def main(args):
    seed_torch(args['seed'])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(args['seed'])

    model = Model(args).to(device)
    print(model)
    weights = torch.tensor([1, args['pos_weight']], dtype=torch.float32).to(device)
    loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    max_acc = 0
    for epoch in range(1, args['epoch'] + 1):
        train(args, model, device, _tr_set, optimizer, loss_fn, epoch)
        acc = val(args, model, device, _val_set, optimizer, loss_fn, epoch)
        random.shuffle(_tr_set)
        random.shuffle(_val_set)
        scheduler.step()
        if acc > max_acc:
            max_acc = acc
            print('Saving model (epoch = {:4d}, max_acc = {:.4f})'
                .format(epoch, max_acc))
            torch.save(model.state_dict(), args['save_path'])

In [22]:
args = {
    'lr': 0.02,
    'epoch': 400,
    'seed': 42,
    'save_path': './model/subgraph2hop',
    'pos_weight': 3,
    'conv_hidden': 1024, 
    'cls_hidden': 1024,
    'n_layers': 3
}

In [26]:
main(args)

Model(
  (conv_layers): ModuleList(
    (0): SAGEConv(1024, 1024)
    (1): SAGEConv(1024, 1024)
    (2): SAGEConv(1024, 1024)
  )
  (conv1): SAGEConv(26, 1024)
  (linear1): Linear(in_features=1024, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=2, bias=True)
  (relu): ReLU()
  (drop1): Dropout(p=0.5, inplace=False)
)
Train Epoch: 1, Ave Loss: 0.28298214054242976 ACC: 0.8980744626293523  AUC: nan MCC: 0.18903074702441905
Train Epoch: 1, Ave Loss: 0.23855979564792906 ACC: 0.9148924004477842  AUC: 0.9023518390235831 MCC: 0.21411756799963338
Saving model (epoch =    1, max_acc = 0.9149)
Train Epoch: 2, Ave Loss: 0.2711812774227886 ACC: 0.9036124761948415  AUC: nan MCC: 0.21500841314639962
Train Epoch: 2, Ave Loss: 0.25346555039926527 ACC: 0.9134354784391149  AUC: 0.9093653469821021 MCC: 0.20526605699971087
Train Epoch: 3, Ave Loss: 0.27654132011152827 ACC: 0.9007571097371871  AUC: nan MCC: 0.2056977334227321
Train Epoch: 3, Ave Loss: 0.2501215306128879 ACC

KeyboardInterrupt: 

In [23]:
def test(model, device, test_set):
    model.eval()
    sf = nn.Softmax(dim=1)
    all_acc = []
    all_auc = []
    all_mcc = []
    all_recall = []
    all_precision = []
    all_jaccard = []
    subgraph_num = 0
    with torch.no_grad():
        for mol in test_set:
            sub_acc = []
            sub_auc = []
            sub_mcc = []
            sub_recall = []
            sub_precision = []
            sub_jaccard = []
            for sub_mol, target in mol:
                subgraph_num += 1
                sub_mol = from_networkx(sub_mol)
                sub_mol = sub_mol.to(device)
                sub_mol.x = sub_mol.x.to(torch.float32)
                target = torch.tensor(target, dtype=torch.int64).to(device)
                output= model(sub_mol)
            # squeeze
            # output = torch.squeeze(output)
            # tracking
            sub_acc.append(accuracy_score(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            try:
                # 验证集会出现全是0，这样子算不了auc，所以这种情况，直接跳过
                sub_auc.append(roc_auc_score(target.cpu().detach().numpy(), sf(output)[:, 1].cpu().detach().numpy()))
            except ValueError:
                pass
            sub_mcc.append(matthews_corrcoef(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            sub_precision.append(precision_score(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            sub_recall.append(recall_score(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))
            sub_jaccard.append(jaccard_score(target.cpu().detach().numpy(), np.argmax(output.cpu().detach().numpy(), axis=1)))

            print(f'sub graph average metrics: ACC : {np.mean(sub_acc)} AUC : {np.mean(sub_auc)}\
                    MCC : {np.mean(sub_mcc)} Recall : {np.mean(sub_recall)}\
                    Precision : {np.mean(sub_precision)} Jaccard : {np.mean(sub_jaccard)}')
        all_acc.append(np.mean(sub_acc))
        all_auc.append(np.mean(sub_auc))
        all_mcc.append(np.mean(sub_mcc))
        all_recall.append(np.mean(sub_recall))
        all_precision.append(np.mean(sub_precision))
        all_jaccard.append(np.mean(sub_jaccard))

    # all_pred = np.concatenate(all_pred).ravel()
    # all_pred_raw = np.concatenate(all_pred_raw).ravel()
    # all_labels = np.concatenate(all_labels).ravel()
    # mcc, selectivity, recall, g_mean, balancedAcc = metrics(all_pred, all_labels)
    print(f'ACC: {np.mean(all_acc)} AUC: {np.mean(all_auc)}\
        MCC: {np.mean(all_mcc)}  recall {np.mean(all_recall)} \
        precision score {np.mean(all_precision)} jaccard score {np.mean(all_jaccard)}')

In [24]:
model = Model(args).to("cuda")
model.load_state_dict(torch.load(args['save_path']))

<All keys matched successfully>

In [25]:
test(model, "cuda", _test_set)

sub graph average metrics: ACC : 1.0 AUC : 1.0                    MCC : 1.0 Recall : 1.0                    Precision : 1.0 Jaccard : 1.0
sub graph average metrics: ACC : 1.0 AUC : 1.0                    MCC : 1.0 Recall : 1.0                    Precision : 1.0 Jaccard : 1.0
sub graph average metrics: ACC : 0.75 AUC : 1.0                    MCC : 0.0 Recall : 0.0                    Precision : 0.0 Jaccard : 0.0
sub graph average metrics: ACC : 0.6 AUC : 0.8333333333333334                    MCC : 0.0 Recall : 0.0                    Precision : 0.0 Jaccard : 0.0
sub graph average metrics: ACC : 1.0 AUC : nan                    MCC : 0.0 Recall : 0.0                    Precision : 0.0 Jaccard : 0.0
sub graph average metrics: ACC : 1.0 AUC : 1.0                    MCC : 1.0 Recall : 1.0                    Precision : 1.0 Jaccard : 1.0
sub graph average metrics: ACC : 0.6666666666666666 AUC : 1.0                    MCC : 0.0 Recall : 0.0                    Precision : 0.0 Jaccard : 0.0
sub