In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
train_filename = "./data/benchmark/Kd_P31389_1_25_train.csv"
test_filename = "./data/benchmark/Kd_P31389_1_25_test.csv"
test_active = 25
val_rate = 0.15
random_seed = 1
file_list1 = train_filename.split('/')
file1 = file_list1[-1]
file1 = file1[:-10]
number = '_run_0'
model_file = "model_file/0_GAFSE_"+file1+number
log_dir = f'log/{"0_GAFSE_"+file1}'+number
result_dir = './result/0_GAFSE_'+file1+number
print(file1)
print(model_file)

Kd_P31389_1_25
model_file/0_GAFSE_Kd_P31389_1_25_run_0


In [4]:
# task_name = 'Malaria Bioactivity'
tasks = ['value']

# train_filename = "../data/active_inactive/median_active/EC50/Q99500.csv"
feature_filename = train_filename.replace('.csv','.pickle')
filename = train_filename.replace('.csv','')
prefix_filename = train_filename.split('/')[-1].replace('.csv','')
train_df = pd.read_csv(train_filename, header=0, names = ["smiles","value"],usecols=[0,1])
# train_df = train_df[1:]
# train_df = train_df.drop(0,axis=1,inplace=False) 
print(train_df[:5])
# print(train_df.iloc(1))
def add_canonical_smiles(train_df):
    smilesList = train_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    train_df = train_df[train_df["smiles"].isin(remained_smiles)]
    train_df['cano_smiles'] =canonical_smiles_list
    return train_df
# print(train_df)
train_df = add_canonical_smiles(train_df)

print(train_df.head())
# plt.figure(figsize=(5, 3))
# sns.set(font_scale=1.5)
# ax = sns.distplot(atom_num_dist, bins=28, kde=False)
# plt.tight_layout()
# # plt.savefig("atom_num_dist_"+prefix_filename+".png",dpi=200)
# plt.show()
# plt.close()


                                              smiles     value
0    CN(C)CCCN1C(SCC1=O)C2=CC(=CC=C2)[N+](=O)[O-].Cl -3.600000
1                          CC(=O)NC(=NCCCC1=CN=CN1)N -3.930000
2       C1CCC(CC1)CCCOCCCC2=CN=CN2.C(=CC(=O)O)C(=O)O -3.900000
3  CN(C)C1=CC=CC2=C1C=CC=C2S(=O)(=O)NCCCCCCN(C)CC... -1.249932
4    C1=CC=C2C(=C1)C(=CN2)CCC(=O)NC(=NCCCC3=CN=CN3)N -2.600003
number of all smiles:  178
number of successfully processed smiles:  178
                                              smiles     value  \
0    CN(C)CCCN1C(SCC1=O)C2=CC(=CC=C2)[N+](=O)[O-].Cl -3.600000   
1                          CC(=O)NC(=NCCCC1=CN=CN1)N -3.930000   
2       C1CCC(CC1)CCCOCCCC2=CN=CN2.C(=CC(=O)O)C(=O)O -3.900000   
3  CN(C)C1=CC=CC2=C1C=CC=C2S(=O)(=O)NCCCCCCN(C)CC... -1.249932   
4    C1=CC=C2C(=C1)C(=CN2)CCC(=O)NC(=NCCCC3=CN=CN3)N -2.600003   

                                         cano_smiles  
0       CN(C)CCCN1C(=O)CSC1c1cccc([N+](=O)[O-])c1.Cl  
1                         CC(=O)NC(N)=NCCC

In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 2 # default: 2
T = 1
per_task_output_units_num = 1 # for regression model
output_units_num = len(tasks) * per_task_output_units_num

In [6]:
test_df = pd.read_csv(test_filename,header=0,names=["smiles","value"],usecols=[0,1])
test_df = add_canonical_smiles(test_df)
for l in test_df["cano_smiles"]:
    if l in train_df["cano_smiles"]:
        print("same smiles:",l)
        
print(test_df.shape)
print(test_df.head())

number of all smiles:  133
number of successfully processed smiles:  133
(133, 3)
                                              smiles     value  \
