# Instalations

In [1]:
!pip install torch torchvision
!pip install networkx
!pip install pysmiles
!curl -L bit.ly/rdkit-colab | tar xz -C /
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-geometric

import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   163  100   163    0     0   3134      0 --:--:-- --:--:-- --:--:--  3134
100   133  100   133    0     0    950      0 --:--:-- --:--:-- --:--:--   950
100   620  100   620    0     0   2707      0 --:--:-- --:--:-- --:--:--  605k
100 29.6M  100 29.6M    0     0  20.1M      0  0:00:01  0:00:01 --:--:-- 26.2M


# Dataset

In [2]:
#https://drive.google.com/file/d/18XqI6VT5dLR-m-jJ7fLtv4Y1Uyc9ZtXP/view?usp=sharing
#https://drive.google.com/file/d/1Se9qKoSHE24HwYhZfQP5xOWio1fvET6m/view?usp=sharing
#https://drive.google.com/file/d/1mN10JygWjyIfEvWP7fro8nHZTLlM2SZJ/view?usp=sharing

!gdown --id "18s872gYgaLXk_sLxySKjsh3FRjNa15L3" # hiv1_hcv dataset
!gdown --id "1Se9qKoSHE24HwYhZfQP5xOWio1fvET6m" # flua_hiv1 dataset
!gdown --id '1mN10JygWjyIfEvWP7fro8nHZTLlM2SZJ' # 

