In [None]:
# pip install --upgrade torch

In [None]:
# pip install dgl.sparse 

In [None]:
# pip install dgl==0.9

In [47]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import scipy.sparse as sp

import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from sklearn.model_selection import train_test_split
from time import time
from tqdm import tqdm
from sklearn.manifold import TSNE
from sklearn.cluster import FeatureAgglomeration
from sklearn.decomposition import PCA
import networkx as nx

from functions_HGPLS import train, test

In [48]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("seyonec/SMILES_tokenized_PubChem_shard00_160k")

model = AutoModelForMaskedLM.from_pretrained("seyonec/SMILES_tokenized_PubChem_shard00_160k")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [49]:
atoms = ["", "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm"]
print(len(atoms))
ids = tokenizer(atoms).input_ids

toks = list()
for i in ids:
    if len(i) >= 3:
        toks.append(i[1])
    else:
        toks.append(i[0])

100


In [50]:
dataset, labels = dgl.load_graphs("../data/HIV_dgl_graphs")

In [51]:
labels = labels['glabel'].tolist()

In [52]:
# d = dataset[0].ndata["atomic"]
tok = torch.zeros((len(toks), 1))
for i in range(len(toks)):
    tok[i] = toks[i]

tok
transfo_atoms = model(tok.to(torch.long)).logits.reshape(100, -1)

In [53]:
pca = PCA(n_components=100, random_state=42)
res = pca.fit_transform(transfo_atoms.detach().numpy())

In [54]:
new_dataset = list()
for i in range(len(dataset)):
    #g, l = dataset[i]
    g = dataset[i]
    l = labels[i]
    #g.ndata["feature"] = torch.ones_like(dataset.feature[g.ndata["_ID"]])
    ids=list(g.ndata["atomic"].numpy().flatten().astype(int))
    g.ndata["feature"] = torch.tensor(res[ids]) #g.ndata["atomic"]
    g = dgl.add_self_loop(g)
    g = dgl.add_reverse_edges(g)
    new_dataset.append((g, l))

In [55]:
train_dataset, test_dataset = train_test_split(new_dataset, test_size=0.25, random_state=42)

train_dataloader = GraphDataLoader(train_dataset, batch_size=16, drop_last=False)
test_dataloader = GraphDataLoader(test_dataset, batch_size=16, drop_last=False)

In [77]:
from dgl.nn.pytorch import GATConv

class GAT(nn.Module):
    def __init__(self, in_feats, h_feats, n_classes):
        super(GAT, self).__init__()
        self.layer1 = GATConv(in_feats, h_feats, num_heads=4)
        self.layer2 = GATConv(4*h_feats, h_feats, num_heads=4)
        self.layer3 = GATConv(4*h_feats, h_feats, num_heads=6)
        self.fc = nn.Linear(h_feats, n_classes)
        self.elu = nn.ELU()

    def forward(self, g, in_feat):
        x1 = self.layer1(g, in_feat)
        x1 = self.elu(x1)
        x1 = x1.view(in_feat.shape[0], -1)
        x2 = self.layer2(g, x1)
        x2 = self.elu(x2)
        x2 = x2.view(in_feat.shape[0], -1)
        x3 = self.layer3(g, x2)
        x3 = torch.mean(x3, dim=1)
        with g.local_scope():
            g.ndata['h'] = x3
            x4 = dgl.readout_nodes(g, 'h')
        self.hidden = x4
        return F.log_softmax(self.fc(x4), dim=-1)

In [78]:
# Load model architecture
device = 'cpu' if torch.cuda.is_available() else 'cpu'
model = GAT(in_feats=100, n_classes=3, h_feats=256).to(device)

In [79]:
# Define optimizer and loss
optimizer = torch.optim.Adam(
        model.parameters(), lr=0.000002 #, weight_decay=0.0001
    )
loss = torch.nn.CrossEntropyLoss()

In [80]:
# Train model and keep the best validation loss model
bad_cound = 0
best_val_acc = 0
best_epoch = 0
epochs = 40
patience = 30
print_every = 1
train_times = []
for e in range(epochs):
    s_time = time()
    train_loss, train_acc = train(model, optimizer, loss, train_dataloader, device)
    train_times.append(time() - s_time)
    val_acc, val_loss = test(model, loss, test_dataloader, device)
    if best_val_acc < val_acc:
        best_val_acc = val_acc
        bad_cound = 0
        best_epoch = e + 1
        torch.save(model.state_dict(), "../models/GATModel_prot.pt")
    else:
        bad_cound += 1
    if bad_cound >= patience:
        break

    if (e + 1) % print_every == 0:
        log_format = (
            "Epoch {}: train_loss={:.4f}, train_acc={:.4f}, val_acc={:.4f}, vall_loss={:.4f}"
        )
        print(log_format.format(e + 1, train_loss, train_acc, val_acc, val_loss))