0  CN(CCCCCCNC(=O)CCCCCNC1=CC=C(C2=NON=C12)[N+](=... -0.290035   
1                 CN(C)CCCN1C(SCC1=O)C2=CC=C(C=C2)Cl -3.400001   
2  C1CCN(CC1)CC2=CC(=CC=C2)OCCCN=C(NCCNC(=O)C3=CC... -3.199999   
3  C1CN(CCN1CCCOC2=CC3=C(C=C2)C(=C(C(=O)O3)[N+](=... -1.100026   
4                         C1CCC(CC1)CCCOCCCC2=CN=CN2 -3.900000   

                                         cano_smiles  
0  COc1ccc(CN(CCN(C)CCCCCCNC(=O)CCCCCNc2ccc([N+](...  
1                    CN(C)CCCN1C(=O)CSC1c1ccc(Cl)cc1  
2  N#CNC(=NCCCOc1cccc(CN2CCCCC2)c1)NCCNC(=O)c1cc(...  
3  O=c1oc2cc(OCCCN3CCN(Cc4ccccc4)CC3)ccc2c(O)c1[N...  
4                        c1ncc(CCCOCCCC2CCCCC2)[nH]1  


In [7]:
print(feature_filename)
print(filename)
total_df = pd.concat([train_df,test_df],axis=0)
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
# if os.path.isfile(feature_filename):
#     feature_dicts = pickle.load(open(feature_filename, "rb" ))
# else:
#     feature_dicts = save_smiles_dicts(smilesList,filename)
feature_dicts = save_smiles_dicts(total_smilesList,filename)
remained_df = total_df[total_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = total_df.drop(remained_df.index)

./data/benchmark/Kd_P31389_1_25_train.pickle
./data/benchmark/Kd_P31389_1_25_train
311
feature dicts file saved as ./data/benchmark/Kd_P31389_1_25_train.pickle


In [8]:
val_df = train_df.sample(frac=val_rate,random_state=random_seed)
train_df = train_df.drop(val_df.index)
train_df = train_df.reset_index(drop=True)
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)
print(train_df.shape,val_df.shape,test_df.shape)

(151, 3) (27, 3) (133, 3)


In [9]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [10]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [11]:
def caculate_r2(y,predict):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/y2

In [12]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None):
    mol_prediction = model(feature=f, d=0)
    pred = mol_prediction.detach()
#     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
    eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
    eps = Variable(eps, requires_grad=True)
    # Predict on randomly perturbed image
    eps_p = model(feature=f, d=eps.cuda())
    eps_p_ = model(feature=f, d=-eps.cuda())
    p_aux = nn.Sigmoid()(eps_p/(pred+1e-6))
    p_aux_ = nn.Sigmoid()(eps_p_/(pred+1e-6))
#     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
    loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
    loss.backward(retain_graph=True)

    # Based on perturbed image, get direction of greatest error
    eps_adv = eps.grad#/10**-learning_rate
    optimizer_AFSE.zero_grad()
    # Use that direction as adversarial perturbation
    eps_adv_normed = normalize_perturbation(eps_adv)
    d_adv = lamda * eps_adv_normed.cuda()
    if output_lr:
        f_p, max_lr = model(feature=f, d=d_adv, output_lr=output_lr)
    f_p = model(feature=f, d=d_adv)
    f_p_ = model(feature=f, d=-d_adv)
    p = nn.Sigmoid()(f_p/(pred+1e-6))
    p_ = nn.Sigmoid()(f_p_/(pred+1e-6))
    vat_loss = loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    if output_lr:
        if output_plr:
            loss = loss_function(mol_prediction,y)
            loss.backward(retain_graph=True)
            optimizer_AFSE.zero_grad()
            punish_lr = torch.norm(torch.mean(eps.grad,0))
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [13]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def generate_loss_function(refer_atom_list, x_atom, validity_mask, atom_list):
    [a,b,c] = x_atom.shape
    reconstruction_loss = 0
    counter = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        reconstruction_loss += weighted_CE_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16]) - \
                        ((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9)).mean(-1).mean(-1)
        counter += 1
    reconstruction_loss = reconstruction_loss/counter
    return reconstruction_loss


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(float)
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                  torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
                                                  mol_feature=mol_feature,activated_features=activated_features.detach())
        
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1, 
                                                                                       lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1)) # 10**-learning_rate     
        regression_loss = loss_function(mol_prediction, torch.Tensor(y_val).view(-1,1))
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
        success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=1)
        reconstruction_loss = generate_loss_function(refer_atom_list, x_atom, validity_mask, atom_list)
        x_atom_test = torch.Tensor(x_atom_test)
        x_bonds_test = torch.Tensor(x_bonds_test)
        x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
        bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
        bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate)
        atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
                                                torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
        refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                            torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                            mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
        success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
                            bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
                                                     refer_atom_list_test, refer_bond_list_test,topn=1)
        test_reconstruction_loss = generate_loss_function(atom_list_test, x_atom_test, validity_mask_test, atom_list_test)
        
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        logger.add_scalar('loss/regression', regression_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
        logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  regression_loss + 0.6 * (vat_loss + test_vat_loss) + reconstruction_loss + test_reconstruction_loss
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, topn=1, output_feature=False, generate=False, modify_atom=True,return_GRN_loss=False, viz=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(float)
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda)
#         print(mol_feature,d_adv)
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
                                      torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
        if generate:
            if modify_atom:
                success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=topn,viz=viz)
            else:
                modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
            generated_smiles.extend(modified_smiles)
            success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
            success_reconstruction += reconstruction
            success_validity += validity
            reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
        
        MSE = F.mse_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction.cpu().detach().numpy())
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
        test_MSE_list.extend(MSE.data.view(-1,1).cpu().numpy())