Downloading...
From: https://drive.google.com/uc?id=18s872gYgaLXk_sLxySKjsh3FRjNa15L3
To: /content/hiv1_hcv.csv
5.45GB [00:51, 106MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1Se9qKoSHE24HwYhZfQP5xOWio1fvET6m
To: /content/flua_hiv1_fixed.csv
3.09GB [00:36, 85.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1mN10JygWjyIfEvWP7fro8nHZTLlM2SZJ
To: /content/flua_hcv_fixed.csv
4.75GB [00:57, 70.2MB/s]


In [3]:
import torch
import networkx as nx
import pandas as pd
import torch.utils.data as data
import numpy as np
import os, glob

from pysmiles import read_smiles
from rdkit import Chem
from rdkit.Chem import Draw
from torch.utils.data import dataloader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold, KFold
from torch_geometric.utils.convert import from_networkx
from torch_geometric.data import Batch

def mol_to_nx(mol):
    G = nx.Graph()

    for atom in mol.GetAtoms():
        features = [atom.GetAtomicNum(),
                   atom.GetMass(),
                   atom.GetFormalCharge(),
                   hybridization_encoding(atom.GetHybridization()),
                   atom.GetNumExplicitHs(),
                   atom.GetExplicitValence(), #explicit valence (including Hs)
                   atom.GetNumRadicalElectrons(),
                   (1 if atom.GetIsAromatic() else 0)]
        G.add_node(atom.GetIdx(),
                    x=features)
    for bond in mol.GetBonds():
        G.add_edge(bond.GetBeginAtomIdx(),
                   bond.GetEndAtomIdx(),
                   bond_type=bond.GetBondType())
    return G

def hybridization_encoding(hybridization):
    if hybridization == Chem.HybridizationType.S:
        return 1
    if hybridization == Chem.HybridizationType.SP:
        return 2
    if hybridization == Chem.HybridizationType.SP2:
        return 3
    if hybridization == Chem.HybridizationType.SP3:
        return 4
    if hybridization == Chem.HybridizationType.SP3D:
        return 5
    if hybridization == Chem.HybridizationType.SP3D2:
        return 6
    if hybridization == Chem.HybridizationType.OTHER:
        return 7

In [4]:
def read_data(data_path):
    data = None
    if data_path.endswith('.csv'):
        try:
            data = pd.read_csv(data_path)
        except ValueError:
            print('ValueError')

    return data


def train_validation_split(data_path):
    if os.path.isdir(data_path):
        train_path = os.path.join(data_path, 'train.csv')
        val_path = os.path.join(data_path, 'val.csv')
    else:
        train_path = data_path.split('.')[0] + '_' + 'train.csv'
        val_path = data_path.split('.')[0] + '_' + 'val.csv'
    if os.path.exists(train_path) and os.path.exists(val_path):

        return pd.read_csv(train_path), pd.read_csv(val_path)

    data = read_data(data_path)
    train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
    train_data.to_csv(train_path, index=False)
    val_data.to_csv(val_path, index=False)

    return train_data, val_data

In [5]:
class ANYDataset(data.Dataset):

    def __init__(self, data, infer=False):

        if isinstance(data, pd.DataFrame):
            self.data = data
        elif isinstance(data, str):
            self.data = read_data(data)

        self.NON_MORD_NAMES = ['smiles', 'active']
        
        scl = StandardScaler()

        self.mord_ft = scl.fit_transform(
            self.data.drop(columns=self.NON_MORD_NAMES).astype(np.float64)).tolist()

        self.graphs = [Chem.MolFromSmiles(s) for s in self.data['smiles'].values.tolist()]
        self.graphs = [from_networkx(mol_to_nx(g)) for g in self.graphs]
        self.label = self.data['active'].values.tolist()
        

    def __len__(self):

        return len(self.graphs)

    def __getitem__(self, idx):

        return self.graphs[idx], self.mord_ft[idx], self.label[idx]

class Collater(object):

    def __init__(self, follow_batch=[], exclude_keys=[]):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def collate(self, batch):
        
        graph_batch = Batch.from_data_list(list(zip(*batch))[0], self.follow_batch,
                                           self.exclude_keys)
        md_batch = torch.tensor(list(zip(*batch))[1])
        labels_batch = torch.tensor(list(zip(*batch))[2]).long()
        
        return (graph_batch, md_batch), labels_batch

    def __call__(self, batch):
        return self.collate(batch)

In [6]:
def get_dataset(dataset_type: str):
    if dataset_type == 'flua_hcv':
        path2data = '/content/flua_hcv_fixed.csv'
        vtoi = {'hcv': 0, 'flua': 1}
    if dataset_type == 'hiv1_hcv':
        path2data = '/content/hiv1_hcv.csv'
        vtoi = {'hcv': 0, 'hiv': 1}
    if dataset_type == 'flua_hiv1':
        path2data = '/content/flua_hiv1_fixed.csv'
        vtoi = {'flua': 0, 'hiv1': 1}

    train_data, val_data = train_validation_split(path2data)

    train_data = train_data.drop(columns=['smile_ft'])
    val_data = val_data.drop(columns=['smile_ft'])
    train_data['active'] = train_data['active'].apply(lambda x: vtoi[x])
    val_data['active'] = val_data['active'].apply(lambda x: vtoi[x])
    to_drop = ['Nc1ccc(cc1)[S+]2(=O)[NH2+]c3nccc[n+]3[AgH3-4]O2']
    train_data = train_data[~train_data.smiles.isin(to_drop)]
    val_data = val_data[~val_data.smiles.isin(to_drop)]

    return train_data, val_data

# Define the model


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv, MFConv
from torch_geometric.nn import global_mean_pool

class GraphNet(nn.Module):
    def __init__(self):
        super(GraphNet, self).__init__()

        self.graph_conv1 = MFConv(8, 32)
        self.graph_conv2 = MFConv(32, 32)
        self.graph_conv3 = MFConv(32, 32)

        self.fc1 = nn.Linear(862, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 64)

        self.batch_norm1 = nn.BatchNorm1d(512)
        self.batch_norm2 = nn.BatchNorm1d(128)
        self.batch_norm3 = nn.BatchNorm1d(64)

        self.dropout = nn.Dropout()

        self.linear = nn.Linear(32 + 64, 1)

    def forward(self, x, edge_index, batch, x_md):

        x = self.graph_conv1(x, edge_index)
        x = x.relu()
        x = self.graph_conv2(x, edge_index)
        x = x.relu()
        x = self.graph_conv3(x, edge_index)

        x = global_mean_pool(x, batch)
        
        x_md = F.relu(self.fc1(x_md))
        x_md = self.batch_norm1(x_md)
        x_md = self.dropout(x_md)

        x_md = F.relu(self.fc2(x_md))
        x_md = self.batch_norm2(x_md)
        x_md = self.dropout(x_md)

        x_md = F.relu(self.fc3(x_md))
        x_md = self.batch_norm3(x_md)
        x_md = self.dropout(x_md)

        x = torch.cat([x, x_md], dim=1)

        return torch.sigmoid(self.linear(x))

# Metrics



In [8]:
import numpy as np
import sklearn.metrics as metrics
THRESH = 0.2

def auc(y_true, y_scores):
    y_true = y_true.cpu().detach().numpy()
    y_scores = y_scores.cpu().detach().numpy()
    return metrics.roc_auc_score(y_true, y_scores)


def auc_threshold(y_true, y_scores):
    y_true = y_true.cpu().detach().numpy()
    y_scores = y_scores.cpu().detach().numpy()
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_scores)
    return metrics.auc(fpr, tpr)


def get_score_obj(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + thresh).astype(np.int16)
    return metrics.classification_report(y_true, y_scores, output_dict=True)


def f1(y_true, y_scores):
    score_obj = get_score_obj(y_true, y_scores)
    return score_obj['weighted avg']['f1-score']


def sensitivity(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_scores).ravel()
    return tp / (tp + fn)


def specificity(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_scores).ravel()
    return tn / (tn + fp)


def accuracy(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    return metrics.accuracy_score(y_true, y_scores)


def mcc(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    return metrics.matthews_corrcoef(y_true, y_scores)

def plot_roc_curve(y_true, y_pred, hashcode=''):

    if not os.path.exists('vis/'):
        os.makedirs('vis/')

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
    auc_roc = metrics.roc_auc_score(y_true, y_pred)
    print('AUC: {:4f}'.format(auc_roc))
    plt.plot(fpr, tpr)
    plt.savefig('vis/ROC_{}'.format(hashcode + '.png'))
    plt.clf()  # Clear figure


def plot_precision_recall(y_true, y_pred, hashcode=''):

    if not os.path.exists('vis/'):
        os.makedirs('vis/')

    precisions, recalls, thresholds = precision_recall_curve(y_true, y_pred)
    plt.plot(thresholds, precisions[:-1], label="Precision")
    plt.plot(thresholds, recalls[:-1], label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0, 1])
    plt.savefig('vis/PR_{}'.format(hashcode + '.png'))
    plt.clf()  # Clear figure

# Utils

In [9]:
import os
import pickle
import torch


def get_max_length(x):
    return len(max(x, key=len))


def pad_sequence(seq):
    def _pad(_it, _max_len):
        return [0] * (_max_len - len(_it)) + _it
    padded = [_pad(it, get_max_length(seq)) for it in seq]
    return padded

def create_dir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)


def save_pickle(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)


def read_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def save_model(model, model_dir_path, hash_code):
    if not os.path.exists(model_dir_path):
        os.makedirs(model_dir_path)
    torch.save(model.state_dict(), "{}/model_{}_{}".format(model_dir_path, hash_code, "BEST"))
    print('Save done!')

# Train loop

In [10]:
import argparse
import torch
import torch.nn as nn
import tensorboard_logger
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import warnings

from sklearn.metrics import precision_recall_curve
from torch.utils.data import dataloader

plt.switch_backend('agg')
warnings.filterwarnings('ignore')

models_path = '/content'

def train_validate(train_dataset,
                          val_dataset,
                          train_device,
                          val_device,
                          opt_type,
                          n_epoch,
                          batch_size,
                          metrics,
                          hash_code,
                          lr):
    train_loader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=batch_size, 
                                               shuffle=True,
                                               collate_fn=Collater())
    val_loader = torch.utils.data.DataLoader(val_dataset, 
                                             batch_size=batch_size, 
                                             shuffle=False,
                                             collate_fn=Collater())
    device = 'cuda:0'
    
    try:
        tensorboard_logger.configure('logs/' + hash_code)
    except:
        pass

    criterion = nn.BCELoss()
    model = GraphNet().to(device)

    if opt_type == 'sgd':
        opt = optim.SGD(model.parameters(),
                        lr=lr,
                        momentum=0.99)
        
    elif opt_type == 'adam':
        opt = optim.Adam(model.parameters(),
                         lr=lr)
        
    scheduler = optim.lr_scheduler.MultiStepLR(opt, [40, 80], gamma=0.1)

    min_loss = 100
    early_stop_count = 0

    for e in range(n_epoch):
        train_losses = []
        val_losses = []
        train_outputs = []
        val_outputs = []
        train_labels = []
        val_labels = []

        print(e, '--', 'TRAINING ==============>')

        for i, ((mol_graph, md), label) in enumerate(train_loader):
            model.train()

            mol_graph.x = mol_graph.x.float().to(device)
            label = label.float().to(train_device)
            mol_graph.edge_index = mol_graph.edge_index.to(train_device)
            mol_graph.batch = mol_graph.batch.to(train_device)
            md = md.to(device)

            # Forward
            opt.zero_grad()
            outputs = model(mol_graph.x, mol_graph.edge_index, mol_graph.batch, md)
            outputs = torch.squeeze(outputs)
            
            loss = criterion(outputs, label)
            train_losses.append(float(loss.item()))
            train_outputs.extend(outputs)
            train_labels.extend(label)

            # Parameters update
            loss.backward()
            opt.step()
            scheduler.step()

        # Validate after each epoch
        print('EPOCH', e, '--', 'VALIDATION ==============>')
        for i, ((mol_graph, md), label) in enumerate(val_loader):
            model.eval()

            mol_graph.x = mol_graph.x.float().to(device)
            label = label.float().to(train_device)
            mol_graph.edge_index = mol_graph.edge_index.to(train_device)
            mol_graph.batch = mol_graph.batch.to(train_device)
            md = md.to(device)
            
            with torch.no_grad():
                outputs = model(mol_graph.x, mol_graph.edge_index, mol_graph.batch, md)
                outputs = torch.squeeze(outputs)
                
                loss = criterion(outputs, label)
                val_losses.append(float(loss.item()))
                val_outputs.extend(outputs)
                val_labels.extend(label)

        train_outputs = torch.stack(train_outputs)
        val_outputs = torch.stack(val_outputs)
        train_labels = torch.stack(train_labels)
        val_labels = torch.stack(val_labels)
        tensorboard_logger.log_value('train_loss', sum(train_losses) / len(train_losses), e + 1)
        tensorboard_logger.log_value('val_loss', sum(val_losses) / len(val_losses), e + 1)

        print('{"metric": "train_loss", "value": %f, "epoch": %d}' % (sum(train_losses) / len(train_losses), e + 1))
        print('{"metric": "val_loss", "value": %f, "epoch": %d}' % (sum(val_losses) / len(val_losses), e + 1))
        
        for key in metrics.keys():
            train_metric = metrics[key](train_labels, train_outputs)
            val_metric = metrics[key](val_labels, val_outputs)

            print('{"metric": "%s", "value": %f, "epoch": %d}' % ('train_' + key, train_metric, e + 1))
            print('{"metric": "%s", "value": %f, "epoch": %d}' % ('val_' + key, val_metric, e + 1))

            tensorboard_logger.log_value('train_{}'.format(key),
                                         train_metric, e + 1)
            tensorboard_logger.log_value('val_{}'.format(key),
                                         val_metric, e + 1)
            
        loss_epoch = sum(val_losses) / len(val_losses)

        if loss_epoch < min_loss:
            early_stop_count = 0
            min_loss = loss_epoch
            save_model(model, models_path, hash_code)
        else:
            early_stop_count += 1
            if early_stop_count > 30:
                print('Traning can not improve from epoch {}\tBest loss: {}'.format(e, min_loss))
                break

    train_metrics = {}
    val_metrics = {}

    for key in metrics.keys():
        train_metrics[key] = metrics[key](train_labels, train_outputs)
        val_metrics[key] = metrics[key](val_labels, val_outputs)

    return train_metrics, val_metrics


