# Prepare dataset

In [4]:
#https://drive.google.com/file/d/1r_hC1sa1H2Ht3ZGlp-rkWbcN1a-CIr8I/view?usp=sharing
# https://drive.google.com/file/d/1BOzwXO0sYPmsoMMPGpDiJxuZbafKGe9N/view?usp=sharing
!gdown --id "1BOzwXO0sYPmsoMMPGpDiJxuZbafKGe9N"

Downloading...
From: https://drive.google.com/uc?id=1BOzwXO0sYPmsoMMPGpDiJxuZbafKGe9N
To: /content/hiv1_hcv.csv
5.45GB [01:45, 51.8MB/s]


In [5]:
path2data = 'hiv1_hcv.csv'

In [6]:
'''
use subset of the original hiv1_hcv if you don't have 20Gb RAM!
'''

import pandas as pd
subset = pd.read_csv(path2data).sample(frac=0.1)

subset['active'] = subset['active'].replace('hcv',0)
subset['active'] = subset['active'].replace('hiv1',1)
# df['active'] = df['active'].replace('flua',2)

subset.to_csv('subset_hiv1_hcv.csv', index=False)




del subset
path2data = 'subset_hiv1_hcv.csv'

In [38]:
device = 'cuda'

In [7]:
#dataset.py

import pandas as pd
import torch.utils.data as data
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold, KFold
import numpy as np
import os, glob
import json


def read_data(data_path):
    data = None
    if data_path.endswith('.csv'):
        try:
            #data = pd.read_json(data_path, lines=True, nrows=100, chunksize=1000)
            data = pd.read_csv(data_path)
        except ValueError:
            print('ValueError')
            #data = pd.read_json(data_path)
    #if data_path.endswith('.zip'):
        #try:
            #data = pd.read_json(data_path, compression='zip', lines=True)
        #except ValueError:
            #data = pd.read_json(data_path, compression='zip')
    return data


def train_validation_split(data_path):
    if os.path.isdir(data_path):
        train_path = os.path.join(data_path, 'train.csv')
        val_path = os.path.join(data_path, 'val.csv')
    else:
        train_path = data_path.split('.')[0] + '_' + 'train.csv'
        val_path = data_path.split('.')[0] + '_' + 'val.csv'
    if os.path.exists(train_path) and os.path.exists(val_path):
        # return read_data(train_path), read_data(val_path)
        return pd.read_csv(train_path), pd.read_csv(val_path)
    data = read_data(data_path)
    train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
    print('Train')
    print(train_data.shape)
    train_data.to_csv(train_path, index=False)
    print('Val')
    print(val_data.shape)
    val_data.to_csv(val_path, index=False)
    return train_data, val_data


def train_cross_validation_split(data_path):
    dir_path = os.path.dirname(os.path.abspath(data_path))
    fold_dirs = glob.glob(os.path.join(dir_path, 'folds_*'))
    if len(fold_dirs) == 5:
        for fold_dir in fold_dirs:
            train_path = os.path.join(fold_dir, 'train.csv')
            val_path = os.path.join(fold_dir, 'val.csv')
            yield pd.read_csv(train_path), pd.read_csv(val_path)
    else:
        kfold = KFold(n_splits=5, shuffle=True, random_state=42)
        data = read_data(data_path)
        for i, (train_ids, val_ids) in enumerate(kfold.split(X=data.drop('active', axis=1).values,
                                                             y=data['active'].values)):
            train_data = data.iloc[train_ids, :]
            val_data = data.iloc[val_ids, :]
            # os.makedirs(os.path.join(dir_path, 'folds_{}'.format(i)), exist_ok=True)
            # train_data.to_json(os.path.join(os.path.join(dir_path, 'folds_{}'.format(i)), 'train.json'))
            # val_data.to_json(os.path.join(os.path.join(dir_path, 'folds_{}'.format(i)), 'val.json'))

            yield train_data, val_data


class ANYDataset(data.Dataset):
    def __init__(self, data, infer=False):
        if isinstance(data, pd.DataFrame):
            self.data = data
        elif isinstance(data, str):
            self.data = read_data(data)
        #self.NON_MORD_NAMES = ['smile_ft', 'id', 'subset', 'quinazoline', 'pyrimidine', 'smiles', 'active']
        self.NON_MORD_NAMES = ['smile_ft', 'smiles', 'active']
        self.infer = infer

        # Standardize mord features
        scl = StandardScaler()
        self.mord_ft = scl.fit_transform(self.data.drop(columns=self.NON_MORD_NAMES).astype(np.float64)).tolist()
        self.non_mord_ft_temp = self.data['smile_ft'].values.tolist()
        self.non_mord_ft = []
        for i in range(len(self.non_mord_ft_temp)):
          self.non_mord_ft.append(json.loads(self.non_mord_ft_temp[i]))
        self.smiles = self.data['smiles'].values.tolist()
        self.label = self.data['active'].values.tolist()

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        if self.infer:
            return self.smiles[idx], self.mord_ft[idx], self.non_mord_ft[idx], self.label[idx]
        else:
            return self.mord_ft[idx], self.non_mord_ft[idx], self.label[idx]

    def get_dim(self, ft):
        if ft == 'non_mord':
            return len(self.non_mord_ft[0])
        if ft == 'mord':
            return len(self.mord_ft[0])

    def get_smile_ft(self):
        return self.non_mord_ft

In [8]:
train_data, val_data = train_validation_split(path2data)

Train
(8715, 865)
Val
(2179, 865)


In [47]:
train_dataset = ANYDataset(train_data, True)
val_dataset = ANYDataset(val_data, True)

In [10]:
train_dataset[0]

