In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
train_filename = "./data/benchmark/Ki_P34972_1_500_train.csv"
test_filename = "./data/benchmark/Ki_P34972_1_500_test.csv"
test_active = 500
val_rate = 0.2
random_seed = 68
file_list1 = train_filename.split('/')
file1 = file_list1[-1]
file1 = file1[:-10]
number = '_run_0'
model_file = "model_file/3_GAFSE_"+file1+number
log_dir = f'log/{"3_GAFSE_"+file1}'+number
result_dir = './result/3_GAFSE_'+file1+number
print(file1)
print(model_file)

Ki_P34972_1_500
model_file/3_GAFSE_Ki_P34972_1_500_run_0


In [4]:
# task_name = 'Malaria Bioactivity'
tasks = ['value']

# train_filename = "../data/active_inactive/median_active/EC50/Q99500.csv"
feature_filename = train_filename.replace('.csv','.pickle')
filename = train_filename.replace('.csv','')
prefix_filename = train_filename.split('/')[-1].replace('.csv','')
train_df = pd.read_csv(train_filename, header=0, names = ["smiles","value"],usecols=[0,1])
# train_df = train_df[1:]
# train_df = train_df.drop(0,axis=1,inplace=False) 
print(train_df[:5])
# print(train_df.iloc(1))
def add_canonical_smiles(train_df):
    smilesList = train_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    train_df = train_df[train_df["smiles"].isin(remained_smiles)]
    train_df['cano_smiles'] =canonical_smiles_list
    return train_df
# print(train_df)
train_df = add_canonical_smiles(train_df)

print(train_df.head())
# plt.figure(figsize=(5, 3))
# sns.set(font_scale=1.5)
# ax = sns.distplot(atom_num_dist, bins=28, kde=False)
# plt.tight_layout()
# # plt.savefig("atom_num_dist_"+prefix_filename+".png",dpi=200)
# plt.show()
# plt.close()


                                              smiles     value
