In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
train_filename = "./data/benchmark/Ki_P21554_1_420_train.csv"
test_filename = "./data/benchmark/Ki_P21554_1_420_test.csv"
test_active = 420
val_rate = 0.2
random_seed = 68
file_list1 = train_filename.split('/')
file1 = file_list1[-1]
file1 = file1[:-10]
number = '_run_0'
model_file = "model_file/3_GAFSE_"+file1+number
log_dir = f'log/{"3_GAFSE_"+file1}'+number
result_dir = './result/3_GAFSE_'+file1+number
print(file1)
print(model_file)

Ki_P21554_1_420
model_file/3_GAFSE_Ki_P21554_1_420_run_0


In [4]:
# task_name = 'Malaria Bioactivity'
tasks = ['value']

# train_filename = "../data/active_inactive/median_active/EC50/Q99500.csv"
feature_filename = train_filename.replace('.csv','.pickle')
filename = train_filename.replace('.csv','')
prefix_filename = train_filename.split('/')[-1].replace('.csv','')
train_df = pd.read_csv(train_filename, header=0, names = ["smiles","value"],usecols=[0,1])
# train_df = train_df[1:]
# train_df = train_df.drop(0,axis=1,inplace=False) 
print(train_df[:5])
# print(train_df.iloc(1))
def add_canonical_smiles(train_df):
    smilesList = train_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    train_df = train_df[train_df["smiles"].isin(remained_smiles)]
    train_df['cano_smiles'] =canonical_smiles_list
    return train_df
# print(train_df)
train_df = add_canonical_smiles(train_df)

print(train_df.head())
# plt.figure(figsize=(5, 3))
# sns.set(font_scale=1.5)
# ax = sns.distplot(atom_num_dist, bins=28, kde=False)
# plt.tight_layout()
# # plt.savefig("atom_num_dist_"+prefix_filename+".png",dpi=200)
# plt.show()
# plt.close()


                                              smiles     value