('O=C(Nc1ccccc1)C2CCN(CC2)S(=O)(=O)c3ccc4OCCOc4c3',
 [-0.12477677497720815,
  0.16997366858129753,
  0.19944471464629376,
  0.044969534265017326,
  -0.2515138661932105,
  -0.15929467829346314,
  -0.095745712764554,
  0.13034938319307443,
  0.2522470364622163,
  -0.17062190595282817,
  0.7820400988084859,
  0.0636185085151905,
  0.48885857108214104,
  -0.690802506655947,
  -0.5867501053463648,
  0.7249459014764524,
  0.15532818814395566,
  0.12445747414584754,
  0.35492081797602826,
  0.0750212618337784,
  0.11553023164584444,
  0.44527709845617447,
  -0.17376866165045968,
  -0.09824097196419096,
  -0.1087649975641443,
  0.06397351858782278,
  -0.1403196626314968,
  0.420233608446911,
  0.5618025822911455,
  -0.1911834242348009,
  -0.14192439068512502,
  -0.011475957513711645,
  -0.23641333721835917,
  -0.4174944181534245,
  0.21971575560739257,
  -0.19006305795893863,
  -0.4137313913315446,
  -0.05387300997095499,
  0.24129816162349682,
  -0.5637374811561482,
  -0.24395225103135704,
  

In [36]:
import sys
import os
import requests
import subprocess
import shutil
from logging import getLogger, StreamHandler, INFO


logger = getLogger(__name__)
logger.addHandler(StreamHandler())
logger.setLevel(INFO)


def install(
        chunk_size=4096,
        file_name="Miniconda3-latest-Linux-x86_64.sh",
        url_base="https://repo.continuum.io/miniconda/",
        conda_path=os.path.expanduser(os.path.join("~", "miniconda")),
        rdkit_version=None,
        add_python_path=True,
        force=False):
    """install rdkit from miniconda
    ```
    import rdkit_installer
    rdkit_installer.install()
    ```
    """

    python_path = os.path.join(
        conda_path,
        "lib",
        "python{0}.{1}".format(*sys.version_info),
        "site-packages",
    )

    if add_python_path and python_path not in sys.path:
        logger.info("add {} to PYTHONPATH".format(python_path))
        sys.path.append(python_path)

    if os.path.isdir(os.path.join(python_path, "rdkit")):
        logger.info("rdkit is already installed")
        if not force:
            return

        logger.info("force re-install")

    url = url_base + file_name
    python_version = "{0}.{1}.{2}".format(*sys.version_info)

    logger.info("python version: {}".format(python_version))

    if os.path.isdir(conda_path):
        logger.warning("remove current miniconda")
        shutil.rmtree(conda_path)
    elif os.path.isfile(conda_path):
        logger.warning("remove {}".format(conda_path))
        os.remove(conda_path)

    logger.info('fetching installer from {}'.format(url))
    res = requests.get(url, stream=True)
    res.raise_for_status()
    with open(file_name, 'wb') as f:
        for chunk in res.iter_content(chunk_size):
            f.write(chunk)
    logger.info('done')

    logger.info('installing miniconda to {}'.format(conda_path))
    subprocess.check_call(["bash", file_name, "-b", "-p", conda_path])
    logger.info('done')

    logger.info("installing rdkit")
    subprocess.check_call([
        os.path.join(conda_path, "bin", "conda"),
        "install",
        "--yes",
        "-c", "rdkit",
        "python=={}".format(python_version),
        "rdkit" if rdkit_version is None else "rdkit=={}".format(rdkit_version)])
    logger.info("done")

    import rdkit
    logger.info("rdkit-{} installation finished!".format(rdkit.__version__))


if __name__ == "__main__":
    install()

add /root/miniconda/lib/python3.7/site-packages to PYTHONPATH
python version: 3.7.10
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing rdkit
done
rdkit-2020.09.1 installation finished!


# Define the model


In [39]:
import torch.nn as nn

class MyEncoder(nn.Module):
    def __init__(self, hidden=64, heads=4, layers = 4, max_len = 250, vocab_size=45):
        super(MyEncoder,self).__init__()
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden, nhead=heads,dim_feedforward=64, dropout=0.1)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers)
        
        self.pos_embed = nn.Embedding(max_len+2, hidden)
        
        self.tok_embed = nn.Embedding(vocab_size, hidden)
        
        self.max_len = max_len
        
    def forward(self, input, mask=None, bracket_mask=None):
        tok_emb = self.tok_embed(input)  # (T,B,H)
        # print(tok_emb)

        # bracket_mask = None
        # mask = None
        if bracket_mask is None:
          pos = torch.arange(0., self.max_len+2, dtype=torch.int).unsqueeze(1).to(device)
        else:
          
          pos = torch.cumsum(bracket_mask, dim=0).to(device)
          
          # print(pos)
        pos_emb = self.pos_embed(pos) # (T,B,H)
        
        emb = tok_emb + pos_emb
        
        hidden = self.transformer_encoder(emb,src_key_padding_mask = mask) # (T,B,H)
        
        return hidden

import torch
from pretrained_transformer.pretrain_trfm import TrfmSeq2seq
from pretrained_transformer.build_vocab import WordVocab
from pretrained_transformer.utils import split

max_length = 250 # TODO dynamic padding

pad_index = 0
unk_index = 1
eos_index = 2
sos_index = 3
mask_index = 4

def get_inputs(sm):
    seq_len = max_length+2
    sm = sm.split()