#     print(r2_list)
    if generate:
        generated_num = len(generated_smiles)
        eval_num = len(dataset)
        unique = generated_num
        novelty = generated_num
        for i in range(generated_num):
            for j in range(generated_num-i-1):
                if generated_smiles[i]==generated_smiles[i+j+1]:
                    unique -= 1
            for k in range(eval_num):
                if generated_smiles[i]==dataset['smiles'].values[k]:
                    novelty -= 1
        unique_rate = unique/(generated_num+1e-9)
        novelty_rate = novelty/(generated_num+1e-9)
#         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
        return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    if return_GRN_loss:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    return caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 30
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [14]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
if os.path.isdir(log_dir):
    for files in os.listdir(log_dir):
        os.remove(log_dir+"/"+files)
    os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

log/0_GAFSE_Kd_P31389_1_25_run_0


In [15]:
# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = 0
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
#     print(train_df.shape,test_df.shape)
    train_d, train_f, train_r2, train_MSE, train_predict, reconstruction_loss, one_hot_loss, interger_loss,binary_loss = eval(model, amodel, gmodel, train_df,output_feature=True,return_GRN_loss=True)
    train_predict = np.array(train_predict)
    train_WTI = weighted_top_index(train_df, train_predict, len(train_df))
    train_tau, _ = scipy.stats.kendalltau(train_predict,train_df[tasks[0]].values.astype(float).tolist())
    val_d, val_f, val_r2, val_MSE, val_predict, val_reconstruction_loss, val_one_hot_loss, val_interger_loss,val_binary_loss = eval(model, amodel, gmodel, val_df,output_feature=True,return_GRN_loss=True)
    val_predict = np.array(val_predict)
    val_WTI = weighted_top_index(val_df, val_predict, len(val_df))
    val_AP = AP(val_df, val_predict, len(val_df))
    val_tau, _ = scipy.stats.kendalltau(val_predict,val_df[tasks[0]].values.astype(float).tolist())
    
    test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
    test_d, test_f, test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df,output_feature=True)
    test_predict = np.array(test_predict)
    test_WTI = weighted_top_index(test_df, test_predict, test_active)
#     test_AP = AP(test_df, test_predict, test_active)
    test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())
    
    k_list = [int(len(test_df)*0.01),int(len(test_df)*0.03),int(len(test_df)*0.1),10,30,100]
    topk_list =[]
    false_positive_rate_list = []
    for k in k_list:
        a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
        topk_list.append(a)
        false_positive_rate_list.append(b)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    logger.add_scalar('val/WTI', val_WTI, global_step)
    logger.add_scalar('val/AP', val_AP, global_step)
    logger.add_scalar('val/r2', val_r2, global_step)
    logger.add_scalar('val/RMSE', val_MSE**0.5, global_step)
    logger.add_scalar('val/Tau', val_tau, global_step)
