In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
train_filename = "./data/benchmark/Ki_P14416_0.3333333333333333_350_train.csv"
test_filename = "./data/benchmark/Ki_P14416_0.3333333333333333_350_test.csv"
test_active = 350
val_rate = 0.2
random_seed = 68
file_list1 = train_filename.split('/')
file1 = file_list1[-1]
file1 = file1[:-10]
number = '_run_0'
model_file = "model_file/3_GAFSE_"+file1+number
log_dir = f'log/{"3_GAFSE_"+file1}'+number
result_dir = './result/3_GAFSE_'+file1+number
print(file1)
print(model_file)

Ki_P14416_0.3333333333333333_350
model_file/3_GAFSE_Ki_P14416_0.3333333333333333_350_run_0


In [4]:
# task_name = 'Malaria Bioactivity'
tasks = ['value']

# train_filename = "../data/active_inactive/median_active/EC50/Q99500.csv"
feature_filename = train_filename.replace('.csv','.pickle')
filename = train_filename.replace('.csv','')
prefix_filename = train_filename.split('/')[-1].replace('.csv','')
train_df = pd.read_csv(train_filename, header=0, names = ["smiles","value"],usecols=[0,1])
# train_df = train_df[1:]
# train_df = train_df.drop(0,axis=1,inplace=False) 
print(train_df[:5])
# print(train_df.iloc(1))
def add_canonical_smiles(train_df):
    smilesList = train_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    train_df = train_df[train_df["smiles"].isin(remained_smiles)]
    train_df['cano_smiles'] =canonical_smiles_list
    return train_df
# print(train_df)
train_df = add_canonical_smiles(train_df)

print(train_df.head())
# plt.figure(figsize=(5, 3))
# sns.set(font_scale=1.5)
# ax = sns.distplot(atom_num_dist, bins=28, kde=False)
# plt.tight_layout()
# # plt.savefig("atom_num_dist_"+prefix_filename+".png",dpi=200)
# plt.show()
# plt.close()


                                              smiles     value