#     print(sm)
    if len(sm)>max_length:
        print('SMILES is too long ({:d})'.format(len(sm)))
        sm = sm[:max_length//2]+sm[-max_length//2:]
    ids = [vocab.stoi.get(token, unk_index) for token in sm]
#     print(ids)
    ids = [sos_index] + ids + [eos_index]

    # Fill seg with zeros where brackets (try to exclude branches) 
    # TODO vectorize
    in_bracket = 0
    branches = []
    for token in sm:
      if token == '(':
        in_bracket+=1
      elif token == ')':
        in_bracket-=1

      if not in_bracket and token != '(' and token != ')':
        branches.append(1)
      else:
        branches.append(0)

    branches = [1] + branches + [1]
    seg = [True]*len(ids)
    padding = [pad_index]*(seq_len - len(ids))
    bool_padding = [False]*(seq_len - len(ids))
    ids.extend(padding), seg.extend(bool_padding), branches.extend(padding)
    return ids, seg, branches

def get_array(smiles):
    x_id, x_seg, branches = [], [], []
    for sm in smiles:
        #print(sm)
        a,b,c = get_inputs(split(sm))
        x_id.append(a)
        x_seg.append(b)
        branches.append(c)
    return torch.tensor(x_id), torch.tensor(x_seg), torch.tensor(branches)

vocab = WordVocab.load_vocab('pretrained_transformer/vocab.pkl')
trfm = MyEncoder().to(device)
print('Total parameters:', sum(p.numel() for p in trfm.parameters()))

Total parameters: 119872


In [59]:
#nets.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import os


class UnitedNet(nn.Module):
    def __init__(self, dense_dim, smiles_len, use_mord=True, use_mat=True, infer=False, dir_path=None, vis_thresh=0.2):
        super(UnitedNet, self).__init__()
        self.use_mord = use_mord
        self.use_mat = use_mat
        self.infer = infer
        self.vis_thresh = vis_thresh
        self.dir_path = dir_path
        self.smiles_len = smiles_len
        
        if self.dir_path:
            self.smile_out_f = open(os.path.join(self.dir_path, 'smiles.txt'), 'w')
            self.weight_f = open(os.path.join(self.dir_path, 'weight.txt'), 'w')
        
        self.trfm = MyEncoder()#.to(device)#trfm.to(device)
            

        self.trfm_conv1 = nn.Conv1d(64, 64, kernel_size=3,padding=1)
        self.trfm_pool = nn.MaxPool1d(3)
        # self.trfm_pool = nn.Identity()
        self.trfm_conv2 = nn.Conv1d(64, 128, kernel_size=3,padding=1)
        self.trfm_fc = nn.Linear(128 * 28, self.smiles_len)
        # self.trfm_fc = nn.Linear(128 * 252, 150)

        self.trfm_dropout = nn.Dropout(0.05)
        
#         # PARAMS FOR CNN NET
#         # Convolutionals
#         self.conv_conv1 = nn.Conv1d(42, 64, kernel_size=5, padding=2)
#         self.conv_pool = nn.MaxPool1d(5)
#         self.conv_conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
#         self.relu = nn.ReLU()

#         # Fully connected
        
#         self.conv_fc = nn.Linear(128*12, self.smiles_len)#self.smiles_len // 5 // 5 , 120)

        # Batch norms
        self.conv_batch_norm1 = nn.BatchNorm1d(64)
        self.conv_batch_norm2 = nn.BatchNorm1d(128)
        # # PARAMS FOR CNN NET
        # # Convolutionals
        # self.conv_conv1 = nn.Conv2d(1, 6, kernel_size=3) #smiles_len-2
        # self.conv_pool = nn.MaxPool2d(2, 2) #smiles_len/2
        # self.conv_conv2 = nn.Conv2d(6, 16, kernel_size=3) #smiles_len-2

        # # Fully connected
        # self.conv_fc = nn.Linear(16 * 9 * int(((self.smiles_len-2)/2 - 2)/2), self.smiles_len)

        # # Batch norms
        # self.conv_batch_norm1 = nn.BatchNorm2d(6)
        # self.conv_batch_norm2 = nn.BatchNorm2d(16)

        # PARAMS FOR DENSE NET
        # Fully connected
        if self.use_mord:
            self.dense_fc1 = nn.Linear(dense_dim, 512)
            self.dense_fc2 = nn.Linear(512, 128)
            self.dense_fc3 = nn.Linear(128, 64)

            # Batch norms
            self.dense_batch_norm1 = nn.BatchNorm1d(512)
            self.dense_batch_norm2 = nn.BatchNorm1d(128)
            self.dense_batch_norm3 = nn.BatchNorm1d(64)

            # Dropouts
            self.dense_dropout = nn.Dropout()

        # PARAMS FOR ATTENTION NET
        if self.use_mat:
            #self.att_fc = nn.Linear(256, 1)
            self.att_fc = nn.Linear(self.smiles_len+42+64, 1)
        else:
            self.comb_fc_alt = nn.Linear(128, 1)

        # PARAMS FOR COMBINED NET
        if self.use_mord:
            self.comb_fc = nn.Linear(self.smiles_len+64, 1)
        else:
            self.comb_fc = nn.Linear(self.smiles_len, 1)

    def forward(self, x_non_mord, x_mord, x_mat, smiles=None, mask=None, branch_mask=None):
        
        trfm_x = self.trfm(torch.t(smiles), mask, torch.t(branch_mask)) 
        x = torch.transpose(trfm_x, 0,1) # (B,T,H)
        
#                 trfm_x*=mask[:,:,None]
#         print(trfm_x)
        
        x = torch.transpose(x, -1,-2) # (B,R,T)
        #print(x.shape)
        x = self.trfm_conv1(x) # (B, 64, T)
        #print(x.shape)
        x = F.relu(x)
        x = self.trfm_pool(x)
        #print(x.shape)
        x = self.trfm_conv2(x) # (B, 64, T)
        #print(x.shape)
        x = F.relu(x)
        x = self.trfm_pool(x)
        # print(x.view(x.size(0), -1).shape)
        # print(x.shape)
        x = self.trfm_dropout(self.trfm_fc(x.view(x.size(0), -1)))
        if self.use_mat:
            x_non_mord = F.sigmoid(x)
        else:
            x_non_mord = F.relu(x)
#         # FORWARD CNN
#         x_non_mord = torch.transpose(x_non_mord, -1,-2)

#         x_non_mord = self.conv_conv1(x_non_mord)
#         x_non_mord = self.conv_batch_norm1(x_non_mord)
#         x_non_mord = F.relu(x_non_mord)
#         x_non_mord = self.conv_pool(x_non_mord)

#         x_non_mord = self.conv_conv2(x_non_mord)
#         x_non_mord = self.conv_batch_norm2(x_non_mord)
#         x_non_mord = F.relu(x_non_mord)
#         x_non_mord = self.conv_pool(x_non_mord)

#         # print(x_non_mord.shape)
#         x_non_mord = x_non_mord.view(x_non_mord.size(0), -1)
#         if self.use_mat:
#             x_non_mord = F.sigmoid(self.conv_fc(x_non_mord))
#         else:
#             x_non_mord = F.relu(self.conv_fc(x_non_mord))

        # FORWARD DENSE
        if self.use_mord:
            x_mord = F.relu(self.dense_fc1(x_mord))
            x_mord = self.dense_batch_norm1(x_mord)
            x_mord = self.dense_dropout(x_mord)

            x_mord = F.relu(self.dense_fc2(x_mord))
            x_mord = self.dense_batch_norm2(x_mord)
            x_mord = self.dense_dropout(x_mord)

            x_mord = F.relu(self.dense_fc3(x_mord))
            x_mord = self.dense_batch_norm3(x_mord)
            x_mord = self.dense_dropout(x_mord)

        # FORWARD ATTENTION
        if self.use_mat:
            x_mat = torch.bmm(x_mat.permute(0, 2, 1), x_non_mord.unsqueeze(-1)).squeeze(-1)
            x_mat = torch.cat([x_mat, x_non_mord], dim=1)

            if self.use_mord:
                x_comb = torch.cat([x_mat, x_mord], dim=1)
                probs = torch.sigmoid(self.att_fc(x_comb))
                if self.infer:
                    if not smiles:
                        raise ValueError('Please input smiles')
                    alphas = x_comb.cpu().detach().numpy().tolist()
                    alphas = ["\t".join([str(round(elem, 4)) for elem in seq]) for seq in alphas]
                    prob_list = probs.cpu().detach().numpy().tolist()
                    for smile, alpha, prob in zip(smiles, alphas, prob_list):
                        if prob[0] > self.vis_thresh:
                            self.weight_f.write(alpha + '\n')
                            self.smile_out_f.write(smile + '\n')
                return probs
            else:
                return torch.sigmoid(self.comb_fc(x_mat))
        else:
            if self.use_mord:
                x_comb = torch.cat([x_non_mord, x_mord], dim=1)
            else:
                x_comb = x_non_mord
            # print(x_comb.shape)
            return torch.sigmoid(self.comb_fc(x_comb))

    def __del__(self):
        print('Closing files ...')
        if hasattr(self, 'weight_f'):
            self.weight_f.close()
        if hasattr(self, 'smile_out_f'):
            self.smile_out_f.close()

# Metrics + utils



In [41]:
# metrics.py

import numpy as np
import sklearn.metrics as metrics
THRESH = 0.2


def auc(y_true, y_scores):
    y_true = y_true.cpu().detach().numpy()
    y_scores = y_scores.cpu().detach().numpy()
    return metrics.roc_auc_score(y_true, y_scores)


def auc_threshold(y_true, y_scores):
    y_true = y_true.cpu().detach().numpy()
    y_scores = y_scores.cpu().detach().numpy()
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_scores)
    return metrics.auc(fpr, tpr)