#     logger.add_scalar('test/TAP', test_AP, global_step)
    logger.add_scalar('test/r2', test_r2_a, global_step)
    logger.add_scalar('test/RMSE', test_MSE_a**0.5, global_step)
    logger.add_scalar('test/Tau', test_tau, global_step)
    logger.add_scalar('val/GRN', reconstruction_loss, global_step)
    logger.add_scalar('test/EF0.01', topk_list[0], global_step)
    logger.add_scalar('test/EF0.03', topk_list[1], global_step)
    logger.add_scalar('test/EF0.1', topk_list[2], global_step)
    logger.add_scalar('test/EF10', topk_list[3], global_step)
    logger.add_scalar('test/EF30', topk_list[4], global_step)
    logger.add_scalar('test/EF100', topk_list[5], global_step)
    
#     train_mse_list.append(train_MSE**0.5)
#     train_r2_list.append(train_r2)
#     val_mse_list.append(val_MSE**0.5)  
#     val_r2_list.append(val_r2)
#     train_f_list.append(train_f)
#     val_f_list.append(val_f)
#     test_f_list.append(test_f)
#     epoch_list.append(epoch)
#     train_predict_list.append(train_predict.flatten())
#     test_predict_list.append(test_predict.flatten())
#     val_predict_list.append(val_predict.flatten())
#     train_y_list.append(train_df[tasks[0]].values)
#     val_y_list.append(val_df[tasks[0]].values)
#     test_y_list.append(test_df[tasks[0]].values)
#     train_d_list.append(train_d)
#     val_d_list.append(val_d)
#     test_d_list.append(test_d)

    stop_index = val_r2
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
#     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
    print('Epoch:',epoch, 'Step:', global_step, 'Index:%.4f'%stop_index, 'R2:%.4f'%train_r2,'%.4f'%val_r2,'%.4f'%test_r2_a, 'RMSE:%.4f'%train_MSE**0.5, '%.4f'%val_MSE**0.5, 
          '%.4f'%test_MSE_a**0.5, 'Tau:%.4f'%train_tau,'%.4f'%val_tau,'%.4f'%test_tau)#, 'Tau:%.4f'%val_tau,'%.4f'%test_tau,'GRN:%.4f'%reconstruction_loss,'%.4f'%val_reconstruction_loss
    if early_stop:
        break


  y = torch.FloatTensor(y).reshape(-1,1)


Epoch: 1 Step: 15 Index:0.0124 R2:0.0350 0.0124 0.1081 RMSE:1.5943 1.8478 1.5284 Tau:-0.1101 0.1066 -0.3244
EarlyStopping counter: 1 out of 30
Epoch: 2 Step: 30 Index:0.0102 R2:0.0323 0.0102 0.1059 RMSE:1.5805 1.6692 1.5941 Tau:-0.1069 0.1009 -0.3193
EarlyStopping counter: 2 out of 30
Epoch: 3 Step: 45 Index:0.0056 R2:0.0254 0.0056 0.1022 RMSE:1.5119 1.7209 1.4610 Tau:-0.0938 0.1066 -0.3088
EarlyStopping counter: 3 out of 30
Epoch: 4 Step: 60 Index:0.0030 R2:0.0205 0.0030 0.0973 RMSE:1.4696 1.6316 1.4339 Tau:-0.0873 0.0951 -0.2963
EarlyStopping counter: 4 out of 30
Epoch: 5 Step: 75 Index:0.0007 R2:0.0149 0.0007 0.0914 RMSE:1.4357 1.5285 1.4196 Tau:-0.0793 0.0836 -0.2831
EarlyStopping counter: 5 out of 30
Epoch: 6 Step: 90 Index:0.0013 R2:0.0061 0.0013 0.0787 RMSE:1.3933 1.5485 1.3547 Tau:-0.0639 0.0951 -0.2604
EarlyStopping counter: 6 out of 30
Epoch: 7 Step: 105 Index:0.0120 R2:0.0007 0.0120 0.0625 RMSE:1.3498 1.4536 1.3193 Tau:-0.0447 0.1009 -0.2238
Epoch: 8 Step: 120 Index:0.0778 R