print(
    "Best Epoch {}, final test loss {:.4f}".format(
        best_epoch, best_val_acc
    )
)

100%|██████████| 115/115 [00:19<00:00,  5.83it/s]
100%|██████████| 39/39 [00:01<00:00, 19.76it/s]


Epoch 1: train_loss=62.1915, train_acc=0.3876, val_acc=0.4255, vall_loss=6.7287


100%|██████████| 115/115 [00:15<00:00,  7.21it/s]
100%|██████████| 39/39 [00:02<00:00, 15.58it/s]


Epoch 2: train_loss=6.3449, train_acc=0.3750, val_acc=0.3650, vall_loss=4.4796


100%|██████████| 115/115 [00:18<00:00,  6.20it/s]
100%|██████████| 39/39 [00:02<00:00, 15.62it/s]


Epoch 3: train_loss=3.8966, train_acc=0.3881, val_acc=0.3928, vall_loss=3.1683


100%|██████████| 115/115 [00:18<00:00,  6.09it/s]
100%|██████████| 39/39 [00:03<00:00, 11.71it/s]


Epoch 4: train_loss=2.9235, train_acc=0.3979, val_acc=0.4648, vall_loss=2.4346


100%|██████████| 115/115 [00:25<00:00,  4.45it/s]
100%|██████████| 39/39 [00:03<00:00, 11.05it/s]


Epoch 5: train_loss=2.5616, train_acc=0.4078, val_acc=0.4632, vall_loss=2.1929


100%|██████████| 115/115 [00:26<00:00,  4.33it/s]
100%|██████████| 39/39 [00:03<00:00, 12.09it/s]


Epoch 6: train_loss=2.4541, train_acc=0.4007, val_acc=0.4484, vall_loss=2.0312


100%|██████████| 115/115 [00:25<00:00,  4.58it/s]
100%|██████████| 39/39 [00:03<00:00, 11.73it/s]


Epoch 7: train_loss=2.3391, train_acc=0.3996, val_acc=0.4664, vall_loss=1.8969


100%|██████████| 115/115 [00:25<00:00,  4.56it/s]
100%|██████████| 39/39 [00:03<00:00, 11.92it/s]


Epoch 8: train_loss=2.1957, train_acc=0.3985, val_acc=0.4484, vall_loss=1.9681


100%|██████████| 115/115 [00:25<00:00,  4.56it/s]
100%|██████████| 39/39 [00:03<00:00, 11.61it/s]


Epoch 9: train_loss=2.1086, train_acc=0.4083, val_acc=0.4534, vall_loss=1.8757


100%|██████████| 115/115 [00:20<00:00,  5.50it/s]
100%|██████████| 39/39 [00:02<00:00, 16.02it/s]


Epoch 10: train_loss=2.0569, train_acc=0.4105, val_acc=0.4534, vall_loss=1.8053


100%|██████████| 115/115 [00:18<00:00,  6.24it/s]
100%|██████████| 39/39 [00:02<00:00, 16.39it/s]


Epoch 11: train_loss=2.0113, train_acc=0.4192, val_acc=0.4337, vall_loss=1.7953


100%|██████████| 115/115 [00:19<00:00,  5.76it/s]
100%|██████████| 39/39 [00:03<00:00, 11.80it/s]


Epoch 12: train_loss=1.9673, train_acc=0.4159, val_acc=0.4321, vall_loss=1.8242


100%|██████████| 115/115 [00:24<00:00,  4.69it/s]
100%|██████████| 39/39 [00:03<00:00, 11.74it/s]


Epoch 13: train_loss=1.8962, train_acc=0.4181, val_acc=0.4255, vall_loss=1.7991


100%|██████████| 115/115 [00:25<00:00,  4.58it/s]
100%|██████████| 39/39 [00:03<00:00, 12.18it/s]


Epoch 14: train_loss=1.8342, train_acc=0.4056, val_acc=0.4239, vall_loss=1.9766


100%|██████████| 115/115 [00:24<00:00,  4.66it/s]
100%|██████████| 39/39 [00:03<00:00, 12.10it/s]