0  CC1CCN(C2=C1C=CC=C2CCN3CCN(CC3)C4=NSC5=CC=CC=C... -1.230449
1  C1CN(CCN1CCCOC2=CC3=C(CNC3=O)C=C2)C4=CC=CC5=C4... -1.444045
2  C1CN(CCN1CC2=CN=CC(=C2)C3=CC=CC=C3)C4=CC5=C(C=... -1.100371
3  CCC1CC2=CC=CC=C2N1C(=O)CN3CCN(CC3)CC4=CC=C(C=C... -2.340444
4  CN1CCC2=CC(=C(C=C2C3C1CCC4=C(C=CC=C34)C5=CC=C(... -1.949390
number of all smiles:  1419
number of successfully processed smiles:  1419
                                              smiles     value  \
0  CC1CCN(C2=C1C=CC=C2CCN3CCN(CC3)C4=NSC5=CC=CC=C... -1.230449   
1  C1CN(CCN1CCCOC2=CC3=C(CNC3=O)C=C2)C4=CC=CC5=C4... -1.444045   
2  C1CN(CCN1CC2=CN=CC(=C2)C3=CC=CC=C3)C4=CC5=C(C=... -1.100371   
3  CCC1CC2=CC=CC=C2N1C(=O)CN3CCN(CC3)CC4=CC=C(C=C... -2.340444   
4  CN1CCC2=CC(=C(C=C2C3C1CCC4=C(C=CC=C34)C5=CC=C(... -1.949390   

                                         cano_smiles  
0  CC(=O)N1CCC(C)c2cccc(CCN3CCN(c4nsc5ccccc45)CC3...  
1   O=C1NCc2ccc(OCCCN3CCN(c4cccc5ccc(F)c

In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 2 # default: 2
T = 1
per_task_output_units_num = 1 # for regression model
output_units_num = len(tasks) * per_task_output_units_num

In [6]:
test_df = pd.read_csv(test_filename,header=0,names=["smiles","value"],usecols=[0,1])
test_df = add_canonical_smiles(test_df)
for l in test_df["cano_smiles"]:
    if l in train_df["cano_smiles"]:
        print("same smiles:",l)
        
print(test_df.shape)
print(test_df.head())

number of all smiles:  1664
number of successfully processed smiles:  1664
(1664, 3)
                                              smiles     value  \
0  C1CN(CCN1CCCCOC2=CC3=C(CNC3)C=C2)C4=CC=CC5=C4C... -1.906874   
1  CN1CCN(CC1)C2=NC3=C(C=CC(=C3)Cl)N(C4=CC=CC=C42... -3.107210   
2  CN1C=C(C=CC1=O)C2=NN=C(N2C)SCCCN3CCC4(C3)CC4C5... -2.940008   
3  CCOC1=CC=CC=C1N2CCCN(CC2)CCCCOC3=CC4=C(C=C3)C=... -0.556303   
4  C1CN(CCN1CC(CCNC(=O)C2=CC=C(C=C2)C3=CC=CC=N3)O... -2.426511   

                                         cano_smiles  
0    Fc1ccc2cccc(N3CCN(CCCCOc4ccc5c(c4)CNC5)CC3)c2c1  
1  CN1CCN(C2=Nc3cc(Cl)ccc3N(NC(=O)c3cc(F)c(F)c(F)...  
2  Cn1c(SCCCN2CCC3(CC3c3ccc(C(F)(F)F)cc3)C2)nnc1-...  
3   CCOc1ccccc1N1CCCN(CCCCOc2ccc3ccc(=O)[nH]c3c2)CC1  
4  O=C(NCCC(O)CN1CCN(c2cccc(Cl)c2Cl)CC1)c1ccc(-c2...  


In [7]:
print(feature_filename)
print(filename)
total_df = pd.concat([train_df,test_df],axis=0)
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
# if os.path.isfile(feature_filename):
#     feature_dicts = pickle.load(open(feature_filename, "rb" ))
# else:
#     feature_dicts = save_smiles_dicts(smilesList,filename)
feature_dicts = save_smiles_dicts(total_smilesList,filename)
remained_df = total_df[total_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = total_df.drop(remained_df.index)

./data/benchmark/Ki_P14416_0.3333333333333333_350_train.pickle
./data/benchmark/Ki_P14416_0.3333333333333333_350_train
3083
Cc1ncoc1-c1nnc(SCCCN2CC3CC3(c3cccc(S(F)(F)(F)(F)F)c3)C2)n1C
feature dicts file saved as ./data/benchmark/Ki_P14416_0.3333333333333333_350_train.pickle


In [8]:
val_df = train_df.sample(frac=val_rate,random_state=random_seed)
train_df = train_df.drop(val_df.index)
train_df = train_df.reset_index(drop=True)
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)
print(train_df.shape,val_df.shape,test_df.shape)

(1135, 3) (284, 3) (1663, 3)


In [9]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [10]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [11]:
def caculate_r2(y,predict):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/y2

In [12]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None):
    mol_prediction = model(feature=f, d=0)
    pred = mol_prediction.detach()
#     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
    eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
    eps = Variable(eps, requires_grad=True)
    # Predict on randomly perturbed image
    eps_p = model(feature=f, d=eps.cuda())
    eps_p_ = model(feature=f, d=-eps.cuda())
    p_aux = nn.Sigmoid()(eps_p/(pred+1e-6))
    p_aux_ = nn.Sigmoid()(eps_p_/(pred+1e-6))
#     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
    loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
    loss.backward(retain_graph=True)

    # Based on perturbed image, get direction of greatest error
    eps_adv = eps.grad#/10**-learning_rate
    optimizer_AFSE.zero_grad()
    # Use that direction as adversarial perturbation
    eps_adv_normed = normalize_perturbation(eps_adv)
    d_adv = lamda * eps_adv_normed.cuda()
    if output_lr:
        f_p, max_lr = model(feature=f, d=d_adv, output_lr=output_lr)
    f_p = model(feature=f, d=d_adv)
    f_p_ = model(feature=f, d=-d_adv)
    p = nn.Sigmoid()(f_p/(pred+1e-6))
    p_ = nn.Sigmoid()(f_p_/(pred+1e-6))
    vat_loss = loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    if output_lr:
        if output_plr:
            loss = loss_function(mol_prediction,y)
            loss.backward(retain_graph=True)
            optimizer_AFSE.zero_grad()
            punish_lr = torch.norm(torch.mean(eps.grad,0))
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [13]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25):
    weighted_CE = nn.CrossEntropyLoss(weight=alpha, reduce=False)
    focal_weight = (1-torch.max(pred * target, -1)[0])**gamma
    loss = focal_weight * weighted_CE(pred, torch.argmax(target,-1))
    loss = torch.mean(loss)
    return loss

def generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list):
    [a,b,c] = x_atom.shape
    [d,e,f,g] = bond_neighbor.shape
    ce_loss = nn.CrossEntropyLoss()
    one_hot_loss = 0
    run_times = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        atom_weights = 1-x_atom[i,:l,:16].sum(-2)/l
        ce_atom_loss = nn.CrossEntropyLoss(weight=atom_weights)
#         print(atom_weights[1], refer_atom_list[i,0,torch.argmax(x_atom[i,0,:16],-1)], torch.argmax(x_atom[i,0,:16],-1))
#         one_hot_loss += ce_atom_loss(refer_atom_list[i,:l,:16], torch.argmax(x_atom[i,:l,:16],-1))- \
#                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        one_hot_loss += py_sigmoid_focal_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16], alpha=atom_weights)- \
                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        run_times += 2
    total_loss = one_hot_loss/run_times
    return total_loss, 0, 0, 0


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(float)
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                  torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
                                                  mol_feature=mol_feature,activated_features=activated_features.detach())
        
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1, 
                                                                                       lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1)) # 10**-learning_rate     
        regression_loss = loss_function(mol_prediction, torch.Tensor(y_val).view(-1,1))
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
        success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=1)
        reconstruction_loss, one_hot_loss, interger_loss,binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, 
                                                                                              bond_neighbor, validity_mask, atom_list, 
                                                                                              bond_list)
        x_atom_test = torch.Tensor(x_atom_test)
        x_bonds_test = torch.Tensor(x_bonds_test)
        x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
        bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
        bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate)
        atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
                                                torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
        refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                            torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                            mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
        success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
                            bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
                                                     refer_atom_list_test, refer_bond_list_test,topn=1)
        test_reconstruction_loss, test_one_hot_loss, test_interger_loss,test_binary_loss = generate_loss_function(atom_list_test, x_atom_test, bond_list_test, bond_neighbor_test, validity_mask_test, atom_list_test, bond_list_test)
        
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        max_lr = 1e-3
        conv_lr = conv_lr - conv_lr**2 + 0.06 * punish_lr
        if conv_lr < max_lr and conv_lr >= 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = conv_lr.detach()
                AFSE_lr = conv_lr    
        elif conv_lr < 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = 0
                AFSE_lr = 0
        elif conv_lr >= max_lr:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = max_lr
                AFSE_lr = max_lr
        
        logger.add_scalar('loss/regression', regression_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
        logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_one_hot', one_hot_loss, global_step)
        logger.add_scalar('loss/GRN_interger', interger_loss, global_step)
        logger.add_scalar('loss/GRN_binary', binary_loss, global_step)
        logger.add_scalar('lr/max_lr', conv_lr, global_step)
        logger.add_scalar('lr/punish_lr', punish_lr, global_step)
        logger.add_scalar('lr/AFSE_lr', AFSE_lr, global_step)
    
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  regression_loss + 0.6 * (vat_loss + test_vat_loss) + 0.3 * (reconstruction_loss + test_reconstruction_loss)
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, topn=1, output_feature=False, generate=False, modify_atom=True,return_GRN_loss=False, viz=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(float)
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda)
#         print(mol_feature,d_adv)
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
                                      torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
        if generate:
            if modify_atom:
                success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=topn,viz=viz)
            else:
                modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
            generated_smiles.extend(modified_smiles)
            success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
            success_reconstruction += reconstruction
            success_validity += validity
            reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
        
        MSE = F.mse_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction.cpu().detach().numpy())
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
        test_MSE_list.extend(MSE.data.view(-1,1).cpu().numpy())
#     print(r2_list)
    if generate:
        generated_num = len(generated_smiles)
        eval_num = len(dataset)
        unique = generated_num
        novelty = generated_num
        for i in range(generated_num):
            for j in range(generated_num-i-1):
                if generated_smiles[i]==generated_smiles[i+j+1]:
                    unique -= 1
            for k in range(eval_num):
                if generated_smiles[i]==dataset['smiles'].values[k]:
                    novelty -= 1
        unique_rate = unique/(generated_num+1e-9)
        novelty_rate = novelty/(generated_num+1e-9)
#         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
        return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    if return_GRN_loss:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    return caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 100
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [14]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
# if os.path.isdir(log_dir):
#     for files in os.listdir(log_dir):
#         os.remove(log_dir+"/"+files)
#     os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = 149
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
#     print(train_df.shape,test_df.shape)
    train_d, train_f, train_r2, train_MSE, train_predict, reconstruction_loss, one_hot_loss, interger_loss,binary_loss = eval(model, amodel, gmodel, train_df,output_feature=True,return_GRN_loss=True)
    train_predict = np.array(train_predict)
    train_WTI = weighted_top_index(train_df, train_predict, len(train_df))
    train_tau, _ = scipy.stats.kendalltau(train_predict,train_df[tasks[0]].values.astype(float).tolist())
    val_d, val_f, val_r2, val_MSE, val_predict, val_reconstruction_loss, val_one_hot_loss, val_interger_loss,val_binary_loss = eval(model, amodel, gmodel, val_df,output_feature=True,return_GRN_loss=True)
    val_predict = np.array(val_predict)
    val_WTI = weighted_top_index(val_df, val_predict, len(val_df))
    val_AP = AP(val_df, val_predict, len(val_df))
    val_tau, _ = scipy.stats.kendalltau(val_predict,val_df[tasks[0]].values.astype(float).tolist())
    
    test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
    test_d, test_f, test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df,output_feature=True)
    test_predict = np.array(test_predict)
    test_WTI = weighted_top_index(test_df, test_predict, test_active)
#     test_AP = AP(test_df, test_predict, test_active)
    test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())
    
    k_list = [int(len(test_df)*0.01),int(len(test_df)*0.03),int(len(test_df)*0.1),10,30,100]
    topk_list =[]
    false_positive_rate_list = []
    for k in k_list:
        a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
        topk_list.append(a)
        false_positive_rate_list.append(b)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    logger.add_scalar('train/r2', train_r2, global_step)
    logger.add_scalar('train/RMSE', train_MSE**0.5, global_step)
    logger.add_scalar('train/Tau', train_tau, global_step)
    logger.add_scalar('val/WTI', val_WTI, global_step)
    logger.add_scalar('val/AP', val_AP, global_step)
    logger.add_scalar('val/r2', val_r2, global_step)
    logger.add_scalar('val/RMSE', val_MSE**0.5, global_step)
    logger.add_scalar('val/Tau', val_tau, global_step)
#     logger.add_scalar('test/TAP', test_AP, global_step)
    logger.add_scalar('test/r2', test_r2_a, global_step)
    logger.add_scalar('test/RMSE', test_MSE_a**0.5, global_step)
    logger.add_scalar('test/Tau', test_tau, global_step)
    logger.add_scalar('val/GRN', reconstruction_loss, global_step)
    logger.add_scalar('val/GRN_one_hot', one_hot_loss, global_step)
    logger.add_scalar('val/GRN_interger', interger_loss, global_step)
    logger.add_scalar('val/GRN_binary', binary_loss, global_step)
    logger.add_scalar('test/EF0.01', topk_list[0], global_step)
    logger.add_scalar('test/EF0.03', topk_list[1], global_step)
    logger.add_scalar('test/EF0.1', topk_list[2], global_step)
    logger.add_scalar('test/EF10', topk_list[3], global_step)
    logger.add_scalar('test/EF30', topk_list[4], global_step)
    logger.add_scalar('test/EF100', topk_list[5], global_step)
    
#     train_mse_list.append(train_MSE**0.5)
#     train_r2_list.append(train_r2)
#     val_mse_list.append(val_MSE**0.5)  
#     val_r2_list.append(val_r2)
#     train_f_list.append(train_f)
#     val_f_list.append(val_f)
#     test_f_list.append(test_f)
#     epoch_list.append(epoch)
#     train_predict_list.append(train_predict.flatten())
#     test_predict_list.append(test_predict.flatten())
#     val_predict_list.append(val_predict.flatten())
#     train_y_list.append(train_df[tasks[0]].values)
#     val_y_list.append(val_df[tasks[0]].values)
#     test_y_list.append(test_df[tasks[0]].values)
#     train_d_list.append(train_d)
#     val_d_list.append(val_d)
#     test_d_list.append(test_d)

    stop_index = - val_MSE**0.5 + val_tau
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
#     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
    print('Epoch:',epoch, 'Step:', global_step, 'Index:%.4f'%stop_index, 'R2:%.4f'%train_r2,'%.4f'%val_r2,'%.4f'%test_r2_a, 'RMSE:%.4f'%train_MSE**0.5, '%.4f'%val_MSE**0.5, 
          '%.4f'%test_MSE_a**0.5, 'Tau:%.4f'%train_tau,'%.4f'%val_tau,'%.4f'%test_tau)#, 'Tau:%.4f'%val_tau,'%.4f'%test_tau,'GRN:%.4f'%reconstruction_loss,'%.4f'%val_reconstruction_loss
    if early_stop:
        continue


log/3_GAFSE_Ki_P14416_0.3333333333333333_350_run_0




Epoch: 150 Step: 24900 Index:-0.2537 R2:0.6925 0.4357 0.3771 RMSE:0.5623 0.7383 0.7883 Tau:0.6355 0.4847 0.2985
EarlyStopping counter: 1 out of 100
Epoch: 151 Step: 25066 Index:-0.2742 R2:0.7001 0.4288 0.3682 RMSE:0.5572 0.7551 0.8014 Tau:0.6414 0.4808 0.2981
EarlyStopping counter: 2 out of 100
Epoch: 152 Step: 25232 Index:-0.2588 R2:0.6943 0.4380 0.3767 RMSE:0.5495 0.7430 0.7881 Tau:0.6372 0.4842 0.3032
EarlyStopping counter: 3 out of 100
Epoch: 153 Step: 25398 Index:-0.2815 R2:0.6850 0.4176 0.3915 RMSE:0.5670 0.7560 0.7719 Tau:0.6354 0.4745 0.3222
EarlyStopping counter: 4 out of 100
Epoch: 154 Step: 25564 Index:-0.2666 R2:0.6889 0.4491 0.3922 RMSE:0.5978 0.7572 0.8017 Tau:0.6400 0.4907 0.3086
Epoch: 155 Step: 25730 Index:-0.2487 R2:0.7083 0.4287 0.3871 RMSE:0.5505 0.7283 0.7700 Tau:0.6478 0.4796 0.2949
EarlyStopping counter: 1 out of 100
Epoch: 156 Step: 25896 Index:-0.2602 R2:0.7005 0.4198 0.3809 RMSE:0.5475 0.7415 0.7771 Tau:0.6431 0.4813 0.3002
EarlyStopping counter: 2 out of 100


EarlyStopping counter: 42 out of 100
Epoch: 206 Step: 34196 Index:-0.3174 R2:0.7603 0.4406 0.4090 RMSE:0.5753 0.8026 0.8189 Tau:0.6864 0.4852 0.3151
EarlyStopping counter: 43 out of 100
Epoch: 207 Step: 34362 Index:-0.2916 R2:0.7566 0.4311 0.4001 RMSE:0.5424 0.7760 0.7987 Tau:0.6855 0.4844 0.3176
EarlyStopping counter: 44 out of 100
Epoch: 208 Step: 34528 Index:-0.2762 R2:0.7628 0.4351 0.3957 RMSE:0.5101 0.7615 0.7890 Tau:0.6885 0.4853 0.3186
EarlyStopping counter: 45 out of 100
Epoch: 209 Step: 34694 Index:-0.2745 R2:0.7555 0.4290 0.3890 RMSE:0.4946 0.7483 0.7816 Tau:0.6826 0.4737 0.3170
EarlyStopping counter: 46 out of 100
Epoch: 210 Step: 34860 Index:-0.2923 R2:0.7510 0.4093 0.3971 RMSE:0.5005 0.7612 0.7773 Tau:0.6779 0.4689 0.3112
EarlyStopping counter: 47 out of 100
Epoch: 211 Step: 35026 Index:-0.2482 R2:0.7604 0.4474 0.4095 RMSE:0.4897 0.7400 0.7689 Tau:0.6859 0.4919 0.3221
EarlyStopping counter: 48 out of 100
Epoch: 212 Step: 35192 Index:-0.2590 R2:0.7675 0.4356 0.4045 RMSE:0.4

EarlyStopping counter: 33 out of 100
Epoch: 262 Step: 43492 Index:-0.2750 R2:0.8036 0.4462 0.4019 RMSE:0.4561 0.7648 0.7969 Tau:0.7196 0.4898 0.3223
EarlyStopping counter: 34 out of 100
Epoch: 263 Step: 43658 Index:-0.2917 R2:0.8006 0.4321 0.3805 RMSE:0.4429 0.7646 0.8029 Tau:0.7155 0.4729 0.3192
EarlyStopping counter: 35 out of 100
Epoch: 264 Step: 43824 Index:-0.2748 R2:0.7919 0.4314 0.3950 RMSE:0.4561 0.7496 0.7796 Tau:0.7101 0.4748 0.3268
EarlyStopping counter: 36 out of 100
Epoch: 265 Step: 43990 Index:-0.2566 R2:0.8124 0.4387 0.4039 RMSE:0.4450 0.7410 0.7731 Tau:0.7273 0.4845 0.3183
EarlyStopping counter: 37 out of 100
Epoch: 266 Step: 44156 Index:-0.2593 R2:0.8067 0.4525 0.3837 RMSE:0.4379 0.7479 0.7956 Tau:0.7213 0.4886 0.3251
EarlyStopping counter: 38 out of 100
Epoch: 267 Step: 44322 Index:-0.2785 R2:0.8111 0.4354 0.3965 RMSE:0.4337 0.7581 0.7854 Tau:0.7271 0.4796 0.3221
EarlyStopping counter: 39 out of 100
Epoch: 268 Step: 44488 Index:-0.2848 R2:0.8061 0.4340 0.3982 RMSE:0.4

EarlyStopping counter: 88 out of 100
Epoch: 317 Step: 52622 Index:-0.2786 R2:0.8433 0.4285 0.3996 RMSE:0.3981 0.7611 0.7861 Tau:0.7501 0.4825 0.3253
EarlyStopping counter: 89 out of 100
Epoch: 318 Step: 52788 Index:-0.3272 R2:0.8422 0.4171 0.3894 RMSE:0.4038 0.7967 0.8138 Tau:0.7492 0.4695 0.3240
EarlyStopping counter: 90 out of 100
Epoch: 319 Step: 52954 Index:-0.2974 R2:0.8407 0.4227 0.3936 RMSE:0.3973 0.7707 0.7958 Tau:0.7477 0.4733 0.3239
EarlyStopping counter: 91 out of 100
Epoch: 320 Step: 53120 Index:-0.2855 R2:0.8333 0.4347 0.3914 RMSE:0.4123 0.7650 0.7992 Tau:0.7421 0.4795 0.3233
EarlyStopping counter: 92 out of 100
Epoch: 321 Step: 53286 Index:-0.3038 R2:0.8437 0.4170 0.3747 RMSE:0.3983 0.7739 0.8103 Tau:0.7497 0.4701 0.3252
EarlyStopping counter: 93 out of 100
Epoch: 322 Step: 53452 Index:-0.3301 R2:0.8476 0.4132 0.3824 RMSE:0.4105 0.7991 0.8256 Tau:0.7532 0.4690 0.3262
EarlyStopping counter: 94 out of 100
Epoch: 323 Step: 53618 Index:-0.2686 R2:0.8419 0.4468 0.3753 RMSE:0.3

EarlyStopping counter: 143 out of 100
Epoch: 372 Step: 61752 Index:-0.3046 R2:0.8732 0.4218 0.3906 RMSE:0.3557 0.7767 0.8023 Tau:0.7747 0.4721 0.3317
EarlyStopping counter: 144 out of 100
Epoch: 373 Step: 61918 Index:-0.3330 R2:0.8694 0.4145 0.3901 RMSE:0.3602 0.7967 0.8171 Tau:0.7719 0.4637 0.3328
EarlyStopping counter: 145 out of 100
Epoch: 374 Step: 62084 Index:-0.3034 R2:0.8757 0.4186 0.3803 RMSE:0.3619 0.7727 0.8117 Tau:0.7761 0.4693 0.3227
EarlyStopping counter: 146 out of 100
Epoch: 375 Step: 62250 Index:-0.3027 R2:0.8744 0.4215 0.3786 RMSE:0.3657 0.7759 0.8131 Tau:0.7779 0.4732 0.3247
EarlyStopping counter: 147 out of 100
Epoch: 376 Step: 62416 Index:-0.3230 R2:0.8470 0.4175 0.3662 RMSE:0.3922 0.7872 0.8184 Tau:0.7568 0.4642 0.3302
EarlyStopping counter: 148 out of 100
Epoch: 377 Step: 62582 Index:-0.3109 R2:0.8765 0.4236 0.3774 RMSE:0.3685 0.7812 0.8194 Tau:0.7803 0.4703 0.3272
EarlyStopping counter: 149 out of 100
Epoch: 378 Step: 62748 Index:-0.3292 R2:0.8640 0.4138 0.4087 R

EarlyStopping counter: 198 out of 100
Epoch: 427 Step: 70882 Index:-0.3577 R2:0.8937 0.3961 0.3799 RMSE:0.3244 0.8183 0.8294 Tau:0.7994 0.4605 0.3367
EarlyStopping counter: 199 out of 100
Epoch: 428 Step: 71048 Index:-0.3464 R2:0.9005 0.3997 0.3729 RMSE:0.3339 0.8144 0.8405 Tau:0.8071 0.4679 0.3391
EarlyStopping counter: 200 out of 100
Epoch: 429 Step: 71214 Index:-0.3619 R2:0.8979 0.4058 0.3766 RMSE:0.3557 0.8366 0.8577 Tau:0.8003 0.4748 0.3388
EarlyStopping counter: 201 out of 100
Epoch: 430 Step: 71380 Index:-0.3484 R2:0.8963 0.4061 0.3836 RMSE:0.3234 0.8169 0.8330 Tau:0.7978 0.4685 0.3451
EarlyStopping counter: 202 out of 100
Epoch: 431 Step: 71546 Index:-0.3462 R2:0.8958 0.3985 0.3576 RMSE:0.3266 0.8135 0.8556 Tau:0.8005 0.4673 0.3389
EarlyStopping counter: 203 out of 100
Epoch: 432 Step: 71712 Index:-0.3396 R2:0.9002 0.4107 0.3736 RMSE:0.3244 0.8094 0.8364 Tau:0.8036 0.4698 0.3358
EarlyStopping counter: 204 out of 100
Epoch: 433 Step: 71878 Index:-0.3247 R2:0.8979 0.4334 0.4002 R

EarlyStopping counter: 253 out of 100
Epoch: 482 Step: 80012 Index:-0.3701 R2:0.9066 0.3966 0.3805 RMSE:0.3202 0.8342 0.8516 Tau:0.8135 0.4641 0.3445
EarlyStopping counter: 254 out of 100
Epoch: 483 Step: 80178 Index:-0.3755 R2:0.9126 0.3979 0.3643 RMSE:0.2941 0.8457 0.8641 Tau:0.8175 0.4702 0.3400
EarlyStopping counter: 255 out of 100
Epoch: 484 Step: 80344 Index:-0.3492 R2:0.9130 0.4061 0.3611 RMSE:0.2949 0.8213 0.8580 Tau:0.8184 0.4721 0.3462
EarlyStopping counter: 256 out of 100
Epoch: 485 Step: 80510 Index:-0.3600 R2:0.9159 0.4056 0.3700 RMSE:0.3030 0.8328 0.8643 Tau:0.8219 0.4729 0.3414
EarlyStopping counter: 257 out of 100
Epoch: 486 Step: 80676 Index:-0.3900 R2:0.9173 0.3874 0.3702 RMSE:0.2881 0.8506 0.8657 Tau:0.8224 0.4606 0.3447
EarlyStopping counter: 258 out of 100
Epoch: 487 Step: 80842 Index:-0.3522 R2:0.9190 0.3980 0.3722 RMSE:0.2849 0.8279 0.8521 Tau:0.8240 0.4758 0.3432
EarlyStopping counter: 259 out of 100
Epoch: 488 Step: 81008 Index:-0.4114 R2:0.9189 0.3637 0.3594 R

EarlyStopping counter: 308 out of 100
Epoch: 537 Step: 89142 Index:-0.4349 R2:0.9259 0.3535 0.3434 RMSE:0.2832 0.9014 0.8931 Tau:0.8352 0.4665 0.3452
EarlyStopping counter: 309 out of 100
Epoch: 538 Step: 89308 Index:-0.4152 R2:0.9275 0.3632 0.3406 RMSE:0.2698 0.8784 0.8921 Tau:0.8354 0.4632 0.3460
EarlyStopping counter: 310 out of 100
Epoch: 539 Step: 89474 Index:-0.3777 R2:0.9141 0.3811 0.3327 RMSE:0.2926 0.8494 0.8987 Tau:0.8174 0.4717 0.3433
EarlyStopping counter: 311 out of 100
Epoch: 540 Step: 89640 Index:-0.4372 R2:0.9315 0.3639 0.3477 RMSE:0.2613 0.8926 0.9090 Tau:0.8428 0.4554 0.3492
EarlyStopping counter: 312 out of 100
Epoch: 541 Step: 89806 Index:-0.3869 R2:0.9077 0.4013 0.3588 RMSE:0.3071 0.8577 0.8791 Tau:0.8148 0.4708 0.3458
EarlyStopping counter: 313 out of 100
Epoch: 542 Step: 89972 Index:-0.4337 R2:0.9063 0.3862 0.3479 RMSE:0.3866 0.8983 0.9300 Tau:0.8110 0.4647 0.3417
EarlyStopping counter: 314 out of 100
Epoch: 543 Step: 90138 Index:-0.4097 R2:0.9325 0.3794 0.3509 R

EarlyStopping counter: 363 out of 100
Epoch: 592 Step: 98272 Index:-0.4338 R2:0.9391 0.3737 0.3593 RMSE:0.2555 0.9003 0.9029 Tau:0.8520 0.4665 0.3553
EarlyStopping counter: 364 out of 100
Epoch: 593 Step: 98438 Index:-0.4958 R2:0.9395 0.3268 0.3454 RMSE:0.2465 0.9603 0.9056 Tau:0.8520 0.4645 0.3532
EarlyStopping counter: 365 out of 100
Epoch: 594 Step: 98604 Index:-0.4964 R2:0.9389 0.3279 0.3371 RMSE:0.2462 0.9557 0.9228 Tau:0.8516 0.4593 0.3490
EarlyStopping counter: 366 out of 100
Epoch: 595 Step: 98770 Index:-0.4685 R2:0.9403 0.3593 0.3577 RMSE:0.2516 0.9300 0.9141 Tau:0.8503 0.4615 0.3497
EarlyStopping counter: 367 out of 100
Epoch: 596 Step: 98936 Index:-0.4984 R2:0.9183 0.3362 0.3188 RMSE:0.2855 0.9409 0.9424 Tau:0.8266 0.4425 0.3474
EarlyStopping counter: 368 out of 100
Epoch: 597 Step: 99102 Index:-0.4436 R2:0.9282 0.3388 0.3437 RMSE:0.3063 0.9012 0.8776 Tau:0.8373 0.4576 0.3479
EarlyStopping counter: 369 out of 100
Epoch: 598 Step: 99268 Index:-0.5135 R2:0.9436 0.3219 0.3406 R

EarlyStopping counter: 418 out of 100
Epoch: 647 Step: 107402 Index:-0.6043 R2:0.9417 0.2871 0.3190 RMSE:0.2573 1.0552 0.9576 Tau:0.8597 0.4509 0.3595
EarlyStopping counter: 419 out of 100
Epoch: 648 Step: 107568 Index:-0.4965 R2:0.9443 0.3324 0.3325 RMSE:0.2374 0.9713 0.9400 Tau:0.8625 0.4748 0.3588
EarlyStopping counter: 420 out of 100
Epoch: 649 Step: 107734 Index:-0.6240 R2:0.9399 0.2864 0.3289 RMSE:0.2665 1.0842 0.9876 Tau:0.8515 0.4602 0.3577
EarlyStopping counter: 421 out of 100
Epoch: 650 Step: 107900 Index:-0.5592 R2:0.9462 0.3097 0.3172 RMSE:0.2413 1.0316 0.9778 Tau:0.8633 0.4724 0.3595
EarlyStopping counter: 422 out of 100
Epoch: 651 Step: 108066 Index:-0.6068 R2:0.9385 0.2879 0.3116 RMSE:0.2470 1.0567 0.9610 Tau:0.8482 0.4499 0.3591
EarlyStopping counter: 423 out of 100
Epoch: 652 Step: 108232 Index:-0.4986 R2:0.9411 0.3398 0.3283 RMSE:0.2538 0.9466 0.9465 Tau:0.8542 0.4480 0.3599
EarlyStopping counter: 424 out of 100
Epoch: 653 Step: 108398 Index:-0.6161 R2:0.9478 0.2878 0

EarlyStopping counter: 473 out of 100
Epoch: 702 Step: 116532 Index:-0.5602 R2:0.9451 0.3147 0.3036 RMSE:0.2350 1.0002 0.9789 Tau:0.8636 0.4400 0.3629
EarlyStopping counter: 474 out of 100
Epoch: 703 Step: 116698 Index:-0.6895 R2:0.9548 0.2874 0.3051 RMSE:0.2793 1.1400 1.0596 Tau:0.8752 0.4505 0.3620
EarlyStopping counter: 475 out of 100
Epoch: 704 Step: 116864 Index:-0.6666 R2:0.9516 0.2762 0.3152 RMSE:0.2196 1.1138 0.9961 Tau:0.8705 0.4472 0.3631
EarlyStopping counter: 476 out of 100
Epoch: 705 Step: 117030 Index:-1.0108 R2:0.8023 0.1560 0.2407 RMSE:0.4552 1.4234 1.1220 Tau:0.7252 0.4126 0.3242
EarlyStopping counter: 477 out of 100
Epoch: 706 Step: 117196 Index:-0.4041 R2:0.9215 0.3997 0.3525 RMSE:0.3099 0.8939 0.9361 Tau:0.8290 0.4899 0.3625
EarlyStopping counter: 478 out of 100
Epoch: 707 Step: 117362 Index:-0.4393 R2:0.9452 0.3843 0.3509 RMSE:0.2402 0.9228 0.9385 Tau:0.8601 0.4835 0.3640
EarlyStopping counter: 479 out of 100
Epoch: 708 Step: 117528 Index:-0.4421 R2:0.9493 0.3874 0

EarlyStopping counter: 528 out of 100
Epoch: 757 Step: 125662 Index:-0.5933 R2:0.9604 0.3164 0.2903 RMSE:0.1991 1.0504 1.0444 Tau:0.8861 0.4571 0.3686
EarlyStopping counter: 529 out of 100
Epoch: 758 Step: 125828 Index:-0.6123 R2:0.9556 0.3005 0.2844 RMSE:0.2527 1.0616 1.0585 Tau:0.8782 0.4493 0.3671
EarlyStopping counter: 530 out of 100
Epoch: 759 Step: 125994 Index:-0.5847 R2:0.9410 0.2891 0.3169 RMSE:0.2451 1.0314 0.9849 Tau:0.8529 0.4467 0.3622
EarlyStopping counter: 531 out of 100
Epoch: 760 Step: 126160 Index:-0.6100 R2:0.9519 0.3016 0.2988 RMSE:0.2275 1.0540 1.0285 Tau:0.8701 0.4440 0.3672
EarlyStopping counter: 532 out of 100
Epoch: 761 Step: 126326 Index:-0.6514 R2:0.9621 0.2767 0.3019 RMSE:0.1941 1.0968 1.0171 Tau:0.8876 0.4454 0.3669
EarlyStopping counter: 533 out of 100
Epoch: 762 Step: 126492 Index:-0.6779 R2:0.9540 0.2781 0.2854 RMSE:0.2226 1.1180 1.0703 Tau:0.8755 0.4401 0.3727
EarlyStopping counter: 534 out of 100
Epoch: 763 Step: 126658 Index:-0.5746 R2:0.9554 0.3289 0

EarlyStopping counter: 583 out of 100
Epoch: 812 Step: 134792 Index:-0.7273 R2:0.9671 0.2496 0.2818 RMSE:0.1811 1.1819 1.0747 Tau:0.8974 0.4545 0.3701
EarlyStopping counter: 584 out of 100
Epoch: 813 Step: 134958 Index:-0.6776 R2:0.9521 0.2776 0.2691 RMSE:0.2171 1.1269 1.0953 Tau:0.8716 0.4493 0.3673
EarlyStopping counter: 585 out of 100
Epoch: 814 Step: 135124 Index:-0.8492 R2:0.9618 0.2102 0.2801 RMSE:0.2117 1.2938 1.1017 Tau:0.8864 0.4446 0.3734
EarlyStopping counter: 586 out of 100
Epoch: 815 Step: 135290 Index:-0.6185 R2:0.9521 0.3057 0.2742 RMSE:0.2496 1.0731 1.0844 Tau:0.8720 0.4546 0.3702
EarlyStopping counter: 587 out of 100
Epoch: 816 Step: 135456 Index:-0.7890 R2:0.9629 0.2342 0.2559 RMSE:0.1939 1.2376 1.1183 Tau:0.8896 0.4487 0.3690
EarlyStopping counter: 588 out of 100
Epoch: 817 Step: 135622 Index:-0.7998 R2:0.9644 0.2446 0.2895 RMSE:0.2348 1.2435 1.1076 Tau:0.8897 0.4437 0.3743
EarlyStopping counter: 589 out of 100
Epoch: 818 Step: 135788 Index:-0.8264 R2:0.9623 0.2290 0

EarlyStopping counter: 638 out of 100
Epoch: 867 Step: 143922 Index:-0.7076 R2:0.9404 0.2697 0.2509 RMSE:0.2610 1.1578 1.1580 Tau:0.8667 0.4502 0.3694
EarlyStopping counter: 639 out of 100
Epoch: 868 Step: 144088 Index:-0.7220 R2:0.9572 0.2614 0.2925 RMSE:0.2124 1.1783 1.0656 Tau:0.8835 0.4562 0.3760
EarlyStopping counter: 640 out of 100
Epoch: 869 Step: 144254 Index:-0.7040 R2:0.9659 0.2731 0.2651 RMSE:0.1865 1.1601 1.1345 Tau:0.8990 0.4561 0.3706
EarlyStopping counter: 641 out of 100
Epoch: 870 Step: 144420 Index:-0.8581 R2:0.9683 0.2248 0.2626 RMSE:0.1952 1.3048 1.1633 Tau:0.9012 0.4467 0.3759
EarlyStopping counter: 642 out of 100
Epoch: 871 Step: 144586 Index:-0.8009 R2:0.9625 0.2430 0.2529 RMSE:0.1945 1.2465 1.1541 Tau:0.8910 0.4457 0.3786
EarlyStopping counter: 643 out of 100
Epoch: 872 Step: 144752 Index:-0.7520 R2:0.9566 0.2643 0.2550 RMSE:0.2068 1.2008 1.1507 Tau:0.8788 0.4488 0.3788
EarlyStopping counter: 644 out of 100
Epoch: 873 Step: 144918 Index:-0.8350 R2:0.9687 0.2340 0

EarlyStopping counter: 694 out of 100
Epoch: 923 Step: 153218 Index:-0.8333 R2:0.9665 0.2374 0.2477 RMSE:0.1886 1.2834 1.2008 Tau:0.8997 0.4502 0.3783
EarlyStopping counter: 695 out of 100
Epoch: 924 Step: 153384 Index:-0.8411 R2:0.9662 0.2384 0.2435 RMSE:0.1824 1.2940 1.2082 Tau:0.8972 0.4529 0.3759
EarlyStopping counter: 696 out of 100
Epoch: 925 Step: 153550 Index:-0.9218 R2:0.9701 0.2079 0.2377 RMSE:0.1717 1.3716 1.2216 Tau:0.9052 0.4498 0.3748
EarlyStopping counter: 697 out of 100
Epoch: 926 Step: 153716 Index:-0.9202 R2:0.9699 0.2200 0.2640 RMSE:0.1918 1.3727 1.1995 Tau:0.9033 0.4525 0.3797
EarlyStopping counter: 698 out of 100
Epoch: 927 Step: 153882 Index:-1.2076 R2:0.9603 0.1441 0.2369 RMSE:0.2006 1.6576 1.2324 Tau:0.8933 0.4500 0.3780
EarlyStopping counter: 699 out of 100
Epoch: 928 Step: 154048 Index:-0.8522 R2:0.9630 0.2260 0.2702 RMSE:0.1939 1.2968 1.1428 Tau:0.8879 0.4446 0.3803
EarlyStopping counter: 700 out of 100
Epoch: 929 Step: 154214 Index:-0.8163 R2:0.9553 0.2302 0

EarlyStopping counter: 749 out of 100
Epoch: 978 Step: 162348 Index:-1.2504 R2:0.9670 0.1543 0.2137 RMSE:0.1882 1.6910 1.3667 Tau:0.8999 0.4406 0.3809
EarlyStopping counter: 750 out of 100
Epoch: 979 Step: 162514 Index:-1.1197 R2:0.9661 0.1780 0.1855 RMSE:0.2210 1.5597 1.4303 Tau:0.8949 0.4400 0.3780
EarlyStopping counter: 751 out of 100
Epoch: 980 Step: 162680 Index:-1.1131 R2:0.9701 0.1706 0.2093 RMSE:0.2195 1.5544 1.3505 Tau:0.9028 0.4413 0.3858
EarlyStopping counter: 752 out of 100
Epoch: 981 Step: 162846 Index:-0.9129 R2:0.9683 0.2300 0.2002 RMSE:0.2058 1.3600 1.3575 Tau:0.9048 0.4471 0.3829
EarlyStopping counter: 753 out of 100
Epoch: 982 Step: 163012 Index:-1.1178 R2:0.9738 0.1725 0.2073 RMSE:0.1772 1.5601 1.3677 Tau:0.9108 0.4423 0.3850
EarlyStopping counter: 754 out of 100
Epoch: 983 Step: 163178 Index:-1.0304 R2:0.9727 0.2007 0.2162 RMSE:0.1903 1.4736 1.3507 Tau:0.9088 0.4432 0.3832
EarlyStopping counter: 755 out of 100
Epoch: 984 Step: 163344 Index:-0.9817 R2:0.9743 0.2206 0

In [15]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df)
test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
test_r2_ina, test_MSE_ina, test_predict_ina = eval(model, amodel, gmodel, test_df[test_active:].reset_index(drop=True))
    
test_predict = np.array(test_predict)
test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())

k_list = [int(len(test_df)*0.01),int(len(test_df)*0.05),int(len(test_df)*0.1),int(len(test_df)*0.15),int(len(test_df)*0.2),int(len(test_df)*0.25),
          int(len(test_df)*0.3),int(len(test_df)*0.4),int(len(test_df)*0.5),50,100,150,200,250,300]
topk_list =[]
false_positive_rate_list = []
for k in k_list:
    a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
    topk_list.append(a)
    false_positive_rate_list.append(b)
WTI = weighted_top_index(test_df, test_predict, test_active)
ap = AP(test_df, test_predict, test_active)


In [16]:
print(' epoch:',epoch,'r2:%.4f'%test_r2_a,'RMSE:%.4f'%test_MSE_a**0.5,'WTI:%.4f'%WTI,'AP:%.4f'%ap,'Tau:%.4f'%test_tau,'\n','\n',
      'Top-1:%.4f'%topk_list[0],'Top-1-fp:%.4f'%false_positive_rate_list[0],'\n',
      'Top-5:%.4f'%topk_list[1],'Top-5-fp:%.4f'%false_positive_rate_list[1],'\n',
      'Top-10:%.4f'%topk_list[2],'Top-10-fp:%.4f'%false_positive_rate_list[2],'\n',
      'Top-15:%.4f'%topk_list[3],'Top-15-fp:%.4f'%false_positive_rate_list[3],'\n',
      'Top-20:%.4f'%topk_list[4],'Top-20-fp:%.4f'%false_positive_rate_list[4],'\n',
      'Top-25:%.4f'%topk_list[5],'Top-25-fp:%.4f'%false_positive_rate_list[5],'\n',
      'Top-30:%.4f'%topk_list[6],'Top-30-fp:%.4f'%false_positive_rate_list[6],'\n',
      'Top-40:%.4f'%topk_list[7],'Top-40-fp:%.4f'%false_positive_rate_list[7],'\n',
      'Top-50:%.4f'%topk_list[8],'Top-50-fp:%.4f'%false_positive_rate_list[8],'\n','\n',
      'Top50:%.4f'%topk_list[9],'Top50-fp:%.4f'%false_positive_rate_list[9],'\n',
      'Top100:%.4f'%topk_list[10],'Top100-fp:%.4f'%false_positive_rate_list[10],'\n',
      'Top150:%.4f'%topk_list[11],'Top150-fp:%.4f'%false_positive_rate_list[11],'\n',
      'Top200:%.4f'%topk_list[12],'Top200-fp:%.4f'%false_positive_rate_list[12],'\n',
      'Top250:%.4f'%topk_list[13],'Top250-fp:%.4f'%false_positive_rate_list[13],'\n',
      'Top300:%.4f'%topk_list[14],'Top300-fp:%.4f'%false_positive_rate_list[14],'\n')

 epoch: 1000 r2:0.3946 RMSE:0.7808 WTI:0.3163 AP:0.4335 Tau:0.3120 
 
 Top-1:0.2500 Top-1-fp:0.0625 
 Top-5:0.5663 Top-5-fp:0.3253 
 Top-10:0.4819 Top-10-fp:0.4639 
 Top-15:0.4699 Top-15-fp:0.4900 
 Top-20:0.4578 Top-20-fp:0.5331 
 Top-25:0.5143 Top-25-fp:0.5663 
 Top-30:0.5771 Top-30-fp:0.5944 
 Top-40:0.6829 Top-40-fp:0.6406 
 Top-50:0.7800 Top-50-fp:0.6715 
 
 Top50:0.5200 Top50-fp:0.2400 
 Top100:0.5500 Top100-fp:0.3500 
 Top150:0.5000 Top150-fp:0.4200 
 Top200:0.4750 Top200-fp:0.4750 
 Top250:0.4720 Top250-fp:0.4880 
 Top300:0.4667 Top300-fp:0.5100 



In [17]:
# print('target_file:',train_filename)
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)

In [18]:
# loss = loss_function(mol_prediction,y)
#             loss.backward(retain_graph=True)
#             optimizer_AFSE.zero_grad()
#             punish_lr = torch.norm(torch.mean(eps.grad,0))

#         init_lr = 1e-4
#         max_lr = 10**-(init_lr-1)
#         conv_lr = conv_lr - conv_lr**2 + 0.1 * punish_lr
#         if conv_lr < max_lr:
#             for param_group in optimizer_AFSE.param_groups:
#                 param_group["lr"] = conv_lr.detach()
#                 AFSE_lr = conv_lr    
#         else:
#             for param_group in optimizer_AFSE.param_groups:
#                 param_group["lr"] = max_lr
#                 AFSE_lr = max_lr
# epoch: 512 r2:0.5619 RMSE:0.7306 WTI:0.3784 AP:0.7228 Tau:0.5159 
 
#  Top-1:0.0000 Top-1-fp:0.0000 
#  Top-5:0.8571 Top-5-fp:0.0000 
#  Top-10:0.7857 Top-10-fp:0.0714 