def predict(dataset, model_path, device='cpu'):
    
    loader = torch.utils.data.DataLoader(val_dataset, 
                                             batch_size=128, 
                                             shuffle=False,
                                             collate_fn=Collater())
    model = GraphNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    # EVAL_MODE

    model.eval()
    probas = []
    for i, ((mol_graph, md), label) in enumerate(loader):
        with torch.no_grad():
            mol_graph = mol_graph.to(device)
            # Forward to get smiles and equivalent weights
            proba = model(mol_graph.x.to(device), 
                          mol_graph.edge_index.to(device), 
                          mol_graph.batch.to(device), md.to(device)).cpu()
            probas.append(proba)
    print('Forward done !!!')
    probas = np.concatenate(probas)
    return probas

# Train

In [11]:
#Hashcode for tf.events
hashcode = 'TEST'

train_data, val_data = get_dataset('flua_hiv1')
train_dataset = ANYDataset(train_data)
val_dataset = ANYDataset(val_data)

if torch.cuda.is_available():
    train_device = 'cuda'
    val_device = 'cuda'
else:
    train_device = 'cpu'
    val_device = 'cpu'

train_validate(train_dataset,
                  val_dataset,
                  train_device,
                  val_device,
                  'adam', #Optimizer adam ('adam') or sgd ('sgd')
                  int(500), #Number of epochs
                  int(128), #Batch size
                  {'sensitivity': sensitivity, 'specificity': specificity,
                    'accuracy': accuracy, 'mcc': mcc, 'auc': auc},
                  hashcode, #Hashcode for tf.events
                  1e-2) #Learning rate