0           CCCCCC1=NN(C(=C1)C2=CC=CC=C2)C3=CC=CC=C3 -3.468790
1  CCCNC(=O)C1=NN(C(=C1C)N2C(=CC=C2C)C)C3=C(C=C(C... -3.290769
2  CC(C)(C)C1=CC(=CC(=C1O)C(C)(C)C)C2=CSC(=N2)C3(... -2.247973
3  C1CCCC(CC1)NC(=O)C2=CC3=C(N=CC=C3)N(C2=O)CC4=C... -1.342423
4         CCCCN1C(=C(C=C(C1=S)OC2=NC3=CC=CC=C3O2)C)C -2.004321
number of all smiles:  2591
number of successfully processed smiles:  2591
                                              smiles     value  \
0           CCCCCC1=NN(C(=C1)C2=CC=CC=C2)C3=CC=CC=C3 -3.468790   
1  CCCNC(=O)C1=NN(C(=C1C)N2C(=CC=C2C)C)C3=C(C=C(C... -3.290769   
2  CC(C)(C)C1=CC(=CC(=C1O)C(C)(C)C)C2=CSC(=N2)C3(... -2.247973   
3  C1CCCC(CC1)NC(=O)C2=CC3=C(N=CC=C3)N(C2=O)CC4=C... -1.342423   
4         CCCCN1C(=C(C=C(C1=S)OC2=NC3=CC=CC=C3O2)C)C -2.004321   

                                         cano_smiles  
0                 CCCCCc1cc(-c2ccccc2)n(-c2ccccc2)n1  
1   CCCNC(=O)c1nn(-c2ccc(Cl)cc2Cl)c(-n2c

In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 2 # default: 2
T = 1
per_task_output_units_num = 1 # for regression model
output_units_num = len(tasks) * per_task_output_units_num

In [6]:
test_df = pd.read_csv(test_filename,header=0,names=["smiles","value"],usecols=[0,1])
test_df = add_canonical_smiles(test_df)
for l in test_df["cano_smiles"]:
    if l in train_df["cano_smiles"]:
        print("same smiles:",l)
        
print(test_df.shape)
print(test_df.head())

number of all smiles:  1314
number of successfully processed smiles:  1314
(1314, 3)
                                              smiles     value  \
0  C1CN(CCC1(CNS(=O)(=O)C(F)(F)F)C#N)S(=O)(=O)C2=... -1.574031   
1  CCN(CC)C1=CC=C(C=C1)CN(C2=CC=C(C=C2)C)S(=O)(=O... -0.477121   
2  CC1(C2CCC(C2)(C1NC(=O)C3=CN(C4=C3C=CC=C4OC)CCN... -1.113943   
3  CCCN(CCC)C(=O)C1=CC2=C(C=C1)N(C(=N2)CC3=CC=C(C... -0.447158   
4  CC1=CC=C(C=C1)CN2C=C(C=C2C3=CC(=C(C=C3)Cl)C)CN... -2.535294   

                                         cano_smiles  
0  N#CC1(CNS(=O)(=O)C(F)(F)F)CCN(S(=O)(=O)c2ccc(C...  
1  CCN(CC)c1ccc(CN(c2ccc(C)cc2)S(=O)(=O)c2ccc(Cl)...  
2  COc1cccc2c(C(=O)NC3C4(C)CCC(C4)C3(C)C)cn(CCN3C...  
3  CCCN(CCC)C(=O)c1ccc2c(c1)nc(Cc1ccc(OCC)cc1)n2C...  
4  Cc1ccc(Cn2cc(CNC3C4(C)CCC(C4)C3(C)C)cc2-c2ccc(...  


In [7]:
print(feature_filename)
print(filename)
total_df = pd.concat([train_df,test_df],axis=0)
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
# if os.path.isfile(feature_filename):
#     feature_dicts = pickle.load(open(feature_filename, "rb" ))
# else:
#     feature_dicts = save_smiles_dicts(smilesList,filename)
feature_dicts = save_smiles_dicts(total_smilesList,filename)
remained_df = total_df[total_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = total_df.drop(remained_df.index)

./data/benchmark/Ki_P34972_1_500_train.pickle
./data/benchmark/Ki_P34972_1_500_train
3905
feature dicts file saved as ./data/benchmark/Ki_P34972_1_500_train.pickle


In [8]:
val_df = train_df.sample(frac=val_rate,random_state=random_seed)
train_df = train_df.drop(val_df.index)
train_df = train_df.reset_index(drop=True)
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)
print(train_df.shape,val_df.shape,test_df.shape)

(2073, 3) (518, 3) (1314, 3)


In [9]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=10**(-learning_rate), weight_decay=10**-weight_decay)

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [10]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [11]:
def caculate_r2(y,predict):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/y2

In [12]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None):
    mol_prediction = model(feature=f, d=0)
    pred = mol_prediction.detach()
#     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
    eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
    eps = Variable(eps, requires_grad=True)
    # Predict on randomly perturbed image
    eps_p = model(feature=f, d=eps.cuda())
    eps_p_ = model(feature=f, d=-eps.cuda())
    p_aux = nn.Sigmoid()(eps_p/(pred+1e-6))
    p_aux_ = nn.Sigmoid()(eps_p_/(pred+1e-6))
#     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
    loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
    loss.backward(retain_graph=True)

    # Based on perturbed image, get direction of greatest error
    eps_adv = eps.grad#/10**-learning_rate
    optimizer_AFSE.zero_grad()
    # Use that direction as adversarial perturbation
    eps_adv_normed = normalize_perturbation(eps_adv)
    d_adv = lamda * eps_adv_normed.cuda()
    if output_lr:
        f_p, max_lr = model(feature=f, d=d_adv, output_lr=output_lr)
    f_p = model(feature=f, d=d_adv)
    f_p_ = model(feature=f, d=-d_adv)
    p = nn.Sigmoid()(f_p/(pred+1e-6))
    p_ = nn.Sigmoid()(f_p_/(pred+1e-6))
    vat_loss = loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    if output_lr:
        if output_plr:
            loss = loss_function(mol_prediction,y)
            loss.backward(retain_graph=True)
            optimizer_AFSE.zero_grad()
            punish_lr = torch.norm(torch.mean(eps.grad,0))
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [13]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25):
    weighted_CE = nn.CrossEntropyLoss(weight=alpha, reduce=False)
    focal_weight = (1-torch.max(pred * target, -1)[0])**gamma
    loss = focal_weight * weighted_CE(pred, torch.argmax(target,-1))
    loss = torch.mean(loss)
    return loss

def generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list):
    [a,b,c] = x_atom.shape
    [d,e,f,g] = bond_neighbor.shape
    ce_loss = nn.CrossEntropyLoss()
    one_hot_loss = 0
    run_times = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        atom_weights = 1-x_atom[i,:l,:16].sum(-2)/l
        ce_atom_loss = nn.CrossEntropyLoss(weight=atom_weights)
#         print(atom_weights[1], refer_atom_list[i,0,torch.argmax(x_atom[i,0,:16],-1)], torch.argmax(x_atom[i,0,:16],-1))
#         one_hot_loss += ce_atom_loss(refer_atom_list[i,:l,:16], torch.argmax(x_atom[i,:l,:16],-1))- \
#                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        one_hot_loss += py_sigmoid_focal_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16], alpha=atom_weights)- \
                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        run_times += 2
    total_loss = one_hot_loss/run_times
    return total_loss, 0, 0, 0


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(float)
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                  torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
                                                  mol_feature=mol_feature,activated_features=activated_features.detach())
        
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1, 
                                                                                       lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1)) # 10**-learning_rate     
        regression_loss = loss_function(mol_prediction, torch.Tensor(y_val).view(-1,1))
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
        success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=1)
        reconstruction_loss, one_hot_loss, interger_loss,binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, 
                                                                                              bond_neighbor, validity_mask, atom_list, 
                                                                                              bond_list)
        x_atom_test = torch.Tensor(x_atom_test)
        x_bonds_test = torch.Tensor(x_bonds_test)
        x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
        bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
        bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate)
        atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
                                                torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
        refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                            torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
                                                            mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
        success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
                            bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
                                                     refer_atom_list_test, refer_bond_list_test,topn=1)
        test_reconstruction_loss, test_one_hot_loss, test_interger_loss,test_binary_loss = generate_loss_function(atom_list_test, x_atom_test, bond_list_test, bond_neighbor_test, validity_mask_test, atom_list_test, bond_list_test)
        
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        max_lr = 1e-3
        conv_lr = conv_lr - conv_lr**2 + 0.06 * punish_lr
        if conv_lr < max_lr and conv_lr >= 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = conv_lr.detach()
                AFSE_lr = conv_lr    
        elif conv_lr < 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = 0
                AFSE_lr = 0
        elif conv_lr >= max_lr:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = max_lr
                AFSE_lr = max_lr
        
        logger.add_scalar('loss/regression', regression_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
        logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
        logger.add_scalar('loss/GRN_one_hot', one_hot_loss, global_step)
        logger.add_scalar('loss/GRN_interger', interger_loss, global_step)
        logger.add_scalar('loss/GRN_binary', binary_loss, global_step)
        logger.add_scalar('lr/max_lr', conv_lr, global_step)
        logger.add_scalar('lr/punish_lr', punish_lr, global_step)
        logger.add_scalar('lr/AFSE_lr', AFSE_lr, global_step)
    
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  regression_loss + 0.6 * (vat_loss + test_vat_loss) + 0.3 * (reconstruction_loss + test_reconstruction_loss)
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, topn=1, output_feature=False, generate=False, modify_atom=True,return_GRN_loss=False, viz=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(float)
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
        x_atom = torch.Tensor(x_atom)
        x_bonds = torch.Tensor(x_bonds)
        x_bond_index = torch.cuda.LongTensor(x_bond_index)
        bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
        bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda)
#         print(mol_feature,d_adv)
        atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
                                      torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
                                      torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
        refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
        if generate:
            if modify_atom:
                success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
                            bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
                                                     refer_atom_list, refer_bond_list,topn=topn,viz=viz)
            else:
                modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
            generated_smiles.extend(modified_smiles)
            success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
            success_reconstruction += reconstruction
            success_validity += validity
            reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
        
        MSE = F.mse_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction.cpu().detach().numpy())
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
        test_MSE_list.extend(MSE.data.view(-1,1).cpu().numpy())
#     print(r2_list)
    if generate:
        generated_num = len(generated_smiles)
        eval_num = len(dataset)
        unique = generated_num
        novelty = generated_num
        for i in range(generated_num):
            for j in range(generated_num-i-1):
                if generated_smiles[i]==generated_smiles[i+j+1]:
                    unique -= 1
            for k in range(eval_num):
                if generated_smiles[i]==dataset['smiles'].values[k]:
                    novelty -= 1
        unique_rate = unique/(generated_num+1e-9)
        novelty_rate = novelty/(generated_num+1e-9)
#         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
        return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    if return_GRN_loss:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
    return caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 100
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [14]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
# if os.path.isdir(log_dir):
#     for files in os.listdir(log_dir):
#         os.remove(log_dir+"/"+files)
#     os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = 135
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
#     print(train_df.shape,test_df.shape)
    train_d, train_f, train_r2, train_MSE, train_predict, reconstruction_loss, one_hot_loss, interger_loss,binary_loss = eval(model, amodel, gmodel, train_df,output_feature=True,return_GRN_loss=True)
    train_predict = np.array(train_predict)
    train_WTI = weighted_top_index(train_df, train_predict, len(train_df))
    train_tau, _ = scipy.stats.kendalltau(train_predict,train_df[tasks[0]].values.astype(float).tolist())
    val_d, val_f, val_r2, val_MSE, val_predict, val_reconstruction_loss, val_one_hot_loss, val_interger_loss,val_binary_loss = eval(model, amodel, gmodel, val_df,output_feature=True,return_GRN_loss=True)
    val_predict = np.array(val_predict)
    val_WTI = weighted_top_index(val_df, val_predict, len(val_df))
    val_AP = AP(val_df, val_predict, len(val_df))
    val_tau, _ = scipy.stats.kendalltau(val_predict,val_df[tasks[0]].values.astype(float).tolist())
    
    test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
    test_d, test_f, test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df,output_feature=True)
    test_predict = np.array(test_predict)
    test_WTI = weighted_top_index(test_df, test_predict, test_active)
#     test_AP = AP(test_df, test_predict, test_active)
    test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())
    
    k_list = [int(len(test_df)*0.01),int(len(test_df)*0.03),int(len(test_df)*0.1),10,30,100]
    topk_list =[]
    false_positive_rate_list = []
    for k in k_list:
        a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
        topk_list.append(a)
        false_positive_rate_list.append(b)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    logger.add_scalar('train/r2', train_r2, global_step)
    logger.add_scalar('train/RMSE', train_MSE**0.5, global_step)
    logger.add_scalar('train/Tau', train_tau, global_step)
    logger.add_scalar('val/WTI', val_WTI, global_step)
    logger.add_scalar('val/AP', val_AP, global_step)
    logger.add_scalar('val/r2', val_r2, global_step)
    logger.add_scalar('val/RMSE', val_MSE**0.5, global_step)
    logger.add_scalar('val/Tau', val_tau, global_step)
#     logger.add_scalar('test/TAP', test_AP, global_step)
    logger.add_scalar('test/r2', test_r2_a, global_step)
    logger.add_scalar('test/RMSE', test_MSE_a**0.5, global_step)
    logger.add_scalar('test/Tau', test_tau, global_step)
    logger.add_scalar('val/GRN', reconstruction_loss, global_step)
    logger.add_scalar('val/GRN_one_hot', one_hot_loss, global_step)
    logger.add_scalar('val/GRN_interger', interger_loss, global_step)
    logger.add_scalar('val/GRN_binary', binary_loss, global_step)
    logger.add_scalar('test/EF0.01', topk_list[0], global_step)
    logger.add_scalar('test/EF0.03', topk_list[1], global_step)
    logger.add_scalar('test/EF0.1', topk_list[2], global_step)
    logger.add_scalar('test/EF10', topk_list[3], global_step)
    logger.add_scalar('test/EF30', topk_list[4], global_step)
    logger.add_scalar('test/EF100', topk_list[5], global_step)
    
#     train_mse_list.append(train_MSE**0.5)
#     train_r2_list.append(train_r2)
#     val_mse_list.append(val_MSE**0.5)  
#     val_r2_list.append(val_r2)
#     train_f_list.append(train_f)
#     val_f_list.append(val_f)
#     test_f_list.append(test_f)
#     epoch_list.append(epoch)
#     train_predict_list.append(train_predict.flatten())
#     test_predict_list.append(test_predict.flatten())
#     val_predict_list.append(val_predict.flatten())
#     train_y_list.append(train_df[tasks[0]].values)
#     val_y_list.append(val_df[tasks[0]].values)
#     test_y_list.append(test_df[tasks[0]].values)
#     train_d_list.append(train_d)
#     val_d_list.append(val_d)
#     test_d_list.append(test_d)

    stop_index = - val_MSE**0.5 + val_tau
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
#     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
    print('Epoch:',epoch, 'Step:', global_step, 'Index:%.4f'%stop_index, 'R2:%.4f'%train_r2,'%.4f'%val_r2,'%.4f'%test_r2_a, 'RMSE:%.4f'%train_MSE**0.5, '%.4f'%val_MSE**0.5, 
          '%.4f'%test_MSE_a**0.5, 'Tau:%.4f'%train_tau,'%.4f'%val_tau,'%.4f'%test_tau)#, 'Tau:%.4f'%val_tau,'%.4f'%test_tau,'GRN:%.4f'%reconstruction_loss,'%.4f'%val_reconstruction_loss
    if early_stop:
        continue


log/3_GAFSE_Ki_P34972_1_500_run_0




Epoch: 136 Step: 28152 Index:-0.3161 R2:0.6469 0.5002 0.4899 RMSE:0.7135 0.8461 0.8250 Tau:0.6125 0.5301 0.4393
Epoch: 137 Step: 28359 Index:-0.2977 R2:0.6518 0.5114 0.5232 RMSE:0.7004 0.8333 0.7941 Tau:0.6182 0.5356 0.4555
EarlyStopping counter: 1 out of 100
Epoch: 138 Step: 28566 Index:-0.3749 R2:0.6338 0.4988 0.4960 RMSE:0.8030 0.9129 0.8872 Tau:0.6080 0.5380 0.4542
Epoch: 139 Step: 28773 Index:-0.2973 R2:0.6458 0.5022 0.4890 RMSE:0.6895 0.8319 0.8151 Tau:0.6129 0.5346 0.4304
Epoch: 140 Step: 28980 Index:-0.2946 R2:0.6588 0.5106 0.5123 RMSE:0.6956 0.8380 0.8089 Tau:0.6227 0.5433 0.4466
Epoch: 141 Step: 29187 Index:-0.2917 R2:0.6596 0.5040 0.5060 RMSE:0.6859 0.8283 0.7931 Tau:0.6213 0.5366 0.4396
Epoch: 142 Step: 29394 Index:-0.2773 R2:0.6594 0.5125 0.5018 RMSE:0.6740 0.8197 0.8020 Tau:0.6208 0.5425 0.4424
EarlyStopping counter: 1 out of 100
Epoch: 143 Step: 29601 Index:-0.2829 R2:0.6635 0.5152 0.5281 RMSE:0.6867 0.8291 0.7900 Tau:0.6267 0.5462 0.4480
EarlyStopping counter: 2 out of 

EarlyStopping counter: 30 out of 100
Epoch: 194 Step: 40158 Index:-0.2612 R2:0.7143 0.5225 0.5201 RMSE:0.6152 0.8123 0.7865 Tau:0.6604 0.5511 0.4522
EarlyStopping counter: 31 out of 100
Epoch: 195 Step: 40365 Index:-0.2749 R2:0.7083 0.5376 0.5339 RMSE:0.6648 0.8347 0.8132 Tau:0.6557 0.5598 0.4537
EarlyStopping counter: 32 out of 100
Epoch: 196 Step: 40572 Index:-0.2574 R2:0.7124 0.5233 0.5485 RMSE:0.6203 0.8106 0.7575 Tau:0.6592 0.5532 0.4482
EarlyStopping counter: 33 out of 100
Epoch: 197 Step: 40779 Index:-0.2664 R2:0.7044 0.5206 0.5370 RMSE:0.6321 0.8128 0.7710 Tau:0.6544 0.5464 0.4567
EarlyStopping counter: 34 out of 100
Epoch: 198 Step: 40986 Index:-0.2527 R2:0.7175 0.5334 0.5409 RMSE:0.6196 0.8083 0.7731 Tau:0.6637 0.5556 0.4546
EarlyStopping counter: 35 out of 100
Epoch: 199 Step: 41193 Index:-0.2632 R2:0.7071 0.5234 0.5242 RMSE:0.6278 0.8141 0.7835 Tau:0.6545 0.5509 0.4416
Epoch: 200 Step: 41400 Index:-0.2422 R2:0.7229 0.5474 0.5339 RMSE:0.6290 0.8051 0.7911 Tau:0.6661 0.5628 0

EarlyStopping counter: 14 out of 100
Epoch: 252 Step: 52164 Index:-0.2630 R2:0.7656 0.5549 0.5336 RMSE:0.6206 0.8266 0.8186 Tau:0.6953 0.5636 0.4519
EarlyStopping counter: 15 out of 100
Epoch: 253 Step: 52371 Index:-0.2299 R2:0.7588 0.5476 0.5467 RMSE:0.5754 0.7919 0.7653 Tau:0.6920 0.5620 0.4687
EarlyStopping counter: 16 out of 100
Epoch: 254 Step: 52578 Index:-0.2458 R2:0.7479 0.5561 0.5251 RMSE:0.6115 0.8092 0.8112 Tau:0.6819 0.5635 0.4573
EarlyStopping counter: 17 out of 100
Epoch: 255 Step: 52785 Index:-0.2330 R2:0.7620 0.5503 0.5313 RMSE:0.5671 0.7914 0.7851 Tau:0.6936 0.5583 0.4553
EarlyStopping counter: 18 out of 100
Epoch: 256 Step: 52992 Index:-0.3466 R2:0.7334 0.5418 0.5418 RMSE:0.7416 0.9042 0.8687 Tau:0.6721 0.5576 0.4598
EarlyStopping counter: 19 out of 100
Epoch: 257 Step: 53199 Index:-0.2284 R2:0.7601 0.5459 0.5224 RMSE:0.5711 0.7950 0.7927 Tau:0.6939 0.5666 0.4516
EarlyStopping counter: 20 out of 100
Epoch: 258 Step: 53406 Index:-0.2180 R2:0.7592 0.5593 0.5674 RMSE:0.5

EarlyStopping counter: 2 out of 100
Epoch: 308 Step: 63756 Index:-0.2316 R2:0.8014 0.5667 0.5623 RMSE:0.5606 0.8053 0.7856 Tau:0.7230 0.5737 0.4686
EarlyStopping counter: 3 out of 100
Epoch: 309 Step: 63963 Index:-0.2026 R2:0.7987 0.5698 0.5561 RMSE:0.5193 0.7708 0.7552 Tau:0.7186 0.5682 0.4706
EarlyStopping counter: 4 out of 100
Epoch: 310 Step: 64170 Index:-0.2174 R2:0.8013 0.5594 0.5568 RMSE:0.5226 0.7843 0.7608 Tau:0.7242 0.5668 0.4768
EarlyStopping counter: 5 out of 100
Epoch: 311 Step: 64377 Index:-0.2187 R2:0.7960 0.5589 0.5556 RMSE:0.5232 0.7829 0.7576 Tau:0.7176 0.5641 0.4755
EarlyStopping counter: 6 out of 100
Epoch: 312 Step: 64584 Index:-0.2494 R2:0.8074 0.5531 0.5556 RMSE:0.5649 0.8137 0.7876 Tau:0.7296 0.5643 0.4798
EarlyStopping counter: 7 out of 100
Epoch: 313 Step: 64791 Index:-0.2097 R2:0.8012 0.5724 0.5612 RMSE:0.5290 0.7828 0.7670 Tau:0.7250 0.5731 0.4784
Epoch: 314 Step: 64998 Index:-0.1947 R2:0.8027 0.5710 0.5638 RMSE:0.5148 0.7694 0.7466 Tau:0.7235 0.5747 0.4726


EarlyStopping counter: 49 out of 100
Epoch: 364 Step: 75348 Index:-0.2804 R2:0.8205 0.5526 0.5355 RMSE:0.5796 0.8453 0.8454 Tau:0.7393 0.5649 0.4768
EarlyStopping counter: 50 out of 100
Epoch: 365 Step: 75555 Index:-0.2439 R2:0.8273 0.5764 0.5481 RMSE:0.5581 0.8198 0.8257 Tau:0.7433 0.5759 0.4785
EarlyStopping counter: 51 out of 100
Epoch: 366 Step: 75762 Index:-0.1943 R2:0.8292 0.5800 0.5578 RMSE:0.4923 0.7701 0.7663 Tau:0.7440 0.5758 0.4799
EarlyStopping counter: 52 out of 100
Epoch: 367 Step: 75969 Index:-0.2060 R2:0.8281 0.5666 0.5527 RMSE:0.4767 0.7775 0.7630 Tau:0.7461 0.5716 0.4730
EarlyStopping counter: 53 out of 100
Epoch: 368 Step: 76176 Index:-0.2615 R2:0.8124 0.5596 0.5320 RMSE:0.5469 0.8327 0.8412 Tau:0.7382 0.5712 0.4734
EarlyStopping counter: 54 out of 100
Epoch: 369 Step: 76383 Index:-0.1969 R2:0.8251 0.5724 0.5550 RMSE:0.4900 0.7730 0.7611 Tau:0.7442 0.5761 0.4852
EarlyStopping counter: 55 out of 100
Epoch: 370 Step: 76590 Index:-0.2047 R2:0.8307 0.5791 0.5356 RMSE:0.5

EarlyStopping counter: 104 out of 100
Epoch: 419 Step: 86733 Index:-0.2249 R2:0.8518 0.5759 0.5475 RMSE:0.4827 0.8001 0.8069 Tau:0.7665 0.5752 0.4851
EarlyStopping counter: 105 out of 100
Epoch: 420 Step: 86940 Index:-0.2089 R2:0.8431 0.5711 0.5378 RMSE:0.4591 0.7785 0.7911 Tau:0.7586 0.5696 0.4788
EarlyStopping counter: 106 out of 100
Epoch: 421 Step: 87147 Index:-0.2137 R2:0.8583 0.5646 0.5533 RMSE:0.4490 0.7828 0.7710 Tau:0.7720 0.5691 0.4799
EarlyStopping counter: 107 out of 100
Epoch: 422 Step: 87354 Index:-0.2326 R2:0.8435 0.5620 0.5487 RMSE:0.4708 0.8050 0.7933 Tau:0.7608 0.5724 0.4814
EarlyStopping counter: 108 out of 100
Epoch: 423 Step: 87561 Index:-0.2408 R2:0.8512 0.5514 0.5440 RMSE:0.4784 0.8042 0.7913 Tau:0.7653 0.5634 0.4820
EarlyStopping counter: 109 out of 100
Epoch: 424 Step: 87768 Index:-0.2327 R2:0.8433 0.5536 0.5401 RMSE:0.4648 0.7990 0.7909 Tau:0.7591 0.5663 0.4853
EarlyStopping counter: 110 out of 100
Epoch: 425 Step: 87975 Index:-0.2300 R2:0.8485 0.5567 0.5326 R

EarlyStopping counter: 30 out of 100
Epoch: 475 Step: 98325 Index:-0.2105 R2:0.8718 0.5743 0.5338 RMSE:0.4399 0.7860 0.8014 Tau:0.7841 0.5754 0.4924
EarlyStopping counter: 31 out of 100
Epoch: 476 Step: 98532 Index:-0.2169 R2:0.8290 0.5696 0.5226 RMSE:0.4936 0.7890 0.8202 Tau:0.7491 0.5722 0.4769
EarlyStopping counter: 32 out of 100
Epoch: 477 Step: 98739 Index:-0.2503 R2:0.8656 0.5767 0.5409 RMSE:0.5054 0.8277 0.8476 Tau:0.7793 0.5775 0.4814
EarlyStopping counter: 33 out of 100
Epoch: 478 Step: 98946 Index:-0.1825 R2:0.8728 0.5828 0.5487 RMSE:0.4267 0.7652 0.7789 Tau:0.7864 0.5827 0.4912
EarlyStopping counter: 34 out of 100
Epoch: 479 Step: 99153 Index:-0.2073 R2:0.8686 0.5745 0.5511 RMSE:0.4403 0.7829 0.7737 Tau:0.7823 0.5756 0.4958
EarlyStopping counter: 35 out of 100
Epoch: 480 Step: 99360 Index:-0.2052 R2:0.8683 0.5741 0.5512 RMSE:0.4275 0.7834 0.7875 Tau:0.7838 0.5782 0.4858
EarlyStopping counter: 36 out of 100
Epoch: 481 Step: 99567 Index:-0.2021 R2:0.8693 0.5742 0.5583 RMSE:0.4

EarlyStopping counter: 5 out of 100
Epoch: 530 Step: 109710 Index:-0.2383 R2:0.8634 0.5612 0.5371 RMSE:0.4518 0.8055 0.8133 Tau:0.7750 0.5672 0.4898
EarlyStopping counter: 6 out of 100
Epoch: 531 Step: 109917 Index:-0.2324 R2:0.8852 0.5569 0.5421 RMSE:0.4214 0.8024 0.8060 Tau:0.7987 0.5700 0.4936
EarlyStopping counter: 7 out of 100
Epoch: 532 Step: 110124 Index:-0.2388 R2:0.8664 0.5488 0.5432 RMSE:0.4398 0.7967 0.7859 Tau:0.7786 0.5579 0.4839
EarlyStopping counter: 8 out of 100
Epoch: 533 Step: 110331 Index:-0.2213 R2:0.8825 0.5713 0.5392 RMSE:0.4160 0.7925 0.8118 Tau:0.7940 0.5712 0.4960
EarlyStopping counter: 9 out of 100
Epoch: 534 Step: 110538 Index:-0.2137 R2:0.8835 0.5723 0.5601 RMSE:0.4046 0.7937 0.7904 Tau:0.7959 0.5800 0.4939
EarlyStopping counter: 10 out of 100
Epoch: 535 Step: 110745 Index:-0.1992 R2:0.8865 0.5768 0.5473 RMSE:0.3906 0.7768 0.7821 Tau:0.7977 0.5776 0.5015
EarlyStopping counter: 11 out of 100
Epoch: 536 Step: 110952 Index:-0.2790 R2:0.8582 0.5454 0.5239 RMSE:0

EarlyStopping counter: 60 out of 100
Epoch: 585 Step: 121095 Index:-0.1799 R2:0.8932 0.5877 0.5487 RMSE:0.3945 0.7650 0.7894 Tau:0.8057 0.5851 0.4937
EarlyStopping counter: 61 out of 100
Epoch: 586 Step: 121302 Index:-0.2258 R2:0.8979 0.5595 0.5547 RMSE:0.3790 0.7947 0.7848 Tau:0.8097 0.5689 0.4969
EarlyStopping counter: 62 out of 100
Epoch: 587 Step: 121509 Index:-0.2207 R2:0.8984 0.5602 0.5436 RMSE:0.3714 0.7882 0.7866 Tau:0.8109 0.5676 0.4942
EarlyStopping counter: 63 out of 100
Epoch: 588 Step: 121716 Index:-0.2705 R2:0.8923 0.5544 0.5453 RMSE:0.4323 0.8361 0.8431 Tau:0.8042 0.5656 0.4931
EarlyStopping counter: 64 out of 100
Epoch: 589 Step: 121923 Index:-0.2156 R2:0.8956 0.5612 0.5570 RMSE:0.3742 0.7869 0.7703 Tau:0.8081 0.5713 0.5034
EarlyStopping counter: 65 out of 100
Epoch: 590 Step: 122130 Index:-0.2433 R2:0.8959 0.5597 0.5390 RMSE:0.4080 0.8133 0.8243 Tau:0.8099 0.5700 0.4931
EarlyStopping counter: 66 out of 100
Epoch: 591 Step: 122337 Index:-0.1971 R2:0.8974 0.5727 0.5473 R

EarlyStopping counter: 115 out of 100
Epoch: 640 Step: 132480 Index:-0.2205 R2:0.8968 0.5558 0.5311 RMSE:0.3755 0.7868 0.7875 Tau:0.8091 0.5663 0.4833
EarlyStopping counter: 116 out of 100
Epoch: 641 Step: 132687 Index:-0.2285 R2:0.9065 0.5530 0.5434 RMSE:0.3544 0.7944 0.7824 Tau:0.8211 0.5660 0.4974
EarlyStopping counter: 117 out of 100
Epoch: 642 Step: 132894 Index:-0.2375 R2:0.9055 0.5486 0.5394 RMSE:0.3546 0.7999 0.7923 Tau:0.8189 0.5624 0.4958
EarlyStopping counter: 118 out of 100
Epoch: 643 Step: 133101 Index:-0.2297 R2:0.9060 0.5566 0.5615 RMSE:0.3608 0.8004 0.7633 Tau:0.8187 0.5707 0.5051
EarlyStopping counter: 119 out of 100
Epoch: 644 Step: 133308 Index:-0.3354 R2:0.8685 0.5045 0.4963 RMSE:0.4291 0.8867 0.8738 Tau:0.7965 0.5512 0.4924
EarlyStopping counter: 120 out of 100
Epoch: 645 Step: 133515 Index:-0.2167 R2:0.8949 0.5767 0.5427 RMSE:0.3906 0.7982 0.8311 Tau:0.8070 0.5815 0.4928
EarlyStopping counter: 121 out of 100
Epoch: 646 Step: 133722 Index:-0.2014 R2:0.9127 0.5640 0

EarlyStopping counter: 171 out of 100
Epoch: 696 Step: 144072 Index:-0.2046 R2:0.9140 0.5793 0.5509 RMSE:0.3544 0.7885 0.8142 Tau:0.8276 0.5839 0.5019
EarlyStopping counter: 172 out of 100
Epoch: 697 Step: 144279 Index:-0.2151 R2:0.9218 0.5611 0.5545 RMSE:0.3256 0.7900 0.7744 Tau:0.8385 0.5750 0.5042
EarlyStopping counter: 173 out of 100
Epoch: 698 Step: 144486 Index:-0.2272 R2:0.9143 0.5570 0.5487 RMSE:0.3372 0.8006 0.7949 Tau:0.8296 0.5734 0.5028
EarlyStopping counter: 174 out of 100
Epoch: 699 Step: 144693 Index:-0.2593 R2:0.9059 0.5375 0.5353 RMSE:0.3680 0.8210 0.8204 Tau:0.8246 0.5617 0.4992
EarlyStopping counter: 175 out of 100
Epoch: 700 Step: 144900 Index:-0.2374 R2:0.9165 0.5557 0.5345 RMSE:0.3555 0.8078 0.8345 Tau:0.8313 0.5704 0.4954
EarlyStopping counter: 176 out of 100
Epoch: 701 Step: 145107 Index:-0.2202 R2:0.9161 0.5617 0.5378 RMSE:0.3342 0.7951 0.8101 Tau:0.8312 0.5750 0.4966
EarlyStopping counter: 177 out of 100
Epoch: 702 Step: 145314 Index:-0.2707 R2:0.9074 0.5425 0

EarlyStopping counter: 226 out of 100
Epoch: 751 Step: 155457 Index:-0.2400 R2:0.9294 0.5553 0.5359 RMSE:0.3256 0.8065 0.8245 Tau:0.8462 0.5665 0.5034
EarlyStopping counter: 227 out of 100
Epoch: 752 Step: 155664 Index:-0.2495 R2:0.9259 0.5498 0.5440 RMSE:0.3147 0.8179 0.8171 Tau:0.8433 0.5684 0.5046
EarlyStopping counter: 228 out of 100
Epoch: 753 Step: 155871 Index:-0.2329 R2:0.9307 0.5549 0.5417 RMSE:0.3231 0.8065 0.8254 Tau:0.8478 0.5736 0.5013
EarlyStopping counter: 229 out of 100
Epoch: 754 Step: 156078 Index:-0.2324 R2:0.9305 0.5505 0.5367 RMSE:0.3126 0.8073 0.8148 Tau:0.8482 0.5750 0.5062
EarlyStopping counter: 230 out of 100
Epoch: 755 Step: 156285 Index:-0.2524 R2:0.9234 0.5421 0.5184 RMSE:0.3210 0.8133 0.8283 Tau:0.8394 0.5609 0.5027
EarlyStopping counter: 231 out of 100
Epoch: 756 Step: 156492 Index:-0.2464 R2:0.9257 0.5490 0.5235 RMSE:0.3192 0.8150 0.8366 Tau:0.8431 0.5686 0.4968
EarlyStopping counter: 232 out of 100
Epoch: 757 Step: 156699 Index:-0.2436 R2:0.9278 0.5562 0

EarlyStopping counter: 281 out of 100
Epoch: 806 Step: 166842 Index:-0.2096 R2:0.9225 0.5651 0.5427 RMSE:0.3224 0.7854 0.7939 Tau:0.8386 0.5758 0.5050
EarlyStopping counter: 282 out of 100
Epoch: 807 Step: 167049 Index:-0.2141 R2:0.9254 0.5640 0.5460 RMSE:0.3153 0.7869 0.7985 Tau:0.8420 0.5729 0.5030
EarlyStopping counter: 283 out of 100
Epoch: 808 Step: 167256 Index:-0.2433 R2:0.9080 0.5534 0.5342 RMSE:0.3747 0.8152 0.8373 Tau:0.8258 0.5719 0.5009
EarlyStopping counter: 284 out of 100
Epoch: 809 Step: 167463 Index:-0.2117 R2:0.9314 0.5673 0.5420 RMSE:0.3144 0.7858 0.8085 Tau:0.8509 0.5741 0.5019
EarlyStopping counter: 285 out of 100
Epoch: 810 Step: 167670 Index:-0.2802 R2:0.9172 0.5501 0.5180 RMSE:0.3707 0.8435 0.8892 Tau:0.8339 0.5634 0.4982
EarlyStopping counter: 286 out of 100
Epoch: 811 Step: 167877 Index:-0.2317 R2:0.9330 0.5585 0.5366 RMSE:0.2971 0.8017 0.8131 Tau:0.8498 0.5700 0.5034
EarlyStopping counter: 287 out of 100
Epoch: 812 Step: 168084 Index:-0.2340 R2:0.9320 0.5540 0

EarlyStopping counter: 336 out of 100
Epoch: 861 Step: 178227 Index:-0.2471 R2:0.9371 0.5441 0.5116 RMSE:0.2888 0.8104 0.8492 Tau:0.8554 0.5634 0.4954
EarlyStopping counter: 337 out of 100
Epoch: 862 Step: 178434 Index:-0.2247 R2:0.9418 0.5550 0.5219 RMSE:0.2794 0.7971 0.8320 Tau:0.8618 0.5724 0.5036
EarlyStopping counter: 338 out of 100
Epoch: 863 Step: 178641 Index:-0.2412 R2:0.9424 0.5532 0.5291 RMSE:0.2773 0.8156 0.8438 Tau:0.8621 0.5744 0.5070
EarlyStopping counter: 339 out of 100
Epoch: 864 Step: 178848 Index:-0.2185 R2:0.9304 0.5568 0.5272 RMSE:0.3032 0.7975 0.8317 Tau:0.8453 0.5790 0.5081
EarlyStopping counter: 340 out of 100
Epoch: 865 Step: 179055 Index:-0.2308 R2:0.9429 0.5562 0.5280 RMSE:0.2887 0.8049 0.8450 Tau:0.8635 0.5741 0.5033
EarlyStopping counter: 341 out of 100
Epoch: 866 Step: 179262 Index:-0.2655 R2:0.9256 0.5428 0.5251 RMSE:0.3151 0.8336 0.8507 Tau:0.8455 0.5681 0.5046
EarlyStopping counter: 342 out of 100
Epoch: 867 Step: 179469 Index:-0.2264 R2:0.9323 0.5535 0

EarlyStopping counter: 392 out of 100
Epoch: 917 Step: 189819 Index:-0.2523 R2:0.9454 0.5560 0.5240 RMSE:0.2898 0.8230 0.8816 Tau:0.8683 0.5707 0.5027
EarlyStopping counter: 393 out of 100
Epoch: 918 Step: 190026 Index:-0.2197 R2:0.9474 0.5657 0.5298 RMSE:0.2643 0.7967 0.8444 Tau:0.8695 0.5769 0.5046
EarlyStopping counter: 394 out of 100
Epoch: 919 Step: 190233 Index:-0.2487 R2:0.9450 0.5597 0.5199 RMSE:0.2834 0.8225 0.8860 Tau:0.8656 0.5738 0.5021
EarlyStopping counter: 395 out of 100
Epoch: 920 Step: 190440 Index:-0.2504 R2:0.9481 0.5467 0.5178 RMSE:0.2646 0.8194 0.8695 Tau:0.8700 0.5690 0.5014
EarlyStopping counter: 396 out of 100
Epoch: 921 Step: 190647 Index:-0.2693 R2:0.9419 0.5380 0.5177 RMSE:0.2963 0.8320 0.8749 Tau:0.8611 0.5628 0.5074
EarlyStopping counter: 397 out of 100
Epoch: 922 Step: 190854 Index:-0.3019 R2:0.9336 0.5393 0.4999 RMSE:0.3308 0.8600 0.9375 Tau:0.8516 0.5582 0.5056
EarlyStopping counter: 398 out of 100
Epoch: 923 Step: 191061 Index:-0.2457 R2:0.9502 0.5492 0

EarlyStopping counter: 447 out of 100
Epoch: 972 Step: 201204 Index:-0.2459 R2:0.9520 0.5547 0.5173 RMSE:0.2614 0.8215 0.8950 Tau:0.8760 0.5756 0.5039
EarlyStopping counter: 448 out of 100
Epoch: 973 Step: 201411 Index:-0.2451 R2:0.9491 0.5466 0.5207 RMSE:0.2643 0.8123 0.8382 Tau:0.8721 0.5672 0.5035
EarlyStopping counter: 449 out of 100
Epoch: 974 Step: 201618 Index:-0.2544 R2:0.9444 0.5448 0.5131 RMSE:0.2709 0.8248 0.8744 Tau:0.8641 0.5704 0.5032
EarlyStopping counter: 450 out of 100
Epoch: 975 Step: 201825 Index:-0.2636 R2:0.9483 0.5389 0.5127 RMSE:0.2636 0.8261 0.8691 Tau:0.8707 0.5625 0.5071
EarlyStopping counter: 451 out of 100
Epoch: 976 Step: 202032 Index:-0.2378 R2:0.9506 0.5450 0.5236 RMSE:0.2622 0.8088 0.8351 Tau:0.8731 0.5710 0.5112
EarlyStopping counter: 452 out of 100
Epoch: 977 Step: 202239 Index:-0.2590 R2:0.9489 0.5412 0.5159 RMSE:0.2602 0.8241 0.8571 Tau:0.8704 0.5651 0.5051
EarlyStopping counter: 453 out of 100
Epoch: 978 Step: 202446 Index:-0.2320 R2:0.9497 0.5544 0

In [15]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_r2, test_MSE, test_predict = eval(model, amodel, gmodel, test_df)
test_r2_a, test_MSE_a, test_predict_a = eval(model, amodel, gmodel, test_df[:test_active])
test_r2_ina, test_MSE_ina, test_predict_ina = eval(model, amodel, gmodel, test_df[test_active:].reset_index(drop=True))
    
test_predict = np.array(test_predict)
test_tau, _ = scipy.stats.kendalltau(test_predict,test_df[tasks[0]].values.astype(float).tolist())

k_list = [int(len(test_df)*0.01),int(len(test_df)*0.05),int(len(test_df)*0.1),int(len(test_df)*0.15),int(len(test_df)*0.2),int(len(test_df)*0.25),
          int(len(test_df)*0.3),int(len(test_df)*0.4),int(len(test_df)*0.5),50,100,150,200,250,300]
topk_list =[]
false_positive_rate_list = []
for k in k_list:
    a,b = topk_acc_recall(test_df, test_predict, k, test_active, False, epoch)
    topk_list.append(a)
    false_positive_rate_list.append(b)
WTI = weighted_top_index(test_df, test_predict, test_active)
ap = AP(test_df, test_predict, test_active)


In [16]:
print(' epoch:',epoch,'r2:%.4f'%test_r2_a,'RMSE:%.4f'%test_MSE_a**0.5,'WTI:%.4f'%WTI,'AP:%.4f'%ap,'Tau:%.4f'%test_tau,'\n','\n',
      'Top-1:%.4f'%topk_list[0],'Top-1-fp:%.4f'%false_positive_rate_list[0],'\n',
      'Top-5:%.4f'%topk_list[1],'Top-5-fp:%.4f'%false_positive_rate_list[1],'\n',
      'Top-10:%.4f'%topk_list[2],'Top-10-fp:%.4f'%false_positive_rate_list[2],'\n',
      'Top-15:%.4f'%topk_list[3],'Top-15-fp:%.4f'%false_positive_rate_list[3],'\n',
      'Top-20:%.4f'%topk_list[4],'Top-20-fp:%.4f'%false_positive_rate_list[4],'\n',
      'Top-25:%.4f'%topk_list[5],'Top-25-fp:%.4f'%false_positive_rate_list[5],'\n',
      'Top-30:%.4f'%topk_list[6],'Top-30-fp:%.4f'%false_positive_rate_list[6],'\n',
      'Top-40:%.4f'%topk_list[7],'Top-40-fp:%.4f'%false_positive_rate_list[7],'\n',
      'Top-50:%.4f'%topk_list[8],'Top-50-fp:%.4f'%false_positive_rate_list[8],'\n','\n',
      'Top50:%.4f'%topk_list[9],'Top50-fp:%.4f'%false_positive_rate_list[9],'\n',
      'Top100:%.4f'%topk_list[10],'Top100-fp:%.4f'%false_positive_rate_list[10],'\n',
      'Top150:%.4f'%topk_list[11],'Top150-fp:%.4f'%false_positive_rate_list[11],'\n',
      'Top200:%.4f'%topk_list[12],'Top200-fp:%.4f'%false_positive_rate_list[12],'\n',
      'Top250:%.4f'%topk_list[13],'Top250-fp:%.4f'%false_positive_rate_list[13],'\n',
      'Top300:%.4f'%topk_list[14],'Top300-fp:%.4f'%false_positive_rate_list[14],'\n')

 epoch: 1000 r2:0.5567 RMSE:0.7624 WTI:0.3579 AP:0.6780 Tau:0.4986 
 
 Top-1:0.3077 Top-1-fp:0.0000 
 Top-5:0.5385 Top-5-fp:0.0462 
 Top-10:0.6336 Top-10-fp:0.0840 
 Top-15:0.7056 Top-15-fp:0.1066 
 Top-20:0.7290 Top-20-fp:0.1565 
 Top-25:0.7134 Top-25-fp:0.2195 
 Top-30:0.7132 Top-30-fp:0.2563 
 Top-40:0.7020 Top-40-fp:0.3314 
 Top-50:0.7920 Top-50-fp:0.3973 
 
 Top50:0.5000 Top50-fp:0.0600 
 Top100:0.6300 Top100-fp:0.0600 
 Top150:0.6533 Top150-fp:0.0867 
 Top200:0.7050 Top200-fp:0.1050 
 Top250:0.7240 Top250-fp:0.1560 
 Top300:0.7067 Top300-fp:0.1967 



In [17]:
# print('target_file:',train_filename)
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)

In [18]:
# loss = loss_function(mol_prediction,y)
#             loss.backward(retain_graph=True)
#             optimizer_AFSE.zero_grad()
#             punish_lr = torch.norm(torch.mean(eps.grad,0))

#         init_lr = 1e-4
#         max_lr = 10**-(init_lr-1)
#         conv_lr = conv_lr - conv_lr**2 + 0.1 * punish_lr
#         if conv_lr < max_lr:
#             for param_group in optimizer_AFSE.param_groups:
#                 param_group["lr"] = conv_lr.detach()
#                 AFSE_lr = conv_lr    
#         else:
#             for param_group in optimizer_AFSE.param_groups:
#                 param_group["lr"] = max_lr
#                 AFSE_lr = max_lr
# epoch: 512 r2:0.5619 RMSE:0.7306 WTI:0.3784 AP:0.7228 Tau:0.5159 
 
#  Top-1:0.0000 Top-1-fp:0.0000 
#  Top-5:0.8571 Top-5-fp:0.0000 
#  Top-10:0.7857 Top-10-fp:0.0714 