In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
raw_filename = ["./data/ADMET/A/C/Pgp-inhibitor.csv", "./data/ADMET/A/C/Pgp-substrate.csv", "./data/ADMET/D/C/BBB_Penetration.csv",
               "./data/ADMET/M/C/CYP1A2_inhibitor.csv", "./data/ADMET/M/C/CYP1A2_substrate.csv", "./data/ADMET/M/C/CYP2C9_inhibitor.csv",
               "./data/ADMET/M/C/CYP2C9_substrate.csv", "./data/ADMET/M/C/CYP3A4_inhibitor.csv", "./data/ADMET/M/C/CYP3A4_substrate.csv",
               "./data/ADMET/T/C/Ames.csv",
               "./data/ADMET/T/C/Eye_Corrosion.csv", "./data/ADMET/T/C/FDAMDD.csv",
               "./data/ADMET/T/C/NR-AhR.csv", "./data/ADMET/T/C/NR-AR-LBD.csv", "./data/ADMET/T/C/NR-AR.csv",
               "./data/ADMET/T/C/NR-Aromatase.csv", "./data/ADMET/T/C/NR-ER-LBD.csv", "./data/ADMET/T/C/NR-ER.csv",
               "./data/ADMET/T/C/NR-PPAR-gamma.csv", "./data/ADMET/T/C/Skin_Sensitization.csv", "./data/ADMET/T/C/SR-ARE.csv",
               "./data/ADMET/T/C/SR-ATAD5.csv", "./data/ADMET/T/C/SR-HSE.csv", "./data/ADMET/T/C/SR-MMP.csv",
               "./data/ADMET/T/C/SR-p53.csv"]

task_id = [3,5,7,9,14,16,21] # sample size > 7000
task_num = len(task_id)
raw_filename = [raw_filename[i] for i in task_id]
random_seed = 68
file_name = f'Multi_Tasks_Large'
# for i in range(task_num):
#     file_list = raw_filename[i].split('/')
#     file = '_'+file_list[-3]+'_'+file_list[-1]
#     file_name += file[:-4]
    
number = 'run_0'
model_file = "model_file/3C_GAFSE_"+file_name+'_'+number
log_dir = f'log/{"3C_GAFSE_"+file_name}_'+number
result_dir = './result/3C_GAFSE_'+file_name+'_'+number
print(raw_filename)
print(file_name)
print(model_file)

['./data/ADMET/M/C/CYP1A2_inhibitor.csv', './data/ADMET/M/C/CYP2C9_inhibitor.csv', './data/ADMET/M/C/CYP3A4_inhibitor.csv', './data/ADMET/T/C/Ames.csv', './data/ADMET/T/C/NR-AR.csv', './data/ADMET/T/C/NR-ER-LBD.csv', './data/ADMET/T/C/SR-ATAD5.csv']
Multi_Tasks_Large
model_file/3C_GAFSE_Multi_Tasks_Large_run_0


In [4]:
tasks = ['value']
total_df = pd.DataFrame([])

for i in range(task_num):
    task_df = pd.read_csv(raw_filename[i], header=0, names = ["smiles", "dataset", "value"],usecols=[0,1,2])
    task_df["task_id"] = i
    total_df = pd.concat([total_df,task_df])
    
print(total_df[:3],total_df[-3:])

def add_canonical_smiles(total_df):
    smilesList = total_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    total_df = total_df[total_df["smiles"].isin(remained_smiles)]
    total_df['cano_smiles'] =canonical_smiles_list
    return total_df

total_df = add_canonical_smiles(total_df)
print(total_df.head())

                                           smiles   dataset value  task_id
0        Clc1c(COc2c(C(=O)OC)sc3ncccc23)c(Cl)ccc1  training     1        0
1        N#Cc1cc(-c2cc3c(NCc4cnccc4)ncnc3cc2)ccc1  training     1        0
2  O(Cc1ccccc1)c1nc2N(Cc3cc(OC)ccc3)C(=O)C=Nc2cn1  training     1        0                                               smiles dataset value  task_id
7167                                 o1c2c(cc1)cccc2     val     0        6
7168  O=[N+]([O-])c1cc(C(=O)OC(C)C)cc(C(=O)OC(C)C)c1     val     0        6
7169                         O=Cc1c(OC)cc(OC)c(OC)c1     val     0        6
number of all smiles:  67202
SMILES
ON(=O)c1ccc(Cl)cc1
Nc1ccc(cc1)N(=O)O
CCCN(=O)O
NC(COC(=O)C=N=[N-])C(=O)O
Nc1ccc(O)c(c1)N(=O)O
Cc1ccc(cc1N(=O)O)N(=O)O
Nc1ncc(s1)N(=O)O
Nc1ccc(cc1O)N(=O)O
Cc1ccc2C(=O)c3ccccc3C(=O)c2c1N(=O)O
CCNC(=O)Nc1ncc(s1)N(=O)O
CCCN(CCC)c1c(cc(cc1N(=O)O)C(F)(F)F)N(=O)O
CCOc1ccc(NC(=O)C)cc1N(=O)O
ON(=O)c1ccc(Oc2ccc(Cl)cc2Cl)cc1
CC1CN(N=Cc2ccc(o2)N(=O)O)C(=O)N1
CCN(Cc1cccc

CN1C[N]C2=C1N=C=N(=C2N)O
CC(O)CN(=[NH+]C(=O)C(=C)C)(C)C
CC(O)CN(=[NH+]C(=O)C(=C)C)(C)C
CN1C[N]C2=C1N=C=N(O)=C2N
number of successfully processed smiles:  66858
                                           smiles   dataset value  task_id  \
0        Clc1c(COc2c(C(=O)OC)sc3ncccc23)c(Cl)ccc1  training     1        0   
1        N#Cc1cc(-c2cc3c(NCc4cnccc4)ncnc3cc2)ccc1  training     1        0   
2  O(Cc1ccccc1)c1nc2N(Cc3cc(OC)ccc3)C(=O)C=Nc2cn1  training     1        0   
3       S(CC(=O)OCC)c1nc(NC(=O)c2cc(C)ccc2)[nH]n1  training     1        0   
4         Brc1ccc(C2Nc3c(NC4=C2C(=O)NC4)cccc3)cc1  training     1        0   

                                   cano_smiles  
0         COC(=O)c1sc2ncccc2c1OCc1c(Cl)cccc1Cl  
1     N#Cc1cccc(-c2ccc3ncnc(NCc4cccnc4)c3c2)c1  
2  COc1cccc(Cn2c(=O)cnc3cnc(OCc4ccccc4)nc32)c1  
3      CCOC(=O)CSc1n[nH]c(NC(=O)c2cccc(C)c2)n1  
4        O=C1NCC2=C1C(c1ccc(Br)cc1)Nc1ccccc1N2  


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 3 # default: 2
T = 2
per_task_output_units_num = 1 # for regression model
output_units_num = task_num * per_task_output_units_num

In [6]:
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
feature_filename = './features/'+model_file.split('/')[-1][:-1]+'0.pickle'
filename = './features/'+model_file.split('/')[-1]
print(feature_filename)
if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
    print('Loading features successfully.')
else:
    feature_dicts = save_smiles_dicts(total_smilesList,filename)

66858
./features/3C_GAFSE_Multi_Tasks_Large_run_0.pickle
Loading features successfully.


In [7]:
test_df = total_df[total_df.dataset.values == "test"]
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)

val_df = total_df[total_df.dataset.values == "val"]
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)

train_df = total_df[total_df.dataset.values == "training"]
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)

print(total_df.shape, len(train_df)+len(val_df)+len(test_df), train_df.shape,val_df.shape,test_df.shape)