def get_score_obj(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + thresh).astype(np.int16)
    return metrics.classification_report(y_true, y_scores, output_dict=True)


def f1(y_true, y_scores):
    score_obj = get_score_obj(y_true, y_scores)
    return score_obj['weighted avg']['f1-score']

# Metrics for benchmark


def sensitivity(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_scores).ravel()
    return tp / (tp + fn)


def specificity(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_scores).ravel()
    return tn / (tn + fp)


def accuracy(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    return metrics.accuracy_score(y_true, y_scores)


def mcc(y_true, y_scores, thresh=THRESH):
    y_true = y_true.cpu().detach().numpy()
    y_scores = (y_scores.cpu().detach().numpy() + 1 - thresh).astype(np.int16)
    return metrics.matthews_corrcoef(y_true, y_scores)

# METRICS FOR CV


def auc_cv(y_true, y_scores):
    return metrics.roc_auc_score(y_true, y_scores)


def get_score_obj_cv(y_true, y_scores, thresh=THRESH):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    y_scores = (y_scores + 1 - thresh).astype(np.int16)
    return metrics.classification_report(y_true, y_scores, output_dict=True)


def f1_cv(y_true, y_scores):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    score_obj = get_score_obj_cv(y_true, y_scores)
    return score_obj['weighted avg']['f1-score']


def class1_precision_cv(y_true, y_scores):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    score_obj = get_score_obj_cv(y_true, y_scores)
    return score_obj['1.0']['precision']


def class1_recall_cv(y_true, y_scores):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    score_obj = get_score_obj_cv(y_true, y_scores)
    return score_obj['1.0']['recall']


def sensitivity_cv(y_true, y_scores, thresh=THRESH):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    y_scores = (y_scores + 1 - thresh).astype(np.int16)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_scores).ravel()
    return tp / (tp + fn)


def specificity_cv(y_true, y_scores, thresh=THRESH):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    y_scores = (y_scores + 1 - thresh).astype(np.int16)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_scores).ravel()
    return tn / (tn + fp)


def accuracy_cv(y_true, y_scores, thresh=THRESH):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    y_scores = (y_scores + 1 - thresh).astype(np.int16)
    return metrics.accuracy_score(y_true, y_scores)


def mcc_cv(y_true, y_scores, thresh=THRESH):
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)
    y_scores = (y_scores + 1 - thresh).astype(np.int16)
    return metrics.matthews_corrcoef(y_true, y_scores)

In [42]:
#utils.py

import os
import pickle
import torch


def get_max_length(x):
    return len(max(x, key=len))


def pad_sequence(seq):
    def _pad(_it, _max_len):
        return [0] * (_max_len - len(_it)) + _it
    padded = [_pad(it, get_max_length(seq)) for it in seq]
    return padded


def custom_collate(batch):
    transposed = zip(*batch)
    lst = []
    for samples in transposed:
        try:
            if isinstance(samples[0], str) or isinstance(samples[0], unicode):
                lst.append(samples)
        except NameError:
            if isinstance(samples[0], str):
                lst.append(samples)
        if isinstance(samples[0], int):
            lst.append(torch.LongTensor(samples))
        elif isinstance(samples[0], float):
            lst.append(torch.DoubleTensor(samples))
        elif isinstance(samples[0], list):
            lst.append(torch.LongTensor(pad_sequence(samples)))
    return lst


def create_dir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)


def save_pickle(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)


def read_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def save_model(model, model_dir_path, hash_code):
    if not os.path.exists(model_dir_path):
        os.makedirs(model_dir_path)
    torch.save(model.state_dict(), "{}/model_{}_{}".format(model_dir_path, hash_code, "BEST"))
    print('Save done!')

# Train

In [43]:
!pip install tensorboard_logger



In [44]:
#single_run.py