{"metric": "train_loss", "value": 0.592743, "epoch": 1}
{"metric": "val_loss", "value": 0.556198, "epoch": 1}
{"metric": "train_sensitivity", "value": 0.970869, "epoch": 1}
{"metric": "val_sensitivity", "value": 0.992692, "epoch": 1}
{"metric": "train_specificity", "value": 0.087811, "epoch": 1}
{"metric": "val_specificity", "value": 0.042009, "epoch": 1}
{"metric": "train_accuracy", "value": 0.521533, "epoch": 1}
{"metric": "val_accuracy", "value": 0.515237, "epoch": 1}
{"metric": "train_mcc", "value": 0.124511, "epoch": 1}
{"metric": "val_mcc", "value": 0.111709, "epoch": 1}
{"metric": "train_auc", "value": 0.746193, "epoch": 1}
{"metric": "val_auc", "value": 0.773776, "epoch": 1}
Save done!
{"metric": "train_loss", "value": 0.562851, "epoch": 2}
{"metric": "val_loss", "value": 0.546578, "epoch": 2}
{"metric": "train_sensitivity", "value": 0.976465, "epoch": 2}
{"metric": "val_sensitivity", "value": 0.988308, "epoch": 2}
{"metric": "train_specificity", "value": 0.086977, "epoch": 2}


KeyboardInterrupt: ignored

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs