In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
raw_filename = ["./data/ADMET/A/C/Pgp-inhibitor.csv", "./data/ADMET/A/C/Pgp-substrate.csv", "./data/ADMET/D/C/BBB_Penetration.csv",
               "./data/ADMET/M/C/CYP1A2_inhibitor.csv", "./data/ADMET/M/C/CYP1A2_substrate.csv", "./data/ADMET/M/C/CYP2C9_inhibitor.csv",
               "./data/ADMET/M/C/CYP2C9_substrate.csv", "./data/ADMET/M/C/CYP3A4_inhibitor.csv", "./data/ADMET/M/C/CYP3A4_substrate.csv",
               "./data/ADMET/T/C/Ames.csv",
               "./data/ADMET/T/C/Eye_Corrosion.csv", "./data/ADMET/T/C/FDAMDD.csv",
               "./data/ADMET/T/C/NR-AhR.csv", "./data/ADMET/T/C/NR-AR-LBD.csv", "./data/ADMET/T/C/NR-AR.csv",
               "./data/ADMET/T/C/NR-Aromatase.csv", "./data/ADMET/T/C/NR-ER-LBD.csv", "./data/ADMET/T/C/NR-ER.csv",
               "./data/ADMET/T/C/NR-PPAR-gamma.csv", "./data/ADMET/T/C/Skin_Sensitization.csv", "./data/ADMET/T/C/SR-ARE.csv",
               "./data/ADMET/T/C/SR-ATAD5.csv", "./data/ADMET/T/C/SR-HSE.csv", "./data/ADMET/T/C/SR-MMP.csv",
               "./data/ADMET/T/C/SR-p53.csv"]

task_id = [12,13,17,18,22,24] # 6000 < sample size < 7000
task_num = len(task_id)
raw_filename = [raw_filename[i] for i in task_id]
random_seed = 68
file_name = f'Multi_Tasks_Big'
# for i in range(task_num):
#     file_list = raw_filename[i].split('/')
#     file = '_'+file_list[-3]+'_'+file_list[-1]
#     file_name += file[:-4]
    
number = 'run_0'
model_file = "model_file/3C_GAFSE_"+file_name+'_'+number
log_dir = f'log/{"3C_GAFSE_"+file_name}_'+number
result_dir = './result/3C_GAFSE_'+file_name+'_'+number
print(raw_filename)
print(file_name)
print(model_file)

['./data/ADMET/T/C/NR-AhR.csv', './data/ADMET/T/C/NR-AR-LBD.csv', './data/ADMET/T/C/NR-ER.csv', './data/ADMET/T/C/NR-PPAR-gamma.csv', './data/ADMET/T/C/SR-HSE.csv', './data/ADMET/T/C/SR-p53.csv']
Multi_Tasks_Big
model_file/3C_GAFSE_Multi_Tasks_Big_run_0


In [4]:
tasks = ['value']
total_df = pd.DataFrame([])

for i in range(task_num):
    task_df = pd.read_csv(raw_filename[i], header=0, names = ["smiles", "dataset", "value"],usecols=[0,1,2])
    task_df["task_id"] = i
    total_df = pd.concat([total_df,task_df])
    
print(total_df[:3],total_df[-3:])

def add_canonical_smiles(total_df):
    smilesList = total_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    total_df = total_df[total_df["smiles"].isin(remained_smiles)]
    total_df['cano_smiles'] =canonical_smiles_list
    return total_df

total_df = add_canonical_smiles(total_df)
print(total_df.head())

                                              smiles   dataset  value  task_id