Epoch: 63 Step: 945 Index:0.4685 R2:0.5625 0.4685 0.4460 RMSE:0.8955 0.9806 0.9289 Tau:0.5848 0.4294 0.3497
EarlyStopping counter: 1 out of 30
Epoch: 64 Step: 960 Index:0.4554 R2:0.5608 0.4554 0.4736 RMSE:0.8807 1.0679 0.9206 Tau:0.5759 0.4294 0.3657
Epoch: 65 Step: 975 Index:0.4759 R2:0.5696 0.4759 0.4448 RMSE:0.9145 0.9759 0.9521 Tau:0.5904 0.4294 0.3427
EarlyStopping counter: 1 out of 30
Epoch: 66 Step: 990 Index:0.4732 R2:0.5713 0.4732 0.4602 RMSE:0.8638 0.9918 0.9043 Tau:0.5893 0.4352 0.3571
EarlyStopping counter: 2 out of 30
Epoch: 67 Step: 1005 Index:0.4722 R2:0.5744 0.4722 0.4647 RMSE:0.8625 0.9863 0.9007 Tau:0.5904 0.4352 0.3564
EarlyStopping counter: 3 out of 30
Epoch: 68 Step: 1020 Index:0.4676 R2:0.5747 0.4676 0.4781 RMSE:0.8611 1.0382 0.9056 Tau:0.5876 0.4410 0.3642
Epoch: 69 Step: 1035 Index:0.4812 R2:0.5794 0.4812 0.4523 RMSE:0.8719 0.9691 0.9218 Tau:0.5945 0.4294 0.3505
EarlyStopping counter: 1 out of 30
Epoch: 70 Step: 1050 Index:0.4770 R2:0.5817 0.4770 0.4724 RMSE:0.8

Epoch: 127 Step: 1905 Index:0.6252 R2:0.7184 0.6252 0.5857 RMSE:0.7379 0.8399 0.8161 Tau:0.6648 0.4813 0.3700
EarlyStopping counter: 1 out of 30
Epoch: 128 Step: 1920 Index:0.6114 R2:0.7160 0.6114 0.5933 RMSE:0.7559 0.8678 0.8278 Tau:0.6645 0.4928 0.3610
EarlyStopping counter: 2 out of 30
Epoch: 129 Step: 1935 Index:0.6193 R2:0.7150 0.6193 0.5981 RMSE:0.7899 0.9431 0.8752 Tau:0.6534 0.4582 0.3759
Epoch: 130 Step: 1950 Index:0.6257 R2:0.7211 0.6257 0.5927 RMSE:0.7225 0.8311 0.8013 Tau:0.6638 0.4755 0.3673
EarlyStopping counter: 1 out of 30
Epoch: 131 Step: 1965 Index:0.6218 R2:0.7225 0.6218 0.6004 RMSE:0.6941 0.8359 0.7778 Tau:0.6650 0.4698 0.3720
Epoch: 132 Step: 1980 Index:0.6278 R2:0.7239 0.6278 0.5951 RMSE:0.6905 0.8232 0.7808 Tau:0.6637 0.4813 0.3716
Epoch: 133 Step: 1995 Index:0.6308 R2:0.7254 0.6308 0.5979 RMSE:0.6893 0.8220 0.7787 Tau:0.6652 0.4698 0.3708
Epoch: 134 Step: 2010 Index:0.6383 R2:0.7044 0.6383 0.5430 RMSE:0.8259 0.8737 0.9334 Tau:0.6510 0.5159 0.3567
EarlyStopping c