0  CCCCCCC(C)(C)C1=CC(=CC(=C1)OCCCCCCCC(=O)NC(C)CO)O -1.232996
1                     CCCCCCCCC=CCCCCCCCC(=O)NC1CC1O -2.812913
2  CC1CCN(CC1)C(=O)C2=CC3=C(C=C2)N(C4=C3CN(CC4)C5... -1.748188
3  CC1=CC=C(N1C2=C(C(=NN2C3=C(C=C(C=C3)Cl)Cl)C(=O... -2.158362
4  CCCCCCC1(CC1(Br)Br)C2=CC3=C(C4CC(=CCC4C(O3)(C)...  0.149384
number of all smiles:  2046
number of successfully processed smiles:  2046
                                              smiles     value  \
0  CCCCCCC(C)(C)C1=CC(=CC(=C1)OCCCCCCCC(=O)NC(C)CO)O -1.232996   
1                     CCCCCCCCC=CCCCCCCCC(=O)NC1CC1O -2.812913   
2  CC1CCN(CC1)C(=O)C2=CC3=C(C=C2)N(C4=C3CN(CC4)C5... -1.748188   
3  CC1=CC=C(N1C2=C(C(=NN2C3=C(C=C(C=C3)Cl)Cl)C(=O... -2.158362   
4  CCCCCCC1(CC1(Br)Br)C2=CC3=C(C4CC(=CCC4C(O3)(C)...  0.149384   

                                         cano_smiles  
0     CCCCCCC(C)(C)c1cc(O)cc(OCCCCCCCC(=O)NC(C)CO)c1  
1                     CCCCCCCCC=CCCCCCCC

In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 2 # default: 2
T = 1
per_task_output_units_num = 1 # for regression model
output_units_num = len(tasks) * per_task_output_units_num

In [6]:
test_df = pd.read_csv(test_filename,header=0,names=["smiles","value"],usecols=[0,1])
test_df = add_canonical_smiles(test_df)
for l in test_df["cano_smiles"]:
    if l in train_df["cano_smiles"]:
        print("same smiles:",l)
        
print(test_df.shape)
print(test_df.head())

number of all smiles:  1342
number of successfully processed smiles:  1342
(1342, 3)
                                              smiles     value  \
0  C1CCC(C(C1)NC(=O)C2=CN=C(C(=C2)C3=C(C=C(C=C3)C... -2.086360   
1  CC1=C(C=C(C=C1)C(=O)NC2C(C3CCC2(C3)C)(C)C)S(=O... -2.113943   
2  CCCCCN1C=C(C(=O)C2=C1C=C(C=C2)SC3=CC=CC=C3)C(=... -1.110590   
3  CC1=C(N=C(N1C2=CC=C(C=C2)Cl)C3=C(C=C(C=C3)Cl)C... -2.232996   
4  CCCCCCC(C)(C)C1=CC(=C2C3C=C(CCC3C(OC2=C1)(C)C)... -3.127105   

                                         cano_smiles  
0  O=C(NC1CCCCC1O)c1cnc(OCC2CC2)c(-c2ccc(Cl)cc2Cl)c1  
1  Cc1ccc(C(=O)NC2C3(C)CCC(C3)C2(C)C)cc1S(=O)(=O)...  
2  CCCCCn1cc(C(=O)NOCc2ccccc2)c(=O)c2ccc(Sc3ccccc...  
3  Cc1c(C(=O)NCc2ccc(C(F)(F)F)cc2)nc(-c2ccc(Cl)cc...  
4  CCCCCCC(C)(C)c1cc(NC(C)=O)c2c(c1)OC(C)(C)C1CCC...  


In [7]:
print(feature_filename)
print(filename)
total_df = pd.concat([train_df,test_df],axis=0)
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
# if os.path.isfile(feature_filename):
#     feature_dicts = pickle.load(open(feature_filename, "rb" ))
# else:
#     feature_dicts = save_smiles_dicts(smilesList,filename)
feature_dicts = save_smiles_dicts(total_smilesList,filename)
remained_df = total_df[total_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = total_df.drop(remained_df.index)

./data/benchmark/Ki_P21554_1_420_train.pickle
./data/benchmark/Ki_P21554_1_420_train
3388
feature dicts file saved as ./data/benchmark/Ki_P21554_1_420_train.pickle


In [8]:
val_df = train_df.sample(frac=val_rate,random_state=random_seed)
train_df = train_df.drop(val_df.index)
train_df = train_df.reset_index(drop=True)
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)
print(train_df.shape,val_df.shape,test_df.shape)

(1637, 3) (409, 3) (1342, 3)


In [9]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [10]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [11]:
def caculate_r2(y,predict):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/y2

In [12]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None):
    mol_prediction = model(feature=f, d=0)
    pred = mol_prediction.detach()
#     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
    eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
    eps = Variable(eps, requires_grad=True)
    # Predict on randomly perturbed image
    eps_p = model(feature=f, d=eps.cuda())
    eps_p_ = model(feature=f, d=-eps.cuda())
    p_aux = nn.Sigmoid()(eps_p/(pred+1e-6))
    p_aux_ = nn.Sigmoid()(eps_p_/(pred+1e-6))
#     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
    loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
    loss.backward(retain_graph=True)

    # Based on perturbed image, get direction of greatest error
    eps_adv = eps.grad#/10**-learning_rate
    optimizer_AFSE.zero_grad()
    # Use that direction as adversarial perturbation
    eps_adv_normed = normalize_perturbation(eps_adv)
    d_adv = lamda * eps_adv_normed.cuda()
    if output_lr:
        f_p, max_lr = model(feature=f, d=d_adv, output_lr=output_lr)
    f_p = model(feature=f, d=d_adv)
    f_p_ = model(feature=f, d=-d_adv)
    p = nn.Sigmoid()(f_p/(pred+1e-6))
    p_ = nn.Sigmoid()(f_p_/(pred+1e-6))
    vat_loss = loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    if output_lr:
        if output_plr:
            loss = loss_function(mol_prediction,y)
            loss.backward(retain_graph=True)
            optimizer_AFSE.zero_grad()
            punish_lr = torch.norm(torch.mean(eps.grad,0))
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [13]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25):
    weighted_CE = nn.CrossEntropyLoss(weight=alpha, reduce=False)
    focal_weight = (1-torch.max(pred * target, -1)[0])**gamma
    loss = focal_weight * weighted_CE(pred, torch.argmax(target,-1))
    loss = torch.mean(loss)
    return loss

def generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list):
    [a,b,c] = x_atom.shape
    [d,e,f,g] = bond_neighbor.shape
    ce_loss = nn.CrossEntropyLoss()
    one_hot_loss = 0
    run_times = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        atom_weights = 1-x_atom[i,:l,:16].sum(-2)/l
        ce_atom_loss = nn.CrossEntropyLoss(weight=atom_weights)
#         print(atom_weights[1], refer_atom_list[i,0,torch.argmax(x_atom[i,0,:16],-1)], torch.argmax(x_atom[i,0,:16],-1))
#         one_hot_loss += ce_atom_loss(refer_atom_list[i,:l,:16], torch.argmax(x_atom[i,:l,:16],-1))- \
#                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        one_hot_loss += py_sigmoid_focal_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16], alpha=atom_weights)- \
                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        run_times += 2
    total_loss = one_hot_loss/run_times
    return total_loss, 0, 0, 0


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(float)
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                  torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
                                                  mol_feature=mol_feature,activated_features=activated_features.detach())
        
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1, 
                                                                                       lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1)) # 10**-learning_rate     
        regression_loss = loss_function(mol_prediction, torch.Tensor(y_val).view(-1,1))
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
        success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=1)
        reconstruction_loss, one_hot_loss, interger_loss,binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, 
                                                                                              bond_neighbor, validity_mask, atom_list, 
                                                                                              bond_list)
        x_atom_test = torch.Tensor(x_atom_test)
        x_bonds_test = torch.Tensor(x_bonds_test)
        x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
        bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
        bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate)
        atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
                                                torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
        refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                            torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                            mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
        success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
                            bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
                                                     refer_atom_list_test, refer_bond_list_test,topn=1)
        test_reconstruction_loss, test_one_hot_loss, test_interger_loss,test_binary_loss = generate_loss_function(atom_list_test, x_atom_test, bond_list_test, bond_neighbor_test, validity_mask_test, atom_list_test, bond_list_test)
        
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        max_lr = 1e-3
        conv_lr = conv_lr - conv_lr**2 + 0.06 * punish_lr
        if conv_lr < max_lr and conv_lr >= 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = conv_lr.detach()
                AFSE_lr = conv_lr    
        elif conv_lr < 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = 0
                AFSE_lr = 0
        elif conv_lr >= max_lr:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = max_lr
                AFSE_lr = max_lr
        
        logger.add_scalar('loss/regression', regression_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
        logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_one_hot', one_hot_loss, global_step)
        logger.add_scalar('loss/GRN_interger', interger_loss, global_step)
        logger.add_scalar('loss/GRN_binary', binary_loss, global_step)
        logger.add_scalar('lr/max_lr', conv_lr, global_step)
        logger.add_scalar('lr/punish_lr', punish_lr, global_step)
        logger.add_scalar('lr/AFSE_lr', AFSE_lr, global_step)
    
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  regression_loss + 0.6 * (vat_loss + test_vat_loss) + 0.3 * (reconstruction_loss + test_reconstruction_loss)
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, topn=1, output_feature=False, generate=False, modify_atom=True,return_GRN_loss=False, viz=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(float)
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda)
#         print(mol_feature,d_adv)
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
                                      torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
        if generate:
            if modify_atom:
                success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=topn,viz=viz)
            else:
                modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
            generated_smiles.extend(modified_smiles)
            success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
            success_reconstruction += reconstruction
            success_validity += validity
            reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
        
        MSE = F.mse_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction.cpu().detach().numpy())
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
        test_MSE_list.extend(MSE.data.view(-1,1).cpu().numpy())
#     print(r2_list)
    if generate:
        generated_num = len(generated_smiles)
        eval_num = len(dataset)
        unique = generated_num
        novelty = generated_num
        for i in range(generated_num):
            for j in range(generated_num-i-1):
                if generated_smiles[i]==generated_smiles[i+j+1]:
                    unique -= 1
            for k in range(eval_num):
                if generated_smiles[i]==dataset['smiles'].values[k]:
                    novelty -= 1
        unique_rate = unique/(generated_num+1e-9)
        novelty_rate = novelty/(generated_num+1e-9)
#         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
        return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    if return_GRN_loss:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    return caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 100
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [14]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
# if os.path.isdir(log_dir):
#     for files in os.listdir(log_dir):
#         os.remove(log_dir+"/"+files)
#     os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = 152
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
#     print(train_df.shape,test_df.shape)
    train_d, train_f, train_r2, train_MSE, train_predict, reconstruction_loss, one_hot_loss, interger_loss,binary_loss = eval(model, amodel, gmodel, train_df,output_feature=True,return_GRN_loss=True)
    train_predict = np.array(train_predict)
    train_WTI = weighted_top_index(train_df, train_predict, len(train_df))
    train_tau, _ = scipy.stats.kendalltau(train_predict,train_df[tasks[0]].values.astype(float).tolist())
    val_d, val_f, val_r2, val_MSE, val_predict, val_reconstruction_loss, val_one_hot_loss, val_interger_loss,val_binary_loss = eval(model, amodel, gmodel, val_df,output_feature=True,return_GRN_loss=True)
    val_predict = np.array(val_predict)
    val_WTI = weighted_top_index(val_df, val_predict, len(val_df))
    val_AP = AP(val_df, val_predict, len(val_df))
    val_tau, _ = scipy.stats.kendalltau(val_predict,val_df[tasks[0]].values.astype(float).tolist())
    
    test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
    test_d, test_f, test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df,output_feature=True)
    test_predict = np.array(test_predict)
    test_WTI = weighted_top_index(test_df, test_predict, test_active)
#     test_AP = AP(test_df, test_predict, test_active)
    test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())
    
    k_list = [int(len(test_df)*0.01),int(len(test_df)*0.03),int(len(test_df)*0.1),10,30,100]
    topk_list =[]
    false_positive_rate_list = []
    for k in k_list:
        a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
        topk_list.append(a)
        false_positive_rate_list.append(b)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    logger.add_scalar('train/r2', train_r2, global_step)
    logger.add_scalar('train/RMSE', train_MSE**0.5, global_step)
    logger.add_scalar('train/Tau', train_tau, global_step)
    logger.add_scalar('val/WTI', val_WTI, global_step)
    logger.add_scalar('val/AP', val_AP, global_step)
    logger.add_scalar('val/r2', val_r2, global_step)
    logger.add_scalar('val/RMSE', val_MSE**0.5, global_step)
    logger.add_scalar('val/Tau', val_tau, global_step)
#     logger.add_scalar('test/TAP', test_AP, global_step)
    logger.add_scalar('test/r2', test_r2_a, global_step)
    logger.add_scalar('test/RMSE', test_MSE_a**0.5, global_step)
    logger.add_scalar('test/Tau', test_tau, global_step)
    logger.add_scalar('val/GRN', reconstruction_loss, global_step)
    logger.add_scalar('val/GRN_one_hot', one_hot_loss, global_step)
    logger.add_scalar('val/GRN_interger', interger_loss, global_step)
    logger.add_scalar('val/GRN_binary', binary_loss, global_step)
    logger.add_scalar('test/EF0.01', topk_list[0], global_step)
    logger.add_scalar('test/EF0.03', topk_list[1], global_step)
    logger.add_scalar('test/EF0.1', topk_list[2], global_step)
    logger.add_scalar('test/EF10', topk_list[3], global_step)
    logger.add_scalar('test/EF30', topk_list[4], global_step)
    logger.add_scalar('test/EF100', topk_list[5], global_step)
    
#     train_mse_list.append(train_MSE**0.5)
#     train_r2_list.append(train_r2)
#     val_mse_list.append(val_MSE**0.5)  
#     val_r2_list.append(val_r2)
#     train_f_list.append(train_f)
#     val_f_list.append(val_f)
#     test_f_list.append(test_f)
#     epoch_list.append(epoch)
#     train_predict_list.append(train_predict.flatten())
#     test_predict_list.append(test_predict.flatten())
#     val_predict_list.append(val_predict.flatten())
#     train_y_list.append(train_df[tasks[0]].values)
#     val_y_list.append(val_df[tasks[0]].values)
#     test_y_list.append(test_df[tasks[0]].values)
#     train_d_list.append(train_d)
#     val_d_list.append(val_d)
#     test_d_list.append(test_d)

    stop_index = - val_MSE**0.5 + val_tau
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
#     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
    print('Epoch:',epoch, 'Step:', global_step, 'Index:%.4f'%stop_index, 'R2:%.4f'%train_r2,'%.4f'%val_r2,'%.4f'%test_r2_a, 'RMSE:%.4f'%train_MSE**0.5, '%.4f'%val_MSE**0.5, 
          '%.4f'%test_MSE_a**0.5, 'Tau:%.4f'%train_tau,'%.4f'%val_tau,'%.4f'%test_tau)#, 'Tau:%.4f'%val_tau,'%.4f'%test_tau,'GRN:%.4f'%reconstruction_loss,'%.4f'%val_reconstruction_loss
    if early_stop:
        continue


log/3_GAFSE_Ki_P21554_1_420_run_0




Epoch: 153 Step: 24939 Index:-0.2763 R2:0.6711 0.5179 0.5054 RMSE:0.6389 0.8025 0.7912 Tau:0.6123 0.5262 0.3603
Epoch: 154 Step: 25102 Index:-0.2761 R2:0.6711 0.5175 0.4897 RMSE:0.6413 0.7953 0.8013 Tau:0.6095 0.5193 0.3657
EarlyStopping counter: 1 out of 100
Epoch: 155 Step: 25265 Index:-0.3152 R2:0.6573 0.4990 0.4697 RMSE:0.6472 0.8232 0.8214 Tau:0.6002 0.5080 0.3612
Epoch: 156 Step: 25428 Index:-0.2627 R2:0.6755 0.5245 0.5066 RMSE:0.6235 0.7847 0.7802 Tau:0.6124 0.5220 0.3647
EarlyStopping counter: 1 out of 100
Epoch: 157 Step: 25591 Index:-0.2999 R2:0.6624 0.5131 0.4873 RMSE:0.6585 0.8259 0.8192 Tau:0.6079 0.5261 0.3500
EarlyStopping counter: 2 out of 100
Epoch: 158 Step: 25754 Index:-0.2921 R2:0.6648 0.5043 0.4923 RMSE:0.6426 0.8149 0.8005 Tau:0.6095 0.5227 0.3607
EarlyStopping counter: 3 out of 100
Epoch: 159 Step: 25917 Index:-0.2688 R2:0.6748 0.5212 0.5099 RMSE:0.6334 0.7946 0.7819 Tau:0.6129 0.5258 0.3738
EarlyStopping counter: 4 out of 100
Epoch: 160 Step: 26080 Index:-0.2689

EarlyStopping counter: 1 out of 100
Epoch: 211 Step: 34393 Index:-0.2599 R2:0.7294 0.5387 0.5382 RMSE:0.5932 0.8010 0.7817 Tau:0.6508 0.5411 0.3802
EarlyStopping counter: 2 out of 100
Epoch: 212 Step: 34556 Index:-0.2737 R2:0.7082 0.5206 0.5166 RMSE:0.5970 0.7984 0.7820 Tau:0.6394 0.5247 0.3834
EarlyStopping counter: 3 out of 100
Epoch: 213 Step: 34719 Index:-0.3404 R2:0.7114 0.5156 0.5052 RMSE:0.6559 0.8657 0.8529 Tau:0.6397 0.5252 0.3898
EarlyStopping counter: 4 out of 100
Epoch: 214 Step: 34882 Index:-0.3347 R2:0.7111 0.5237 0.5087 RMSE:0.6438 0.8565 0.8431 Tau:0.6368 0.5218 0.3824
EarlyStopping counter: 5 out of 100
Epoch: 215 Step: 35045 Index:-0.2324 R2:0.7266 0.5382 0.5202 RMSE:0.5765 0.7800 0.7739 Tau:0.6523 0.5477 0.3662
Epoch: 216 Step: 35208 Index:-0.2195 R2:0.7337 0.5431 0.5316 RMSE:0.5710 0.7673 0.7587 Tau:0.6543 0.5477 0.3740
EarlyStopping counter: 1 out of 100
Epoch: 217 Step: 35371 Index:-0.2506 R2:0.7252 0.5244 0.5224 RMSE:0.5764 0.7839 0.7662 Tau:0.6505 0.5332 0.3835


EarlyStopping counter: 16 out of 100
Epoch: 267 Step: 43521 Index:-0.2288 R2:0.7583 0.5608 0.5555 RMSE:0.5564 0.7802 0.7699 Tau:0.6756 0.5514 0.3821
EarlyStopping counter: 17 out of 100
Epoch: 268 Step: 43684 Index:-0.2374 R2:0.7654 0.5339 0.5377 RMSE:0.5338 0.7813 0.7580 Tau:0.6798 0.5439 0.3872
EarlyStopping counter: 18 out of 100
Epoch: 269 Step: 43847 Index:-0.2705 R2:0.7612 0.5411 0.5462 RMSE:0.5545 0.8066 0.7804 Tau:0.6749 0.5361 0.3900
EarlyStopping counter: 19 out of 100
Epoch: 270 Step: 44010 Index:-0.2625 R2:0.7620 0.5517 0.5368 RMSE:0.5698 0.8030 0.7995 Tau:0.6761 0.5405 0.3949
EarlyStopping counter: 20 out of 100
Epoch: 271 Step: 44173 Index:-0.2266 R2:0.7628 0.5567 0.5437 RMSE:0.5471 0.7760 0.7676 Tau:0.6782 0.5494 0.3811
EarlyStopping counter: 21 out of 100
Epoch: 272 Step: 44336 Index:-0.2413 R2:0.7710 0.5355 0.5404 RMSE:0.5286 0.7828 0.7579 Tau:0.6851 0.5415 0.3933
EarlyStopping counter: 22 out of 100
Epoch: 273 Step: 44499 Index:-0.2385 R2:0.7672 0.5429 0.5444 RMSE:0.5

EarlyStopping counter: 6 out of 100
Epoch: 323 Step: 52649 Index:-0.2315 R2:0.7983 0.5448 0.5526 RMSE:0.4964 0.7789 0.7492 Tau:0.7063 0.5474 0.4091
EarlyStopping counter: 7 out of 100
Epoch: 324 Step: 52812 Index:-0.2298 R2:0.8002 0.5506 0.5591 RMSE:0.5031 0.7924 0.7698 Tau:0.7106 0.5625 0.3894
EarlyStopping counter: 8 out of 100
Epoch: 325 Step: 52975 Index:-0.2115 R2:0.7998 0.5562 0.5561 RMSE:0.4917 0.7694 0.7522 Tau:0.7094 0.5579 0.3940
EarlyStopping counter: 9 out of 100
Epoch: 326 Step: 53138 Index:-0.2080 R2:0.7959 0.5567 0.5493 RMSE:0.5006 0.7774 0.7658 Tau:0.7045 0.5694 0.3888
EarlyStopping counter: 10 out of 100
Epoch: 327 Step: 53301 Index:-0.2139 R2:0.7965 0.5470 0.5627 RMSE:0.5021 0.7671 0.7343 Tau:0.7044 0.5532 0.4054
EarlyStopping counter: 11 out of 100
Epoch: 328 Step: 53464 Index:-0.2646 R2:0.7851 0.5334 0.5506 RMSE:0.5578 0.8059 0.7681 Tau:0.6967 0.5413 0.3933
EarlyStopping counter: 12 out of 100
Epoch: 329 Step: 53627 Index:-0.2308 R2:0.7953 0.5441 0.5532 RMSE:0.5051 

EarlyStopping counter: 62 out of 100
Epoch: 379 Step: 61777 Index:-0.2061 R2:0.8359 0.5649 0.5611 RMSE:0.4609 0.7723 0.7607 Tau:0.7376 0.5663 0.4016
EarlyStopping counter: 63 out of 100
Epoch: 380 Step: 61940 Index:-0.2160 R2:0.8302 0.5398 0.5461 RMSE:0.4542 0.7803 0.7557 Tau:0.7348 0.5643 0.3977
EarlyStopping counter: 64 out of 100
Epoch: 381 Step: 62103 Index:-0.2200 R2:0.8171 0.5454 0.5357 RMSE:0.4707 0.7905 0.7814 Tau:0.7243 0.5705 0.3913
EarlyStopping counter: 65 out of 100
Epoch: 382 Step: 62266 Index:-0.2188 R2:0.8239 0.5591 0.5691 RMSE:0.4739 0.7824 0.7647 Tau:0.7320 0.5636 0.4013
EarlyStopping counter: 66 out of 100
Epoch: 383 Step: 62429 Index:-0.2323 R2:0.8205 0.5414 0.5623 RMSE:0.4658 0.7882 0.7508 Tau:0.7276 0.5559 0.4060
EarlyStopping counter: 67 out of 100
Epoch: 384 Step: 62592 Index:-0.2200 R2:0.8236 0.5517 0.5629 RMSE:0.4615 0.7803 0.7537 Tau:0.7258 0.5602 0.4038
EarlyStopping counter: 68 out of 100
Epoch: 385 Step: 62755 Index:-0.2110 R2:0.8282 0.5437 0.5593 RMSE:0.4

EarlyStopping counter: 117 out of 100
Epoch: 434 Step: 70742 Index:-0.2423 R2:0.8463 0.5463 0.5488 RMSE:0.4455 0.8023 0.7827 Tau:0.7452 0.5601 0.4039
EarlyStopping counter: 118 out of 100
Epoch: 435 Step: 70905 Index:-0.2132 R2:0.8157 0.5323 0.5158 RMSE:0.4775 0.7802 0.7798 Tau:0.7230 0.5669 0.3812
EarlyStopping counter: 119 out of 100
Epoch: 436 Step: 71068 Index:-0.2065 R2:0.8516 0.5586 0.5580 RMSE:0.4284 0.7739 0.7611 Tau:0.7526 0.5674 0.4001
EarlyStopping counter: 120 out of 100
Epoch: 437 Step: 71231 Index:-0.2249 R2:0.8543 0.5581 0.5638 RMSE:0.4309 0.7877 0.7680 Tau:0.7555 0.5628 0.4117
EarlyStopping counter: 121 out of 100
Epoch: 438 Step: 71394 Index:-0.2575 R2:0.8444 0.5531 0.5629 RMSE:0.4813 0.8087 0.7858 Tau:0.7454 0.5512 0.4167
EarlyStopping counter: 122 out of 100
Epoch: 439 Step: 71557 Index:-0.1959 R2:0.8501 0.5655 0.5501 RMSE:0.4265 0.7604 0.7552 Tau:0.7500 0.5646 0.3987
EarlyStopping counter: 123 out of 100
Epoch: 440 Step: 71720 Index:-0.2209 R2:0.8484 0.5479 0.5527 R

EarlyStopping counter: 10 out of 100
Epoch: 489 Step: 79707 Index:-0.2182 R2:0.8636 0.5612 0.5465 RMSE:0.4043 0.7809 0.7804 Tau:0.7634 0.5627 0.4005
EarlyStopping counter: 11 out of 100
Epoch: 490 Step: 79870 Index:-0.2364 R2:0.8665 0.5582 0.5563 RMSE:0.4303 0.7944 0.7811 Tau:0.7662 0.5580 0.4108
EarlyStopping counter: 12 out of 100
Epoch: 491 Step: 80033 Index:-0.1867 R2:0.8733 0.5685 0.5622 RMSE:0.3961 0.7630 0.7555 Tau:0.7744 0.5762 0.4133
EarlyStopping counter: 13 out of 100
Epoch: 492 Step: 80196 Index:-0.1931 R2:0.8601 0.5733 0.5545 RMSE:0.4181 0.7711 0.7740 Tau:0.7601 0.5781 0.4055
EarlyStopping counter: 14 out of 100
Epoch: 493 Step: 80359 Index:-0.2156 R2:0.8707 0.5633 0.5646 RMSE:0.4257 0.7850 0.7706 Tau:0.7702 0.5693 0.4131
EarlyStopping counter: 15 out of 100
Epoch: 494 Step: 80522 Index:-0.2186 R2:0.8712 0.5634 0.5636 RMSE:0.3994 0.7818 0.7643 Tau:0.7705 0.5632 0.4115
EarlyStopping counter: 16 out of 100
Epoch: 495 Step: 80685 Index:-0.2253 R2:0.8681 0.5580 0.5584 RMSE:0.4

EarlyStopping counter: 26 out of 100
Epoch: 545 Step: 88835 Index:-0.2142 R2:0.8857 0.5569 0.5527 RMSE:0.3828 0.7873 0.7777 Tau:0.7871 0.5731 0.3985
EarlyStopping counter: 27 out of 100
Epoch: 546 Step: 88998 Index:-0.2348 R2:0.8787 0.5490 0.5415 RMSE:0.3878 0.7912 0.7813 Tau:0.7774 0.5564 0.4094
EarlyStopping counter: 28 out of 100
Epoch: 547 Step: 89161 Index:-0.2363 R2:0.8817 0.5514 0.5551 RMSE:0.3888 0.8063 0.7867 Tau:0.7839 0.5700 0.4055
EarlyStopping counter: 29 out of 100
Epoch: 548 Step: 89324 Index:-0.2187 R2:0.8670 0.5613 0.5417 RMSE:0.4035 0.7894 0.8010 Tau:0.7682 0.5706 0.4064
EarlyStopping counter: 30 out of 100
Epoch: 549 Step: 89487 Index:-0.2237 R2:0.8805 0.5471 0.5628 RMSE:0.3833 0.7775 0.7453 Tau:0.7807 0.5539 0.4073
EarlyStopping counter: 31 out of 100
Epoch: 550 Step: 89650 Index:-0.1985 R2:0.8785 0.5594 0.5387 RMSE:0.3850 0.7679 0.7715 Tau:0.7780 0.5694 0.4006
EarlyStopping counter: 32 out of 100
Epoch: 551 Step: 89813 Index:-0.1983 R2:0.8817 0.5659 0.5528 RMSE:0.3

EarlyStopping counter: 13 out of 100
Epoch: 601 Step: 97963 Index:-0.2083 R2:0.8998 0.5686 0.5580 RMSE:0.3506 0.7716 0.7698 Tau:0.8026 0.5633 0.4203
EarlyStopping counter: 14 out of 100
Epoch: 602 Step: 98126 Index:-0.2246 R2:0.8911 0.5500 0.5440 RMSE:0.3685 0.7810 0.7738 Tau:0.7924 0.5565 0.4202
EarlyStopping counter: 15 out of 100
Epoch: 603 Step: 98289 Index:-0.2372 R2:0.8863 0.5617 0.5363 RMSE:0.4013 0.8087 0.8244 Tau:0.7878 0.5715 0.4172
EarlyStopping counter: 16 out of 100
Epoch: 604 Step: 98452 Index:-0.1966 R2:0.8941 0.5731 0.5504 RMSE:0.3569 0.7665 0.7823 Tau:0.7987 0.5700 0.4198
EarlyStopping counter: 17 out of 100
Epoch: 605 Step: 98615 Index:-0.2341 R2:0.8801 0.5407 0.5331 RMSE:0.3829 0.7878 0.7784 Tau:0.7830 0.5537 0.4193
EarlyStopping counter: 18 out of 100
Epoch: 606 Step: 98778 Index:-0.2113 R2:0.8941 0.5573 0.5443 RMSE:0.3561 0.7838 0.7803 Tau:0.7946 0.5725 0.4147
EarlyStopping counter: 19 out of 100
Epoch: 607 Step: 98941 Index:-0.2477 R2:0.8927 0.5577 0.5504 RMSE:0.4

EarlyStopping counter: 68 out of 100
Epoch: 656 Step: 106928 Index:-0.2086 R2:0.9071 0.5589 0.5546 RMSE:0.3399 0.7720 0.7621 Tau:0.8109 0.5634 0.4254
EarlyStopping counter: 69 out of 100
Epoch: 657 Step: 107091 Index:-0.2524 R2:0.9055 0.5555 0.5421 RMSE:0.3798 0.8169 0.8246 Tau:0.8079 0.5644 0.4105
EarlyStopping counter: 70 out of 100
Epoch: 658 Step: 107254 Index:-0.1994 R2:0.8948 0.5625 0.5500 RMSE:0.3587 0.7689 0.7676 Tau:0.7954 0.5695 0.4133
EarlyStopping counter: 71 out of 100
Epoch: 659 Step: 107417 Index:-0.2116 R2:0.9104 0.5643 0.5533 RMSE:0.3402 0.7797 0.7814 Tau:0.8154 0.5681 0.4220
EarlyStopping counter: 72 out of 100
Epoch: 660 Step: 107580 Index:-0.2283 R2:0.8939 0.5676 0.5456 RMSE:0.3599 0.7904 0.7971 Tau:0.7984 0.5622 0.4191
EarlyStopping counter: 73 out of 100
Epoch: 661 Step: 107743 Index:-0.2123 R2:0.9064 0.5604 0.5446 RMSE:0.3567 0.7865 0.7986 Tau:0.8094 0.5742 0.4147
EarlyStopping counter: 74 out of 100
Epoch: 662 Step: 107906 Index:-0.2397 R2:0.9069 0.5624 0.5381 R

EarlyStopping counter: 123 out of 100
Epoch: 711 Step: 115893 Index:-0.2021 R2:0.9169 0.5601 0.5440 RMSE:0.3196 0.7755 0.7760 Tau:0.8214 0.5734 0.4084
EarlyStopping counter: 124 out of 100
Epoch: 712 Step: 116056 Index:-0.2232 R2:0.9161 0.5564 0.5430 RMSE:0.3181 0.7868 0.7879 Tau:0.8231 0.5636 0.4233
EarlyStopping counter: 125 out of 100
Epoch: 713 Step: 116219 Index:-0.2417 R2:0.9114 0.5472 0.5420 RMSE:0.3296 0.8038 0.7971 Tau:0.8140 0.5621 0.4204
EarlyStopping counter: 126 out of 100
Epoch: 714 Step: 116382 Index:-0.2552 R2:0.9130 0.5626 0.5485 RMSE:0.3580 0.8149 0.8149 Tau:0.8166 0.5598 0.4234
EarlyStopping counter: 127 out of 100
Epoch: 715 Step: 116545 Index:-0.3946 R2:0.5916 0.4282 0.4128 RMSE:0.7023 0.8627 0.8517 Tau:0.5751 0.4681 0.3954
EarlyStopping counter: 128 out of 100
Epoch: 716 Step: 116708 Index:-0.2041 R2:0.8765 0.5577 0.5448 RMSE:0.4042 0.7748 0.7751 Tau:0.7805 0.5707 0.4211
EarlyStopping counter: 129 out of 100
Epoch: 717 Step: 116871 Index:-0.1862 R2:0.8964 0.5642 0

EarlyStopping counter: 179 out of 100
Epoch: 767 Step: 125021 Index:-0.2518 R2:0.9179 0.5376 0.5330 RMSE:0.3159 0.8095 0.7969 Tau:0.8247 0.5577 0.4281
EarlyStopping counter: 180 out of 100
Epoch: 768 Step: 125184 Index:-0.2302 R2:0.9228 0.5614 0.5375 RMSE:0.3204 0.7980 0.8109 Tau:0.8291 0.5678 0.4270
EarlyStopping counter: 181 out of 100
Epoch: 769 Step: 125347 Index:-0.2268 R2:0.9220 0.5628 0.5466 RMSE:0.3100 0.7961 0.8023 Tau:0.8303 0.5693 0.4366
EarlyStopping counter: 182 out of 100
Epoch: 770 Step: 125510 Index:-0.2099 R2:0.9227 0.5699 0.5304 RMSE:0.3111 0.7813 0.8065 Tau:0.8295 0.5714 0.4255
EarlyStopping counter: 183 out of 100
Epoch: 771 Step: 125673 Index:-0.2436 R2:0.9228 0.5510 0.5260 RMSE:0.3087 0.8028 0.8167 Tau:0.8305 0.5593 0.4206
EarlyStopping counter: 184 out of 100
Epoch: 772 Step: 125836 Index:-0.2379 R2:0.9023 0.5405 0.5170 RMSE:0.3457 0.8004 0.8133 Tau:0.8040 0.5625 0.4116
EarlyStopping counter: 185 out of 100
Epoch: 773 Step: 125999 Index:-0.2165 R2:0.9168 0.5653 0

EarlyStopping counter: 234 out of 100
Epoch: 822 Step: 133986 Index:-0.2495 R2:0.9295 0.5433 0.5313 RMSE:0.2933 0.8137 0.8200 Tau:0.8385 0.5642 0.4340
EarlyStopping counter: 235 out of 100
Epoch: 823 Step: 134149 Index:-0.2315 R2:0.9349 0.5485 0.5365 RMSE:0.2833 0.7978 0.7973 Tau:0.8468 0.5663 0.4320
EarlyStopping counter: 236 out of 100
Epoch: 824 Step: 134312 Index:-0.2527 R2:0.9235 0.5370 0.5208 RMSE:0.3121 0.8213 0.8318 Tau:0.8291 0.5686 0.4299
EarlyStopping counter: 237 out of 100
Epoch: 825 Step: 134475 Index:-0.2653 R2:0.9358 0.5383 0.5274 RMSE:0.3015 0.8259 0.8348 Tau:0.8456 0.5606 0.4247
EarlyStopping counter: 238 out of 100
Epoch: 826 Step: 134638 Index:-0.3189 R2:0.9311 0.5429 0.5391 RMSE:0.3818 0.8802 0.8890 Tau:0.8405 0.5613 0.4348
EarlyStopping counter: 239 out of 100
Epoch: 827 Step: 134801 Index:-0.2307 R2:0.9315 0.5574 0.5393 RMSE:0.2868 0.7991 0.8030 Tau:0.8384 0.5685 0.4364
EarlyStopping counter: 240 out of 100
Epoch: 828 Step: 134964 Index:-0.3212 R2:0.9130 0.5368 0

EarlyStopping counter: 289 out of 100
Epoch: 877 Step: 142951 Index:-0.2580 R2:0.9412 0.5486 0.5344 RMSE:0.2762 0.8202 0.8304 Tau:0.8542 0.5622 0.4341
EarlyStopping counter: 290 out of 100
Epoch: 878 Step: 143114 Index:-0.2850 R2:0.9375 0.5457 0.5331 RMSE:0.2863 0.8492 0.8584 Tau:0.8476 0.5643 0.4296
EarlyStopping counter: 291 out of 100
Epoch: 879 Step: 143277 Index:-0.2352 R2:0.9353 0.5531 0.5136 RMSE:0.2828 0.8010 0.8289 Tau:0.8433 0.5657 0.4298
EarlyStopping counter: 292 out of 100
Epoch: 880 Step: 143440 Index:-0.3120 R2:0.9342 0.5410 0.5190 RMSE:0.3486 0.8675 0.8952 Tau:0.8431 0.5555 0.4379
EarlyStopping counter: 293 out of 100
Epoch: 881 Step: 143603 Index:-0.2791 R2:0.9333 0.5450 0.5172 RMSE:0.2878 0.8376 0.8554 Tau:0.8421 0.5586 0.4218
EarlyStopping counter: 294 out of 100
Epoch: 882 Step: 143766 Index:-0.2169 R2:0.9372 0.5604 0.5298 RMSE:0.2768 0.7891 0.8145 Tau:0.8494 0.5722 0.4349
EarlyStopping counter: 295 out of 100
Epoch: 883 Step: 143929 Index:-0.2755 R2:0.9328 0.5604 0

EarlyStopping counter: 344 out of 100
Epoch: 932 Step: 151916 Index:-0.2388 R2:0.9403 0.5613 0.5329 RMSE:0.2763 0.8068 0.8334 Tau:0.8549 0.5680 0.4500
EarlyStopping counter: 345 out of 100
Epoch: 933 Step: 152079 Index:-0.2501 R2:0.9390 0.5679 0.5395 RMSE:0.2839 0.8176 0.8425 Tau:0.8505 0.5675 0.4349
EarlyStopping counter: 346 out of 100
Epoch: 934 Step: 152242 Index:-0.2616 R2:0.9346 0.5604 0.5250 RMSE:0.2935 0.8272 0.8667 Tau:0.8453 0.5656 0.4342
EarlyStopping counter: 347 out of 100
Epoch: 935 Step: 152405 Index:-0.2626 R2:0.9444 0.5584 0.5321 RMSE:0.2767 0.8269 0.8571 Tau:0.8607 0.5643 0.4409
EarlyStopping counter: 348 out of 100
Epoch: 936 Step: 152568 Index:-0.2512 R2:0.9439 0.5536 0.5297 RMSE:0.2679 0.8170 0.8383 Tau:0.8590 0.5657 0.4389
EarlyStopping counter: 349 out of 100
Epoch: 937 Step: 152731 Index:-0.2466 R2:0.9416 0.5496 0.5283 RMSE:0.2724 0.8088 0.8276 Tau:0.8550 0.5622 0.4354
EarlyStopping counter: 350 out of 100
Epoch: 938 Step: 152894 Index:-0.2668 R2:0.9411 0.5401 0

EarlyStopping counter: 399 out of 100
Epoch: 987 Step: 160881 Index:-0.2824 R2:0.9455 0.5381 0.5138 RMSE:0.2618 0.8393 0.8608 Tau:0.8604 0.5569 0.4441
EarlyStopping counter: 400 out of 100
Epoch: 988 Step: 161044 Index:-0.2578 R2:0.9491 0.5458 0.5221 RMSE:0.2488 0.8216 0.8421 Tau:0.8671 0.5638 0.4471
EarlyStopping counter: 401 out of 100
Epoch: 989 Step: 161207 Index:-0.2652 R2:0.9464 0.5410 0.5270 RMSE:0.2570 0.8244 0.8365 Tau:0.8602 0.5592 0.4410
EarlyStopping counter: 402 out of 100
Epoch: 990 Step: 161370 Index:-0.2810 R2:0.9441 0.5255 0.5227 RMSE:0.2614 0.8413 0.8520 Tau:0.8587 0.5603 0.4415
EarlyStopping counter: 403 out of 100
Epoch: 991 Step: 161533 Index:-0.3440 R2:0.9257 0.5267 0.5183 RMSE:0.3245 0.9051 0.9221 Tau:0.8359 0.5611 0.4449
EarlyStopping counter: 404 out of 100
Epoch: 992 Step: 161696 Index:-0.2594 R2:0.9479 0.5402 0.5317 RMSE:0.2614 0.8212 0.8232 Tau:0.8648 0.5618 0.4405
EarlyStopping counter: 405 out of 100
Epoch: 993 Step: 161859 Index:-0.3050 R2:0.9473 0.5343 0

In [15]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df)
test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
test_r2_ina, test_MSE_ina, test_predict_ina = eval(model, amodel, gmodel, test_df[test_active:].reset_index(drop=True))
    
test_predict = np.array(test_predict)
test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())

k_list = [int(len(test_df)*0.01),int(len(test_df)*0.05),int(len(test_df)*0.1),int(len(test_df)*0.15),int(len(test_df)*0.2),int(len(test_df)*0.25),
          int(len(test_df)*0.3),int(len(test_df)*0.4),int(len(test_df)*0.5),50,100,150,200,250,300]
topk_list =[]
false_positive_rate_list = []
for k in k_list:
    a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
    topk_list.append(a)
    false_positive_rate_list.append(b)
WTI = weighted_top_index(test_df, test_predict, test_active)
ap = AP(test_df, test_predict, test_active)


In [16]:
print(' epoch:',epoch,'r2:%.4f'%test_r2_a,'RMSE:%.4f'%test_MSE_a**0.5,'WTI:%.4f'%WTI,'AP:%.4f'%ap,'Tau:%.4f'%test_tau,'\n','\n',
      'Top-1:%.4f'%topk_list[0],'Top-1-fp:%.4f'%false_positive_rate_list[0],'\n',
      'Top-5:%.4f'%topk_list[1],'Top-5-fp:%.4f'%false_positive_rate_list[1],'\n',
      'Top-10:%.4f'%topk_list[2],'Top-10-fp:%.4f'%false_positive_rate_list[2],'\n',
      'Top-15:%.4f'%topk_list[3],'Top-15-fp:%.4f'%false_positive_rate_list[3],'\n',
      'Top-20:%.4f'%topk_list[4],'Top-20-fp:%.4f'%false_positive_rate_list[4],'\n',
      'Top-25:%.4f'%topk_list[5],'Top-25-fp:%.4f'%false_positive_rate_list[5],'\n',
      'Top-30:%.4f'%topk_list[6],'Top-30-fp:%.4f'%false_positive_rate_list[6],'\n',
      'Top-40:%.4f'%topk_list[7],'Top-40-fp:%.4f'%false_positive_rate_list[7],'\n',
      'Top-50:%.4f'%topk_list[8],'Top-50-fp:%.4f'%false_positive_rate_list[8],'\n','\n',
      'Top50:%.4f'%topk_list[9],'Top50-fp:%.4f'%false_positive_rate_list[9],'\n',
      'Top100:%.4f'%topk_list[10],'Top100-fp:%.4f'%false_positive_rate_list[10],'\n',
      'Top150:%.4f'%topk_list[11],'Top150-fp:%.4f'%false_positive_rate_list[11],'\n',
      'Top200:%.4f'%topk_list[12],'Top200-fp:%.4f'%false_positive_rate_list[12],'\n',
      'Top250:%.4f'%topk_list[13],'Top250-fp:%.4f'%false_positive_rate_list[13],'\n',
      'Top300:%.4f'%topk_list[14],'Top300-fp:%.4f'%false_positive_rate_list[14],'\n')

 epoch: 1000 r2:0.5476 RMSE:0.7600 WTI:0.3450 AP:0.5900 Tau:0.4148 
 
 Top-1:0.4615 Top-1-fp:0.0000 
 Top-5:0.5672 Top-5-fp:0.1343 
 Top-10:0.6418 Top-10-fp:0.1716 
 Top-15:0.6866 Top-15-fp:0.2040 
 Top-20:0.6567 Top-20-fp:0.2910 
 Top-25:0.6358 Top-25-fp:0.3463 
 Top-30:0.6095 Top-30-fp:0.3905 
 Top-40:0.6738 Top-40-fp:0.4720 
 Top-50:0.7643 Top-50-fp:0.5216 
 
 Top50:0.5200 Top50-fp:0.0600 
 Top100:0.5600 Top100-fp:0.1500 
 Top150:0.6200 Top150-fp:0.1867 
 Top200:0.6850 Top200-fp:0.2000 
 Top250:0.6760 Top250-fp:0.2640 
 Top300:0.6467 Top300-fp:0.3167 



In [17]:
# print('target_file:',train_filename)
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)

In [18]:
# loss = loss_function(mol_prediction,y)
#             loss.backward(retain_graph=True)
#             optimizer_AFSE.zero_grad()
#             punish_lr = torch.norm(torch.mean(eps.grad,0))

#         init_lr = 1e-4
#         max_lr = 10**-(init_lr-1)
#         conv_lr = conv_lr - conv_lr**2 + 0.1 * punish_lr
#         if conv_lr < max_lr:
#             for param_group in optimizer_AFSE.param_groups:
#                 param_group["lr"] = conv_lr.detach()
#                 AFSE_lr = conv_lr    
#         else:
#             for param_group in optimizer_AFSE.param_groups:
#                 param_group["lr"] = max_lr
#                 AFSE_lr = max_lr
# epoch: 512 r2:0.5619 RMSE:0.7306 WTI:0.3784 AP:0.7228 Tau:0.5159 
 
#  Top-1:0.0000 Top-1-fp:0.0000 
#  Top-5:0.8571 Top-5-fp:0.0000 
#  Top-10:0.7857 Top-10-fp:0.0714 