import argparse
import torch
import torch.nn as nn
import tensorboard_logger
from torch.utils.data import dataloader
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import warnings
from sklearn.metrics import precision_recall_curve
plt.switch_backend('agg')
warnings.filterwarnings('ignore')

#saved models
models_path = '/content'

def train_validate_united(train_dataset,
                          val_dataset,
                          train_device,
                          val_device,
                          use_mat,
                          use_mord,
                          opt_type,
                          n_epoch,
                          batch_size,
                          metrics,
                          hash_code,
                          lr):
    train_loader = dataloader.DataLoader(dataset=train_dataset,
                                         batch_size=batch_size,
                                         collate_fn=custom_collate,
                                         shuffle=False)

    val_loader = dataloader.DataLoader(dataset=val_dataset,
                                       batch_size=batch_size,
                                       collate_fn=custom_collate,
                                       shuffle=False)
    
    #tensorflow_logger fix
    #tensorboard_logger.clean_default_logger()
    try:
      tensorboard_logger.configure('logs/' + hash_code)
    except Exception:
      print('tensorboard_logger already configured!')
    

    sm, mord_ft, non_mord_ft, label = next(iter(train_loader))
    smiles_len = int(non_mord_ft.shape[1]/42)

    criterion = nn.BCELoss()
    united_net = UnitedNet(dense_dim=train_dataset.get_dim('mord'), smiles_len=smiles_len,
                           use_mat=use_mat, use_mord=use_mord).to(train_device)

    if opt_type == 'sgd':
        opt = optim.SGD(united_net.parameters(),
                        lr=lr,
                        momentum=0.99)
    elif opt_type == 'adam':
        opt = optim.Adam(united_net.parameters(),
                         lr=lr)

    min_loss = 100  # arbitary large number
    early_stop_count = 0
    for e in range(n_epoch):
        train_losses = []
        val_losses = []
        train_outputs = []
        val_outputs = []
        train_labels = []
        val_labels = []
        print(e, '--', 'TRAINING ==============>')
        for i, (sm, mord_ft, non_mord_ft, label) in enumerate(train_loader):
            
            sm, mask, br_mask = get_array(sm)
            sm = sm.to(val_device)
            mask = mask.to(val_device)
            br_mask = br_mask.to(val_device)
            
            united_net.train()
            mord_ft = mord_ft.float().to(train_device)
            non_mord_ft = non_mord_ft.view((-1, int(non_mord_ft.shape[1]/42), 42)).float().to(train_device)
            mat_ft = non_mord_ft.squeeze(1).float().to(train_device)
            # print(label)
            label = label.float().to(train_device)

            # Forward
            opt.zero_grad()
            outputs = united_net(non_mord_ft, mord_ft, mat_ft, smiles = sm, mask=mask, branch_mask=br_mask)
            
            outputs = torch.squeeze(outputs)
            
            loss = criterion(outputs, label)
            train_losses.append(float(loss.item()))
            train_outputs.extend(outputs)
            train_labels.extend(label)

            # Parameters update
            loss.backward()
            opt.step()

        # Validate after each epoch
        print('EPOCH', e, '--', 'VALIDATION ==============>')
        for i, (sm, mord_ft, non_mord_ft, label) in enumerate(val_loader):
            
            
            united_net.eval()
            
            sm, mask, br_mask = get_array(sm)
            sm = sm.to(val_device)
            mask = mask.to(val_device)
            br_mask = br_mask.to(val_device)
            
            mord_ft = mord_ft.float().to(val_device)
            non_mord_ft = non_mord_ft.view((-1, int(non_mord_ft.shape[1]/42), 42)).float().to(val_device)
            mat_ft = non_mord_ft.squeeze(1).float().to(train_device)
            label = label.float().to(val_device)
            
            with torch.no_grad():
                outputs = united_net(non_mord_ft, mord_ft, mat_ft, smiles=sm, mask=mask, branch_mask=br_mask)
                
                outputs = torch.squeeze(outputs)
                
                loss = criterion(outputs, label)
                val_losses.append(float(loss.item()))
                val_outputs.extend(outputs)
                val_labels.extend(label)

        train_outputs = torch.stack(train_outputs)
        val_outputs = torch.stack(val_outputs)
        train_labels = torch.stack(train_labels)
        val_labels = torch.stack(val_labels)
        tensorboard_logger.log_value('train_loss', sum(train_losses) / len(train_losses), e + 1)
        tensorboard_logger.log_value('val_loss', sum(val_losses) / len(val_losses), e + 1)
        print('{"metric": "train_loss", "value": %f, "epoch": %d}' % (sum(train_losses) / len(train_losses), e + 1))
        print('{"metric": "val_loss", "value": %f, "epoch": %d}' % (sum(val_losses) / len(val_losses), e + 1))
        for key in metrics.keys():
            train_metric = metrics[key](train_labels, train_outputs)
            val_metric = metrics[key](val_labels, val_outputs)
            print('{"metric": "%s", "value": %f, "epoch": %d}' % ('train_' + key, train_metric, e + 1))
            print('{"metric": "%s", "value": %f, "epoch": %d}' % ('val_' + key, val_metric, e + 1))
            tensorboard_logger.log_value('train_{}'.format(key),
                                         train_metric, e + 1)
            tensorboard_logger.log_value('val_{}'.format(key),
                                         val_metric, e + 1)
        loss_epoch = sum(val_losses) / len(val_losses)
        if loss_epoch < min_loss:
            early_stop_count = 0
            min_loss = loss_epoch
            save_model(united_net, models_path, hash_code)
        else:
            early_stop_count += 1
            if early_stop_count > 30:
                print('Traning can not improve from epoch {}\tBest loss: {}'.format(e, min_loss))
                break

    train_metrics = {}
    val_metrics = {}
    for key in metrics.keys():
        train_metrics[key] = metrics[key](train_labels, train_outputs)
        val_metrics[key] = metrics[key](val_labels, val_outputs)

    return train_metrics, val_metrics