(66858, 5) 66855 (54184, 5) (6588, 5) (6083, 5)


In [8]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [9]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [10]:
def caculate_r2(predict,y):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
#     print(len(y), y_mean, len(predict), predict_mean)
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/(y2+1e-9)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score, roc_auc_score
def calc(TN, FP, FN, TP):
    SN = TP / (TP + FN)  # recall
    SP = TN / (TN + FP)
    # Precision = TP / (TP + FP)
    ACC = (TP + TN) / (TP + TN + FN + FP)
    # F1 = (2 * TP) / (2 * TP + FP + FN)
    fz = TP * TN - FP * FN
    fm = (TP + FN) * (TP + FP) * (TN + FP) * (TN + FN)
    MCC = fz / (pow(fm, 0.5)+1e-9)
    return SN, SP, ACC, MCC

In [11]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None, task_id=None, sigmoid=False):
    mol_prediction = model(feature=f, d=0, sigmoid=sigmoid)
    pred = mol_prediction.detach()
    task_counter = 0
    vat_loss = 0
    for i in range(task_num):
        batch_task_sample = len(task_id[task_id==i])
        if batch_task_sample > 0:
            task_counter += 1
            y_mask = np.where(task_id==i, 1, 0)
            y_mask = torch.Tensor(y_mask)
        #     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
            eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
            eps = Variable(eps, requires_grad=True)
            # Predict on randomly perturbed image
            eps_p = model(feature=f, d=eps.cuda(), sigmoid=sigmoid)
            eps_p_ = model(feature=f, d=-eps.cuda(), sigmoid=sigmoid)
            p_aux = nn.Sigmoid()(eps_p[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            p_aux_ = nn.Sigmoid()(eps_p_[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
        #     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
            loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
            loss.backward(retain_graph=True)

            # Based on perturbed image, get direction of greatest error
            eps_adv = eps.grad#/10**-learning_rate
            optimizer_AFSE.zero_grad()
            # Use that direction as adversarial perturbation
            eps_adv_normed = normalize_perturbation(eps_adv)
            d_adv = lamda * eps_adv_normed.cuda()
            f_p, max_lr = model(feature=f, d=d_adv, output_lr=True, sigmoid=sigmoid)
            f_p = model(feature=f, d=d_adv, sigmoid=sigmoid)
            f_p_ = model(feature=f, d=-d_adv, sigmoid=sigmoid)
            p = nn.Sigmoid()(f_p[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            p_ = nn.Sigmoid()(f_p_[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            vat_loss += loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    vat_loss /= task_counter
    if output_lr:
        if output_plr:
            loss = 0
            task_counter = 0
            eps_ = 1e-6 * normalize_perturbation(torch.randn(f.shape))
            eps_ = Variable(eps_, requires_grad=True)
            eps_p__ = model(feature=f+eps_.cuda(), d=0)
            for i in range(task_num):
                batch_task_sample = len(task_id[task_id==i])
                if batch_task_sample > 0:
                    task_counter += 1
                    y_mask = np.where(task_id==i, 1, 0)
                    y_mask = torch.Tensor(y_mask)
                    loss += loss_function(eps_p__[:,i]*y_mask,y.view(-1)*y_mask)
            loss /= task_counter
            loss.backward(retain_graph=True)
            punish_lr = torch.norm(torch.mean(eps_.grad,0))
            optimizer_AFSE.zero_grad()
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [12]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25):
    weighted_CE = nn.CrossEntropyLoss(weight=alpha, reduction='none')
    focal_weight = (1-torch.max(pred * target, -1)[0])**gamma
    loss = focal_weight * weighted_CE(pred, torch.argmax(target,-1))
    loss = torch.mean(loss)
    return loss

def generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list):
    [a,b,c] = x_atom.shape
    [d,e,f,g] = bond_neighbor.shape
    ce_loss = nn.CrossEntropyLoss()
    one_hot_loss = 0
    run_times = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        atom_weights = 1-x_atom[i,:l,:16].sum(-2)/l
        ce_atom_loss = nn.CrossEntropyLoss(weight=atom_weights)
        # print(atom_weights[1], refer_atom_list[i,0,torch.argmax(x_atom[i,0,:16],-1)], torch.argmax(x_atom[i,0,:16],-1))
#         one_hot_loss += ce_atom_loss(refer_atom_list[i,:l,:16], torch.argmax(x_atom[i,:l,:16],-1))- \
#                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        one_hot_loss += py_sigmoid_focal_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16], alpha=atom_weights)- \
                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        run_times += 2
    total_loss = one_hot_loss/run_times
    return total_loss, 0, 0, 0


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(int)
        task_id = batch_df['task_id'].values
        task_id_test = batch_test['task_id'].values
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
#         refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
#                                                   torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
#                                                   mol_feature=mol_feature,activated_features=activated_features.detach())
        
#         x_atom = torch.Tensor(x_atom)
#         x_bonds = torch.Tensor(x_bonds)
#         x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
#         bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
#         bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1,                                                                                lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1),
                                                                                       task_id=task_id, sigmoid=True) # 10**-learning_rate     
        classification_loss = 0
        task_counter = 0
        for i in range(task_num):
            batch_task_sample = len(batch_df[batch_df["task_id"].values==i])
            if batch_task_sample > 0:
                task_counter += 1
                y_mask = np.where(batch_df["task_id"].values==i, 1, 0)
                y_mask = torch.Tensor(y_mask)
                c_loss = - torch.Tensor(y_val) * torch.log(mol_prediction[:,i]+1e-9) - \
                                (1-torch.Tensor(y_val)) * torch.log((1-mol_prediction[:,i])+1e-9)
                classification_loss += torch.sum(c_loss * y_mask)/batch_task_sample
        classification_loss /= task_counter
#         atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
#                                       torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
#         success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
#                             bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
#                                                      refer_atom_list, refer_bond_list,topn=1)
#         reconstruction_loss, one_hot_loss, interger_loss,binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, 
#                                                                                               bond_neighbor, validity_mask, atom_list, 
#                                                                                               bond_list)
#         x_atom_test = torch.Tensor(x_atom_test)
#         x_bonds_test = torch.Tensor(x_bonds_test)
#         x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
#         bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
#         bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate, task_id=task_id_test, sigmoid=True)
#         atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
#                                                 torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
#                                                 mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
#         refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
#                                                             torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
#                                                             mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
#         success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
#                             bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
#                                                      refer_atom_list_test, refer_bond_list_test,topn=1)
#         test_reconstruction_loss, test_one_hot_loss, test_interger_loss,test_binary_loss = generate_loss_function(atom_list_test, x_atom_test, bond_list_test, bond_neighbor_test, validity_mask_test, atom_list_test, bond_list_test)
            
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        max_lr = 1e-3
        adapt_lr = conv_lr - conv_lr**2 + 0.06 * punish_lr # 0.06
        if adapt_lr < max_lr and adapt_lr >= 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = adapt_lr.detach()
                AFSE_lr = adapt_lr    
        elif adapt_lr < 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = 0
                AFSE_lr = 0
        elif adapt_lr >= max_lr:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = max_lr
                AFSE_lr = max_lr
#         AFSE_lr = 1e-4

        logger.add_scalar('loss/classification', classification_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
#         logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
#         logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
#         logger.add_scalar('loss/GRN_one_hot', one_hot_loss, global_step)
#         logger.add_scalar('loss/GRN_interger', interger_loss, global_step)
#         logger.add_scalar('loss/GRN_binary', binary_loss, global_step)
        logger.add_scalar('lr/conv_lr', conv_lr, global_step)
        logger.add_scalar('lr/punish_lr', punish_lr, global_step)
        logger.add_scalar('lr/AFSE_lr', AFSE_lr, global_step)
        
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  classification_loss + 0.08 * (vat_loss + test_vat_loss) # + 0.3 * (reconstruction_loss + test_reconstruction_loss)
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, cth_list=[0.5 for i in range(task_num)], topn=1, output_feature=False, 
         generate=False, modify_atom=True,return_GRN_loss=False, viz=False, output_cth=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    test_MAE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(int)
        task_id = batch_df['task_id'].values
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
#         x_atom = torch.Tensor(x_atom)
#         x_bonds = torch.Tensor(x_bonds)
#         x_bond_index = torch.cuda.LongTensor(x_bond_index)
#         bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
#         bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda, task_id=task_id, sigmoid=True)
        mol_prediction = mol_prediction.cpu().detach().numpy()
        mol_prediction_readout = mol_prediction[:,0]
        for i in range(task_num):
            mol_prediction_readout[batch_df['task_id'].values==i] = mol_prediction[batch_df['task_id'].values==i, i]
#         print(mol_feature,d_adv)
#         atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
#                                       torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
#                                       torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
#         refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
#         if generate:
#             if modify_atom:
#                 success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
#                             bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
#                                                      refer_atom_list, refer_bond_list,topn=topn,viz=viz)
#             else:
#                 modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
#             generated_smiles.extend(modified_smiles)
#             success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
#             success_reconstruction += reconstruction
#             success_validity += validity
#             reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction_readout)
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
    
    predict_list = np.array(predict_list)
    task_lens_list = []
    task_accumulated = 0
    for i in range(task_num):
        task_accumulated += len(dataset[dataset["task_id"]==i])
        task_lens_list.append(task_accumulated)
    task_predict_list = np.split(predict_list, task_lens_list)
    task_y_val_list = np.split(dataset[tasks[0]].values.astype(float), task_lens_list)
    
    auc_list = []
    if output_cth:
        best_cth_list = [0.5 for i in range(task_num)]
        best_sn_list = [0 for i in range(task_num)]
        best_sp_list = [0 for i in range(task_num)]
        best_acc_list = [0 for i in range(task_num)]
        best_mcc_list = [0 for i in range(task_num)]
        for cth in np.linspace(0,1,21):
            for i in range(task_num):
                task_predict = task_predict_list[i]
                task_y_val = task_y_val_list[i]
                if cth == 0:
                    auc = roc_auc_score(task_y_val, task_predict)
                    auc_list.append(auc)
                class_pred = np.where(task_predict>cth,1,0).astype(int)
                tn, fp, fn, tp = confusion_matrix(task_y_val, class_pred).ravel()
                sn, sp, acc, mcc = calc(tn, fp, fn, tp)
                mean_index = (sn + sp + acc + mcc)/4
                best_index = (best_sn_list[i] + best_sp_list[i] + best_acc_list[i] + best_mcc_list[i])/4
                if mean_index > best_index:
                    best_cth_list[i] = cth
                    best_sn_list[i] = sn
                    best_sp_list[i] = sp
                    best_acc_list[i] = acc
                    best_mcc_list[i] = mcc
        sn_list = best_sn_list
        sp_list = best_sp_list
        acc_list = best_acc_list
        mcc_list = best_mcc_list
    else:
        sn_list = []
        sp_list = []
        acc_list = []
        mcc_list = []
        for i in range(task_num):
            task_predict = task_predict_list[i]
            task_y_val = task_y_val_list[i]
            auc = roc_auc_score(task_y_val, task_predict)
            class_pred = np.where(task_predict>cth_list[i],1,0).astype(int)
            tn, fp, fn, tp = confusion_matrix(task_y_val, class_pred).ravel()
            sn, sp, acc, mcc = calc(tn, fp, fn, tp)
            auc_list.append(auc)
            sn_list.append(sn)
            sp_list.append(sp)
            acc_list.append(acc)
            mcc_list.append(mcc)

#     print(r2_list)
#     if generate:
#         generated_num = len(generated_smiles)
#         eval_num = len(dataset)
#         unique = generated_num
#         novelty = generated_num
#         for i in range(generated_num):
#             for j in range(generated_num-i-1):
#                 if generated_smiles[i]==generated_smiles[i+j+1]:
#                     unique -= 1
#             for k in range(eval_num):
#                 if generated_smiles[i]==dataset['smiles'].values[k]:
#                     novelty -= 1
#         unique_rate = unique/(generated_num+1e-9)
#         novelty_rate = novelty/(generated_num+1e-9)
# #         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
#         return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
#     if return_GRN_loss:
#         return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),np.array(test_MAE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list, auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list
    if output_cth:
        return auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list, best_cth_list
    return auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 100
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [13]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
if os.path.isdir(log_dir):
    for files in os.listdir(log_dir):
        os.remove(log_dir+"/"+files)
    os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

log/3C_GAFSE_Multi_Tasks_Large_run_0


In [None]:
# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = patience - patience
# stopper.load_checkpoint(model)
# stopper_afse.load_checkpoint(amodel)
# stopper_generate.load_checkpoint(gmodel)
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
    train_auc, train_sn, train_sp, train_acc, train_mcc, train_predict = eval(model, amodel, gmodel, train_df)
    val_auc, val_sn, val_sp, val_acc, val_mcc, val_predict, val_cth_list = eval(model, amodel, gmodel, val_df, output_cth= True)
    test_auc, test_sn, test_sp, test_acc, test_mcc, test_predict = eval(model, amodel, gmodel, test_df, cth_list=val_cth_list)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    for i in range(task_num):
        logger.add_scalar(f'train/T{i}_auc', train_auc[i], global_step)
        logger.add_scalar(f'train/T{i}_sn', train_sn[i], global_step)
        logger.add_scalar(f'train/T{i}_sp', train_sp[i], global_step)
        logger.add_scalar(f'train/T{i}_acc', train_acc[i], global_step)
        logger.add_scalar(f'train/T{i}_mcc', train_mcc[i], global_step)
        logger.add_scalar(f'val/T{i}_cth', val_cth_list[i], global_step)
        logger.add_scalar(f'val/T{i}_auc', val_auc[i], global_step)
        logger.add_scalar(f'val/T{i}_sn', val_sn[i], global_step)
        logger.add_scalar(f'val/T{i}_sp', val_sp[i], global_step)
        logger.add_scalar(f'val/T{i}_acc', val_acc[i], global_step)
        logger.add_scalar(f'val/T{i}_mcc', val_mcc[i], global_step)
        logger.add_scalar(f'test/T{i}_auc', test_auc[i], global_step)
        logger.add_scalar(f'test/T{i}_sn', test_sn[i], global_step)
        logger.add_scalar(f'test/T{i}_sp', test_sp[i], global_step)
        logger.add_scalar(f'test/T{i}_acc', test_acc[i], global_step)
        logger.add_scalar(f'test/T{i}_mcc', test_mcc[i], global_step)
#         logger.add_scalar(f'val/GRN', reconstruction_loss, global_step)
#         logger.add_scalar(f'val/GRN_one_hot', one_hot_loss, global_step)
#         logger.add_scalar(f'val/GRN_interger', interger_loss, global_step)
#         logger.add_scalar(f'val/GRN_binary', binary_loss, global_step)
        # logger.add_scalar('test/EF0.01', topk_list[0], global_step)
        # logger.add_scalar('test/EF0.03', topk_list[1], global_step)
        # logger.add_scalar('test/EF0.1', topk_list[2], global_step)
        # logger.add_scalar('test/EF10', topk_list[3], global_step)
        # logger.add_scalar('test/EF30', topk_list[4], global_step)
        # logger.add_scalar('test/EF100', topk_list[5], global_step)

    #     train_mse_list.append(train_MSE**0.5)
    #     train_r2_list.append(train_r2)
    #     val_mse_list.append(val_MSE**0.5)  
    #     val_r2_list.append(val_r2)
    #     train_f_list.append(train_f)
    #     val_f_list.append(val_f)
    #     test_f_list.append(test_f)
    #     epoch_list.append(epoch)
    #     train_predict_list.append(train_predict.flatten())
    #     test_predict_list.append(test_predict.flatten())
    #     val_predict_list.append(val_predict.flatten())
    #     train_y_list.append(train_df[tasks[0]].values)
    #     val_y_list.append(val_df[tasks[0]].values)
    #     test_y_list.append(test_df[tasks[0]].values)
    #     train_d_list.append(train_d)
    #     val_d_list.append(val_d)
    #     test_d_list.append(test_d)
    #     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
        print('Epoch:',epoch, 'Task:', i+1,
              'auc:%.3f'%train_auc[i],'%.3f'%val_auc[i],'%.3f'%test_auc[i], 
              'sn:%.3f'%train_sn[i],'%.3f'%val_sn[i],'%.3f'%test_sn[i], 
              'sp:%.3f'%train_sp[i], '%.3f'%val_sp[i], '%.3f'%test_sp[i], 
              'acc:%.3f'%train_acc[i], '%.3f'%val_acc[i], '%.3f'%test_acc[i], 
              'mcc:%.3f'%train_mcc[i],'%.3f'%val_mcc[i],'%.3f'%test_mcc[i])
    
    stop_index = np.mean(val_auc) +  np.mean(val_sn) +  np.mean(val_sp) +  np.mean(val_acc) +  np.mean(val_mcc)
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
    
    if early_stop:
        continue


Epoch: 1 Task: 1 auc:0.885 0.890 0.877 sn:0.809 0.776 0.757 sp:0.794 0.842 0.837 acc:0.801 0.812 0.799 mcc:0.602 0.621 0.596
Epoch: 1 Task: 2 auc:0.822 0.816 0.827 sn:0.312 0.744 0.751 sp:0.950 0.723 0.741 acc:0.738 0.730 0.744 mcc:0.359 0.444 0.469
Epoch: 1 Task: 3 auc:0.804 0.797 0.805 sn:0.709 0.835 0.861 sp:0.718 0.640 0.641 acc:0.714 0.720 0.732 mcc:0.421 0.470 0.499
Epoch: 1 Task: 4 auc:0.708 0.721 0.669 sn:0.586 0.962 0.914 sp:0.722 0.241 0.241 acc:0.648 0.848 0.578 mcc:0.309 0.292 0.210
Epoch: 1 Task: 5 auc:0.851 0.774 0.874 sn:0.099 0.444 0.654 sp:1.000 0.987 0.989 acc:0.967 0.967 0.977 mcc:0.301 0.487 0.655
Epoch: 1 Task: 6 auc:0.751 0.715 0.810 sn:0.000 0.314 0.091 sp:1.000 0.957 0.952 acc:0.951 0.925 0.912 mcc:0.000 0.255 0.042
Epoch: 1 Task: 7 auc:0.751 0.735 0.723 sn:0.000 0.360 0.440 sp:1.000 0.889 0.883 acc:0.965 0.870 0.868 mcc:0.000 0.140 0.177
Epoch: 2 Task: 1 auc:0.893 0.895 0.887 sn:0.806 0.838 0.818 sp:0.815 0.812 0.770 acc:0.811 0.824 0.792 mcc:0.620 0.649 0.586


Epoch: 11 Task: 1 auc:0.927 0.915 0.909 sn:0.833 0.821 0.798 sp:0.860 0.867 0.856 acc:0.847 0.846 0.829 mcc:0.693 0.689 0.655
Epoch: 11 Task: 2 auc:0.895 0.876 0.890 sn:0.728 0.806 0.841 sp:0.880 0.811 0.813 acc:0.829 0.809 0.822 mcc:0.612 0.595 0.627
Epoch: 11 Task: 3 auc:0.905 0.889 0.896 sn:0.670 0.809 0.794 sp:0.902 0.805 0.796 acc:0.807 0.807 0.795 mcc:0.598 0.608 0.585
Epoch: 11 Task: 4 auc:0.898 0.865 0.837 sn:0.841 0.809 0.767 sp:0.798 0.759 0.759 acc:0.822 0.801 0.763 mcc:0.640 0.462 0.526
Epoch: 11 Task: 5 auc:0.884 0.796 0.910 sn:0.535 0.407 0.692 sp:0.999 0.994 0.997 acc:0.982 0.973 0.986 mcc:0.703 0.534 0.783
Epoch: 11 Task: 6 auc:0.885 0.824 0.930 sn:0.369 0.514 0.606 sp:0.993 0.951 0.957 acc:0.963 0.929 0.940 mcc:0.505 0.390 0.467
Epoch: 11 Task: 7 auc:0.862 0.810 0.746 sn:0.050 0.400 0.240 sp:1.000 0.944 0.934 acc:0.967 0.925 0.909 mcc:0.209 0.250 0.123
Epoch: 12 Task: 1 auc:0.927 0.912 0.909 sn:0.826 0.834 0.833 sp:0.870 0.832 0.823 acc:0.850 0.833 0.828 mcc:0.697 0.66

Epoch: 21 Task: 1 auc:0.946 0.925 0.914 sn:0.832 0.841 0.835 sp:0.903 0.846 0.829 acc:0.870 0.844 0.832 mcc:0.739 0.687 0.663
Epoch: 21 Task: 2 auc:0.926 0.895 0.902 sn:0.729 0.831 0.856 sp:0.919 0.816 0.808 acc:0.856 0.821 0.824 mcc:0.669 0.622 0.634
Epoch: 21 Task: 3 auc:0.932 0.905 0.904 sn:0.806 0.835 0.843 sp:0.888 0.811 0.813 acc:0.854 0.821 0.825 mcc:0.698 0.638 0.649
Epoch: 21 Task: 4 auc:0.927 0.920 0.896 sn:0.828 0.923 0.914 sp:0.863 0.716 0.716 acc:0.844 0.890 0.815 mcc:0.689 0.610 0.642
Epoch: 21 Task: 5 auc:0.912 0.772 0.898 sn:0.601 0.407 0.692 sp:0.997 0.992 0.991 acc:0.982 0.970 0.981 mcc:0.718 0.499 0.711
Epoch: 21 Task: 6 auc:0.931 0.797 0.920 sn:0.409 0.429 0.333 sp:0.998 0.996 0.993 acc:0.970 0.968 0.961 mcc:0.601 0.584 0.462
Epoch: 21 Task: 7 auc:0.937 0.837 0.867 sn:0.231 0.720 0.800 sp:0.999 0.876 0.892 acc:0.972 0.870 0.889 mcc:0.440 0.310 0.374
Epoch: 22 Task: 1 auc:0.942 0.918 0.908 sn:0.865 0.910 0.891 sp:0.877 0.755 0.741 acc:0.872 0.827 0.811 mcc:0.742 0.66

Epoch: 31 Task: 1 auc:0.959 0.929 0.925 sn:0.895 0.846 0.840 sp:0.885 0.864 0.842 acc:0.890 0.856 0.841 mcc:0.779 0.710 0.682
Epoch: 31 Task: 2 auc:0.938 0.889 0.897 sn:0.798 0.876 0.876 sp:0.909 0.769 0.783 acc:0.872 0.804 0.814 mcc:0.710 0.611 0.625
Epoch: 31 Task: 3 auc:0.943 0.904 0.908 sn:0.829 0.815 0.773 sp:0.886 0.836 0.843 acc:0.863 0.827 0.814 mcc:0.717 0.647 0.617
Epoch: 31 Task: 4 auc:0.943 0.938 0.926 sn:0.779 0.873 0.871 sp:0.929 0.888 0.888 acc:0.847 0.875 0.879 mcc:0.708 0.644 0.759
Epoch: 31 Task: 5 auc:0.936 0.799 0.902 sn:0.554 0.407 0.615 sp:0.999 0.997 1.000 acc:0.983 0.975 0.986 mcc:0.728 0.577 0.779
Epoch: 31 Task: 6 auc:0.949 0.799 0.873 sn:0.496 0.400 0.333 sp:0.993 0.984 0.981 acc:0.969 0.955 0.950 mcc:0.614 0.451 0.366
Epoch: 31 Task: 7 auc:0.957 0.828 0.908 sn:0.266 0.560 0.480 sp:0.999 0.945 0.954 acc:0.974 0.932 0.937 mcc:0.495 0.357 0.332
EarlyStopping counter: 1 out of 100
Epoch: 32 Task: 1 auc:0.960 0.929 0.919 sn:0.841 0.875 0.844 sp:0.928 0.843 0.841 

Epoch: 40 Task: 1 auc:0.966 0.931 0.921 sn:0.900 0.877 0.859 sp:0.898 0.842 0.820 acc:0.899 0.858 0.838 mcc:0.797 0.717 0.678
Epoch: 40 Task: 2 auc:0.948 0.893 0.906 sn:0.856 0.799 0.826 sp:0.898 0.833 0.837 acc:0.884 0.822 0.833 mcc:0.743 0.614 0.642
Epoch: 40 Task: 3 auc:0.951 0.907 0.911 sn:0.848 0.860 0.839 sp:0.896 0.803 0.812 acc:0.876 0.826 0.823 mcc:0.744 0.654 0.643
Epoch: 40 Task: 4 auc:0.961 0.943 0.933 sn:0.914 0.904 0.871 sp:0.868 0.836 0.836 acc:0.893 0.893 0.853 mcc:0.785 0.660 0.707
Epoch: 40 Task: 5 auc:0.961 0.750 0.880 sn:0.615 0.407 0.692 sp:0.999 0.997 0.999 acc:0.985 0.975 0.988 mcc:0.758 0.577 0.804
Epoch: 40 Task: 6 auc:0.965 0.813 0.902 sn:0.584 0.400 0.303 sp:0.996 0.994 0.996 acc:0.976 0.965 0.963 mcc:0.708 0.543 0.469
Epoch: 40 Task: 7 auc:0.969 0.820 0.876 sn:0.392 0.360 0.360 sp:0.998 0.993 0.990 acc:0.977 0.971 0.968 mcc:0.571 0.468 0.435
EarlyStopping counter: 5 out of 100
Epoch: 41 Task: 1 auc:0.970 0.932 0.922 sn:0.926 0.886 0.857 sp:0.894 0.839 0.822 

Epoch: 49 Task: 1 auc:0.977 0.935 0.918 sn:0.914 0.884 0.852 sp:0.929 0.858 0.828 acc:0.922 0.870 0.839 mcc:0.843 0.741 0.678
Epoch: 49 Task: 2 auc:0.962 0.891 0.900 sn:0.859 0.831 0.836 sp:0.924 0.807 0.822 acc:0.903 0.815 0.827 mcc:0.781 0.612 0.634
Epoch: 49 Task: 3 auc:0.960 0.891 0.904 sn:0.929 0.870 0.861 sp:0.852 0.757 0.771 acc:0.884 0.804 0.808 mcc:0.770 0.617 0.623
Epoch: 49 Task: 4 auc:0.969 0.953 0.930 sn:0.919 0.954 0.905 sp:0.897 0.767 0.767 acc:0.909 0.925 0.836 mcc:0.817 0.719 0.679
Epoch: 49 Task: 5 auc:0.970 0.735 0.885 sn:0.385 0.333 0.423 sp:1.000 0.996 0.997 acc:0.978 0.971 0.977 mcc:0.613 0.488 0.589
Epoch: 49 Task: 6 auc:0.981 0.808 0.821 sn:0.628 0.400 0.394 sp:0.998 0.990 0.988 acc:0.980 0.960 0.960 mcc:0.759 0.498 0.475
Epoch: 49 Task: 7 auc:0.987 0.796 0.869 sn:0.613 0.400 0.480 sp:0.996 0.975 0.964 acc:0.982 0.955 0.947 mcc:0.707 0.362 0.368
EarlyStopping counter: 6 out of 100
Epoch: 50 Task: 1 auc:0.976 0.928 0.917 sn:0.914 0.858 0.849 sp:0.921 0.843 0.834 

Epoch: 58 Task: 1 auc:0.984 0.928 0.918 sn:0.951 0.846 0.811 sp:0.915 0.882 0.865 acc:0.932 0.865 0.840 mcc:0.864 0.729 0.678
Epoch: 58 Task: 2 auc:0.969 0.901 0.901 sn:0.874 0.761 0.779 sp:0.928 0.873 0.864 acc:0.910 0.836 0.836 mcc:0.798 0.631 0.635
Epoch: 58 Task: 3 auc:0.972 0.895 0.909 sn:0.953 0.866 0.867 sp:0.873 0.771 0.781 acc:0.906 0.810 0.817 mcc:0.815 0.627 0.638
Epoch: 58 Task: 4 auc:0.978 0.962 0.954 sn:0.947 0.948 0.922 sp:0.897 0.802 0.802 acc:0.924 0.925 0.862 mcc:0.848 0.727 0.729
Epoch: 58 Task: 5 auc:0.977 0.756 0.894 sn:0.634 0.407 0.654 sp:1.000 0.997 0.997 acc:0.986 0.975 0.985 mcc:0.785 0.577 0.758
Epoch: 58 Task: 6 auc:0.985 0.792 0.854 sn:0.719 0.343 0.303 sp:0.994 1.000 0.999 acc:0.981 0.968 0.966 mcc:0.777 0.576 0.514
Epoch: 58 Task: 7 auc:0.991 0.817 0.833 sn:0.472 0.600 0.520 sp:0.999 0.948 0.925 acc:0.980 0.936 0.911 mcc:0.652 0.391 0.284
EarlyStopping counter: 3 out of 100
Epoch: 59 Task: 1 auc:0.982 0.931 0.919 sn:0.935 0.846 0.815 sp:0.925 0.882 0.850 

Epoch: 67 Task: 1 auc:0.987 0.931 0.912 sn:0.941 0.805 0.770 sp:0.946 0.917 0.897 acc:0.944 0.865 0.838 mcc:0.887 0.731 0.676
Epoch: 67 Task: 2 auc:0.973 0.882 0.892 sn:0.901 0.734 0.771 sp:0.927 0.857 0.883 acc:0.918 0.816 0.846 mcc:0.818 0.587 0.653
Epoch: 67 Task: 3 auc:0.971 0.892 0.894 sn:0.905 0.799 0.755 sp:0.913 0.846 0.850 acc:0.910 0.826 0.811 mcc:0.815 0.643 0.609
Epoch: 67 Task: 4 auc:0.980 0.970 0.954 sn:0.924 0.954 0.914 sp:0.930 0.853 0.853 acc:0.927 0.938 0.884 mcc:0.853 0.779 0.769
Epoch: 67 Task: 5 auc:0.980 0.790 0.900 sn:0.695 0.407 0.615 sp:0.998 0.997 1.000 acc:0.987 0.975 0.986 mcc:0.801 0.577 0.779
Epoch: 67 Task: 6 auc:0.992 0.826 0.816 sn:0.719 0.486 0.485 sp:0.998 0.967 0.970 acc:0.985 0.944 0.947 mcc:0.820 0.430 0.437
Epoch: 67 Task: 7 auc:0.993 0.794 0.789 sn:0.578 0.640 0.560 sp:0.999 0.926 0.915 acc:0.984 0.916 0.903 mcc:0.738 0.357 0.288
EarlyStopping counter: 7 out of 100
Epoch: 68 Task: 1 auc:0.988 0.931 0.915 sn:0.914 0.836 0.818 sp:0.964 0.874 0.857 

Epoch: 76 Task: 1 auc:0.989 0.934 0.915 sn:0.963 0.855 0.828 sp:0.927 0.876 0.844 acc:0.944 0.866 0.837 mcc:0.888 0.731 0.672
Epoch: 76 Task: 2 auc:0.977 0.881 0.884 sn:0.916 0.734 0.731 sp:0.932 0.867 0.863 acc:0.927 0.823 0.819 mcc:0.837 0.600 0.593
Epoch: 76 Task: 3 auc:0.977 0.889 0.897 sn:0.923 0.795 0.769 sp:0.921 0.833 0.835 acc:0.922 0.817 0.808 mcc:0.840 0.626 0.604
Epoch: 76 Task: 4 auc:0.984 0.970 0.970 sn:0.974 0.958 0.931 sp:0.880 0.879 0.879 acc:0.931 0.945 0.905 mcc:0.863 0.805 0.811
Epoch: 76 Task: 5 auc:0.986 0.790 0.898 sn:0.709 0.407 0.654 sp:0.998 0.996 0.997 acc:0.988 0.974 0.985 mcc:0.807 0.555 0.758
Epoch: 76 Task: 6 auc:0.994 0.773 0.804 sn:0.682 0.343 0.303 sp:0.999 0.999 0.993 acc:0.984 0.966 0.960 mcc:0.810 0.551 0.433
Epoch: 76 Task: 7 auc:0.994 0.771 0.754 sn:0.578 0.400 0.360 sp:0.999 0.986 0.983 acc:0.984 0.965 0.961 mcc:0.728 0.429 0.373
EarlyStopping counter: 16 out of 100
Epoch: 77 Task: 1 auc:0.989 0.922 0.912 sn:0.917 0.858 0.835 sp:0.971 0.843 0.838

Epoch: 85 Task: 1 auc:0.990 0.922 0.902 sn:0.932 0.807 0.777 sp:0.965 0.897 0.881 acc:0.950 0.855 0.833 mcc:0.899 0.709 0.664
Epoch: 85 Task: 2 auc:0.983 0.894 0.892 sn:0.935 0.801 0.799 sp:0.938 0.836 0.831 acc:0.937 0.824 0.820 mcc:0.860 0.619 0.611
Epoch: 85 Task: 3 auc:0.980 0.899 0.898 sn:0.917 0.817 0.788 sp:0.929 0.832 0.838 acc:0.924 0.826 0.817 mcc:0.843 0.644 0.625
Epoch: 85 Task: 4 auc:0.985 0.961 0.961 sn:0.958 0.938 0.931 sp:0.920 0.879 0.879 acc:0.941 0.929 0.905 mcc:0.881 0.759 0.811
Epoch: 85 Task: 5 auc:0.991 0.751 0.914 sn:0.624 0.444 0.654 sp:1.000 0.996 0.997 acc:0.986 0.975 0.985 mcc:0.782 0.586 0.758
Epoch: 85 Task: 6 auc:0.994 0.803 0.824 sn:0.723 0.600 0.576 sp:0.998 0.936 0.930 acc:0.985 0.919 0.913 mcc:0.822 0.405 0.367
Epoch: 85 Task: 7 auc:0.991 0.782 0.740 sn:0.558 0.400 0.320 sp:0.999 0.978 0.980 acc:0.984 0.958 0.957 mcc:0.727 0.378 0.319
EarlyStopping counter: 25 out of 100
Epoch: 86 Task: 1 auc:0.994 0.927 0.903 sn:0.947 0.833 0.787 sp:0.973 0.885 0.853

Epoch: 94 Task: 1 auc:0.992 0.923 0.905 sn:0.973 0.843 0.806 sp:0.939 0.866 0.842 acc:0.955 0.855 0.826 mcc:0.911 0.709 0.649
Epoch: 94 Task: 2 auc:0.987 0.890 0.877 sn:0.943 0.799 0.794 sp:0.944 0.848 0.829 acc:0.944 0.832 0.817 mcc:0.876 0.632 0.604
Epoch: 94 Task: 3 auc:0.987 0.868 0.885 sn:0.953 0.807 0.812 sp:0.931 0.777 0.802 acc:0.940 0.790 0.806 mcc:0.878 0.577 0.607
Epoch: 94 Task: 4 auc:0.988 0.980 0.977 sn:0.969 0.959 0.957 sp:0.919 0.888 0.888 acc:0.947 0.948 0.922 mcc:0.893 0.814 0.847
Epoch: 94 Task: 5 auc:0.989 0.777 0.884 sn:0.746 0.407 0.654 sp:0.999 0.996 0.996 acc:0.990 0.974 0.983 mcc:0.846 0.555 0.737
Epoch: 94 Task: 6 auc:0.990 0.755 0.828 sn:0.807 0.400 0.394 sp:0.997 0.994 0.988 acc:0.988 0.965 0.960 mcc:0.863 0.543 0.475
Epoch: 94 Task: 7 auc:0.996 0.817 0.774 sn:0.849 0.480 0.280 sp:0.997 0.971 0.970 acc:0.992 0.954 0.946 mcc:0.872 0.401 0.236
EarlyStopping counter: 2 out of 100
Epoch: 95 Task: 1 auc:0.991 0.922 0.907 sn:0.971 0.833 0.798 sp:0.935 0.870 0.854 

Epoch: 103 Task: 1 auc:0.997 0.929 0.910 sn:0.977 0.863 0.815 sp:0.967 0.874 0.841 acc:0.972 0.869 0.829 mcc:0.943 0.738 0.656
Epoch: 103 Task: 2 auc:0.991 0.883 0.880 sn:0.958 0.794 0.774 sp:0.951 0.849 0.834 acc:0.953 0.831 0.814 mcc:0.897 0.629 0.593
Epoch: 103 Task: 3 auc:0.992 0.897 0.897 sn:0.986 0.882 0.875 sp:0.920 0.762 0.763 acc:0.947 0.812 0.809 mcc:0.895 0.634 0.628
Epoch: 103 Task: 4 auc:0.993 0.982 0.981 sn:0.962 0.964 0.966 sp:0.956 0.871 0.871 acc:0.959 0.949 0.918 mcc:0.918 0.815 0.840
Epoch: 103 Task: 5 auc:0.994 0.744 0.895 sn:0.685 0.407 0.654 sp:1.000 0.997 0.997 acc:0.988 0.975 0.985 mcc:0.820 0.577 0.758
Epoch: 103 Task: 6 auc:0.994 0.805 0.809 sn:0.777 0.486 0.515 sp:0.998 0.976 0.969 acc:0.987 0.952 0.947 mcc:0.856 0.475 0.452
Epoch: 103 Task: 7 auc:0.994 0.819 0.756 sn:0.754 0.560 0.400 sp:0.999 0.968 0.941 acc:0.990 0.954 0.922 mcc:0.847 0.444 0.243
Epoch: 104 Task: 1 auc:0.994 0.925 0.902 sn:0.970 0.814 0.789 sp:0.946 0.898 0.869 acc:0.957 0.859 0.832 mcc:0.

Epoch: 112 Task: 1 auc:0.997 0.923 0.902 sn:0.966 0.794 0.772 sp:0.978 0.905 0.868 acc:0.972 0.854 0.823 mcc:0.944 0.707 0.645
Epoch: 112 Task: 2 auc:0.991 0.882 0.886 sn:0.875 0.786 0.791 sp:0.984 0.842 0.836 acc:0.948 0.823 0.821 mcc:0.882 0.614 0.611
Epoch: 112 Task: 3 auc:0.994 0.883 0.894 sn:0.955 0.809 0.810 sp:0.961 0.818 0.807 acc:0.959 0.814 0.808 mcc:0.915 0.622 0.611
Epoch: 112 Task: 4 auc:0.994 0.975 0.971 sn:0.979 0.943 0.940 sp:0.928 0.940 0.940 acc:0.956 0.942 0.940 mcc:0.911 0.811 0.879
Epoch: 112 Task: 5 auc:0.990 0.790 0.905 sn:0.742 0.444 0.538 sp:0.999 0.993 0.997 acc:0.990 0.973 0.981 mcc:0.846 0.547 0.678
Epoch: 112 Task: 6 auc:0.995 0.814 0.831 sn:0.850 0.457 0.515 sp:0.998 0.982 0.960 acc:0.991 0.956 0.939 mcc:0.900 0.489 0.415
Epoch: 112 Task: 7 auc:0.997 0.833 0.755 sn:0.899 0.480 0.440 sp:0.993 0.977 0.965 acc:0.990 0.959 0.947 mcc:0.860 0.433 0.345
EarlyStopping counter: 7 out of 100
Epoch: 113 Task: 1 auc:0.996 0.929 0.901 sn:0.960 0.811 0.762 sp:0.979 0.90

Epoch: 121 Task: 1 auc:0.998 0.933 0.901 sn:0.976 0.841 0.798 sp:0.977 0.895 0.850 acc:0.977 0.870 0.826 mcc:0.953 0.739 0.649
Epoch: 121 Task: 2 auc:0.995 0.882 0.889 sn:0.946 0.779 0.784 sp:0.977 0.847 0.829 acc:0.967 0.824 0.814 mcc:0.925 0.613 0.595
Epoch: 121 Task: 3 auc:0.996 0.887 0.895 sn:0.972 0.821 0.812 sp:0.964 0.814 0.816 acc:0.967 0.817 0.814 mcc:0.932 0.628 0.622
Epoch: 121 Task: 4 auc:0.993 0.979 0.971 sn:0.965 0.933 0.922 sp:0.957 0.957 0.957 acc:0.961 0.937 0.940 mcc:0.922 0.801 0.880
Epoch: 121 Task: 5 auc:0.996 0.729 0.887 sn:0.742 0.407 0.577 sp:1.000 0.997 1.000 acc:0.990 0.975 0.985 mcc:0.851 0.577 0.754
Epoch: 121 Task: 6 auc:0.993 0.779 0.831 sn:0.843 0.629 0.576 sp:0.996 0.947 0.949 acc:0.989 0.931 0.932 mcc:0.873 0.455 0.420
Epoch: 121 Task: 7 auc:0.998 0.820 0.756 sn:0.819 0.440 0.480 sp:0.999 0.987 0.987 acc:0.993 0.968 0.969 mcc:0.883 0.476 0.508
EarlyStopping counter: 16 out of 100
Epoch: 122 Task: 1 auc:0.998 0.918 0.891 sn:0.979 0.846 0.815 sp:0.977 0.8

Epoch: 130 Task: 1 auc:0.997 0.926 0.894 sn:0.948 0.845 0.787 sp:0.986 0.885 0.825 acc:0.968 0.866 0.807 mcc:0.937 0.731 0.613
Epoch: 130 Task: 2 auc:0.993 0.886 0.892 sn:0.931 0.806 0.826 sp:0.977 0.828 0.825 acc:0.962 0.821 0.825 mcc:0.913 0.615 0.628
Epoch: 130 Task: 3 auc:0.990 0.880 0.882 sn:0.943 0.797 0.786 sp:0.956 0.826 0.827 acc:0.951 0.814 0.810 mcc:0.898 0.620 0.611
Epoch: 130 Task: 4 auc:0.994 0.979 0.975 sn:0.958 0.974 0.974 sp:0.965 0.897 0.897 acc:0.962 0.962 0.935 mcc:0.923 0.859 0.873
Epoch: 130 Task: 5 auc:0.995 0.777 0.881 sn:0.751 0.444 0.654 sp:0.999 0.994 0.990 acc:0.990 0.974 0.978 mcc:0.854 0.566 0.669
Epoch: 130 Task: 6 auc:0.997 0.763 0.854 sn:0.708 0.457 0.485 sp:1.000 0.973 0.969 acc:0.986 0.948 0.946 mcc:0.833 0.436 0.429
Epoch: 130 Task: 7 auc:0.998 0.804 0.755 sn:0.719 0.560 0.400 sp:1.000 0.945 0.906 acc:0.990 0.932 0.889 mcc:0.840 0.357 0.184
EarlyStopping counter: 6 out of 100
Epoch: 131 Task: 1 auc:0.997 0.922 0.897 sn:0.966 0.857 0.828 sp:0.987 0.85

Epoch: 140 Task: 1 auc:0.998 0.915 0.893 sn:0.974 0.782 0.743 sp:0.980 0.914 0.903 acc:0.977 0.853 0.829 mcc:0.954 0.706 0.659
Epoch: 140 Task: 2 auc:0.995 0.877 0.877 sn:0.947 0.801 0.833 sp:0.978 0.814 0.789 acc:0.967 0.809 0.804 mcc:0.926 0.593 0.594
Epoch: 140 Task: 3 auc:0.994 0.888 0.891 sn:0.969 0.793 0.771 sp:0.951 0.822 0.838 acc:0.958 0.810 0.810 mcc:0.914 0.612 0.609
Epoch: 140 Task: 4 auc:0.995 0.982 0.984 sn:0.985 0.959 0.957 sp:0.929 0.931 0.931 acc:0.960 0.955 0.944 mcc:0.919 0.843 0.888
Epoch: 140 Task: 5 auc:0.993 0.775 0.882 sn:0.765 0.481 0.615 sp:0.999 0.986 0.980 acc:0.991 0.967 0.967 mcc:0.860 0.505 0.556
Epoch: 140 Task: 6 auc:0.997 0.818 0.808 sn:0.854 0.429 0.303 sp:0.998 0.993 0.991 acc:0.991 0.965 0.959 mcc:0.902 0.551 0.417
Epoch: 140 Task: 7 auc:0.999 0.778 0.779 sn:0.714 0.560 0.560 sp:1.000 0.939 0.916 acc:0.990 0.926 0.904 mcc:0.837 0.341 0.291
EarlyStopping counter: 16 out of 100
Epoch: 141 Task: 1 auc:0.998 0.923 0.899 sn:0.988 0.819 0.772 sp:0.967 0.9

Epoch: 149 Task: 1 auc:0.998 0.923 0.893 sn:0.979 0.816 0.765 sp:0.984 0.903 0.862 acc:0.982 0.862 0.817 mcc:0.963 0.723 0.632
Epoch: 149 Task: 2 auc:0.995 0.865 0.882 sn:0.929 0.821 0.838 sp:0.986 0.777 0.747 acc:0.967 0.791 0.777 mcc:0.925 0.570 0.554
Epoch: 149 Task: 3 auc:0.994 0.884 0.890 sn:0.970 0.772 0.771 sp:0.958 0.844 0.853 acc:0.963 0.814 0.819 mcc:0.924 0.617 0.626
Epoch: 149 Task: 4 auc:0.995 0.988 0.990 sn:0.975 0.969 0.974 sp:0.954 0.931 0.931 acc:0.966 0.963 0.953 mcc:0.930 0.868 0.906
Epoch: 149 Task: 5 auc:0.996 0.762 0.877 sn:0.582 0.444 0.500 sp:1.000 0.996 0.996 acc:0.985 0.975 0.978 mcc:0.757 0.586 0.627
Epoch: 149 Task: 6 auc:0.996 0.791 0.730 sn:0.657 0.371 0.394 sp:1.000 0.979 0.972 acc:0.983 0.949 0.944 mcc:0.801 0.397 0.371
Epoch: 149 Task: 7 auc:0.999 0.833 0.773 sn:0.734 0.520 0.520 sp:1.000 0.955 0.938 acc:0.990 0.940 0.923 mcc:0.846 0.363 0.313
EarlyStopping counter: 25 out of 100
Epoch: 150 Task: 1 auc:0.998 0.921 0.896 sn:0.962 0.889 0.847 sp:0.987 0.8

Epoch: 158 Task: 1 auc:0.999 0.925 0.896 sn:0.982 0.855 0.813 sp:0.985 0.867 0.823 acc:0.984 0.861 0.818 mcc:0.967 0.722 0.636
Epoch: 158 Task: 2 auc:0.998 0.883 0.886 sn:0.975 0.776 0.806 sp:0.982 0.844 0.841 acc:0.980 0.822 0.829 mcc:0.955 0.608 0.630
Epoch: 158 Task: 3 auc:0.998 0.891 0.899 sn:0.982 0.705 0.690 sp:0.973 0.893 0.909 acc:0.977 0.815 0.818 mcc:0.952 0.615 0.623
Epoch: 158 Task: 4 auc:0.997 0.989 0.989 sn:0.988 0.984 0.974 sp:0.954 0.897 0.897 acc:0.972 0.970 0.935 mcc:0.944 0.886 0.873
Epoch: 158 Task: 5 auc:0.996 0.752 0.910 sn:0.850 0.407 0.615 sp:0.997 0.994 0.993 acc:0.992 0.973 0.979 mcc:0.877 0.534 0.674
Epoch: 158 Task: 6 auc:0.999 0.801 0.798 sn:0.876 0.543 0.455 sp:0.999 0.961 0.946 acc:0.993 0.941 0.923 mcc:0.917 0.448 0.327
Epoch: 158 Task: 7 auc:0.999 0.850 0.750 sn:0.864 0.480 0.480 sp:1.000 0.981 0.970 acc:0.995 0.964 0.953 mcc:0.922 0.461 0.394
EarlyStopping counter: 34 out of 100
Epoch: 159 Task: 1 auc:0.999 0.927 0.902 sn:0.989 0.858 0.827 sp:0.978 0.8

Epoch: 167 Task: 1 auc:0.999 0.923 0.889 sn:0.954 0.821 0.764 sp:0.996 0.905 0.854 acc:0.976 0.866 0.812 mcc:0.952 0.731 0.622
Epoch: 167 Task: 2 auc:0.997 0.887 0.885 sn:0.977 0.801 0.799 sp:0.975 0.843 0.816 acc:0.976 0.829 0.810 mcc:0.946 0.628 0.594
Epoch: 167 Task: 3 auc:0.998 0.890 0.894 sn:0.984 0.823 0.829 sp:0.972 0.807 0.784 acc:0.977 0.813 0.803 mcc:0.953 0.623 0.605
Epoch: 167 Task: 4 auc:0.996 0.978 0.981 sn:0.979 0.959 0.966 sp:0.954 0.914 0.914 acc:0.967 0.952 0.940 mcc:0.934 0.832 0.880
Epoch: 167 Task: 5 auc:0.997 0.772 0.881 sn:0.831 0.407 0.577 sp:0.998 0.997 0.997 acc:0.992 0.975 0.982 mcc:0.883 0.577 0.705
Epoch: 167 Task: 6 auc:0.996 0.770 0.757 sn:0.821 0.486 0.333 sp:0.999 0.987 0.979 acc:0.990 0.962 0.949 mcc:0.888 0.544 0.357
Epoch: 167 Task: 7 auc:0.999 0.857 0.753 sn:0.799 0.600 0.520 sp:1.000 0.948 0.944 acc:0.993 0.936 0.929 mcc:0.888 0.391 0.328
EarlyStopping counter: 43 out of 100
Epoch: 168 Task: 1 auc:0.997 0.925 0.901 sn:0.992 0.823 0.782 sp:0.952 0.8

Epoch: 176 Task: 1 auc:0.999 0.920 0.891 sn:0.979 0.848 0.799 sp:0.989 0.879 0.832 acc:0.984 0.865 0.817 mcc:0.968 0.728 0.632
Epoch: 176 Task: 2 auc:0.997 0.870 0.879 sn:0.973 0.806 0.828 sp:0.979 0.817 0.781 acc:0.977 0.814 0.796 mcc:0.948 0.602 0.580
Epoch: 176 Task: 3 auc:0.997 0.879 0.890 sn:0.957 0.823 0.827 sp:0.983 0.796 0.783 acc:0.973 0.807 0.801 mcc:0.943 0.611 0.602
Epoch: 176 Task: 4 auc:0.997 0.990 0.990 sn:0.960 0.977 0.974 sp:0.979 0.948 0.948 acc:0.969 0.973 0.961 mcc:0.938 0.901 0.923
Epoch: 176 Task: 5 auc:0.997 0.814 0.890 sn:0.840 0.407 0.538 sp:0.999 0.997 1.000 acc:0.993 0.975 0.983 mcc:0.901 0.577 0.728
Epoch: 176 Task: 6 auc:0.999 0.788 0.785 sn:0.861 0.286 0.212 sp:1.000 1.000 0.997 acc:0.993 0.965 0.960 mcc:0.921 0.525 0.393
Epoch: 176 Task: 7 auc:1.000 0.803 0.752 sn:0.935 0.600 0.480 sp:0.999 0.961 0.955 acc:0.997 0.948 0.939 mcc:0.955 0.438 0.336
EarlyStopping counter: 52 out of 100
Epoch: 177 Task: 1 auc:0.999 0.923 0.895 sn:0.980 0.846 0.808 sp:0.989 0.8

In [None]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_auc, test_sn, test_sp, test_acc, test_mcc, test_predict = eval(model, amodel, gmodel, test_df)

In [None]:
for i in range(task_num):
    print('Epoch:',epoch, 'Task:', i+1,
      'auc:%.3f'%train_auc[i],'%.3f'%val_auc[i],'%.3f'%test_auc[i], 
      'sn:%.3f'%train_sn[i],'%.3f'%val_sn[i],'%.3f'%test_sn[i], 
      'sp:%.3f'%train_sp[i], '%.3f'%val_sp[i], '%.3f'%test_sp[i], 
      'acc:%.3f'%train_acc[i], '%.3f'%val_acc[i], '%.3f'%test_acc[i], 
      'mcc:%.3f'%train_mcc[i],'%.3f'%val_mcc[i],'%.3f'%test_mcc[i])

In [None]:
# print('target_file:',raw_filename[0])
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)

In [None]:
# Task-specific AFSE
# Dynamic cth
# loss =  regression_loss + 0.08 * (vat_loss + test_vat_loss)
# r=3 t=2 200 100 100 1 all_lr=3e-4