EarlyStopping counter: 1 out of 30
Epoch: 189 Step: 2835 Index:0.6967 R2:0.7760 0.6967 0.6265 RMSE:0.6262 0.7467 0.7570 Tau:0.7034 0.5274 0.3727
EarlyStopping counter: 2 out of 30
Epoch: 190 Step: 2850 Index:0.6879 R2:0.7815 0.6879 0.6491 RMSE:0.6177 0.7539 0.7290 Tau:0.7015 0.5735 0.3727
EarlyStopping counter: 3 out of 30
Epoch: 191 Step: 2865 Index:0.7021 R2:0.7825 0.7021 0.6354 RMSE:0.6157 0.7349 0.7415 Tau:0.7052 0.5735 0.3727
Epoch: 192 Step: 2880 Index:0.7038 R2:0.7836 0.7038 0.6380 RMSE:0.6511 0.7589 0.7588 Tau:0.7113 0.5678 0.3712
EarlyStopping counter: 1 out of 30
Epoch: 193 Step: 2895 Index:0.6972 R2:0.7841 0.6972 0.6489 RMSE:0.6112 0.7418 0.7320 Tau:0.7063 0.5735 0.3751
EarlyStopping counter: 2 out of 30
Epoch: 194 Step: 2910 Index:0.6936 R2:0.7854 0.6936 0.6518 RMSE:0.6163 0.7463 0.7242 Tau:0.7107 0.5620 0.3751
EarlyStopping counter: 3 out of 30
Epoch: 195 Step: 2925 Index:0.7028 R2:0.7803 0.7028 0.6231 RMSE:0.6270 0.7382 0.7569 Tau:0.7081 0.5678 0.3708
Epoch: 196 Step: 294

Epoch: 247 Step: 3705 Index:0.7429 R2:0.8279 0.7429 0.6751 RMSE:0.6256 0.7348 0.7425 Tau:0.7400 0.6542 0.3801
EarlyStopping counter: 1 out of 30
Epoch: 248 Step: 3720 Index:0.7204 R2:0.8247 0.7204 0.6976 RMSE:0.5963 0.7341 0.6898 Tau:0.7334 0.6081 0.3875
EarlyStopping counter: 2 out of 30
Epoch: 249 Step: 3735 Index:0.7224 R2:0.8238 0.7224 0.6870 RMSE:0.5685 0.7243 0.7048 Tau:0.7341 0.6196 0.3930
Epoch: 250 Step: 3750 Index:0.7489 R2:0.8285 0.7489 0.6738 RMSE:0.5461 0.6734 0.7012 Tau:0.7304 0.6312 0.3829
EarlyStopping counter: 1 out of 30
Epoch: 251 Step: 3765 Index:0.7453 R2:0.8295 0.7453 0.6772 RMSE:0.5451 0.6789 0.6979 Tau:0.7351 0.6196 0.3883
EarlyStopping counter: 2 out of 30
Epoch: 252 Step: 3780 Index:0.7428 R2:0.8270 0.7428 0.6774 RMSE:0.6046 0.7140 0.7179 Tau:0.7412 0.6542 0.3914
EarlyStopping counter: 3 out of 30
Epoch: 253 Step: 3795 Index:0.7344 R2:0.8305 0.7344 0.6940 RMSE:0.5532 0.7074 0.7040 Tau:0.7347 0.6196 0.3911
EarlyStopping counter: 4 out of 30
Epoch: 254 Step: 381

EarlyStopping counter: 6 out of 30
Epoch: 305 Step: 4575 Index:0.7437 R2:0.8586 0.7437 0.6934 RMSE:0.5532 0.6962 0.6974 Tau:0.7519 0.6081 0.3778
EarlyStopping counter: 7 out of 30
Epoch: 306 Step: 4590 Index:0.7462 R2:0.8585 0.7462 0.6731 RMSE:0.5034 0.6797 0.7028 Tau:0.7564 0.6369 0.3981
EarlyStopping counter: 8 out of 30
Epoch: 307 Step: 4605 Index:0.7257 R2:0.8382 0.7257 0.6484 RMSE:0.5418 0.7179 0.7357 Tau:0.7452 0.6254 0.4109
EarlyStopping counter: 9 out of 30
Epoch: 308 Step: 4620 Index:0.7505 R2:0.8361 0.7505 0.6607 RMSE:0.6195 0.7605 0.7897 Tau:0.7396 0.5908 0.3805
EarlyStopping counter: 10 out of 30
Epoch: 309 Step: 4635 Index:0.6974 R2:0.8429 0.6974 0.7001 RMSE:0.5991 0.7777 0.7084 Tau:0.7493 0.6023 0.3992
EarlyStopping counter: 11 out of 30
Epoch: 310 Step: 4650 Index:0.7380 R2:0.8558 0.7380 0.6548 RMSE:0.4999 0.6954 0.7345 Tau:0.7587 0.6715 0.4043
EarlyStopping counter: 12 out of 30
Epoch: 311 Step: 4665 Index:0.7455 R2:0.8610 0.7455 0.6715 RMSE:0.5162 0.6843 0.7113 Tau:0.7