def predict(dataset, model_path, device='cpu'):
    loader = dataloader.DataLoader(dataset=dataset,
                                   batch_size=128,
                                   collate_fn=custom_collate,
                                   shuffle=False)
    
    sm, mord_ft, non_mord_ft, label = next(iter(loader))
    smiles_len = int(non_mord_ft.shape[1]/42)

    united_net = UnitedNet(dense_dim=dataset.get_dim('mord'), smiles_len=smiles_len, use_mat=True).to(device)
    united_net.load_state_dict(torch.load(model_path, map_location=device))
    # EVAL_MODE
    united_net.eval()
    probas = []
    for i, (sm, mord_ft, non_mord_ft, label) in enumerate(loader):
        with torch.no_grad():
            
            sm, mask, br_mask = get_array(sm)
            sm = sm.to(val_device)
            mask = mask.to(val_device)
            br_mask = br_mask.to(val_device)
            
            mord_ft = mord_ft.float().to(device)
            non_mord_ft = non_mord_ft.view((-1, int(non_mord_ft.shape[1]/42), 42)).float().to(device)
            mat_ft = non_mord_ft.squeeze(1).float().to(device)
            # Forward to get smiles and equivalent weights
            proba = united_net(non_mord_ft, mord_ft, mat_ft).cpu()
            probas.append(proba)
    print('Forward done !!!')
    probas = np.concatenate(probas)
    return probas


def plot_roc_curve(y_true, y_pred, hashcode=''):

    if not os.path.exists('vis/'):
        os.makedirs('vis/')

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
    auc_roc = metrics.roc_auc_score(y_true, y_pred)
    print('AUC: {:4f}'.format(auc_roc))
    plt.plot(fpr, tpr)
    plt.savefig('vis/ROC_{}'.format(hashcode + '.png'))
    plt.clf()  # Clear figure


def plot_precision_recall(y_true, y_pred, hashcode=''):

    if not os.path.exists('vis/'):
        os.makedirs('vis/')

    precisions, recalls, thresholds = precision_recall_curve(y_true, y_pred)
    plt.plot(thresholds, precisions[:-1], label="Precision")
    plt.plot(thresholds, recalls[:-1], label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0, 1])
    plt.savefig('vis/PR_{}'.format(hashcode + '.png'))
    plt.clf()  # Clear figure


if torch.cuda.is_available():
    train_device = 'cuda'
    val_device = 'cuda'
else:
    train_device = 'cpu'
    val_device = 'cpu'

In [None]:
#Hashcode for tf.events
hashcode = 'TEST'

train_validate_united(train_dataset,
                      val_dataset,
                      train_device,
                      val_device,
                      False, #Use mat feature (True) or not (False)
                      False, #Use mord feature (True) or not (False)
                      'adam', #Optimizer adam ('adam') or sgd ('sgd')
                      int(500), #Number of epochs
                      int(128), #Batch size
                      {'sensitivity': sensitivity, 'specificity': specificity,
                        'accuracy': accuracy, 'mcc': mcc, 'auc': auc},
                      hashcode, #Hashcode for tf.events
                      1e-5) #Learning rate