0             O=C(O)[C@H](O)[C@@H](O)[C@H](O)C(=O)CO  training    0.0        0
1          O=C(O)[C@H](O)[C@@H](O)[C@H](O)[C@H](O)CO  training    0.0        0
2  O=C(O[C@@H]1[C@]2(C)[C@H]([C@H]3[C@@H](c4c(cc(...  training    0.0        0                         smiles dataset  value  task_id
6912                    BrCCCO     val    0.0        5
6913  c12c3c4ccc1cccc2ccc3ccc4     val    0.0        5
6914      S(P(SCCC)(=O)OCC)CCC     val    0.0        5
number of all smiles:  39451
number of successfully processed smiles:  39451
                                              smiles   dataset  value  \
0             O=C(O)[C@H](O)[C@@H](O)[C@H](O)C(=O)CO  training    0.0   
1          O=C(O)[C@H](O)[C@@H](O)[C@H](O)[C@H](O)CO  training    0.0   
2  O=C(O[C@@H]1[C@]2(C)[C@H]([C@H]3[C@@H](c4c(cc(...  training    0.0   
3   S(C#N)CC(=O)O[C@H]1[C@@]2(C)C(C)(C)[C@@H](C1)CC2  training    0.0   
4                     

In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 3 # default: 2
T = 2
per_task_output_units_num = 1 # for regression model
output_units_num = task_num * per_task_output_units_num

In [6]:
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
feature_filename = './features/'+model_file.split('/')[-1][:-1]+'0.pickle'
filename = './features/'+model_file.split('/')[-1]
print(feature_filename)
if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
    print('Loading features successfully.')
else:
    feature_dicts = save_smiles_dicts(total_smilesList,filename)

39451
./features/3C_GAFSE_Multi_Tasks_Big_run_0.pickle
Loading features successfully.


In [7]:
test_df = total_df[total_df.dataset.values == "test"]
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)

val_df = total_df[total_df.dataset.values == "val"]
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)

train_df = total_df[total_df.dataset.values == "training"]
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)

print(total_df.shape, len(train_df)+len(val_df)+len(test_df), train_df.shape,val_df.shape,test_df.shape)

(39451, 5) 39445 (31577, 5) (3924, 5) (3944, 5)


In [8]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [9]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [10]:
def caculate_r2(predict,y):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
#     print(len(y), y_mean, len(predict), predict_mean)
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/(y2+1e-9)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score, roc_auc_score
def calc(TN, FP, FN, TP):
    SN = TP / (TP + FN)  # recall
    SP = TN / (TN + FP)
    # Precision = TP / (TP + FP)
    ACC = (TP + TN) / (TP + TN + FN + FP)
    # F1 = (2 * TP) / (2 * TP + FP + FN)
    fz = TP * TN - FP * FN
    fm = (TP + FN) * (TP + FP) * (TN + FP) * (TN + FN)
    MCC = fz / (pow(fm, 0.5)+1e-9)
    return SN, SP, ACC, MCC

In [11]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None, task_id=None, sigmoid=False):
    mol_prediction = model(feature=f, d=0, sigmoid=sigmoid)
    pred = mol_prediction.detach()
    task_counter = 0
    vat_loss = 0
    for i in range(task_num):
        batch_task_sample = len(task_id[task_id==i])
        if batch_task_sample > 0:
            task_counter += 1
            y_mask = np.where(task_id==i, 1, 0)
            y_mask = torch.Tensor(y_mask)
        #     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
            eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
            eps = Variable(eps, requires_grad=True)
            # Predict on randomly perturbed image
            eps_p = model(feature=f, d=eps.cuda(), sigmoid=sigmoid)
            eps_p_ = model(feature=f, d=-eps.cuda(), sigmoid=sigmoid)
            p_aux = nn.Sigmoid()(eps_p[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            p_aux_ = nn.Sigmoid()(eps_p_[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
        #     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
            loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
            loss.backward(retain_graph=True)

            # Based on perturbed image, get direction of greatest error
            eps_adv = eps.grad#/10**-learning_rate
            optimizer_AFSE.zero_grad()
            # Use that direction as adversarial perturbation
            eps_adv_normed = normalize_perturbation(eps_adv)
            d_adv = lamda * eps_adv_normed.cuda()
            f_p, max_lr = model(feature=f, d=d_adv, output_lr=True, sigmoid=sigmoid)
            f_p = model(feature=f, d=d_adv, sigmoid=sigmoid)
            f_p_ = model(feature=f, d=-d_adv, sigmoid=sigmoid)
            p = nn.Sigmoid()(f_p[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            p_ = nn.Sigmoid()(f_p_[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            vat_loss += loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    vat_loss /= task_counter
    if output_lr:
        if output_plr:
            loss = 0
            task_counter = 0
            eps_ = 1e-6 * normalize_perturbation(torch.randn(f.shape))
            eps_ = Variable(eps_, requires_grad=True)
            eps_p__ = model(feature=f+eps_.cuda(), d=0)
            for i in range(task_num):
                batch_task_sample = len(task_id[task_id==i])
                if batch_task_sample > 0:
                    task_counter += 1
                    y_mask = np.where(task_id==i, 1, 0)
                    y_mask = torch.Tensor(y_mask)
                    loss += loss_function(eps_p__[:,i]*y_mask,y.view(-1)*y_mask)
            loss /= task_counter
            loss.backward(retain_graph=True)
            punish_lr = torch.norm(torch.mean(eps_.grad,0))
            optimizer_AFSE.zero_grad()
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [12]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25):
    weighted_CE = nn.CrossEntropyLoss(weight=alpha, reduction='none')
    focal_weight = (1-torch.max(pred * target, -1)[0])**gamma
    loss = focal_weight * weighted_CE(pred, torch.argmax(target,-1))
    loss = torch.mean(loss)
    return loss

def generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list):
    [a,b,c] = x_atom.shape
    [d,e,f,g] = bond_neighbor.shape
    ce_loss = nn.CrossEntropyLoss()
    one_hot_loss = 0
    run_times = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        atom_weights = 1-x_atom[i,:l,:16].sum(-2)/l
        ce_atom_loss = nn.CrossEntropyLoss(weight=atom_weights)
        # print(atom_weights[1], refer_atom_list[i,0,torch.argmax(x_atom[i,0,:16],-1)], torch.argmax(x_atom[i,0,:16],-1))
#         one_hot_loss += ce_atom_loss(refer_atom_list[i,:l,:16], torch.argmax(x_atom[i,:l,:16],-1))- \
#                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        one_hot_loss += py_sigmoid_focal_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16], alpha=atom_weights)- \
                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        run_times += 2
    total_loss = one_hot_loss/run_times
    return total_loss, 0, 0, 0


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(int)
        task_id = batch_df['task_id'].values
        task_id_test = batch_test['task_id'].values
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
#         refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
#                                                   torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
#                                                   mol_feature=mol_feature,activated_features=activated_features.detach())
        
#         x_atom = torch.Tensor(x_atom)
#         x_bonds = torch.Tensor(x_bonds)
#         x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
#         bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
#         bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1,                                                                                lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1),
                                                                                       task_id=task_id, sigmoid=True) # 10**-learning_rate     
        classification_loss = 0
        task_counter = 0
        for i in range(task_num):
            batch_task_sample = len(batch_df[batch_df["task_id"].values==i])
            if batch_task_sample > 0:
                task_counter += 1
                y_mask = np.where(batch_df["task_id"].values==i, 1, 0)
                y_mask = torch.Tensor(y_mask)
                c_loss = - torch.Tensor(y_val) * torch.log(mol_prediction[:,i]+1e-9) - \
                                (1-torch.Tensor(y_val)) * torch.log((1-mol_prediction[:,i])+1e-9)
                classification_loss += torch.sum(c_loss * y_mask)/batch_task_sample
        classification_loss /= task_counter
#         atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
#                                       torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
#         success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
#                             bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
#                                                      refer_atom_list, refer_bond_list,topn=1)
#         reconstruction_loss, one_hot_loss, interger_loss,binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, 
#                                                                                               bond_neighbor, validity_mask, atom_list, 
#                                                                                               bond_list)
#         x_atom_test = torch.Tensor(x_atom_test)
#         x_bonds_test = torch.Tensor(x_bonds_test)
#         x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
#         bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
#         bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate, task_id=task_id_test, sigmoid=True)
#         atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
#                                                 torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
#                                                 mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
#         refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
#                                                             torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
#                                                             mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
#         success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
#                             bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
#                                                      refer_atom_list_test, refer_bond_list_test,topn=1)
#         test_reconstruction_loss, test_one_hot_loss, test_interger_loss,test_binary_loss = generate_loss_function(atom_list_test, x_atom_test, bond_list_test, bond_neighbor_test, validity_mask_test, atom_list_test, bond_list_test)
            
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        max_lr = 1e-3
        adapt_lr = conv_lr - conv_lr**2 + 0.06 * punish_lr # 0.06
        if adapt_lr < max_lr and adapt_lr >= 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = adapt_lr.detach()
                AFSE_lr = adapt_lr    
        elif adapt_lr < 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = 0
                AFSE_lr = 0
        elif adapt_lr >= max_lr:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = max_lr
                AFSE_lr = max_lr
#         AFSE_lr = 1e-4

        logger.add_scalar('loss/classification', classification_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
#         logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
#         logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
#         logger.add_scalar('loss/GRN_one_hot', one_hot_loss, global_step)
#         logger.add_scalar('loss/GRN_interger', interger_loss, global_step)
#         logger.add_scalar('loss/GRN_binary', binary_loss, global_step)
        logger.add_scalar('lr/conv_lr', conv_lr, global_step)
        logger.add_scalar('lr/punish_lr', punish_lr, global_step)
        logger.add_scalar('lr/AFSE_lr', AFSE_lr, global_step)
        
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  classification_loss + 0.08 * (vat_loss + test_vat_loss) # + 0.3 * (reconstruction_loss + test_reconstruction_loss)
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, cth_list=[0.5 for i in range(task_num)], topn=1, output_feature=False, 
         generate=False, modify_atom=True,return_GRN_loss=False, viz=False, output_cth=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    test_MAE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(int)
        task_id = batch_df['task_id'].values
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
#         x_atom = torch.Tensor(x_atom)
#         x_bonds = torch.Tensor(x_bonds)
#         x_bond_index = torch.cuda.LongTensor(x_bond_index)
#         bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
#         bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda, task_id=task_id, sigmoid=True)
        mol_prediction = mol_prediction.cpu().detach().numpy()
        mol_prediction_readout = mol_prediction[:,0]
        for i in range(task_num):
            mol_prediction_readout[batch_df['task_id'].values==i] = mol_prediction[batch_df['task_id'].values==i, i]
#         print(mol_feature,d_adv)
#         atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
#                                       torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
#                                       torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
#         refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
#         if generate:
#             if modify_atom:
#                 success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
#                             bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
#                                                      refer_atom_list, refer_bond_list,topn=topn,viz=viz)
#             else:
#                 modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
#             generated_smiles.extend(modified_smiles)
#             success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
#             success_reconstruction += reconstruction
#             success_validity += validity
#             reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction_readout)
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
    
    predict_list = np.array(predict_list)
    task_lens_list = []
    task_accumulated = 0
    for i in range(task_num):
        task_accumulated += len(dataset[dataset["task_id"]==i])
        task_lens_list.append(task_accumulated)
    task_predict_list = np.split(predict_list, task_lens_list)
    task_y_val_list = np.split(dataset[tasks[0]].values.astype(float), task_lens_list)
    
    auc_list = []
    if output_cth:
        best_cth_list = [0.5 for i in range(task_num)]
        best_sn_list = [0 for i in range(task_num)]
        best_sp_list = [0 for i in range(task_num)]
        best_acc_list = [0 for i in range(task_num)]
        best_mcc_list = [0 for i in range(task_num)]
        for cth in np.linspace(0,1,21):
            for i in range(task_num):
                task_predict = task_predict_list[i]
                task_y_val = task_y_val_list[i]
                if cth == 0:
                    auc = roc_auc_score(task_y_val, task_predict)
                    auc_list.append(auc)
                class_pred = np.where(task_predict>cth,1,0).astype(int)
                tn, fp, fn, tp = confusion_matrix(task_y_val, class_pred).ravel()
                sn, sp, acc, mcc = calc(tn, fp, fn, tp)
                mean_index = (sn + sp + acc + mcc)/4
                best_index = (best_sn_list[i] + best_sp_list[i] + best_acc_list[i] + best_mcc_list[i])/4
                if mean_index > best_index:
                    best_cth_list[i] = cth
                    best_sn_list[i] = sn
                    best_sp_list[i] = sp
                    best_acc_list[i] = acc
                    best_mcc_list[i] = mcc
        sn_list = best_sn_list
        sp_list = best_sp_list
        acc_list = best_acc_list
        mcc_list = best_mcc_list
    else:
        sn_list = []
        sp_list = []
        acc_list = []
        mcc_list = []
        for i in range(task_num):
            task_predict = task_predict_list[i]
            task_y_val = task_y_val_list[i]
            auc = roc_auc_score(task_y_val, task_predict)
            class_pred = np.where(task_predict>cth_list[i],1,0).astype(int)
            tn, fp, fn, tp = confusion_matrix(task_y_val, class_pred).ravel()
            sn, sp, acc, mcc = calc(tn, fp, fn, tp)
            auc_list.append(auc)
            sn_list.append(sn)
            sp_list.append(sp)
            acc_list.append(acc)
            mcc_list.append(mcc)

#     print(r2_list)
#     if generate:
#         generated_num = len(generated_smiles)
#         eval_num = len(dataset)
#         unique = generated_num
#         novelty = generated_num
#         for i in range(generated_num):
#             for j in range(generated_num-i-1):
#                 if generated_smiles[i]==generated_smiles[i+j+1]:
#                     unique -= 1
#             for k in range(eval_num):
#                 if generated_smiles[i]==dataset['smiles'].values[k]:
#                     novelty -= 1
#         unique_rate = unique/(generated_num+1e-9)
#         novelty_rate = novelty/(generated_num+1e-9)
# #         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
#         return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
#     if return_GRN_loss:
#         return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),np.array(test_MAE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list, auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list
    if output_cth:
        return auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list, best_cth_list
    return auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 100
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [13]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
if os.path.isdir(log_dir):
    for files in os.listdir(log_dir):
        os.remove(log_dir+"/"+files)
    os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

log/3C_GAFSE_Multi_Tasks_Big_run_0


In [None]:
# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = patience - patience
# stopper.load_checkpoint(model)
# stopper_afse.load_checkpoint(amodel)
# stopper_generate.load_checkpoint(gmodel)
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
    train_auc, train_sn, train_sp, train_acc, train_mcc, train_predict = eval(model, amodel, gmodel, train_df)
    val_auc, val_sn, val_sp, val_acc, val_mcc, val_predict, val_cth_list = eval(model, amodel, gmodel, val_df, output_cth= True)
    test_auc, test_sn, test_sp, test_acc, test_mcc, test_predict = eval(model, amodel, gmodel, test_df, cth_list=val_cth_list)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    for i in range(task_num):
        logger.add_scalar(f'train/T{i}_auc', train_auc[i], global_step)
        logger.add_scalar(f'train/T{i}_sn', train_sn[i], global_step)
        logger.add_scalar(f'train/T{i}_sp', train_sp[i], global_step)
        logger.add_scalar(f'train/T{i}_acc', train_acc[i], global_step)
        logger.add_scalar(f'train/T{i}_mcc', train_mcc[i], global_step)
        logger.add_scalar(f'val/T{i}_cth', val_cth_list[i], global_step)
        logger.add_scalar(f'val/T{i}_auc', val_auc[i], global_step)
        logger.add_scalar(f'val/T{i}_sn', val_sn[i], global_step)
        logger.add_scalar(f'val/T{i}_sp', val_sp[i], global_step)
        logger.add_scalar(f'val/T{i}_acc', val_acc[i], global_step)
        logger.add_scalar(f'val/T{i}_mcc', val_mcc[i], global_step)
        logger.add_scalar(f'test/T{i}_auc', test_auc[i], global_step)
        logger.add_scalar(f'test/T{i}_sn', test_sn[i], global_step)
        logger.add_scalar(f'test/T{i}_sp', test_sp[i], global_step)
        logger.add_scalar(f'test/T{i}_acc', test_acc[i], global_step)
        logger.add_scalar(f'test/T{i}_mcc', test_mcc[i], global_step)
#         logger.add_scalar(f'val/GRN', reconstruction_loss, global_step)
#         logger.add_scalar(f'val/GRN_one_hot', one_hot_loss, global_step)
#         logger.add_scalar(f'val/GRN_interger', interger_loss, global_step)
#         logger.add_scalar(f'val/GRN_binary', binary_loss, global_step)
        # logger.add_scalar('test/EF0.01', topk_list[0], global_step)
        # logger.add_scalar('test/EF0.03', topk_list[1], global_step)
        # logger.add_scalar('test/EF0.1', topk_list[2], global_step)
        # logger.add_scalar('test/EF10', topk_list[3], global_step)
        # logger.add_scalar('test/EF30', topk_list[4], global_step)
        # logger.add_scalar('test/EF100', topk_list[5], global_step)

    #     train_mse_list.append(train_MSE**0.5)
    #     train_r2_list.append(train_r2)
    #     val_mse_list.append(val_MSE**0.5)  
    #     val_r2_list.append(val_r2)
    #     train_f_list.append(train_f)
    #     val_f_list.append(val_f)
    #     test_f_list.append(test_f)
    #     epoch_list.append(epoch)
    #     train_predict_list.append(train_predict.flatten())
    #     test_predict_list.append(test_predict.flatten())
    #     val_predict_list.append(val_predict.flatten())
    #     train_y_list.append(train_df[tasks[0]].values)
    #     val_y_list.append(val_df[tasks[0]].values)
    #     test_y_list.append(test_df[tasks[0]].values)
    #     train_d_list.append(train_d)
    #     val_d_list.append(val_d)
    #     test_d_list.append(test_d)
    #     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
        print('Epoch:',epoch, 'Task:', i+1,
              'auc:%.3f'%train_auc[i],'%.3f'%val_auc[i],'%.3f'%test_auc[i], 
              'sn:%.3f'%train_sn[i],'%.3f'%val_sn[i],'%.3f'%test_sn[i], 
              'sp:%.3f'%train_sp[i], '%.3f'%val_sp[i], '%.3f'%test_sp[i], 
              'acc:%.3f'%train_acc[i], '%.3f'%val_acc[i], '%.3f'%test_acc[i], 
              'mcc:%.3f'%train_mcc[i],'%.3f'%val_mcc[i],'%.3f'%test_mcc[i])
    
    stop_index = np.mean(val_auc) +  np.mean(val_sn) +  np.mean(val_sp) +  np.mean(val_acc) +  np.mean(val_mcc)
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
    
    if early_stop:
        continue


Epoch: 1 Task: 1 auc:0.836 0.815 0.839 sn:0.390 0.553 0.506 sp:0.935 0.885 0.884 acc:0.872 0.847 0.840 mcc:0.342 0.376 0.342
Epoch: 1 Task: 2 auc:0.880 0.783 0.738 sn:0.000 0.500 0.435 sp:1.000 0.957 0.955 acc:0.966 0.941 0.938 mcc:0.000 0.359 0.299
Epoch: 1 Task: 3 auc:0.706 0.695 0.728 sn:0.000 0.209 0.273 sp:1.000 0.956 0.955 acc:0.891 0.875 0.881 mcc:0.000 0.214 0.276
Epoch: 1 Task: 4 auc:0.701 0.780 0.652 sn:0.000 0.400 0.158 sp:1.000 0.923 0.930 acc:0.970 0.907 0.908 mcc:0.000 0.197 0.057
Epoch: 1 Task: 5 auc:0.651 0.714 0.680 sn:0.000 0.222 0.229 sp:1.000 0.919 0.913 acc:0.943 0.879 0.875 mcc:0.000 0.115 0.110
Epoch: 1 Task: 6 auc:0.725 0.767 0.716 sn:0.000 0.478 0.391 sp:1.000 0.855 0.839 acc:0.934 0.829 0.809 mcc:0.000 0.224 0.151
Epoch: 2 Task: 1 auc:0.851 0.805 0.838 sn:0.011 0.421 0.377 sp:1.000 0.943 0.945 acc:0.886 0.883 0.878 mcc:0.101 0.391 0.356
Epoch: 2 Task: 2 auc:0.878 0.804 0.779 sn:0.608 0.500 0.391 sp:0.990 0.991 0.985 acc:0.977 0.974 0.965 mcc:0.633 0.564 0.413


Epoch: 12 Task: 1 auc:0.926 0.867 0.912 sn:0.549 0.618 0.701 sp:0.966 0.902 0.895 acc:0.918 0.870 0.872 mcc:0.565 0.456 0.505
Epoch: 12 Task: 2 auc:0.932 0.899 0.869 sn:0.591 0.708 0.565 sp:0.997 0.982 0.965 acc:0.983 0.972 0.952 mcc:0.705 0.630 0.428
Epoch: 12 Task: 3 auc:0.792 0.770 0.804 sn:0.108 0.448 0.455 sp:0.997 0.971 0.947 acc:0.901 0.914 0.894 mcc:0.280 0.496 0.422
Epoch: 12 Task: 4 auc:0.864 0.855 0.827 sn:0.165 0.850 0.632 sp:0.997 0.759 0.779 acc:0.972 0.762 0.775 mcc:0.322 0.238 0.162
Epoch: 12 Task: 5 auc:0.838 0.805 0.888 sn:0.208 0.722 0.743 sp:0.997 0.877 0.827 acc:0.952 0.868 0.822 mcc:0.393 0.382 0.323
Epoch: 12 Task: 6 auc:0.845 0.828 0.849 sn:0.110 0.522 0.500 sp:0.996 0.885 0.906 acc:0.938 0.860 0.879 mcc:0.256 0.292 0.309
Epoch: 13 Task: 1 auc:0.924 0.857 0.907 sn:0.500 0.487 0.558 sp:0.969 0.954 0.941 acc:0.915 0.900 0.896 mcc:0.536 0.475 0.500
Epoch: 13 Task: 2 auc:0.926 0.905 0.873 sn:0.151 0.750 0.609 sp:0.999 0.954 0.953 acc:0.971 0.947 0.942 mcc:0.356 0.50

Epoch: 23 Task: 1 auc:0.940 0.882 0.917 sn:0.707 0.658 0.688 sp:0.942 0.916 0.909 acc:0.914 0.886 0.883 mcc:0.609 0.513 0.522
Epoch: 23 Task: 2 auc:0.957 0.930 0.846 sn:0.683 0.750 0.565 sp:0.994 0.982 0.970 acc:0.983 0.974 0.956 mcc:0.728 0.657 0.450
Epoch: 23 Task: 3 auc:0.838 0.784 0.797 sn:0.336 0.418 0.455 sp:0.985 0.964 0.947 acc:0.914 0.904 0.894 mcc:0.456 0.443 0.422
Epoch: 23 Task: 4 auc:0.920 0.909 0.856 sn:0.215 0.850 0.737 sp:0.996 0.840 0.832 acc:0.973 0.841 0.829 mcc:0.365 0.308 0.245
Epoch: 23 Task: 5 auc:0.876 0.815 0.879 sn:0.374 0.722 0.743 sp:0.991 0.864 0.824 acc:0.956 0.856 0.819 mcc:0.497 0.362 0.320
Epoch: 23 Task: 6 auc:0.889 0.920 0.878 sn:0.382 0.891 0.717 sp:0.971 0.839 0.841 acc:0.933 0.843 0.832 mcc:0.394 0.450 0.350
Epoch: 24 Task: 1 auc:0.940 0.879 0.913 sn:0.643 0.658 0.675 sp:0.957 0.911 0.907 acc:0.921 0.882 0.880 mcc:0.607 0.502 0.509
Epoch: 24 Task: 2 auc:0.957 0.927 0.862 sn:0.548 0.833 0.609 sp:0.999 0.982 0.967 acc:0.983 0.976 0.955 mcc:0.710 0.71

Epoch: 34 Task: 1 auc:0.955 0.864 0.924 sn:0.503 0.658 0.753 sp:0.987 0.906 0.895 acc:0.931 0.877 0.878 mcc:0.615 0.492 0.541
Epoch: 34 Task: 2 auc:0.967 0.944 0.893 sn:0.570 0.833 0.696 sp:0.999 0.974 0.959 acc:0.984 0.969 0.951 mcc:0.724 0.657 0.486
Epoch: 34 Task: 3 auc:0.851 0.767 0.800 sn:0.233 0.388 0.470 sp:0.997 0.964 0.945 acc:0.914 0.901 0.894 mcc:0.435 0.416 0.430
Epoch: 34 Task: 4 auc:0.946 0.952 0.870 sn:0.184 0.650 0.632 sp:1.000 0.961 0.953 acc:0.975 0.951 0.944 mcc:0.408 0.450 0.401
Epoch: 34 Task: 5 auc:0.902 0.827 0.892 sn:0.388 0.528 0.543 sp:0.993 0.968 0.958 acc:0.959 0.943 0.935 mcc:0.534 0.483 0.450
Epoch: 34 Task: 6 auc:0.928 0.902 0.872 sn:0.310 0.761 0.630 sp:0.993 0.864 0.889 acc:0.948 0.857 0.871 mcc:0.467 0.411 0.366
EarlyStopping counter: 2 out of 100
Epoch: 35 Task: 1 auc:0.952 0.876 0.929 sn:0.639 0.566 0.636 sp:0.973 0.942 0.941 acc:0.935 0.898 0.906 mcc:0.660 0.505 0.559
Epoch: 35 Task: 2 auc:0.963 0.921 0.855 sn:0.688 0.583 0.565 sp:0.997 1.000 0.994 

Epoch: 45 Task: 1 auc:0.971 0.883 0.921 sn:0.805 0.658 0.688 sp:0.970 0.937 0.931 acc:0.951 0.905 0.903 mcc:0.762 0.561 0.571
Epoch: 45 Task: 2 auc:0.977 0.954 0.857 sn:0.710 0.792 0.652 sp:0.996 0.980 0.961 acc:0.986 0.974 0.951 mcc:0.776 0.673 0.466
Epoch: 45 Task: 3 auc:0.895 0.808 0.808 sn:0.494 0.567 0.576 sp:0.981 0.911 0.895 acc:0.928 0.873 0.860 mcc:0.576 0.427 0.401
Epoch: 45 Task: 4 auc:0.960 0.966 0.911 sn:0.291 0.600 0.316 sp:0.999 0.991 0.992 acc:0.978 0.979 0.973 mcc:0.516 0.622 0.402
Epoch: 45 Task: 5 auc:0.932 0.837 0.909 sn:0.550 0.583 0.629 sp:0.991 0.949 0.961 acc:0.966 0.929 0.943 mcc:0.643 0.453 0.525
Epoch: 45 Task: 6 auc:0.956 0.901 0.866 sn:0.503 0.739 0.674 sp:0.992 0.897 0.890 acc:0.960 0.887 0.876 mcc:0.622 0.453 0.396
Epoch: 46 Task: 1 auc:0.957 0.873 0.907 sn:0.467 0.553 0.610 sp:0.994 0.971 0.953 acc:0.933 0.923 0.913 mcc:0.624 0.586 0.574
Epoch: 46 Task: 2 auc:0.970 0.947 0.848 sn:0.726 0.875 0.739 sp:0.996 0.965 0.952 acc:0.987 0.962 0.945 mcc:0.786 0.63

Epoch: 56 Task: 1 auc:0.982 0.882 0.923 sn:0.679 0.776 0.844 sp:0.991 0.875 0.869 acc:0.955 0.864 0.866 mcc:0.764 0.520 0.559
Epoch: 56 Task: 2 auc:0.974 0.831 0.912 sn:0.634 0.542 0.609 sp:0.999 0.998 0.979 acc:0.987 0.982 0.967 mcc:0.781 0.702 0.535
Epoch: 56 Task: 3 auc:0.927 0.799 0.821 sn:0.675 0.388 0.424 sp:0.959 0.971 0.971 acc:0.928 0.907 0.912 mcc:0.632 0.443 0.475
Epoch: 56 Task: 4 auc:0.969 0.955 0.904 sn:0.430 0.800 0.526 sp:0.999 0.972 0.961 acc:0.982 0.967 0.949 mcc:0.631 0.599 0.364
Epoch: 56 Task: 5 auc:0.948 0.821 0.852 sn:0.481 0.611 0.600 sp:0.996 0.912 0.903 acc:0.967 0.895 0.886 mcc:0.641 0.377 0.348
Epoch: 56 Task: 6 auc:0.960 0.862 0.803 sn:0.536 0.435 0.457 sp:0.992 0.962 0.966 acc:0.962 0.926 0.932 mcc:0.644 0.405 0.436
EarlyStopping counter: 1 out of 100
Epoch: 57 Task: 1 auc:0.967 0.868 0.886 sn:0.561 0.539 0.545 sp:0.993 0.962 0.952 acc:0.943 0.914 0.904 mcc:0.689 0.545 0.518
Epoch: 57 Task: 2 auc:0.974 0.947 0.887 sn:0.683 0.833 0.739 sp:0.997 0.985 0.977 

Epoch: 67 Task: 1 auc:0.984 0.875 0.931 sn:0.851 0.671 0.753 sp:0.975 0.932 0.933 acc:0.961 0.902 0.912 mcc:0.811 0.558 0.622
Epoch: 67 Task: 2 auc:0.985 0.955 0.852 sn:0.753 0.792 0.696 sp:0.996 0.971 0.940 acc:0.987 0.965 0.932 mcc:0.798 0.613 0.418
Epoch: 67 Task: 3 auc:0.945 0.771 0.838 sn:0.601 0.328 0.348 sp:0.986 0.985 0.989 acc:0.944 0.914 0.920 mcc:0.684 0.454 0.493
Epoch: 67 Task: 4 auc:0.984 0.978 0.842 sn:0.506 0.900 0.579 sp:0.998 0.939 0.930 acc:0.983 0.938 0.920 mcc:0.660 0.512 0.305
Epoch: 67 Task: 5 auc:0.974 0.838 0.904 sn:0.671 0.778 0.800 sp:0.995 0.899 0.876 acc:0.977 0.892 0.871 mcc:0.763 0.453 0.420
Epoch: 67 Task: 6 auc:0.985 0.917 0.867 sn:0.714 0.630 0.587 sp:0.990 0.940 0.932 acc:0.972 0.919 0.909 mcc:0.758 0.481 0.426
EarlyStopping counter: 7 out of 100
Epoch: 68 Task: 1 auc:0.985 0.870 0.926 sn:0.711 0.684 0.779 sp:0.993 0.918 0.916 acc:0.961 0.891 0.900 mcc:0.795 0.536 0.601
Epoch: 68 Task: 2 auc:0.987 0.938 0.854 sn:0.726 0.792 0.696 sp:0.999 0.979 0.970 

Epoch: 78 Task: 1 auc:0.987 0.870 0.927 sn:0.864 0.645 0.675 sp:0.978 0.950 0.950 acc:0.965 0.915 0.918 mcc:0.832 0.588 0.612
Epoch: 78 Task: 2 auc:0.991 0.908 0.851 sn:0.710 0.625 0.696 sp:0.999 0.992 0.983 acc:0.990 0.979 0.974 mcc:0.828 0.674 0.629
Epoch: 78 Task: 3 auc:0.956 0.787 0.816 sn:0.552 0.403 0.364 sp:0.996 0.976 0.973 acc:0.948 0.914 0.907 mcc:0.699 0.479 0.427
Epoch: 78 Task: 4 auc:0.988 0.971 0.857 sn:0.639 0.850 0.579 sp:0.997 0.942 0.955 acc:0.986 0.939 0.944 mcc:0.740 0.495 0.374
Epoch: 78 Task: 5 auc:0.980 0.838 0.882 sn:0.675 0.611 0.600 sp:0.996 0.926 0.933 acc:0.978 0.908 0.914 mcc:0.773 0.407 0.413
Epoch: 78 Task: 6 auc:0.983 0.880 0.855 sn:0.810 0.587 0.543 sp:0.977 0.950 0.952 acc:0.966 0.925 0.925 mcc:0.742 0.479 0.453
EarlyStopping counter: 18 out of 100
Epoch: 79 Task: 1 auc:0.992 0.869 0.905 sn:0.856 0.684 0.727 sp:0.988 0.921 0.907 acc:0.973 0.894 0.886 mcc:0.864 0.544 0.546
Epoch: 79 Task: 2 auc:0.988 0.919 0.865 sn:0.774 0.875 0.783 sp:0.999 0.963 0.944

Epoch: 89 Task: 1 auc:0.994 0.855 0.909 sn:0.807 0.605 0.649 sp:0.997 0.947 0.941 acc:0.975 0.908 0.907 mcc:0.874 0.549 0.569
Epoch: 89 Task: 2 auc:0.991 0.918 0.871 sn:0.823 0.792 0.826 sp:0.997 0.948 0.941 acc:0.991 0.943 0.938 mcc:0.853 0.509 0.497
Epoch: 89 Task: 3 auc:0.977 0.790 0.844 sn:0.700 0.328 0.379 sp:0.993 0.987 0.982 acc:0.961 0.915 0.917 mcc:0.787 0.464 0.482
Epoch: 89 Task: 4 auc:0.993 0.978 0.859 sn:0.715 0.700 0.316 sp:0.998 0.991 0.988 acc:0.990 0.982 0.968 mcc:0.813 0.691 0.352
Epoch: 89 Task: 5 auc:0.987 0.868 0.877 sn:0.675 0.583 0.514 sp:0.999 0.953 0.951 acc:0.981 0.932 0.927 mcc:0.807 0.465 0.406
Epoch: 89 Task: 6 auc:0.995 0.874 0.835 sn:0.745 0.609 0.565 sp:0.997 0.948 0.957 acc:0.980 0.925 0.931 mcc:0.829 0.489 0.485
EarlyStopping counter: 6 out of 100
Epoch: 90 Task: 1 auc:0.993 0.853 0.903 sn:0.795 0.605 0.597 sp:0.996 0.967 0.950 acc:0.973 0.926 0.909 mcc:0.860 0.614 0.554
Epoch: 90 Task: 2 auc:0.992 0.914 0.841 sn:0.812 0.750 0.696 sp:0.996 0.994 0.985 

Epoch: 100 Task: 1 auc:0.995 0.858 0.900 sn:0.918 0.684 0.649 sp:0.986 0.945 0.940 acc:0.978 0.915 0.906 mcc:0.893 0.603 0.565
Epoch: 100 Task: 2 auc:0.993 0.919 0.851 sn:0.747 0.875 0.696 sp:0.999 0.966 0.943 acc:0.991 0.963 0.935 mcc:0.851 0.638 0.427
Epoch: 100 Task: 3 auc:0.979 0.784 0.837 sn:0.797 0.388 0.379 sp:0.986 0.973 0.962 acc:0.965 0.909 0.899 mcc:0.814 0.451 0.401
Epoch: 100 Task: 4 auc:0.997 0.980 0.821 sn:0.741 0.750 0.316 sp:1.000 0.995 0.989 acc:0.992 0.988 0.970 mcc:0.850 0.784 0.367
Epoch: 100 Task: 5 auc:0.991 0.839 0.891 sn:0.761 0.667 0.800 sp:0.996 0.862 0.829 acc:0.983 0.851 0.827 mcc:0.830 0.328 0.356
Epoch: 100 Task: 6 auc:0.998 0.873 0.813 sn:0.816 0.565 0.609 sp:0.998 0.920 0.915 acc:0.986 0.896 0.895 mcc:0.880 0.384 0.401
EarlyStopping counter: 17 out of 100
Epoch: 101 Task: 1 auc:0.995 0.844 0.877 sn:0.903 0.605 0.597 sp:0.989 0.959 0.947 acc:0.979 0.918 0.906 mcc:0.898 0.585 0.544
Epoch: 101 Task: 2 auc:0.991 0.952 0.862 sn:0.769 0.833 0.739 sp:0.999 0.9

Epoch: 112 Task: 1 auc:0.997 0.864 0.897 sn:0.957 0.711 0.662 sp:0.987 0.937 0.929 acc:0.983 0.911 0.898 mcc:0.920 0.599 0.548
Epoch: 112 Task: 2 auc:0.994 0.956 0.850 sn:0.817 0.792 0.696 sp:0.998 0.989 0.967 acc:0.992 0.982 0.958 mcc:0.869 0.752 0.521
Epoch: 112 Task: 3 auc:0.989 0.796 0.859 sn:0.828 0.373 0.379 sp:0.992 0.982 0.971 acc:0.975 0.915 0.907 mcc:0.864 0.477 0.434
Epoch: 112 Task: 4 auc:0.996 0.955 0.851 sn:0.873 0.750 0.368 sp:0.997 0.991 0.989 acc:0.993 0.983 0.971 mcc:0.881 0.723 0.415
Epoch: 112 Task: 5 auc:0.996 0.842 0.843 sn:0.827 0.583 0.571 sp:0.999 0.921 0.909 acc:0.989 0.902 0.890 mcc:0.891 0.377 0.342
Epoch: 112 Task: 6 auc:0.998 0.881 0.811 sn:0.898 0.587 0.522 sp:0.996 0.946 0.952 acc:0.990 0.922 0.923 mcc:0.915 0.469 0.436
EarlyStopping counter: 29 out of 100
Epoch: 113 Task: 1 auc:0.996 0.863 0.890 sn:0.879 0.750 0.727 sp:0.996 0.901 0.883 acc:0.982 0.883 0.865 mcc:0.911 0.548 0.501
Epoch: 113 Task: 2 auc:0.996 0.913 0.863 sn:0.806 0.708 0.696 sp:1.000 0.9

Epoch: 123 Task: 1 auc:0.998 0.850 0.874 sn:0.838 0.632 0.623 sp:0.999 0.950 0.933 acc:0.981 0.914 0.896 mcc:0.902 0.579 0.528
Epoch: 123 Task: 2 auc:0.995 0.907 0.841 sn:0.801 0.667 0.696 sp:0.999 0.986 0.982 acc:0.992 0.975 0.972 mcc:0.877 0.640 0.616
Epoch: 123 Task: 3 auc:0.989 0.785 0.830 sn:0.884 0.463 0.530 sp:0.981 0.949 0.933 acc:0.970 0.896 0.890 mcc:0.850 0.435 0.446
Epoch: 123 Task: 4 auc:0.997 0.972 0.828 sn:0.848 0.750 0.316 sp:0.999 0.994 0.991 acc:0.995 0.986 0.971 mcc:0.905 0.762 0.384
Epoch: 123 Task: 5 auc:0.997 0.810 0.830 sn:0.830 0.472 0.429 sp:0.998 0.963 0.960 acc:0.989 0.935 0.930 mcc:0.891 0.419 0.369
Epoch: 123 Task: 6 auc:0.996 0.804 0.777 sn:0.830 0.413 0.500 sp:0.998 0.979 0.971 acc:0.987 0.941 0.939 mcc:0.890 0.466 0.491
EarlyStopping counter: 40 out of 100
Epoch: 124 Task: 1 auc:0.994 0.843 0.898 sn:0.879 0.645 0.727 sp:0.995 0.930 0.914 acc:0.982 0.897 0.892 mcc:0.908 0.534 0.561
Epoch: 124 Task: 2 auc:0.996 0.947 0.830 sn:0.801 0.625 0.696 sp:0.999 0.9

Epoch: 134 Task: 1 auc:0.998 0.851 0.871 sn:0.885 0.553 0.662 sp:0.998 0.961 0.933 acc:0.985 0.914 0.901 mcc:0.926 0.550 0.557
Epoch: 134 Task: 2 auc:0.996 0.941 0.879 sn:0.849 0.833 0.739 sp:0.999 0.986 0.982 acc:0.994 0.981 0.974 mcc:0.910 0.748 0.645
Epoch: 134 Task: 3 auc:0.994 0.775 0.833 sn:0.797 0.493 0.515 sp:0.997 0.945 0.960 acc:0.976 0.896 0.912 mcc:0.869 0.450 0.511
Epoch: 134 Task: 4 auc:0.999 0.954 0.777 sn:0.785 0.750 0.474 sp:1.000 0.980 0.981 acc:0.994 0.973 0.967 mcc:0.883 0.621 0.433
Epoch: 134 Task: 5 auc:0.997 0.859 0.902 sn:0.872 0.444 0.543 sp:0.998 0.970 0.970 acc:0.991 0.940 0.946 mcc:0.915 0.425 0.499
Epoch: 134 Task: 6 auc:0.998 0.878 0.819 sn:0.901 0.717 0.674 sp:0.998 0.904 0.870 acc:0.992 0.891 0.857 mcc:0.929 0.452 0.364
EarlyStopping counter: 51 out of 100
Epoch: 135 Task: 1 auc:0.998 0.846 0.874 sn:0.948 0.724 0.740 sp:0.991 0.908 0.879 acc:0.986 0.886 0.863 mcc:0.934 0.543 0.505
Epoch: 135 Task: 2 auc:0.996 0.949 0.836 sn:0.833 0.750 0.739 sp:1.000 0.9

Epoch: 145 Task: 1 auc:0.999 0.831 0.887 sn:0.926 0.605 0.727 sp:0.998 0.932 0.914 acc:0.990 0.894 0.892 mcc:0.950 0.509 0.561
Epoch: 145 Task: 2 auc:0.997 0.920 0.862 sn:0.812 0.792 0.739 sp:1.000 0.980 0.970 acc:0.993 0.974 0.962 mcc:0.895 0.673 0.565
Epoch: 145 Task: 3 auc:0.994 0.782 0.809 sn:0.920 0.403 0.379 sp:0.990 0.969 0.960 acc:0.982 0.907 0.898 mcc:0.907 0.450 0.395
Epoch: 145 Task: 4 auc:0.999 0.963 0.843 sn:0.905 0.850 0.368 sp:0.999 0.966 0.978 acc:0.996 0.962 0.961 mcc:0.936 0.593 0.330
Epoch: 145 Task: 5 auc:0.998 0.845 0.844 sn:0.900 0.556 0.686 sp:0.997 0.934 0.916 acc:0.992 0.913 0.903 mcc:0.921 0.390 0.428
Epoch: 145 Task: 6 auc:0.999 0.897 0.793 sn:0.940 0.674 0.609 sp:0.995 0.937 0.924 acc:0.991 0.919 0.903 mcc:0.928 0.502 0.422
EarlyStopping counter: 3 out of 100
Epoch: 146 Task: 1 auc:0.998 0.837 0.885 sn:0.910 0.645 0.675 sp:0.996 0.947 0.916 acc:0.986 0.912 0.887 mcc:0.932 0.579 0.527
Epoch: 146 Task: 2 auc:0.998 0.959 0.835 sn:0.839 0.875 0.739 sp:0.999 0.98

Epoch: 156 Task: 1 auc:0.999 0.842 0.873 sn:0.921 0.671 0.753 sp:0.998 0.930 0.878 acc:0.989 0.900 0.863 mcc:0.947 0.554 0.511
Epoch: 156 Task: 2 auc:0.998 0.923 0.827 sn:0.839 0.833 0.739 sp:1.000 0.979 0.965 acc:0.994 0.974 0.958 mcc:0.907 0.687 0.541
Epoch: 156 Task: 3 auc:0.996 0.751 0.830 sn:0.912 0.537 0.636 sp:0.996 0.925 0.889 acc:0.987 0.883 0.862 mcc:0.932 0.435 0.436
Epoch: 156 Task: 4 auc:0.999 0.947 0.821 sn:0.892 0.700 0.368 sp:0.999 0.989 0.989 acc:0.996 0.980 0.971 mcc:0.926 0.673 0.415
Epoch: 156 Task: 5 auc:0.999 0.791 0.861 sn:0.841 0.444 0.429 sp:0.999 0.983 0.982 acc:0.990 0.952 0.951 mcc:0.901 0.499 0.472
Epoch: 156 Task: 6 auc:0.998 0.812 0.797 sn:0.923 0.543 0.630 sp:0.999 0.943 0.916 acc:0.994 0.916 0.897 mcc:0.951 0.428 0.419
EarlyStopping counter: 14 out of 100
Epoch: 157 Task: 1 auc:0.999 0.853 0.891 sn:0.975 0.632 0.688 sp:0.994 0.952 0.926 acc:0.992 0.915 0.898 mcc:0.960 0.584 0.559
Epoch: 157 Task: 2 auc:0.993 0.975 0.850 sn:0.849 0.917 0.739 sp:1.000 0.9

Epoch: 167 Task: 1 auc:0.999 0.841 0.860 sn:0.970 0.645 0.714 sp:0.996 0.937 0.898 acc:0.993 0.903 0.877 mcc:0.966 0.551 0.520
Epoch: 167 Task: 2 auc:0.999 0.936 0.838 sn:0.898 0.875 0.739 sp:1.000 0.977 0.967 acc:0.996 0.974 0.959 mcc:0.943 0.702 0.549
Epoch: 167 Task: 3 auc:0.997 0.756 0.825 sn:0.886 0.358 0.364 sp:0.998 0.985 0.982 acc:0.986 0.917 0.916 mcc:0.925 0.482 0.468
Epoch: 167 Task: 4 auc:0.999 0.955 0.837 sn:0.905 0.750 0.474 sp:0.999 0.983 0.972 acc:0.996 0.976 0.958 mcc:0.933 0.646 0.376
Epoch: 167 Task: 5 auc:0.988 0.807 0.815 sn:0.779 0.500 0.514 sp:1.000 0.960 0.936 acc:0.987 0.933 0.913 mcc:0.876 0.428 0.363
Epoch: 167 Task: 6 auc:0.994 0.829 0.817 sn:0.901 0.652 0.630 sp:0.999 0.912 0.890 acc:0.993 0.894 0.873 mcc:0.940 0.426 0.369
EarlyStopping counter: 25 out of 100
Epoch: 168 Task: 1 auc:0.997 0.818 0.841 sn:0.933 0.684 0.727 sp:0.997 0.914 0.876 acc:0.989 0.888 0.858 mcc:0.946 0.529 0.490
Epoch: 168 Task: 2 auc:0.999 0.921 0.862 sn:0.876 0.833 0.696 sp:1.000 0.9

Epoch: 178 Task: 1 auc:0.999 0.835 0.885 sn:0.982 0.605 0.597 sp:0.988 0.962 0.926 acc:0.987 0.921 0.887 mcc:0.939 0.596 0.492
Epoch: 178 Task: 2 auc:0.997 0.909 0.817 sn:0.855 0.792 0.696 sp:1.000 0.991 0.983 acc:0.995 0.984 0.974 mcc:0.922 0.767 0.629
Epoch: 178 Task: 3 auc:0.999 0.746 0.839 sn:0.948 0.507 0.621 sp:0.996 0.940 0.893 acc:0.991 0.893 0.864 mcc:0.952 0.447 0.431
Epoch: 178 Task: 4 auc:0.999 0.954 0.834 sn:0.892 0.850 0.474 sp:0.999 0.977 0.977 acc:0.996 0.973 0.962 mcc:0.930 0.660 0.402
Epoch: 178 Task: 5 auc:0.999 0.799 0.871 sn:0.865 0.417 0.400 sp:1.000 0.981 0.980 acc:0.992 0.949 0.948 mcc:0.926 0.465 0.437
Epoch: 178 Task: 6 auc:0.999 0.868 0.807 sn:0.929 0.630 0.587 sp:0.998 0.961 0.941 acc:0.994 0.938 0.918 mcc:0.949 0.549 0.451
EarlyStopping counter: 36 out of 100
Epoch: 179 Task: 1 auc:0.999 0.828 0.893 sn:0.984 0.566 0.636 sp:0.993 0.957 0.929 acc:0.991 0.912 0.895 mcc:0.959 0.549 0.529
Epoch: 179 Task: 2 auc:0.998 0.940 0.853 sn:0.898 0.792 0.739 sp:0.999 0.9

Epoch: 203 Task: 1 auc:0.999 0.846 0.859 sn:0.946 0.539 0.623 sp:0.998 0.969 0.929 acc:0.992 0.920 0.893 mcc:0.960 0.569 0.519
Epoch: 203 Task: 2 auc:0.999 0.863 0.833 sn:0.925 0.667 0.739 sp:0.998 0.985 0.977 acc:0.996 0.974 0.969 mcc:0.935 0.627 0.612
Epoch: 203 Task: 3 auc:0.999 0.746 0.820 sn:0.944 0.537 0.621 sp:0.993 0.922 0.889 acc:0.988 0.880 0.860 mcc:0.936 0.427 0.425
Epoch: 203 Task: 4 auc:0.999 0.910 0.784 sn:0.937 0.650 0.368 sp:0.999 0.994 0.992 acc:0.997 0.983 0.974 mcc:0.954 0.697 0.451
Epoch: 203 Task: 5 auc:0.999 0.783 0.877 sn:0.931 0.444 0.600 sp:0.997 0.963 0.931 acc:0.993 0.933 0.913 mcc:0.937 0.397 0.408
Epoch: 203 Task: 6 auc:0.999 0.839 0.774 sn:0.956 0.565 0.587 sp:0.996 0.954 0.932 acc:0.993 0.928 0.909 mcc:0.945 0.478 0.426
EarlyStopping counter: 61 out of 100
Epoch: 204 Task: 1 auc:0.999 0.840 0.883 sn:0.964 0.645 0.623 sp:0.996 0.957 0.928 acc:0.993 0.921 0.892 mcc:0.964 0.609 0.515
Epoch: 204 Task: 2 auc:1.000 0.900 0.831 sn:0.882 0.792 0.609 sp:1.000 0.9

Epoch: 214 Task: 1 auc:1.000 0.826 0.879 sn:0.987 0.566 0.649 sp:0.997 0.969 0.943 acc:0.995 0.923 0.909 mcc:0.978 0.590 0.574
Epoch: 214 Task: 2 auc:1.000 0.905 0.848 sn:0.957 0.792 0.739 sp:0.999 0.991 0.974 acc:0.997 0.984 0.967 mcc:0.961 0.767 0.592
Epoch: 214 Task: 3 auc:0.998 0.766 0.837 sn:0.976 0.463 0.576 sp:0.992 0.940 0.924 acc:0.990 0.888 0.886 mcc:0.951 0.411 0.459
Epoch: 214 Task: 4 auc:0.999 0.949 0.786 sn:0.968 0.750 0.474 sp:0.997 0.980 0.978 acc:0.996 0.973 0.964 mcc:0.934 0.621 0.412
Epoch: 214 Task: 5 auc:1.000 0.829 0.818 sn:0.969 0.417 0.457 sp:0.998 0.985 0.982 acc:0.996 0.952 0.952 mcc:0.967 0.487 0.496
Epoch: 214 Task: 6 auc:1.000 0.865 0.795 sn:0.989 0.696 0.652 sp:0.997 0.924 0.879 acc:0.997 0.909 0.864 mcc:0.973 0.483 0.365
EarlyStopping counter: 72 out of 100
Epoch: 215 Task: 1 auc:1.000 0.832 0.881 sn:0.948 0.658 0.701 sp:0.999 0.947 0.897 acc:0.993 0.914 0.874 mcc:0.967 0.588 0.508
Epoch: 215 Task: 2 auc:1.000 0.921 0.814 sn:0.919 0.833 0.739 sp:1.000 0.9

Epoch: 225 Task: 1 auc:1.000 0.833 0.878 sn:0.982 0.684 0.727 sp:0.998 0.933 0.903 acc:0.996 0.905 0.883 mcc:0.981 0.572 0.539
Epoch: 225 Task: 2 auc:0.999 0.923 0.866 sn:0.898 0.875 0.739 sp:0.999 0.982 0.980 acc:0.996 0.978 0.972 mcc:0.934 0.736 0.633
Epoch: 225 Task: 3 auc:0.997 0.741 0.787 sn:0.940 0.418 0.470 sp:0.992 0.953 0.935 acc:0.986 0.894 0.885 mcc:0.929 0.408 0.402
Epoch: 225 Task: 4 auc:1.000 0.947 0.756 sn:0.949 0.650 0.474 sp:0.999 0.994 0.989 acc:0.998 0.983 0.974 mcc:0.960 0.697 0.503
Epoch: 225 Task: 5 auc:0.999 0.815 0.829 sn:0.893 0.583 0.571 sp:0.999 0.958 0.936 acc:0.993 0.937 0.916 mcc:0.932 0.483 0.402
Epoch: 225 Task: 6 auc:0.999 0.886 0.830 sn:0.931 0.587 0.609 sp:0.999 0.953 0.938 acc:0.994 0.928 0.916 mcc:0.952 0.489 0.458
EarlyStopping counter: 83 out of 100
Epoch: 226 Task: 1 auc:1.000 0.846 0.882 sn:0.967 0.684 0.701 sp:0.997 0.942 0.905 acc:0.994 0.912 0.881 mcc:0.970 0.594 0.524
Epoch: 226 Task: 2 auc:1.000 0.901 0.833 sn:0.898 0.833 0.783 sp:1.000 0.9

Epoch: 236 Task: 1 auc:1.000 0.840 0.871 sn:0.982 0.658 0.753 sp:0.998 0.938 0.905 acc:0.996 0.906 0.887 mcc:0.982 0.565 0.561
Epoch: 236 Task: 2 auc:1.000 0.918 0.826 sn:0.903 0.875 0.696 sp:1.000 0.985 0.989 acc:0.997 0.981 0.980 mcc:0.949 0.761 0.685
Epoch: 236 Task: 3 auc:0.999 0.720 0.825 sn:0.966 0.388 0.455 sp:0.996 0.971 0.949 acc:0.993 0.907 0.896 mcc:0.962 0.443 0.427
Epoch: 236 Task: 4 auc:1.000 0.943 0.838 sn:0.937 0.800 0.474 sp:1.000 0.969 0.960 acc:0.998 0.964 0.946 mcc:0.967 0.580 0.323
Epoch: 236 Task: 5 auc:1.000 0.817 0.873 sn:0.941 0.444 0.429 sp:1.000 0.973 0.966 acc:0.996 0.943 0.937 mcc:0.965 0.441 0.395
Epoch: 236 Task: 6 auc:1.000 0.820 0.799 sn:0.975 0.565 0.543 sp:0.998 0.964 0.950 acc:0.997 0.937 0.923 mcc:0.975 0.514 0.448
EarlyStopping counter: 94 out of 100
Epoch: 237 Task: 1 auc:0.998 0.835 0.895 sn:0.980 0.658 0.766 sp:0.992 0.949 0.905 acc:0.991 0.915 0.889 mcc:0.955 0.593 0.570
Epoch: 237 Task: 2 auc:0.999 0.961 0.835 sn:0.930 0.917 0.739 sp:0.999 0.9

Epoch: 247 Task: 1 auc:1.000 0.837 0.863 sn:0.980 0.724 0.714 sp:0.997 0.909 0.881 acc:0.995 0.888 0.861 mcc:0.976 0.546 0.489
Epoch: 247 Task: 2 auc:0.972 0.837 0.806 sn:0.769 0.625 0.652 sp:1.000 0.986 0.982 acc:0.992 0.974 0.971 mcc:0.873 0.611 0.587
Epoch: 247 Task: 3 auc:0.995 0.740 0.798 sn:0.896 0.328 0.439 sp:0.996 0.976 0.969 acc:0.985 0.906 0.912 mcc:0.922 0.410 0.481
Epoch: 247 Task: 4 auc:1.000 0.983 0.823 sn:0.987 0.950 0.632 sp:0.999 0.947 0.938 acc:0.999 0.947 0.929 mcc:0.981 0.566 0.353
Epoch: 247 Task: 5 auc:0.999 0.790 0.828 sn:0.962 0.472 0.571 sp:0.996 0.960 0.941 acc:0.994 0.932 0.921 mcc:0.947 0.406 0.416
Epoch: 247 Task: 6 auc:0.990 0.776 0.729 sn:0.918 0.522 0.543 sp:0.997 0.937 0.923 acc:0.992 0.909 0.897 mcc:0.934 0.394 0.373
EarlyStopping counter: 105 out of 100
Epoch: 248 Task: 1 auc:0.999 0.832 0.877 sn:0.967 0.671 0.688 sp:0.998 0.928 0.921 acc:0.995 0.898 0.893 mcc:0.974 0.550 0.548
Epoch: 248 Task: 2 auc:0.999 0.916 0.811 sn:0.919 0.833 0.696 sp:1.000 0.

Epoch: 258 Task: 1 auc:1.000 0.827 0.895 sn:0.995 0.724 0.753 sp:0.993 0.932 0.879 acc:0.993 0.908 0.865 mcc:0.966 0.596 0.514
Epoch: 258 Task: 2 auc:0.999 0.911 0.844 sn:0.968 0.833 0.739 sp:0.998 0.954 0.944 acc:0.997 0.950 0.938 mcc:0.953 0.557 0.457
Epoch: 258 Task: 3 auc:0.997 0.702 0.831 sn:0.922 0.433 0.485 sp:0.998 0.958 0.929 acc:0.990 0.901 0.881 mcc:0.948 0.438 0.401
Epoch: 258 Task: 4 auc:0.999 0.951 0.775 sn:0.962 0.650 0.368 sp:0.999 0.994 0.992 acc:0.998 0.983 0.974 mcc:0.967 0.697 0.451
Epoch: 258 Task: 5 auc:0.998 0.819 0.828 sn:0.900 0.583 0.543 sp:0.998 0.943 0.909 acc:0.992 0.922 0.889 mcc:0.927 0.433 0.324
Epoch: 258 Task: 6 auc:0.999 0.852 0.789 sn:0.953 0.587 0.587 sp:0.998 0.953 0.920 acc:0.995 0.928 0.897 mcc:0.963 0.489 0.397
EarlyStopping counter: 116 out of 100
Epoch: 259 Task: 1 auc:1.000 0.826 0.876 sn:0.980 0.671 0.688 sp:0.999 0.938 0.919 acc:0.997 0.908 0.892 mcc:0.984 0.575 0.544
Epoch: 259 Task: 2 auc:0.999 0.905 0.830 sn:0.919 0.833 0.739 sp:1.000 0.

Epoch: 269 Task: 1 auc:1.000 0.821 0.877 sn:0.975 0.658 0.727 sp:0.999 0.932 0.902 acc:0.997 0.900 0.881 mcc:0.983 0.548 0.536
Epoch: 269 Task: 2 auc:1.000 0.923 0.896 sn:0.952 0.708 0.696 sp:1.000 0.988 0.979 acc:0.998 0.978 0.969 mcc:0.969 0.683 0.594
Epoch: 269 Task: 3 auc:0.999 0.728 0.853 sn:0.953 0.418 0.439 sp:0.997 0.965 0.942 acc:0.993 0.906 0.888 mcc:0.961 0.449 0.395
Epoch: 269 Task: 4 auc:1.000 0.945 0.791 sn:0.949 0.950 0.474 sp:0.999 0.945 0.928 acc:0.998 0.945 0.915 mcc:0.957 0.560 0.243
Epoch: 269 Task: 5 auc:0.996 0.783 0.830 sn:0.913 0.417 0.543 sp:1.000 0.963 0.939 acc:0.995 0.932 0.917 mcc:0.951 0.375 0.391
Epoch: 269 Task: 6 auc:1.000 0.811 0.780 sn:0.942 0.565 0.630 sp:0.999 0.946 0.907 acc:0.996 0.921 0.889 mcc:0.964 0.453 0.400
EarlyStopping counter: 127 out of 100
Epoch: 270 Task: 1 auc:1.000 0.833 0.870 sn:0.967 0.671 0.714 sp:0.999 0.947 0.921 acc:0.995 0.915 0.896 mcc:0.976 0.598 0.566
Epoch: 270 Task: 2 auc:1.000 0.930 0.865 sn:0.925 0.792 0.696 sp:1.000 0.

Epoch: 280 Task: 1 auc:1.000 0.842 0.865 sn:0.975 0.697 0.688 sp:1.000 0.940 0.905 acc:0.997 0.912 0.880 mcc:0.984 0.599 0.515
Epoch: 280 Task: 2 auc:1.000 0.882 0.831 sn:0.925 0.708 0.652 sp:1.000 0.992 0.989 acc:0.997 0.982 0.978 mcc:0.955 0.731 0.656
Epoch: 280 Task: 3 auc:0.999 0.718 0.816 sn:0.927 0.388 0.424 sp:0.999 0.958 0.945 acc:0.991 0.896 0.890 mcc:0.952 0.398 0.392
Epoch: 280 Task: 4 auc:0.999 0.950 0.807 sn:0.943 0.850 0.474 sp:0.999 0.969 0.958 acc:0.998 0.965 0.944 mcc:0.960 0.610 0.318
Epoch: 280 Task: 5 auc:1.000 0.791 0.827 sn:0.945 0.472 0.457 sp:0.998 0.963 0.958 acc:0.995 0.935 0.930 mcc:0.954 0.419 0.385
Epoch: 280 Task: 6 auc:1.000 0.799 0.757 sn:0.989 0.543 0.543 sp:0.998 0.957 0.912 acc:0.997 0.929 0.887 mcc:0.977 0.473 0.351
EarlyStopping counter: 138 out of 100
Epoch: 281 Task: 1 auc:0.999 0.835 0.865 sn:0.966 0.645 0.623 sp:0.998 0.955 0.921 acc:0.995 0.920 0.886 mcc:0.973 0.604 0.500
Epoch: 281 Task: 2 auc:1.000 0.912 0.815 sn:0.887 0.792 0.652 sp:0.999 0.

Epoch: 291 Task: 1 auc:1.000 0.819 0.884 sn:0.993 0.671 0.688 sp:0.995 0.949 0.914 acc:0.995 0.917 0.887 mcc:0.974 0.603 0.533
Epoch: 291 Task: 2 auc:1.000 0.909 0.823 sn:0.925 0.833 0.739 sp:1.000 0.977 0.971 acc:0.997 0.972 0.964 mcc:0.955 0.677 0.574
Epoch: 291 Task: 3 auc:0.999 0.691 0.820 sn:0.963 0.373 0.424 sp:0.994 0.962 0.945 acc:0.990 0.898 0.890 mcc:0.951 0.397 0.392
Epoch: 291 Task: 4 auc:0.998 0.932 0.834 sn:0.892 0.750 0.526 sp:1.000 0.980 0.974 acc:0.997 0.973 0.961 mcc:0.943 0.621 0.422
Epoch: 291 Task: 5 auc:1.000 0.815 0.806 sn:0.948 0.500 0.514 sp:0.999 0.965 0.948 acc:0.996 0.938 0.924 mcc:0.959 0.448 0.395
Epoch: 291 Task: 6 auc:0.999 0.792 0.784 sn:0.970 0.478 0.478 sp:0.998 0.973 0.961 acc:0.996 0.940 0.929 mcc:0.966 0.488 0.435
EarlyStopping counter: 149 out of 100
Epoch: 292 Task: 1 auc:1.000 0.830 0.865 sn:0.984 0.697 0.714 sp:0.998 0.932 0.903 acc:0.997 0.905 0.881 mcc:0.983 0.577 0.530
Epoch: 292 Task: 2 auc:1.000 0.896 0.837 sn:0.952 0.708 0.696 sp:0.999 0.

Epoch: 302 Task: 1 auc:1.000 0.835 0.901 sn:0.984 0.645 0.714 sp:0.998 0.961 0.928 acc:0.996 0.924 0.903 mcc:0.981 0.620 0.582
Epoch: 302 Task: 2 auc:0.998 0.915 0.854 sn:0.941 0.833 0.739 sp:1.000 0.985 0.988 acc:0.998 0.979 0.980 mcc:0.969 0.735 0.698
Epoch: 302 Task: 3 auc:0.998 0.719 0.811 sn:0.968 0.388 0.424 sp:0.993 0.956 0.924 acc:0.990 0.894 0.870 mcc:0.950 0.392 0.339
Epoch: 302 Task: 4 auc:1.000 0.914 0.834 sn:0.943 0.700 0.421 sp:1.000 0.994 0.980 acc:0.998 0.985 0.964 mcc:0.967 0.730 0.382
Epoch: 302 Task: 5 auc:1.000 0.828 0.834 sn:0.990 0.444 0.429 sp:0.998 0.985 0.968 acc:0.997 0.954 0.938 mcc:0.975 0.510 0.402
Epoch: 302 Task: 6 auc:0.997 0.759 0.766 sn:0.978 0.587 0.565 sp:0.995 0.935 0.887 acc:0.994 0.912 0.866 mcc:0.949 0.437 0.322
EarlyStopping counter: 160 out of 100
Epoch: 303 Task: 1 auc:1.000 0.843 0.884 sn:0.995 0.671 0.701 sp:0.995 0.938 0.907 acc:0.995 0.908 0.883 mcc:0.976 0.575 0.528
Epoch: 303 Task: 2 auc:1.000 0.920 0.855 sn:0.909 0.750 0.696 sp:1.000 0.

Epoch: 313 Task: 1 auc:0.998 0.844 0.870 sn:0.967 0.632 0.597 sp:0.995 0.961 0.933 acc:0.992 0.923 0.893 mcc:0.961 0.610 0.508
Epoch: 313 Task: 2 auc:0.985 0.906 0.848 sn:0.726 0.792 0.609 sp:1.000 0.982 0.977 acc:0.991 0.975 0.965 mcc:0.848 0.684 0.524
Epoch: 313 Task: 3 auc:0.985 0.727 0.762 sn:0.860 0.448 0.470 sp:0.995 0.951 0.927 acc:0.980 0.896 0.878 mcc:0.893 0.428 0.384
Epoch: 313 Task: 4 auc:0.997 0.913 0.769 sn:0.911 0.800 0.474 sp:0.999 0.983 0.970 acc:0.997 0.977 0.956 mcc:0.940 0.677 0.368
Epoch: 313 Task: 5 auc:0.992 0.828 0.788 sn:0.799 0.611 0.543 sp:0.997 0.934 0.921 acc:0.985 0.916 0.900 mcc:0.857 0.428 0.347
Epoch: 313 Task: 6 auc:0.995 0.784 0.770 sn:0.940 0.565 0.565 sp:0.988 0.950 0.906 acc:0.985 0.924 0.883 mcc:0.883 0.463 0.354
EarlyStopping counter: 171 out of 100
Epoch: 314 Task: 1 auc:1.000 0.835 0.872 sn:0.989 0.671 0.662 sp:0.999 0.945 0.900 acc:0.998 0.914 0.872 mcc:0.991 0.593 0.486
Epoch: 314 Task: 2 auc:1.000 0.929 0.826 sn:0.946 0.917 0.739 sp:1.000 0.

Epoch: 324 Task: 1 auc:0.999 0.839 0.846 sn:0.939 0.658 0.649 sp:0.998 0.947 0.909 acc:0.991 0.914 0.878 mcc:0.955 0.588 0.494
Epoch: 324 Task: 2 auc:1.000 0.919 0.800 sn:0.978 0.667 0.696 sp:0.995 0.989 0.979 acc:0.995 0.978 0.969 mcc:0.925 0.670 0.594
Epoch: 324 Task: 3 auc:0.998 0.694 0.804 sn:0.957 0.403 0.470 sp:0.993 0.947 0.942 acc:0.989 0.888 0.891 mcc:0.946 0.379 0.420
Epoch: 324 Task: 4 auc:0.999 0.932 0.818 sn:0.981 0.700 0.421 sp:0.998 0.980 0.970 acc:0.997 0.971 0.955 mcc:0.956 0.588 0.330
Epoch: 324 Task: 5 auc:1.000 0.819 0.828 sn:0.979 0.472 0.543 sp:0.998 0.963 0.970 acc:0.997 0.935 0.946 mcc:0.969 0.419 0.499
Epoch: 324 Task: 6 auc:0.999 0.843 0.809 sn:0.978 0.609 0.652 sp:0.997 0.943 0.899 acc:0.996 0.921 0.883 mcc:0.967 0.475 0.399
EarlyStopping counter: 182 out of 100
Epoch: 325 Task: 1 auc:1.000 0.842 0.870 sn:0.997 0.658 0.714 sp:0.993 0.942 0.886 acc:0.994 0.909 0.866 mcc:0.970 0.574 0.498
Epoch: 325 Task: 2 auc:1.000 0.916 0.841 sn:0.903 0.833 0.739 sp:1.000 0.

In [None]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_auc, test_sn, test_sp, test_acc, test_mcc, test_predict = eval(model, amodel, gmodel, test_df)

In [None]:
for i in range(task_num):
    print('Epoch:',epoch, 'Task:', i+1,
      'auc:%.3f'%train_auc[i],'%.3f'%val_auc[i],'%.3f'%test_auc[i], 
      'sn:%.3f'%train_sn[i],'%.3f'%val_sn[i],'%.3f'%test_sn[i], 
      'sp:%.3f'%train_sp[i], '%.3f'%val_sp[i], '%.3f'%test_sp[i], 
      'acc:%.3f'%train_acc[i], '%.3f'%val_acc[i], '%.3f'%test_acc[i], 
      'mcc:%.3f'%train_mcc[i],'%.3f'%val_mcc[i],'%.3f'%test_mcc[i])

In [None]:
# print('target_file:',raw_filename[0])
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)

In [None]:
# Task-specific AFSE
# Dynamic cth
# loss =  regression_loss + 0.08 * (vat_loss + test_vat_loss)
# r=3 t=2 200 100 100 1 all_lr=3e-4