Epoch 15: train_loss=1.7756, train_acc=0.4165, val_acc=0.4206, vall_loss=1.8510


100%|██████████| 115/115 [00:25<00:00,  4.46it/s]
100%|██████████| 39/39 [00:03<00:00, 11.48it/s]


Epoch 16: train_loss=1.7292, train_acc=0.4121, val_acc=0.4239, vall_loss=1.6098


100%|██████████| 115/115 [00:25<00:00,  4.57it/s]
100%|██████████| 39/39 [00:03<00:00, 12.48it/s]


Epoch 17: train_loss=1.7228, train_acc=0.4170, val_acc=0.4304, vall_loss=1.5823


100%|██████████| 115/115 [00:24<00:00,  4.62it/s]
100%|██████████| 39/39 [00:03<00:00, 11.99it/s]


Epoch 18: train_loss=1.6724, train_acc=0.4219, val_acc=0.4452, vall_loss=1.4232


100%|██████████| 115/115 [00:23<00:00,  4.84it/s]
100%|██████████| 39/39 [00:03<00:00, 11.01it/s]


Epoch 19: train_loss=1.6588, train_acc=0.4252, val_acc=0.4435, vall_loss=1.4295


100%|██████████| 115/115 [00:24<00:00,  4.64it/s]
100%|██████████| 39/39 [00:03<00:00, 12.07it/s]


Epoch 20: train_loss=1.6439, train_acc=0.4209, val_acc=0.4435, vall_loss=1.4181


100%|██████████| 115/115 [00:24<00:00,  4.70it/s]
100%|██████████| 39/39 [00:03<00:00, 11.77it/s]


Epoch 21: train_loss=1.6267, train_acc=0.4285, val_acc=0.4452, vall_loss=1.4133


100%|██████████| 115/115 [00:24<00:00,  4.68it/s]
100%|██████████| 39/39 [00:03<00:00, 11.34it/s]


Epoch 22: train_loss=1.6138, train_acc=0.4361, val_acc=0.4435, vall_loss=1.4087


100%|██████████| 115/115 [00:23<00:00,  4.82it/s]
100%|██████████| 39/39 [00:03<00:00, 11.70it/s]


Epoch 23: train_loss=1.6034, train_acc=0.4361, val_acc=0.4468, vall_loss=1.4036


100%|██████████| 115/115 [00:24<00:00,  4.61it/s]
100%|██████████| 39/39 [00:03<00:00, 11.70it/s]


Epoch 24: train_loss=1.5946, train_acc=0.4367, val_acc=0.4419, vall_loss=1.3992


100%|██████████| 115/115 [00:24<00:00,  4.67it/s]
100%|██████████| 39/39 [00:03<00:00, 11.78it/s]


Epoch 25: train_loss=1.5870, train_acc=0.4356, val_acc=0.4354, vall_loss=1.3958


100%|██████████| 115/115 [00:24<00:00,  4.68it/s]
100%|██████████| 39/39 [00:03<00:00, 11.51it/s]


Epoch 26: train_loss=1.5797, train_acc=0.4383, val_acc=0.4370, vall_loss=1.3929


100%|██████████| 115/115 [00:24<00:00,  4.70it/s]
100%|██████████| 39/39 [00:03<00:00, 11.83it/s]


Epoch 27: train_loss=1.5731, train_acc=0.4378, val_acc=0.4354, vall_loss=1.3899


100%|██████████| 115/115 [00:24<00:00,  4.64it/s]
100%|██████████| 39/39 [00:03<00:00, 11.85it/s]


Epoch 28: train_loss=1.5667, train_acc=0.4400, val_acc=0.4321, vall_loss=1.3834


100%|██████████| 115/115 [00:24<00:00,  4.74it/s]
100%|██████████| 39/39 [00:03<00:00, 12.04it/s]


Epoch 29: train_loss=1.5617, train_acc=0.4389, val_acc=0.4321, vall_loss=1.3754


100%|██████████| 115/115 [00:24<00:00,  4.77it/s]
100%|██████████| 39/39 [00:03<00:00, 11.69it/s]


Epoch 30: train_loss=1.5590, train_acc=0.4329, val_acc=0.4337, vall_loss=1.3627


100%|██████████| 115/115 [00:24<00:00,  4.71it/s]
100%|██████████| 39/39 [00:03<00:00, 11.76it/s]


Epoch 31: train_loss=1.5563, train_acc=0.4356, val_acc=0.4337, vall_loss=1.3490