tensorboard_logger already configured!
Closing files ...
Closing files ...
Closing files ...
Closing files ...
{"metric": "train_loss", "value": 0.605204, "epoch": 1}
{"metric": "val_loss", "value": 0.605946, "epoch": 1}
{"metric": "train_sensitivity", "value": 1.000000, "epoch": 1}
{"metric": "val_sensitivity", "value": 1.000000, "epoch": 1}
{"metric": "train_specificity", "value": 0.000000, "epoch": 1}
{"metric": "val_specificity", "value": 0.000000, "epoch": 1}
{"metric": "train_accuracy", "value": 0.269650, "epoch": 1}
{"metric": "val_accuracy", "value": 0.272602, "epoch": 1}
{"metric": "train_mcc", "value": 0.000000, "epoch": 1}
{"metric": "val_mcc", "value": 0.000000, "epoch": 1}
{"metric": "train_auc", "value": 0.491777, "epoch": 1}
{"metric": "val_auc", "value": 0.500636, "epoch": 1}
Save done!
{"metric": "train_loss", "value": 0.582091, "epoch": 2}
{"metric": "val_loss", "value": 0.603345, "epoch": 2}
{"metric": "train_sensitivity", "value": 1.000000, "epoch": 2}
{"metric": "v

In [None]:
#single_run.py

import argparse
import torch
import torch.nn as nn
import tensorboard_logger
from torch.utils.data import dataloader
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import warnings
from sklearn.metrics import precision_recall_curve
plt.switch_backend('agg')
warnings.filterwarnings('ignore')

#saved models
models_path = '/content'

def train_validate_united(train_dataset,
                          val_dataset,
                          train_device,
                          val_device,
                          use_mat,
                          use_mord,
                          opt_type,
                          n_epoch,
                          batch_size,
                          metrics,
                          hash_code,
                          lr):
    train_loader = dataloader.DataLoader(dataset=train_dataset,
                                         batch_size=batch_size,
                                         collate_fn=custom_collate,
                                         shuffle=False)

    val_loader = dataloader.DataLoader(dataset=val_dataset,
                                       batch_size=batch_size,
                                       collate_fn=custom_collate,
                                       shuffle=False)
    
    #tensorflow_logger fix
    #tensorboard_logger.clean_default_logger()
    try:
      tensorboard_logger.configure('logs/' + hash_code)
    except Exception:
      print('tensorboard_logger already configured!')
    

    sm, mord_ft, non_mord_ft, label = next(iter(train_loader))
    smiles_len = int(non_mord_ft.shape[1]/42)

    criterion = nn.BCELoss()
    united_net = UnitedNet(dense_dim=train_dataset.get_dim('mord'), smiles_len=smiles_len,
                           use_mat=use_mat, use_mord=use_mord).to(train_device)

    if opt_type == 'sgd':
        opt = optim.SGD(united_net.parameters(),
                        lr=lr,
                        momentum=0.99)
    elif opt_type == 'adam':
        opt = optim.Adam(united_net.parameters(),
                         lr=lr)

    min_loss = 100  # arbitary large number
    early_stop_count = 0
    for e in range(n_epoch):
        train_losses = []
        val_losses = []
        train_outputs = []
        val_outputs = []
        train_labels = []
        val_labels = []
        print(e, '--', 'TRAINING ==============>')
        for i, (mord_ft, non_mord_ft, label) in enumerate(train_loader):
            united_net.train()
            mord_ft = mord_ft.float().to(train_device)
            non_mord_ft = non_mord_ft.view((-1, int(non_mord_ft.shape[1]/42), 42)).float().to(train_device)
            mat_ft = non_mord_ft.squeeze(1).float().to(train_device)
            # print(label)
            label = label.float().to(train_device)

            # Forward
            opt.zero_grad()
            outputs = united_net(non_mord_ft, mord_ft, mat_ft)
            
            outputs = torch.squeeze(outputs)
            
            loss = criterion(outputs, label)
            train_losses.append(float(loss.item()))
            train_outputs.extend(outputs)
            train_labels.extend(label)

            # Parameters update
            loss.backward()
            opt.step()

        # Validate after each epoch
        print('EPOCH', e, '--', 'VALIDATION ==============>')
        for i, (mord_ft, non_mord_ft, label) in enumerate(val_loader):
            united_net.eval()
            mord_ft = mord_ft.float().to(val_device)
            non_mord_ft = non_mord_ft.view((-1, int(non_mord_ft.shape[1]/42), 42)).float().to(val_device)
            mat_ft = non_mord_ft.squeeze(1).float().to(train_device)
            label = label.float().to(val_device)
            
            with torch.no_grad():
                outputs = united_net(non_mord_ft, mord_ft, mat_ft)
                
                outputs = torch.squeeze(outputs)
                
                loss = criterion(outputs, label)
                val_losses.append(float(loss.item()))
                val_outputs.extend(outputs)
                val_labels.extend(label)

        train_outputs = torch.stack(train_outputs)
        val_outputs = torch.stack(val_outputs)
        train_labels = torch.stack(train_labels)
        val_labels = torch.stack(val_labels)
        tensorboard_logger.log_value('train_loss', sum(train_losses) / len(train_losses), e + 1)
        tensorboard_logger.log_value('val_loss', sum(val_losses) / len(val_losses), e + 1)
        print('{"metric": "train_loss", "value": %f, "epoch": %d}' % (sum(train_losses) / len(train_losses), e + 1))
        print('{"metric": "val_loss", "value": %f, "epoch": %d}' % (sum(val_losses) / len(val_losses), e + 1))
        for key in metrics.keys():
            train_metric = metrics[key](train_labels, train_outputs)
            val_metric = metrics[key](val_labels, val_outputs)
            print('{"metric": "%s", "value": %f, "epoch": %d}' % ('train_' + key, train_metric, e + 1))
            print('{"metric": "%s", "value": %f, "epoch": %d}' % ('val_' + key, val_metric, e + 1))
            tensorboard_logger.log_value('train_{}'.format(key),
                                         train_metric, e + 1)
            tensorboard_logger.log_value('val_{}'.format(key),
                                         val_metric, e + 1)
        loss_epoch = sum(val_losses) / len(val_losses)
        if loss_epoch < min_loss:
            early_stop_count = 0
            min_loss = loss_epoch
            save_model(united_net, models_path, hash_code)
        else:
            early_stop_count += 1
            if early_stop_count > 30:
                print('Traning can not improve from epoch {}\tBest loss: {}'.format(e, min_loss))
                break

    train_metrics = {}
    val_metrics = {}
    for key in metrics.keys():
        train_metrics[key] = metrics[key](train_labels, train_outputs)
        val_metrics[key] = metrics[key](val_labels, val_outputs)

    return train_metrics, val_metrics


def predict(dataset, model_path, device='cpu'):
    loader = dataloader.DataLoader(dataset=dataset,
                                   batch_size=128,
                                   collate_fn=custom_collate,
                                   shuffle=False)
    
    sm, mord_ft, non_mord_ft, label = next(iter(loader))
    smiles_len = int(non_mord_ft.shape[1]/42)

    united_net = UnitedNet(dense_dim=dataset.get_dim('mord'), smiles_len=smiles_len, use_mat=True).to(device)
    united_net.load_state_dict(torch.load(model_path, map_location=device))
    # EVAL_MODE
    united_net.eval()
    probas = []
    for i, (mord_ft, non_mord_ft, label) in enumerate(loader):
        with torch.no_grad():
            mord_ft = mord_ft.float().to(device)
            non_mord_ft = non_mord_ft.view((-1, int(non_mord_ft.shape[1]/42), 42)).float().to(device)
            mat_ft = non_mord_ft.squeeze(1).float().to(device)
            # Forward to get smiles and equivalent weights
            proba = united_net(non_mord_ft, mord_ft, mat_ft).cpu()
            probas.append(proba)
    print('Forward done !!!')
    probas = np.concatenate(probas)
    return probas


def plot_roc_curve(y_true, y_pred, hashcode=''):

    if not os.path.exists('vis/'):
        os.makedirs('vis/')

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
    auc_roc = metrics.roc_auc_score(y_true, y_pred)
    print('AUC: {:4f}'.format(auc_roc))
    plt.plot(fpr, tpr)
    plt.savefig('vis/ROC_{}'.format(hashcode + '.png'))
    plt.clf()  # Clear figure


def plot_precision_recall(y_true, y_pred, hashcode=''):

    if not os.path.exists('vis/'):
        os.makedirs('vis/')

    precisions, recalls, thresholds = precision_recall_curve(y_true, y_pred)
    plt.plot(thresholds, precisions[:-1], label="Precision")
    plt.plot(thresholds, recalls[:-1], label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0, 1])
    plt.savefig('vis/PR_{}'.format(hashcode + '.png'))
    plt.clf()  # Clear figure


if torch.cuda.is_available():
    train_device = 'cuda'
    val_device = 'cuda'
else:
    train_device = 'cpu'
    val_device = 'cpu'

In [None]:
#nets.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import os


class UnitedNet(nn.Module):
    def __init__(self, dense_dim, smiles_len, use_mord=True, use_mat=True, infer=False, dir_path=None, vis_thresh=0.2):
        super(UnitedNet, self).__init__()
        self.use_mord = use_mord
        self.use_mat = use_mat
        self.infer = infer
        self.vis_thresh = vis_thresh
        self.dir_path = dir_path
        self.smiles_len = smiles_len
        
        if self.dir_path:
            self.smile_out_f = open(os.path.join(self.dir_path, 'smiles.txt'), 'w')
            self.weight_f = open(os.path.join(self.dir_path, 'weight.txt'), 'w')

        # PARAMS FOR CNN NET
        # Convolutionals
        self.conv_conv1 = nn.Conv1d(42, 64, kernel_size=5, padding=2)
        self.conv_pool = nn.MaxPool1d(5)
        self.conv_conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.relu = nn.ReLU()

        # Fully connected
        
        self.conv_fc = nn.Linear(128*12, self.smiles_len)#self.smiles_len // 5 // 5 , 120)

        # Batch norms
        self.conv_batch_norm1 = nn.BatchNorm1d(64)
        self.conv_batch_norm2 = nn.BatchNorm1d(128)
        # # PARAMS FOR CNN NET
        # # Convolutionals
        # self.conv_conv1 = nn.Conv2d(1, 6, kernel_size=3) #smiles_len-2
        # self.conv_pool = nn.MaxPool2d(2, 2) #smiles_len/2
        # self.conv_conv2 = nn.Conv2d(6, 16, kernel_size=3) #smiles_len-2

        # # Fully connected
        # self.conv_fc = nn.Linear(16 * 9 * int(((self.smiles_len-2)/2 - 2)/2), self.smiles_len)

        # # Batch norms
        # self.conv_batch_norm1 = nn.BatchNorm2d(6)
        # self.conv_batch_norm2 = nn.BatchNorm2d(16)

        # PARAMS FOR DENSE NET
        # Fully connected
        if self.use_mord:
            self.dense_fc1 = nn.Linear(dense_dim, 512)
            self.dense_fc2 = nn.Linear(512, 128)
            self.dense_fc3 = nn.Linear(128, 64)

            # Batch norms
            self.dense_batch_norm1 = nn.BatchNorm1d(512)
            self.dense_batch_norm2 = nn.BatchNorm1d(128)
            self.dense_batch_norm3 = nn.BatchNorm1d(64)

            # Dropouts
            self.dense_dropout = nn.Dropout()

        # PARAMS FOR ATTENTION NET
        if self.use_mat:
            #self.att_fc = nn.Linear(256, 1)
            self.att_fc = nn.Linear(self.smiles_len+42+64, 1)
        else:
            self.comb_fc_alt = nn.Linear(128, 1)

        # PARAMS FOR COMBINED NET
        if self.use_mord:
            self.comb_fc = nn.Linear(self.smiles_len+64, 1)
        else:
            self.comb_fc = nn.Linear(self.smiles_len, 1)

    def forward(self, x_non_mord, x_mord, x_mat, smiles=None):
        # FORWARD CNN
        x_non_mord = torch.transpose(x_non_mord, -1,-2)

        x_non_mord = self.conv_conv1(x_non_mord)
        x_non_mord = self.conv_batch_norm1(x_non_mord)
        x_non_mord = F.relu(x_non_mord)
        x_non_mord = self.conv_pool(x_non_mord)

        x_non_mord = self.conv_conv2(x_non_mord)
        x_non_mord = self.conv_batch_norm2(x_non_mord)
        x_non_mord = F.relu(x_non_mord)
        x_non_mord = self.conv_pool(x_non_mord)

        # print(x_non_mord.shape)
        x_non_mord = x_non_mord.view(x_non_mord.size(0), -1)
        if self.use_mat:
            x_non_mord = F.sigmoid(self.conv_fc(x_non_mord))
        else:
            x_non_mord = F.relu(self.conv_fc(x_non_mord))

        # FORWARD DENSE
        if self.use_mord:
            x_mord = F.relu(self.dense_fc1(x_mord))
            x_mord = self.dense_batch_norm1(x_mord)
            x_mord = self.dense_dropout(x_mord)

            x_mord = F.relu(self.dense_fc2(x_mord))
            x_mord = self.dense_batch_norm2(x_mord)
            x_mord = self.dense_dropout(x_mord)

            x_mord = F.relu(self.dense_fc3(x_mord))
            x_mord = self.dense_batch_norm3(x_mord)
            x_mord = self.dense_dropout(x_mord)

        # FORWARD ATTENTION
        if self.use_mat:
            x_mat = torch.bmm(x_mat.permute(0, 2, 1), x_non_mord.unsqueeze(-1)).squeeze(-1)
            x_mat = torch.cat([x_mat, x_non_mord], dim=1)

            if self.use_mord:
                x_comb = torch.cat([x_mat, x_mord], dim=1)
                probs = torch.sigmoid(self.att_fc(x_comb))
                if self.infer:
                    if not smiles:
                        raise ValueError('Please input smiles')
                    alphas = x_comb.cpu().detach().numpy().tolist()
                    alphas = ["\t".join([str(round(elem, 4)) for elem in seq]) for seq in alphas]
                    prob_list = probs.cpu().detach().numpy().tolist()
                    for smile, alpha, prob in zip(smiles, alphas, prob_list):
                        if prob[0] > self.vis_thresh:
                            self.weight_f.write(alpha + '\n')
                            self.smile_out_f.write(smile + '\n')
                return probs
            else:
                return torch.sigmoid(self.comb_fc(x_mat))
        else:
            if self.use_mord:
                x_comb = torch.cat([x_non_mord, x_mord], dim=1)
            else:
                x_comb = x_non_mord
            return torch.sigmoid(self.comb_fc(x_comb))

    def __del__(self):
        print('Closing files ...')
        if hasattr(self, 'weight_f'):
            self.weight_f.close()
        if hasattr(self, 'smile_out_f'):
            self.smile_out_f.close()

In [None]:
!pip install torch torchvision

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs

# Test

In [None]:
del train_dataset

In [None]:
del val_dataset

In [None]:
pred_dataset = ANYDataset(path2data)

In [None]:
trained_model_path = f'/content/model_{hashcode}_BEST'

In [None]:
y_pred = predict(pred_dataset, trained_model_path, train_device)
y_true = pred_dataset.label

#plots are saved in /content/vis
plot_roc_curve(y_true, y_pred, trained_model_path.split('/')[-1])
plot_precision_recall(y_true, y_pred, trained_model_path.split('/')[-1])