In [10]:
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from rf_utils import FPDataLoader, area_under_prc
import pandas as pd

from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import RandomizedSearchCV, train_test_split, GridSearchCV
from sklearn.metrics import classification_report


In [2]:
data = pd.read_csv('data/HIV.csv')
data_fp = FPDataLoader(data.smiles.values, data.HIV_active.values)
data_fp.prepare_fps()
data_fp.dump('data/HIV_preprocessed')

100%|██████████| 41127/41127 [00:03<00:00, 10701.88it/s]
100%|██████████| 41127/41127 [00:35<00:00, 1150.67it/s]
100%|██████████| 41127/41127 [00:46<00:00, 879.70it/s] 
  c /= stddev[:, None]
  c /= stddev[None, :]
100%|██████████| 10/10 [00:07<00:00,  1.27it/s]


In [20]:
data_fp = FPDataLoader()
data_fp.pick('data/HIV_preprocessed')



In [21]:
#split_dict = {'train_ind': train_set.indices, 'valid_ind': valid_set.indices, 'test_ind': test_set.indices}
with open('split_indices', 'rb') as f:
    split_dict = pickle.load(f)

In [31]:
train_ind, valid_ind, test_ind = split_dict['train_ind'], split_dict['valid_ind'], split_dict['test_ind']
#best parameters were found by GridSearchCV using roc-auc metric
rf = RandomForestClassifier(random_state=1141, max_features='sqrt', n_estimators=500,
                            min_samples_split=2, )
param_grid = {'max_depth': [None], 'min_samples_split': [2],
              'max_leaf_nodes': [None], 'n_estimators': [500]}
rf.fit(data_fp.fps[train_ind], data_fp.labels[train_ind])
y_pred = rf.predict(data_fp.fps[valid_ind])
print(classification_report(data_fp.labels[valid_ind], y_pred))
print('roc-auc', roc_auc_score(data_fp.labels[valid_ind], y_pred))
print('prc-auc', area_under_prc(data_fp.labels[valid_ind], y_pred))

              precision    recall  f1-score   support

           0       0.98      0.99      0.99      3974
           1       0.66      0.30      0.41       138

    accuracy                           0.97      4112
   macro avg       0.82      0.65      0.70      4112
weighted avg       0.97      0.97      0.97      4112

roc-auc 0.6459085505058241
prc-auc 0.21239239860065418


In [2]:
from graph_conv_utils import HIV_mols
import os
import pickle

from torchdrug import data, utils
from torchdrug.core import Registry as R
from torchdrug import core, models, tasks, datasets
import torch
import json

from collections import defaultdict
import numpy as np

from torch.utils import data as torch_data

In [None]:
%%capture
#evaluate graph convolutional model

with open("models/HIV_gin_model_wo_retrain.json", "r") as fin:
    config = json.load(fin)
    config['gpus'] = None
    solver = core.Configurable.load_config_dict(config)
solver.load("models/HIV_gin_model_wo_retrain.pth")

In [4]:
solver.evaluate("valid")
solver.evaluate("test")
print('')

17:09:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:09:05   Evaluate on valid
17:09:18   ------------------------------
17:09:18   accuracy [HIV_active]: 0.96644
17:09:18   auprc [HIV_active]: 0.202821
17:09:18   auroc [HIV_active]: 0.725041
17:09:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:09:18   Evaluate on test
17:09:29   ------------------------------
17:09:29   accuracy [HIV_active]: 0.970345
17:09:29   auprc [HIV_active]: 0.287026
17:09:29   auroc [HIV_active]: 0.747665



In [None]:
%%capture
#evaluate graph convolutional model with pretrain
with open("models/HIV_gin_model.json", "r") as fin:
    config = json.load(fin)
    config['gpus'] = None
    solver = core.Configurable.load_config_dict(config)
solver.load("models/HIV_gin_model.pth")

In [7]:
solver.evaluate("valid")
solver.evaluate("test")
print('')

17:14:27   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:14:27   Evaluate on valid
17:14:38   ------------------------------
17:14:38   accuracy [HIV_active]: 0.96644
17:14:38   auprc [HIV_active]: 0.330238
17:14:38   auroc [HIV_active]: 0.755691
17:14:38   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:14:38   Evaluate on test
17:14:50   ------------------------------
17:14:50   accuracy [HIV_active]: 0.970345
17:14:50   auprc [HIV_active]: 0.375858
17:14:50   auroc [HIV_active]: 0.792228