In [16]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df)
test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
test_r2_ina, test_MSE_ina, test_predict_ina = eval(model, amodel, gmodel, test_df[test_active:].reset_index(drop=True))
    
test_predict = np.array(test_predict)
test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())

k_list = [int(len(test_df)*0.01),int(len(test_df)*0.05),int(len(test_df)*0.1),int(len(test_df)*0.15),int(len(test_df)*0.2),int(len(test_df)*0.25),
          int(len(test_df)*0.3),int(len(test_df)*0.4),int(len(test_df)*0.5),50,100,150,200,250,300]
topk_list =[]
false_positive_rate_list = []
for k in k_list:
    a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
    topk_list.append(a)
    false_positive_rate_list.append(b)
WTI = weighted_top_index(test_df, test_predict, test_active)
ap = AP(test_df, test_predict, test_active)


In [17]:
print(' epoch:',epoch,'r2:%.4f'%test_r2_a,'RMSE:%.4f'%test_MSE_a**0.5,'WTI:%.4f'%WTI,'AP:%.4f'%ap,'Tau:%.4f'%test_tau,'\n','\n',
      'Top-1:%.4f'%topk_list[0],'Top-1-fp:%.4f'%false_positive_rate_list[0],'\n',
      'Top-5:%.4f'%topk_list[1],'Top-5-fp:%.4f'%false_positive_rate_list[1],'\n',
      'Top-10:%.4f'%topk_list[2],'Top-10-fp:%.4f'%false_positive_rate_list[2],'\n',
      'Top-15:%.4f'%topk_list[3],'Top-15-fp:%.4f'%false_positive_rate_list[3],'\n',
      'Top-20:%.4f'%topk_list[4],'Top-20-fp:%.4f'%false_positive_rate_list[4],'\n',
      'Top-25:%.4f'%topk_list[5],'Top-25-fp:%.4f'%false_positive_rate_list[5],'\n',
      'Top-30:%.4f'%topk_list[6],'Top-30-fp:%.4f'%false_positive_rate_list[6],'\n',
      'Top-40:%.4f'%topk_list[7],'Top-40-fp:%.4f'%false_positive_rate_list[7],'\n',
      'Top-50:%.4f'%topk_list[8],'Top-50-fp:%.4f'%false_positive_rate_list[8],'\n','\n',
      'Top50:%.4f'%topk_list[9],'Top50-fp:%.4f'%false_positive_rate_list[9],'\n',
      'Top100:%.4f'%topk_list[10],'Top100-fp:%.4f'%false_positive_rate_list[10],'\n',
      'Top150:%.4f'%topk_list[11],'Top150-fp:%.4f'%false_positive_rate_list[11],'\n',
      'Top200:%.4f'%topk_list[12],'Top200-fp:%.4f'%false_positive_rate_list[12],'\n',
      'Top250:%.4f'%topk_list[13],'Top250-fp:%.4f'%false_positive_rate_list[13],'\n',
      'Top300:%.4f'%topk_list[14],'Top300-fp:%.4f'%false_positive_rate_list[14],'\n')

 epoch: 329 r2:0.6793 RMSE:0.6981 WTI:0.3439 AP:0.5135 Tau:0.3907 
 
 Top-1:0.0000 Top-1-fp:1.0000 
 Top-5:0.5000 Top-5-fp:0.5000 
 Top-10:0.3077 Top-10-fp:0.6154 
 Top-15:0.5263 Top-15-fp:0.4737 
 Top-20:0.5200 Top-20-fp:0.5000 
 Top-25:0.6800 Top-25-fp:0.4848 
 Top-30:0.7600 Top-30-fp:0.5128 
 Top-40:0.8400 Top-40-fp:0.6038 
 Top-50:0.9200 Top-50-fp:0.6515 
 
 Top50:0.8400 Top50-fp:0.5800 
 Top100:0.9600 Top100-fp:0.7600 
 Top150:1.0000 Top150-fp:0.7200 
 Top200:1.0000 Top200-fp:0.5400 
 Top250:1.0000 Top250-fp:0.4320 
 Top300:1.0000 Top300-fp:0.3600 



In [18]:
# print('target_file:',train_filename)
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)