100%|██████████| 115/115 [00:24<00:00,  4.72it/s]
100%|██████████| 39/39 [00:03<00:00, 11.50it/s]


Epoch 32: train_loss=1.5512, train_acc=0.4367, val_acc=0.4354, vall_loss=1.3379


100%|██████████| 115/115 [00:24<00:00,  4.77it/s]
100%|██████████| 39/39 [00:03<00:00, 11.52it/s]


Epoch 33: train_loss=1.5468, train_acc=0.4389, val_acc=0.4354, vall_loss=1.3255


100%|██████████| 115/115 [00:24<00:00,  4.71it/s]
100%|██████████| 39/39 [00:03<00:00, 11.46it/s]


Epoch 34: train_loss=1.5419, train_acc=0.4394, val_acc=0.4370, vall_loss=1.3160


100%|██████████| 115/115 [00:23<00:00,  4.79it/s]
100%|██████████| 39/39 [00:03<00:00, 11.96it/s]


Epoch 35: train_loss=1.5372, train_acc=0.4410, val_acc=0.4370, vall_loss=1.3087


100%|██████████| 115/115 [00:24<00:00,  4.72it/s]
100%|██████████| 39/39 [00:03<00:00, 11.19it/s]


Epoch 36: train_loss=1.5325, train_acc=0.4454, val_acc=0.4337, vall_loss=1.3019


100%|██████████| 115/115 [00:24<00:00,  4.66it/s]
100%|██████████| 39/39 [00:03<00:00, 11.61it/s]

Best Epoch 7, final test loss 0.4664





In [81]:
model.load_state_dict(torch.load("../models/GATModel_prot.pt"))

<All keys matched successfully>

In [70]:
pred=[]
lab = []
for batch in tqdm(test_dataloader):
        batch_graphs, batch_labels = batch
        batch_graphs = batch_graphs.to(device)
        batch_labels = batch_labels.long().to(device)
        out = model(batch_graphs, batch_graphs.ndata["feature"].to(dtype=torch.float32))
        pred += out.argmax(dim=1).tolist()
        lab += batch_labels.tolist()

100%|██████████| 39/39 [00:01<00:00, 20.96it/s]


In [82]:
model.fc = nn.Identity()

In [83]:
emb_test=[]
lab_test = []
for batch in tqdm(test_dataloader):
        batch_graphs, batch_labels = batch
        batch_graphs = batch_graphs.to(device)
        batch_labels = batch_labels.long().to(device)
        out = model(batch_graphs, batch_graphs.ndata["feature"].to(dtype=torch.float32))
        out = model.hidden
        emb_test += out.tolist()
        lab_test += batch_labels.tolist()

100%|██████████| 39/39 [00:01<00:00, 21.19it/s]


In [84]:
emb_train=[]
lab_train = []
for batch in tqdm(train_dataloader):
        batch_graphs, batch_labels = batch
        batch_graphs = batch_graphs.to(device)
        batch_labels = batch_labels.long().to(device)
        out = model(batch_graphs, batch_graphs.ndata["feature"].to(dtype=torch.float32))
        out = model.hidden
        emb_train += out.tolist()
        lab_train += batch_labels.tolist()

100%|██████████| 115/115 [00:06<00:00, 18.42it/s]


In [74]:
def create_features(g, emb):
    degree = (g.out_degrees().float() + g.in_degrees().float()).numpy()
    num_edges = g.number_of_edges()
    num_nodes = g.number_of_nodes()
    nx_g = g.to_networkx().to_undirected()
    laplacian = nx.laplacian_matrix(nx_g).astype(np.float32).toarray()
    eigenvals, eigenvecs = np.linalg.eigh(laplacian)
    return emb+[np.mean(degree), np.max(degree)/len(degree), 100*num_edges / (num_nodes * (num_nodes - 1))] + list(eigenvals[-3:])

In [40]:
for i in range(len(emb_train)):
    emb_train[i] = create_features(train_dataset[i][0], emb_train[i])

for i in range(len(emb_test)):
    emb_test[i] = create_features(test_dataset[i][0], emb_test[i])

In [85]:
from xgboost import XGBClassifier
xgb =XGBClassifier(random_state=0)

xgb.fit(emb_train, lab_train)
predxgb = xgb.predict(emb_test)

In [86]:
from sklearn.metrics import accuracy_score
accuracy_score(lab_test, predxgb)

0.5761047463175123