In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import math
torch.manual_seed(8)
import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# from tensorboardX import SummaryWriter
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP.AttentiveLayers_Sim_copy import Fingerprint, GRN, AFSE
from AttentiveFP import Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set()
from IPython.display import SVG, display
import sascorer
from AttentiveFP.utils import EarlyStopping
from AttentiveFP.utils import Meter
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import AttentiveFP.Featurizer
import scipy

In [3]:
raw_filename = ["./data/ADMET/A/C/Pgp-inhibitor.csv", "./data/ADMET/A/C/Pgp-substrate.csv", "./data/ADMET/D/C/BBB_Penetration.csv",
               "./data/ADMET/M/C/CYP1A2_inhibitor.csv", "./data/ADMET/M/C/CYP1A2_substrate.csv", "./data/ADMET/M/C/CYP2C9_inhibitor.csv",
               "./data/ADMET/M/C/CYP2C9_substrate.csv", "./data/ADMET/M/C/CYP3A4_inhibitor.csv", "./data/ADMET/M/C/CYP3A4_substrate.csv",
               "./data/ADMET/T/C/Ames.csv",
               "./data/ADMET/T/C/Eye_Corrosion.csv", "./data/ADMET/T/C/FDAMDD.csv",
               "./data/ADMET/T/C/NR-AhR.csv", "./data/ADMET/T/C/NR-AR-LBD.csv", "./data/ADMET/T/C/NR-AR.csv",
               "./data/ADMET/T/C/NR-Aromatase.csv", "./data/ADMET/T/C/NR-ER-LBD.csv", "./data/ADMET/T/C/NR-ER.csv",
               "./data/ADMET/T/C/NR-PPAR-gamma.csv", "./data/ADMET/T/C/Skin_Sensitization.csv", "./data/ADMET/T/C/SR-ARE.csv",
               "./data/ADMET/T/C/SR-ATAD5.csv", "./data/ADMET/T/C/SR-HSE.csv", "./data/ADMET/T/C/SR-MMP.csv",
               "./data/ADMET/T/C/SR-p53.csv"]

task_id = [0,2,10,15,20,23] # 1200 < sample size < 6000
task_num = len(task_id)
raw_filename = [raw_filename[i] for i in task_id]
random_seed = 68
file_name = f'Multi_Tasks_Medium'
# for i in range(task_num):
#     file_list = raw_filename[i].split('/')
#     file = '_'+file_list[-3]+'_'+file_list[-1]
#     file_name += file[:-4]
    
number = 'run_0'
model_file = "model_file/3C_GAFSE_"+file_name+'_'+number
log_dir = f'log/{"3C_GAFSE_"+file_name}_'+number
result_dir = './result/3C_GAFSE_'+file_name+'_'+number
print(raw_filename)
print(file_name)
print(model_file)

['./data/ADMET/A/C/Pgp-inhibitor.csv', './data/ADMET/D/C/BBB_Penetration.csv', './data/ADMET/T/C/Eye_Corrosion.csv', './data/ADMET/T/C/NR-Aromatase.csv', './data/ADMET/T/C/SR-ARE.csv', './data/ADMET/T/C/SR-MMP.csv']
Multi_Tasks_Medium
model_file/3C_GAFSE_Multi_Tasks_Medium_run_0


In [4]:
tasks = ['value']
total_df = pd.DataFrame([])

for i in range(task_num):
    task_df = pd.read_csv(raw_filename[i], header=0, names = ["smiles", "dataset", "value"],usecols=[0,1,2])
    task_df["task_id"] = i
    total_df = pd.concat([total_df,task_df])
    
print(total_df[:3],total_df[-3:])

def add_canonical_smiles(total_df):
    smilesList = total_df.smiles.values
    print("number of all smiles: ",len(smilesList))
    atom_num_dist = []
    remained_smiles = []
    canonical_smiles_list = []
    for smiles in smilesList:
        try:        
            mol = Chem.MolFromSmiles(smiles)
            atom_num_dist.append(len(mol.GetAtoms()))
            remained_smiles.append(smiles)
            canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
        except:
            print(smiles)
            pass
    print("number of successfully processed smiles: ", len(remained_smiles))
    total_df = total_df[total_df["smiles"].isin(remained_smiles)]
    total_df['cano_smiles'] =canonical_smiles_list
    return total_df

total_df = add_canonical_smiles(total_df)
print(total_df.head())

                     smiles   dataset  value  task_id
0  O(C)c1c(OC)cc2c(c1)CNCC2  training    0.0        0
1          N1Cc2c(cccc2)CC1  training    0.0        0
2      O(C)c1cc(CN(C)C)ccc1  training    0.0        0                       smiles dataset  value  task_id
5910  O=C1N(C)C(=O)c2c1cccc2     val    0.0        5
5911         O=C(NO)c1ccccc1     val    0.0        5
5912    S(P(SCCC)(=O)OCC)CCC     val    0.0        5
number of all smiles:  24824
number of successfully processed smiles:  24824
                                 smiles   dataset  value  task_id  \
0              O(C)c1c(OC)cc2c(c1)CNCC2  training    0.0        0   
1                      N1Cc2c(cccc2)CC1  training    0.0        0   
2                  O(C)c1cc(CN(C)C)ccc1  training    0.0        0   
3  Clc1ccc([C@@H](N2CCNCC2)c2ccccc2)cc1  training    0.0        0   
4                   C(N1CC=CC1)c1ccccc1  training    0.0        0   

                           cano_smiles  
0                 COc1cc2c(cc1OC)CNCC2 

In [5]:
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

p_dropout= 0.03
fingerprint_dim = 100

weight_decay = 4.3 # also known as l2_regularization_lambda
learning_rate = 4
radius = 3 # default: 2
T = 2
per_task_output_units_num = 1 # for regression model
output_units_num = task_num * per_task_output_units_num

In [6]:
total_smilesList = total_df['smiles'].values
print(len(total_smilesList))
feature_filename = './features/'+model_file.split('/')[-1][:-1]+'0.pickle'
filename = './features/'+model_file.split('/')[-1]
print(feature_filename)
if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
    print('Loading features successfully.')
else:
    feature_dicts = save_smiles_dicts(total_smilesList,filename)

24824
./features/3C_GAFSE_Multi_Tasks_Medium_run_0.pickle
Loading features successfully.


In [7]:
test_df = total_df[total_df.dataset.values == "test"]
test_df = test_df[test_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
test_df = test_df.reset_index(drop=True)

val_df = total_df[total_df.dataset.values == "val"]
val_df = val_df[val_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
val_df = val_df.reset_index(drop=True)

train_df = total_df[total_df.dataset.values == "training"]
train_df = train_df[train_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
train_df = train_df.reset_index(drop=True)

print(total_df.shape, len(train_df)+len(val_df)+len(test_df), train_df.shape,val_df.shape,test_df.shape)

(24824, 5) 24821 (19860, 5) (2479, 5) (2482, 5)


In [8]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([total_df["cano_smiles"].values[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)
amodel = AFSE(fingerprint_dim, output_units_num, p_dropout)
gmodel = GRN(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, p_dropout)
model.cuda()
amodel.cuda()
gmodel.cuda()

# optimizer = optim.Adam([
# {'params': model.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# {'params': gmodel.parameters(), 'lr': 10**(-learning_rate), 'weight_decay ': 10**-weight_decay}, 
# ])

optimizer = optim.Adam(params=model.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

optimizer_AFSE = optim.Adam(params=amodel.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

# optimizer_AFSE = optim.SGD(params=amodel.parameters(), lr = 0.01, momentum=0.9)

optimizer_GRN = optim.Adam(params=gmodel.parameters(), lr=3*10**(-learning_rate), weight_decay=10**-weight_decay)#, capturable=True

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data.shape)
        

In [9]:
import numpy as np
from matplotlib import pyplot as plt

def sorted_show_pik(dataset, p, k, k_predict, i, acc):
    p_value = dataset[tasks[0]].astype(float).tolist()
    x = np.arange(0,len(dataset),1)
#     print('plt',dataset.head(),p[:10],k_predict,k)
#     plt.figure()
#     fig, ax1 = plt.subplots()
#     ax1.grid(False)
#     ax2 = ax1.twinx()
#     plt.grid(False)
    plt.scatter(x,p,marker='.',s=6,color='r',label='predict')
#     plt.ylabel('predict')
    plt.scatter(x,p_value,s=6,marker=',',color='blue',label='p_value')
    plt.axvline(x=k-1,ls="-",c="black")#添加垂直直线
    k_value = np.ones(len(dataset))
# #     print(EC50[k-1])
    k_value = k_value*k_predict
    plt.plot(x,k_value,'-',color='black')
    plt.ylabel('p_value')
    plt.title("epoch: {},  top-k recall: {}".format(i,acc))
    plt.legend(loc=3,fontsize=5)
    plt.show()
    

def topk_acc2(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
    
    df3 = df2[:k]  #取按预测值排完序后的前k个
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    k_true = true_sort[tasks[0]].values[k-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('k_true: ',type(k_true),k_true)
#     print('k_true: ',k_true,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=k_true,'acc_num: ',len(df3[df3['predict'].values>=k_true]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k)
    acc = len(df3[df3[tasks[0]].values>=k_true])/k #预测值前k个中真实排在前k个的个数/k
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/k
    if k>active_num:
        min_active = true_sort[tasks[0]].values[active_num-1]
        acc = len(df3[df3[tasks[0]].values>=min_active])/k
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

def topk_recall(df, predict, k, active_num, show_flag=False, i=0):
    df['predict'] = predict
    df2 = df.sort_values(by='predict',ascending=False) # 拼接预测值后对预测值进行排序
#     print('df2:\n',df2)
        
    df3 = df2[:k]  #取按预测值排完序后的前k个，因为后面的全是-4.1
    
    true_sort = df.sort_values(by=tasks[0],ascending=False) #返回一个新的按真实值排序列表
    min_active = true_sort[tasks[0]].values[active_num-1]  # 真实排第k个的活性值
#     print('df3:\n',df3['predict'])
#     print('min_active: ',type(min_active),min_active)
#     print('min_active: ',min_active,'min_predict: ',df3['predict'].values[-1],'index: ',df3['predict'].values>=min_active,'acc_num: ',len(df3[df3['predict'].values>=min_active]),
#           'fp_num: ',len(df3[df3['predict'].values>=-4.1]),'k: ',k,'active_num: ',active_num)
    acc = len(df3[df3[tasks[0]].values>-4.1])/active_num #预测值前k个中真实排在前active_num个的个数/active_num
    fp = len(df3[df3[tasks[0]].values==-4.1])/k  #预测值前k个中为-4.1的个数/active_num
    
    if(show_flag):
        #进来的是按实际活性值排好序的
        sorted_show_pik(true_sort,true_sort['predict'],k,k_predict,i,acc)
    return acc,fp

    
def topk_acc_recall(df, predict, k, active_num, show_flag=False, i=0):
    if k>active_num:
        return topk_recall(df, predict, k, active_num, show_flag, i)
    return topk_acc2(df,predict,k, active_num,show_flag,i)

def weighted_top_index(df, predict, active_num):
    weighted_acc_list=[]
    for k in np.arange(1,len(df)+1,1):
        acc, fp = topk_acc_recall(df, predict, k, active_num)
        weight = (len(df)-k)/len(df)
#         print('weight=',weight,'acc=',acc)
        weighted_acc_list.append(acc*weight)#
    weighted_acc_list = np.array(weighted_acc_list)
#     print('weighted_acc_list=',weighted_acc_list)
    return np.sum(weighted_acc_list)/weighted_acc_list.shape[0]

def AP(df, predict, active_num):
    prec = []
    rec = []
    for k in np.arange(1,len(df)+1,1):
        prec_k, fp1 = topk_acc2(df,predict,k, active_num)
        rec_k, fp2 = topk_recall(df, predict, k, active_num)
        prec.append(prec_k)
        rec.append(rec_k)
    # 取所有不同的recall对应的点处的精度值做平均
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # 计算包络线，从后往前取最大保证precise非减
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # 找出所有检测结果中recall不同的点
    i = np.where(mrec[1:] != mrec[:-1])[0]
#     print(prec)
#     print('prec='+str(prec)+'\n\n'+'rec='+str(rec))

    # and sum (\Delta recall) * prec
    # 用recall的间隔对精度作加权平均
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

In [10]:
def caculate_r2(predict,y):
#     print(y)
#     print(predict)
    y = torch.FloatTensor(y).reshape(-1,1)
    predict = torch.FloatTensor(predict).reshape(-1,1)
    y_mean = torch.mean(y)
    predict_mean = torch.mean(predict)
    
#     print(len(y), y_mean, len(predict), predict_mean)
    y1 = torch.pow(torch.mm((y-y_mean).t(),(predict-predict_mean)),2)
    y2 = torch.mm((y-y_mean).t(),(y-y_mean))*torch.mm((predict-predict_mean).t(),(predict-predict_mean))
#     print(y1,y2)
    return y1/(y2+1e-9)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score, roc_auc_score
def calc(TN, FP, FN, TP):
    SN = TP / (TP + FN)  # recall
    SP = TN / (TN + FP)
    # Precision = TP / (TP + FP)
    ACC = (TP + TN) / (TP + TN + FN + FP)
    # F1 = (2 * TP) / (2 * TP + FP + FN)
    fz = TP * TN - FP * FN
    fm = (TP + FN) * (TP + FP) * (TN + FP) * (TN + FN)
    MCC = fz / (pow(fm, 0.5)+1e-9)
    return SN, SP, ACC, MCC

In [11]:
from torch.autograd import Variable
def l2_norm(input, dim):
    norm = torch.norm(input, dim=dim, keepdim=True)
    output = torch.div(input, norm+1e-6)
    return output

def normalize_perturbation(d,dim=-1):
    output = l2_norm(d, dim)
    return output

def tanh(x):
    return (torch.exp(x)-torch.exp(-x))/(torch.exp(x)+torch.exp(-x))

def sigmoid(x):
    return 1/(1+torch.exp(-x))

def perturb_feature(f, model, alpha=1, lamda=10**-learning_rate, output_lr=False, output_plr=False, y=None, task_id=None, sigmoid=False):
    mol_prediction = model(feature=f, d=0, sigmoid=sigmoid)
    pred = mol_prediction.detach()
    task_counter = 0
    vat_loss = 0
    for i in range(task_num):
        batch_task_sample = len(task_id[task_id==i])
        if batch_task_sample > 0:
            task_counter += 1
            y_mask = np.where(task_id==i, 1, 0)
            y_mask = torch.Tensor(y_mask)
        #     f = torch.div(f, torch.norm(f, dim=-1, keepdim=True)+1e-9)
            eps = 1e-6 * normalize_perturbation(torch.randn(f.shape))
            eps = Variable(eps, requires_grad=True)
            # Predict on randomly perturbed image
            eps_p = model(feature=f, d=eps.cuda(), sigmoid=sigmoid)
            eps_p_ = model(feature=f, d=-eps.cuda(), sigmoid=sigmoid)
            p_aux = nn.Sigmoid()(eps_p[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            p_aux_ = nn.Sigmoid()(eps_p_[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
        #     loss = nn.BCELoss()(abs(p_aux),torch.ones_like(p_aux))+nn.BCELoss()(abs(p_aux_),torch.ones_like(p_aux_))
            loss = loss_function(p_aux,torch.ones_like(p_aux))+loss_function(p_aux_,torch.ones_like(p_aux_))
            loss.backward(retain_graph=True)

            # Based on perturbed image, get direction of greatest error
            eps_adv = eps.grad#/10**-learning_rate
            optimizer_AFSE.zero_grad()
            # Use that direction as adversarial perturbation
            eps_adv_normed = normalize_perturbation(eps_adv)
            d_adv = lamda * eps_adv_normed.cuda()
            f_p, max_lr = model(feature=f, d=d_adv, output_lr=True, sigmoid=sigmoid)
            f_p = model(feature=f, d=d_adv, sigmoid=sigmoid)
            f_p_ = model(feature=f, d=-d_adv, sigmoid=sigmoid)
            p = nn.Sigmoid()(f_p[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            p_ = nn.Sigmoid()(f_p_[:,i]*y_mask/(pred[:,i]*y_mask+1e-6))
            vat_loss += loss_function(p,torch.ones_like(p))+loss_function(p_,torch.ones_like(p_))
    vat_loss /= task_counter
    if output_lr:
        if output_plr:
            loss = 0
            task_counter = 0
            eps_ = 1e-6 * normalize_perturbation(torch.randn(f.shape))
            eps_ = Variable(eps_, requires_grad=True)
            eps_p__ = model(feature=f+eps_.cuda(), d=0)
            for i in range(task_num):
                batch_task_sample = len(task_id[task_id==i])
                if batch_task_sample > 0:
                    task_counter += 1
                    y_mask = np.where(task_id==i, 1, 0)
                    y_mask = torch.Tensor(y_mask)
                    loss += loss_function(eps_p__[:,i]*y_mask,y.view(-1)*y_mask)
            loss /= task_counter
            loss.backward(retain_graph=True)
            punish_lr = torch.norm(torch.mean(eps_.grad,0))
            optimizer_AFSE.zero_grad()
            return eps_adv, d_adv, vat_loss, mol_prediction, max_lr, punish_lr
        return eps_adv, d_adv, vat_loss, mol_prediction, max_lr
    return eps_adv, d_adv, vat_loss, mol_prediction

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

def d_loss(f, pred, model, y_val):
    diff_loss = 0
    length = len(pred)
    for i in range(length):
        for j in range(length):
            if j == i:
                continue
            pred_diff = model(feature_only=True, feature1=f[i], feature2=f[j])
            true_diff = y_val[i] - y_val[j]
            diff_loss += loss_function(pred_diff, torch.Tensor([true_diff]).view(-1,1))
    diff_loss = diff_loss/(length*(length-1))
    return diff_loss

In [12]:
def CE(x,y):
    c = 0
    l = len(y)
    for i in range(l):
        if y[i]==1:
            c += 1
    w1 = (l-c)/l
    w0 = c/l
    loss = -w1*y*torch.log(x+1e-6)-w0*(1-y)*torch.log(1-x+1e-6)
    loss = loss.mean(-1)
    return loss

def weighted_CE_loss(x,y):
    weight = 1/(y.detach().float().mean(0)+1e-9)
    weighted_CE = nn.CrossEntropyLoss(weight=weight)
#     atom_weights = (atom_weights-min(atom_weights))/(max(atom_weights)-min(atom_weights))
    return weighted_CE(x, torch.argmax(y,-1))

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25):
    weighted_CE = nn.CrossEntropyLoss(weight=alpha, reduction='none')
    focal_weight = (1-torch.max(pred * target, -1)[0])**gamma
    loss = focal_weight * weighted_CE(pred, torch.argmax(target,-1))
    loss = torch.mean(loss)
    return loss

def generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list):
    [a,b,c] = x_atom.shape
    [d,e,f,g] = bond_neighbor.shape
    ce_loss = nn.CrossEntropyLoss()
    one_hot_loss = 0
    run_times = 0
    validity_mask = torch.from_numpy(validity_mask).cuda()
    for i in range(a):
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        atom_weights = 1-x_atom[i,:l,:16].sum(-2)/l
        ce_atom_loss = nn.CrossEntropyLoss(weight=atom_weights)
        # print(atom_weights[1], refer_atom_list[i,0,torch.argmax(x_atom[i,0,:16],-1)], torch.argmax(x_atom[i,0,:16],-1))
#         one_hot_loss += ce_atom_loss(refer_atom_list[i,:l,:16], torch.argmax(x_atom[i,:l,:16],-1))- \
#                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        one_hot_loss += py_sigmoid_focal_loss(refer_atom_list[i,:l,:16], x_atom[i,:l,:16], alpha=atom_weights)- \
                          (((validity_mask[i,:l]*torch.log(1-atom_list[i,:l,:16]+1e-9)).sum(-1)/(validity_mask[i,:l].sum(-1)+1e-9))).mean(-1)
        run_times += 2
    total_loss = one_hot_loss/run_times
    return total_loss, 0, 0, 0


def train(model, amodel, gmodel, dataset, test_df, optimizer_list, loss_function, epoch):
    model.train()
    amodel.train()
    gmodel.train()
    optimizer, optimizer_AFSE, optimizer_GRN = optimizer_list
    np.random.seed(epoch)
    max_len = np.max([len(dataset),len(test_df)])
    valList = np.arange(0,max_len)
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, max_len, batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)
    for counter, batch in enumerate(batch_list):
        batch_df = dataset.loc[batch%len(dataset),:]
        batch_test = test_df.loc[batch%len(test_df),:]
        global_step = epoch * len(batch_list) + counter
        smiles_list = batch_df.cano_smiles.values
        smiles_list_test = batch_test.cano_smiles.values
        y_val = batch_df[tasks[0]].values.astype(int)
        task_id = batch_df['task_id'].values
        task_id_test = batch_test['task_id'].values
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(smiles_list_test,feature_dicts)
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
                                                torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
#         refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),
#                                                   torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),
#                                                   mol_feature=mol_feature,activated_features=activated_features.detach())
        
#         x_atom = torch.Tensor(x_atom)
#         x_bonds = torch.Tensor(x_bonds)
#         x_bond_index = torch.cuda.LongTensor(x_bond_index)
        
#         bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
#         bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        eps_adv, d_adv, vat_loss, mol_prediction, conv_lr, punish_lr = perturb_feature(mol_feature, amodel, alpha=1,                                                                                lamda=10**-learning_rate, output_lr=True, 
                                                                                       output_plr=True, y=torch.Tensor(y_val).view(-1,1),
                                                                                       task_id=task_id, sigmoid=True) # 10**-learning_rate     
        classification_loss = 0
        task_counter = 0
        for i in range(task_num):
            batch_task_sample = len(batch_df[batch_df["task_id"].values==i])
            if batch_task_sample > 0:
                task_counter += 1
                y_mask = np.where(batch_df["task_id"].values==i, 1, 0)
                y_mask = torch.Tensor(y_mask)
                c_loss = - torch.Tensor(y_val) * torch.log(mol_prediction[:,i]+1e-9) - \
                                (1-torch.Tensor(y_val)) * torch.log((1-mol_prediction[:,i])+1e-9)
                classification_loss += torch.sum(c_loss * y_mask)/batch_task_sample
        classification_loss /= task_counter
#         atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
#                                       torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/1e-6,activated_features=activated_features.detach())
#         success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(smiles_list, x_atom, 
#                             bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
#                                                      refer_atom_list, refer_bond_list,topn=1)
#         reconstruction_loss, one_hot_loss, interger_loss,binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, 
#                                                                                               bond_neighbor, validity_mask, atom_list, 
#                                                                                               bond_list)
#         x_atom_test = torch.Tensor(x_atom_test)
#         x_bonds_test = torch.Tensor(x_bonds_test)
#         x_bond_index_test = torch.cuda.LongTensor(x_bond_index_test)
        
#         bond_neighbor_test = [x_bonds_test[i][x_bond_index_test[i]] for i in range(len(batch_test))]
#         bond_neighbor_test = torch.stack(bond_neighbor_test, dim=0)
        activated_features_test, mol_feature_test = model(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
                                                          torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),
                                                          torch.Tensor(x_mask_test),output_activated_features=True)
#         mol_feature_test = torch.div(mol_feature_test, torch.norm(mol_feature_test, dim=-1, keepdim=True)+1e-9)
#         activated_features_test = torch.div(activated_features_test, torch.norm(activated_features_test, dim=-1, keepdim=True)+1e-9)
        eps_test, d_test, test_vat_loss, mol_prediction_test = perturb_feature(mol_feature_test, amodel, 
                                                                                    alpha=1, lamda=10**-learning_rate, task_id=task_id_test, sigmoid=True)
#         atom_list_test, bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),torch.cuda.LongTensor(x_atom_index_test),
#                                                 torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
#                                                 mol_feature=mol_feature_test+d_test/1e-6,activated_features=activated_features_test.detach())
#         refer_atom_list_test, refer_bond_list_test = gmodel(torch.Tensor(x_atom_test),torch.Tensor(x_bonds_test),
#                                                             torch.cuda.LongTensor(x_atom_index_test),torch.cuda.LongTensor(x_bond_index_test),torch.Tensor(x_mask_test),
#                                                             mol_feature=mol_feature_test,activated_features=activated_features_test.detach())
#         success_smiles_batch_test, modified_smiles_test, success_batch_test, total_batch_test, reconstruction_test, validity_test, validity_mask_test = modify_atoms(smiles_list_test, x_atom_test, 
#                             bond_neighbor_test, atom_list_test, bond_list_test,smiles_list_test,smiles_to_rdkit_list_test,
#                                                      refer_atom_list_test, refer_bond_list_test,topn=1)
#         test_reconstruction_loss, test_one_hot_loss, test_interger_loss,test_binary_loss = generate_loss_function(atom_list_test, x_atom_test, bond_list_test, bond_neighbor_test, validity_mask_test, atom_list_test, bond_list_test)
            
        if vat_loss>1 or test_vat_loss>1:
            vat_loss = 1*(vat_loss/(vat_loss+1e-6).item())
            test_vat_loss = 1*(test_vat_loss/(test_vat_loss+1e-6).item())
        
        max_lr = 1e-3
        adapt_lr = conv_lr - conv_lr**2 + 0.06 * punish_lr # 0.06
        if adapt_lr < max_lr and adapt_lr >= 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = adapt_lr.detach()
                AFSE_lr = adapt_lr    
        elif adapt_lr < 0:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = 0
                AFSE_lr = 0
        elif adapt_lr >= max_lr:
            for param_group in optimizer_AFSE.param_groups:
                param_group["lr"] = max_lr
                AFSE_lr = max_lr
#         AFSE_lr = 1e-4

        logger.add_scalar('loss/classification', classification_loss, global_step)
        logger.add_scalar('loss/AFSE', vat_loss, global_step)
        logger.add_scalar('loss/AFSE_test', test_vat_loss, global_step)
#         logger.add_scalar('loss/GRN', reconstruction_loss, global_step)
#         logger.add_scalar('loss/GRN_test', test_reconstruction_loss, global_step)
#         logger.add_scalar('loss/GRN_one_hot', one_hot_loss, global_step)
#         logger.add_scalar('loss/GRN_interger', interger_loss, global_step)
#         logger.add_scalar('loss/GRN_binary', binary_loss, global_step)
        logger.add_scalar('lr/conv_lr', conv_lr, global_step)
        logger.add_scalar('lr/punish_lr', punish_lr, global_step)
        logger.add_scalar('lr/AFSE_lr', AFSE_lr, global_step)
        
        optimizer.zero_grad()
        optimizer_AFSE.zero_grad()
        optimizer_GRN.zero_grad()
        loss =  classification_loss + 0.08 * (vat_loss + test_vat_loss) # + 0.3 * (reconstruction_loss + test_reconstruction_loss)
        loss.backward()
        optimizer.step()
        optimizer_AFSE.step()
        optimizer_GRN.step()

        
def clear_atom_map(mol):
    [a.ClearProp('molAtomMapNumber') for a  in mol.GetAtoms()]
    return mol

def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol
        
def modify_atoms(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list,refer_atom_list, refer_bond_list,topn=1,viz=False):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    refer_atom_list = refer_atom_list.cpu().detach().numpy()
    refer_bond_list = refer_bond_list.cpu().detach().numpy()
    atom_symbol_sorted = np.argsort(x_atom[:,:,:16], axis=-1)
    atom_symbol_generated_sorted = np.argsort(atom_list[:,:,:16], axis=-1)
    generate_confidence_sorted = np.sort(atom_list[:,:,:16], axis=-1)
    modified_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    confidence_threshold = 0.001
    validity_mask = np.zeros_like(atom_list[:,:,:16])
    symbol_list = ['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']
    symbol_to_rdkit = [4,6,7,8,9,14,15,16,17,33,34,35,52,53,85,0]
    for i in range(len(atom_list)):
        rank = 0
        top_idx = 0
        flag = 0
        first_run_flag = True
        l = (x_atom[i].sum(-1)!=0).sum(-1)
        cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
        mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
        counter = 0
        for j in range(l): 
            if mol.GetAtomWithIdx(int(smiles_to_rdkit_list[cano_smiles][j])).GetAtomicNum() == \
                symbol_to_rdkit[refer_atom_list[i,j,:16].argmax(-1)]:
                counter += 1
#             print(f'atom#{smiles_to_rdkit_list[cano_smiles][j]}(f):',{symbol_list[k]: np.around(refer_atom_list[i,j,k],3) for k in range(16)},
#                   f'\natom#{smiles_to_rdkit_list[cano_smiles][j]}(f+d):',{symbol_list[k]: np.around(atom_list[i,j,k],3) for k in range(16)},
#                  '\n------------------------------------------------------------------------------------------------------------')
#         print('预测为每个原子的平均概率：\n',np.around(atom_list[i,:l,:16].mean(1),2))
#         print('预测为每个原子的最大概率：\n',np.around(atom_list[i,:l,:16].max(1),2))
        if counter == l:
            success_reconstruction += 1
        while not flag==topn:
            if rank == 16:
                rank = 0
                top_idx += 1
            if top_idx == l:
#                 print('没有满足条件的分子生成。')
                flag += 1
                continue
#             if np.sum((atom_symbol_sorted[i,:l,-1]!=atom_symbol_generated_sorted[i,:l,-1-rank]).astype(int))==0:
#                 print(f'根据预测的第{rank}大概率的原子构成的分子与原分子一致，原子位重置为0，生成下一个元素……')
#                 rank += 1
#                 top_idx = 0
#                 generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
#                                              x_atom[i,:l,:16]).max(-1))[-1-top_idx]
#             print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            if rank == 0:
                generate_index = np.argsort((atom_list[i,:l,:16]-refer_atom_list[i,:l,:16] -\
                                             x_atom[i,:l,:16]).max(-1))[-1-top_idx]
            atom_symbol_generated = np.argsort(atom_list[i,generate_index,:16]-\
                                                    refer_atom_list[i,generate_index,:16] -\
                                                    x_atom[i,generate_index,:16])[-1-rank]
            if atom_symbol_generated==x_atom[i,generate_index,:16].argmax(-1):
#                 print('生成了相同元素，生成下一个元素……')
                rank += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            if np.sort(atom_list[i,generate_index,:16]-\
                refer_atom_list[i,generate_index,:16] -\
                x_atom[i,generate_index,:16])[-1-rank]<confidence_threshold:
#                 print(f'原子位{generate_rdkit_index}生成{symbol_list[atom_symbol_generated]}元素的置信度小于{confidence_threshold}，寻找下一个原子位……')
                top_idx += 1
                rank = 0
                continue
#             if symbol_to_rdkit[atom_symbol_generated]==6:
#                 print('生成了不推荐的C元素')
#                 rank += 1
#                 continue
            mol.GetAtomWithIdx(int(generate_rdkit_index)).SetAtomicNum(symbol_to_rdkit[atom_symbol_generated])
            print_mol = mol
            try:
                Chem.SanitizeMol(mol)
                if first_run_flag == True:
                    success_validity += 1
                total[flag] += 1
                if Chem.MolToSmiles(clear_atom_map(print_mol))==y_smiles[i]:
                    success[flag] +=1
#                     print('Congratulations!', success, total)
                    success_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
#                 print("修改前的分子：", smiles[i])
#                 display(mol_init)
                modified_smiles.append(Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 print(f"将第{generate_rdkit_index}个原子修改为{symbol_list[atom_symbol_generated]}的分子：", Chem.MolToSmiles(clear_atom_map(print_mol)))
#                 display(mol_with_atom_index(mol))
                mol_y = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
#                 print("高活性分子：", y_smiles[i])
#                 display(mol_y)
                rank += 1
                flag += 1
            except:
#                 print(f"第{generate_rdkit_index}个原子符号修改为{symbol_list[atom_symbol_generated]}不符合规范，生成下一个元素……")
                validity_mask[i,generate_index,atom_symbol_generated] = 1
                rank += 1
                first_run_flag = False
    return success_smiles, modified_smiles, success, total, success_reconstruction, success_validity, validity_mask

def modify_bonds(smiles, x_atom, bond_neighbor, atom_list, bond_list, y_smiles, smiles_to_rdkit_list):
    x_atom = x_atom.cpu().detach().numpy()
    bond_neighbor = bond_neighbor.cpu().detach().numpy()
    atom_list = atom_list.cpu().detach().numpy()
    bond_list = bond_list.cpu().detach().numpy()
    modified_smiles = []
    for i in range(len(bond_neighbor)):
        l = (bond_neighbor[i].sum(-1).sum(-1)!=0).sum(-1)
        bond_type_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        bond_type_generated_sorted = np.argsort(bond_list[i,:l,:,:4], axis=-1)
        generate_confidence_sorted = np.sort(bond_list[i,:l,:,:4], axis=-1)
        rank = 0
        top_idx = 0
        flag = 0
        while not flag==3:
            cano_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles[i]))
            if np.sum((bond_type_sorted[i,:,-1]!=bond_type_generated_sorted[:,:,-1-rank]).astype(int))==0:
                rank += 1
                top_idx = 0
            print('i:',i,'top_idx:', top_idx, 'rank:',rank)
            bond_type = bond_type_sorted[i,:,-1]
            bond_type_generated = bond_type_generated_sorted[:,:,-1-rank]
            generate_confidence = generate_confidence_sorted[:,:,-1-rank]
#             print(np.sort(generate_confidence + \
#                                     (atom_symbol!=atom_symbol_generated).astype(int), axis=-1))
            generate_index = np.argsort(generate_confidence + 
                                (bond_type!=bond_type_generated).astype(int), axis=-1)[-1-top_idx]
            bond_type_generated_one = bond_type_generated[generate_index]
            mol = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
            if generate_index >= len(smiles_to_rdkit_list[cano_smiles]):
                top_idx += 1
                continue
            generate_rdkit_index = smiles_to_rdkit_list[cano_smiles][generate_index]
            mol.GetBondWithIdx(int(generate_rdkit_index)).SetBondType(bond_type_generated_one)
            try:
                Chem.SanitizeMol(mol)
                mol_init = mol_with_atom_index(Chem.MolFromSmiles(smiles[i]))
                print("修改前的分子：")
                display(mol_init)
                modified_smiles.append(mol)
                print(f"将第{generate_rdkit_index}个键修改为{atom_symbol_generated}的分子：")
                display(mol)
                mol = mol_with_atom_index(Chem.MolFromSmiles(y_smiles[i]))
                print("高活性分子：")
                display(mol)
                rank += 1
                flag += 1
            except:
                print(f"第{generate_rdkit_index}个原子符号修改为{atom_symbol_generated}不符合规范")
                top_idx += 1
    return modified_smiles
        
def eval(model, amodel, gmodel, dataset, cth_list=[0.5 for i in range(task_num)], topn=1, output_feature=False, 
         generate=False, modify_atom=True,return_GRN_loss=False, viz=False, output_cth=False):
    model.eval()
    amodel.eval()
    gmodel.eval()
    predict_list = []
    test_MSE_list = []
    test_MAE_list = []
    r2_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    feature_list = []
    d_list = []
    success = [0 for i in range(topn)]
    total = [0 for i in range(topn)]
    generated_smiles = []
    success_smiles = []
    success_reconstruction = 0
    success_validity = 0
    reconstruction_loss, one_hot_loss, interger_loss, binary_loss = [0,0,0,0]
    
# #     取dataset中排序后的第k个
#     sorted_dataset = dataset.sort_values(by=tasks[0],ascending=False)
#     k_df = sorted_dataset.iloc[[k-1]]
#     k_smiles = k_df['cano_smiles'].values
#     k_value = k_df[tasks[0]].values.astype(float)    
    
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
#     print(batch_list)
    for counter, batch in enumerate(batch_list):
#         print(type(batch))
        batch_df = dataset.loc[batch,:]
        smiles_list = batch_df.cano_smiles.values
        matched_smiles_list = smiles_list
#         print(batch_df)
        y_val = batch_df[tasks[0]].values.astype(int)
        task_id = batch_df['task_id'].values
#         print(type(y_val))
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(matched_smiles_list,feature_dicts)
#         x_atom = torch.Tensor(x_atom)
#         x_bonds = torch.Tensor(x_bonds)
#         x_bond_index = torch.cuda.LongTensor(x_bond_index)
#         bond_neighbor = [x_bonds[i][x_bond_index[i]] for i in range(len(batch_df))]
#         bond_neighbor = torch.stack(bond_neighbor, dim=0)
        
        lamda=10**-learning_rate
        activated_features, mol_feature = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),output_activated_features=True)
#         mol_feature = torch.div(mol_feature, torch.norm(mol_feature, dim=-1, keepdim=True)+1e-9)
#         activated_features = torch.div(activated_features, torch.norm(activated_features, dim=-1, keepdim=True)+1e-9)
        eps_adv, d_adv, vat_loss, mol_prediction = perturb_feature(mol_feature, amodel, alpha=1, lamda=lamda, task_id=task_id, sigmoid=True)
        mol_prediction = mol_prediction.cpu().detach().numpy()
        mol_prediction_readout = mol_prediction[:,0]
        for i in range(task_num):
            mol_prediction_readout[batch_df['task_id'].values==i] = mol_prediction[batch_df['task_id'].values==i, i]
#         print(mol_feature,d_adv)
#         atom_list, bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),
#                                       torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),
#                                       torch.Tensor(x_mask),mol_feature=mol_feature+d_adv/(1e-6),activated_features=activated_features)
#         refer_atom_list, refer_bond_list = gmodel(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask),mol_feature=mol_feature,activated_features=activated_features)
#         if generate:
#             if modify_atom:
#                 success_smiles_batch, modified_smiles, success_batch, total_batch, reconstruction, validity, validity_mask = modify_atoms(matched_smiles_list, x_atom, 
#                             bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list,
#                                                      refer_atom_list, refer_bond_list,topn=topn,viz=viz)
#             else:
#                 modified_smiles = modify_bonds(matched_smiles_list, x_atom, bond_neighbor, atom_list, bond_list,smiles_list,smiles_to_rdkit_list)
#             generated_smiles.extend(modified_smiles)
#             success_smiles.extend(success_smiles_batch)
#             for n in range(topn):
#                 success[n] += success_batch[n]
#                 total[n] += total_batch[n]
#                 print('congratulations:',success,total)
#             success_reconstruction += reconstruction
#             success_validity += validity
#             reconstruction_loss, one_hot_loss, interger_loss, binary_loss = generate_loss_function(refer_atom_list, x_atom, refer_bond_list, bond_neighbor, validity_mask, atom_list, bond_list)
        d = d_adv.cpu().detach().numpy().tolist()
        d_list.extend(d)
        mol_feature_output = mol_feature.cpu().detach().numpy().tolist()
        feature_list.extend(mol_feature_output)
#         MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')   
#         print(type(mol_prediction))
#         r2 = caculate_r2(mol_prediction, torch.Tensor(y_val).view(-1,1))
# #         r2_list.extend(r2.cpu().detach().numpy())
#         if r2!=r2:
#             r2 = torch.tensor(0)
#         r2_list.append(r2.item())
#         predict_list.extend(mol_prediction.cpu().detach().numpy())
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        predict_list.extend(mol_prediction_readout)
#         test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
    
    predict_list = np.array(predict_list)
    task_lens_list = []
    task_accumulated = 0
    for i in range(task_num):
        task_accumulated += len(dataset[dataset["task_id"]==i])
        task_lens_list.append(task_accumulated)
    task_predict_list = np.split(predict_list, task_lens_list)
    task_y_val_list = np.split(dataset[tasks[0]].values.astype(float), task_lens_list)
    
    auc_list = []
    if output_cth:
        best_cth_list = [0.5 for i in range(task_num)]
        best_sn_list = [0 for i in range(task_num)]
        best_sp_list = [0 for i in range(task_num)]
        best_acc_list = [0 for i in range(task_num)]
        best_mcc_list = [0 for i in range(task_num)]
        for cth in np.linspace(0,1,21):
            for i in range(task_num):
                task_predict = task_predict_list[i]
                task_y_val = task_y_val_list[i]
                if cth == 0:
                    auc = roc_auc_score(task_y_val, task_predict)
                    auc_list.append(auc)
                class_pred = np.where(task_predict>cth,1,0).astype(int)
                tn, fp, fn, tp = confusion_matrix(task_y_val, class_pred).ravel()
                sn, sp, acc, mcc = calc(tn, fp, fn, tp)
                mean_index = (sn + sp + acc + mcc)/4
                best_index = (best_sn_list[i] + best_sp_list[i] + best_acc_list[i] + best_mcc_list[i])/4
                if mean_index > best_index:
                    best_cth_list[i] = cth
                    best_sn_list[i] = sn
                    best_sp_list[i] = sp
                    best_acc_list[i] = acc
                    best_mcc_list[i] = mcc
        sn_list = best_sn_list
        sp_list = best_sp_list
        acc_list = best_acc_list
        mcc_list = best_mcc_list
    else:
        sn_list = []
        sp_list = []
        acc_list = []
        mcc_list = []
        for i in range(task_num):
            task_predict = task_predict_list[i]
            task_y_val = task_y_val_list[i]
            auc = roc_auc_score(task_y_val, task_predict)
            class_pred = np.where(task_predict>cth_list[i],1,0).astype(int)
            tn, fp, fn, tp = confusion_matrix(task_y_val, class_pred).ravel()
            sn, sp, acc, mcc = calc(tn, fp, fn, tp)
            auc_list.append(auc)
            sn_list.append(sn)
            sp_list.append(sp)
            acc_list.append(acc)
            mcc_list.append(mcc)

#     print(r2_list)
#     if generate:
#         generated_num = len(generated_smiles)
#         eval_num = len(dataset)
#         unique = generated_num
#         novelty = generated_num
#         for i in range(generated_num):
#             for j in range(generated_num-i-1):
#                 if generated_smiles[i]==generated_smiles[i+j+1]:
#                     unique -= 1
#             for k in range(eval_num):
#                 if generated_smiles[i]==dataset['smiles'].values[k]:
#                     novelty -= 1
#         unique_rate = unique/(generated_num+1e-9)
#         novelty_rate = novelty/(generated_num+1e-9)
# #         print(f'successfully/total generated molecules =', {f'Top-{i+1}': f'{success[i]}/{total[i]}' for i in range(topn)})
#         return success_reconstruction/len(dataset), success_validity/len(dataset), unique_rate, novelty_rate, success_smiles, generated_smiles, caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),predict_list
#     if return_GRN_loss:
#         return d_list, feature_list,caculate_r2(predict_list,dataset[tasks[0]].values.astype(float).tolist()),np.array(test_MSE_list).mean(),np.array(test_MAE_list).mean(),predict_list,reconstruction_loss, one_hot_loss, interger_loss,binary_loss
    if output_feature:
        return d_list, feature_list, auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list
    if output_cth:
        return auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list, best_cth_list
    return auc_list, sn_list, sp_list, acc_list, mcc_list, predict_list

epoch = 0
max_epoch = 1000
batch_size = 10
patience = 100
stopper = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_model.pth')
stopper_afse = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_amodel.pth')
stopper_generate = EarlyStopping(mode='higher', patience=patience, filename=model_file + '_gmodel.pth')

In [13]:
import datetime
from tensorboardX import SummaryWriter
now = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
if os.path.isdir(log_dir):
    for files in os.listdir(log_dir):
        os.remove(log_dir+"/"+files)
    os.rmdir(log_dir)
logger = SummaryWriter(log_dir)
print(log_dir)

log/3C_GAFSE_Multi_Tasks_Medium_run_0


In [None]:
# train_f_list=[]
# train_mse_list=[]
# train_r2_list=[]
# test_f_list=[]
# test_mse_list=[]
# test_r2_list=[]
# val_f_list=[]
# val_mse_list=[]
# val_r2_list=[]
# epoch_list=[]
# train_predict_list=[]
# test_predict_list=[]
# val_predict_list=[]
# train_y_list=[]
# test_y_list=[]
# val_y_list=[]
# train_d_list=[]
# test_d_list=[]
# val_d_list=[]

epoch = patience - patience
# stopper.load_checkpoint(model)
# stopper_afse.load_checkpoint(amodel)
# stopper_generate.load_checkpoint(gmodel)
optimizer_list = [optimizer, optimizer_AFSE, optimizer_GRN]
max_epoch = 1000
while epoch < max_epoch:
    train(model, amodel, gmodel, train_df, test_df, optimizer_list, loss_function, epoch)
    train_auc, train_sn, train_sp, train_acc, train_mcc, train_predict = eval(model, amodel, gmodel, train_df)
    val_auc, val_sn, val_sp, val_acc, val_mcc, val_predict, val_cth_list = eval(model, amodel, gmodel, val_df, output_cth= True)
    test_auc, test_sn, test_sp, test_acc, test_mcc, test_predict = eval(model, amodel, gmodel, test_df, cth_list=val_cth_list)
    
    epoch = epoch + 1
    global_step = epoch * int(np.max([len(train_df),len(test_df)])/batch_size)
    for i in range(task_num):
        logger.add_scalar(f'train/T{i}_auc', train_auc[i], global_step)
        logger.add_scalar(f'train/T{i}_sn', train_sn[i], global_step)
        logger.add_scalar(f'train/T{i}_sp', train_sp[i], global_step)
        logger.add_scalar(f'train/T{i}_acc', train_acc[i], global_step)
        logger.add_scalar(f'train/T{i}_mcc', train_mcc[i], global_step)
        logger.add_scalar(f'val/T{i}_cth', val_cth_list[i], global_step)
        logger.add_scalar(f'val/T{i}_auc', val_auc[i], global_step)
        logger.add_scalar(f'val/T{i}_sn', val_sn[i], global_step)
        logger.add_scalar(f'val/T{i}_sp', val_sp[i], global_step)
        logger.add_scalar(f'val/T{i}_acc', val_acc[i], global_step)
        logger.add_scalar(f'val/T{i}_mcc', val_mcc[i], global_step)
        logger.add_scalar(f'test/T{i}_auc', test_auc[i], global_step)
        logger.add_scalar(f'test/T{i}_sn', test_sn[i], global_step)
        logger.add_scalar(f'test/T{i}_sp', test_sp[i], global_step)
        logger.add_scalar(f'test/T{i}_acc', test_acc[i], global_step)
        logger.add_scalar(f'test/T{i}_mcc', test_mcc[i], global_step)
#         logger.add_scalar(f'val/GRN', reconstruction_loss, global_step)
#         logger.add_scalar(f'val/GRN_one_hot', one_hot_loss, global_step)
#         logger.add_scalar(f'val/GRN_interger', interger_loss, global_step)
#         logger.add_scalar(f'val/GRN_binary', binary_loss, global_step)
        # logger.add_scalar('test/EF0.01', topk_list[0], global_step)
        # logger.add_scalar('test/EF0.03', topk_list[1], global_step)
        # logger.add_scalar('test/EF0.1', topk_list[2], global_step)
        # logger.add_scalar('test/EF10', topk_list[3], global_step)
        # logger.add_scalar('test/EF30', topk_list[4], global_step)
        # logger.add_scalar('test/EF100', topk_list[5], global_step)

    #     train_mse_list.append(train_MSE**0.5)
    #     train_r2_list.append(train_r2)
    #     val_mse_list.append(val_MSE**0.5)  
    #     val_r2_list.append(val_r2)
    #     train_f_list.append(train_f)
    #     val_f_list.append(val_f)
    #     test_f_list.append(test_f)
    #     epoch_list.append(epoch)
    #     train_predict_list.append(train_predict.flatten())
    #     test_predict_list.append(test_predict.flatten())
    #     val_predict_list.append(val_predict.flatten())
    #     train_y_list.append(train_df[tasks[0]].values)
    #     val_y_list.append(val_df[tasks[0]].values)
    #     test_y_list.append(test_df[tasks[0]].values)
    #     train_d_list.append(train_d)
    #     val_d_list.append(val_d)
    #     test_d_list.append(test_d)
    #     print('epoch {:d}/{:d}, validation {} {:.4f}, {} {:.4f},best validation {r2} {:.4f}'.format(epoch, total_epoch, 'r2', val_r2, 'mse:',val_MSE, stopper.best_score))
        print('Epoch:',epoch, 'Task:', i+1,
              'auc:%.3f'%train_auc[i],'%.3f'%val_auc[i],'%.3f'%test_auc[i], 
              'sn:%.3f'%train_sn[i],'%.3f'%val_sn[i],'%.3f'%test_sn[i], 
              'sp:%.3f'%train_sp[i], '%.3f'%val_sp[i], '%.3f'%test_sp[i], 
              'acc:%.3f'%train_acc[i], '%.3f'%val_acc[i], '%.3f'%test_acc[i], 
              'mcc:%.3f'%train_mcc[i],'%.3f'%val_mcc[i],'%.3f'%test_mcc[i])
    
    stop_index = np.mean(val_auc) +  np.mean(val_sn) +  np.mean(val_sp) +  np.mean(val_acc) +  np.mean(val_mcc)
    early_stop = stopper.step(stop_index, model)
    early_stop = stopper_afse.step(stop_index, amodel, if_print=False)
    early_stop = stopper_generate.step(stop_index, gmodel, if_print=False)
    
    if early_stop:
        continue


Epoch: 1 Task: 1 auc:0.878 0.884 0.892 sn:0.953 0.902 0.894 sp:0.513 0.758 0.744 acc:0.776 0.843 0.833 mcc:0.540 0.672 0.651
Epoch: 1 Task: 2 auc:0.785 0.791 0.771 sn:0.695 0.848 0.794 sp:0.713 0.595 0.595 acc:0.703 0.741 0.710 mcc:0.403 0.463 0.398
Epoch: 1 Task: 3 auc:0.970 0.975 0.973 sn:0.449 0.909 0.955 sp:0.989 0.937 0.915 acc:0.781 0.926 0.930 mcc:0.555 0.844 0.858
Epoch: 1 Task: 4 auc:0.788 0.778 0.724 sn:0.000 0.731 0.440 sp:1.000 0.802 0.774 acc:0.956 0.799 0.760 mcc:0.000 0.264 0.102
Epoch: 1 Task: 5 auc:0.712 0.688 0.669 sn:0.009 0.345 0.310 sp:0.996 0.892 0.855 acc:0.845 0.808 0.771 mcc:0.022 0.242 0.159
Epoch: 1 Task: 6 auc:0.784 0.806 0.786 sn:0.007 0.534 0.484 sp:0.998 0.886 0.844 acc:0.848 0.833 0.789 mcc:0.029 0.392 0.292
Epoch: 2 Task: 1 auc:0.890 0.886 0.900 sn:0.853 0.856 0.894 sp:0.767 0.824 0.722 acc:0.819 0.843 0.824 mcc:0.622 0.677 0.632
Epoch: 2 Task: 2 auc:0.803 0.797 0.777 sn:0.947 0.848 0.806 sp:0.500 0.636 0.645 acc:0.757 0.759 0.738 mcc:0.515 0.500 0.458


Epoch: 12 Task: 1 auc:0.924 0.886 0.902 sn:0.833 0.902 0.932 sp:0.870 0.769 0.756 acc:0.848 0.848 0.860 mcc:0.693 0.682 0.709
Epoch: 12 Task: 2 auc:0.904 0.883 0.889 sn:0.712 0.861 0.867 sp:0.901 0.736 0.719 acc:0.792 0.808 0.804 mcc:0.609 0.603 0.596
Epoch: 12 Task: 3 auc:0.990 0.983 0.977 sn:0.962 0.989 0.933 sp:0.961 0.965 0.957 acc:0.961 0.974 0.948 mcc:0.919 0.946 0.890
Epoch: 12 Task: 4 auc:0.876 0.824 0.806 sn:0.083 0.577 0.400 sp:1.000 0.922 0.909 acc:0.960 0.906 0.888 mcc:0.273 0.341 0.205
Epoch: 12 Task: 5 auc:0.828 0.744 0.827 sn:0.268 0.471 0.552 sp:0.975 0.867 0.883 acc:0.867 0.806 0.832 mcc:0.364 0.314 0.405
Epoch: 12 Task: 6 auc:0.926 0.916 0.911 sn:0.496 0.830 0.802 sp:0.981 0.855 0.862 acc:0.908 0.852 0.853 mcc:0.593 0.567 0.561
EarlyStopping counter: 2 out of 100
Epoch: 13 Task: 1 auc:0.944 0.910 0.918 sn:0.945 0.902 0.947 sp:0.724 0.813 0.744 acc:0.855 0.865 0.865 mcc:0.700 0.720 0.720
Epoch: 13 Task: 2 auc:0.918 0.885 0.876 sn:0.904 0.879 0.855 sp:0.749 0.752 0.686 

Epoch: 23 Task: 1 auc:0.971 0.912 0.918 sn:0.937 0.909 0.909 sp:0.856 0.835 0.789 acc:0.904 0.879 0.860 mcc:0.800 0.748 0.708
Epoch: 23 Task: 2 auc:0.955 0.904 0.904 sn:0.912 0.855 0.836 sp:0.855 0.843 0.851 acc:0.888 0.850 0.843 mcc:0.770 0.694 0.682
Epoch: 23 Task: 3 auc:0.998 0.980 0.973 sn:0.986 0.977 0.966 sp:0.969 0.951 0.936 acc:0.976 0.961 0.948 mcc:0.949 0.919 0.893
Epoch: 23 Task: 4 auc:0.935 0.860 0.839 sn:0.249 0.538 0.600 sp:0.999 0.909 0.888 acc:0.966 0.893 0.876 mcc:0.462 0.293 0.290
Epoch: 23 Task: 5 auc:0.892 0.836 0.851 sn:0.454 0.724 0.759 sp:0.970 0.806 0.790 acc:0.892 0.794 0.785 mcc:0.522 0.428 0.435
Epoch: 23 Task: 6 auc:0.958 0.922 0.926 sn:0.700 0.773 0.758 sp:0.970 0.918 0.902 acc:0.929 0.896 0.880 mcc:0.711 0.634 0.596
EarlyStopping counter: 2 out of 100
Epoch: 24 Task: 1 auc:0.968 0.895 0.905 sn:0.829 0.848 0.826 sp:0.958 0.890 0.844 acc:0.881 0.865 0.833 mcc:0.772 0.729 0.662
Epoch: 24 Task: 2 auc:0.958 0.900 0.901 sn:0.935 0.873 0.873 sp:0.828 0.826 0.802 

Epoch: 34 Task: 1 auc:0.986 0.908 0.915 sn:0.931 0.856 0.856 sp:0.959 0.901 0.856 acc:0.943 0.874 0.856 mcc:0.883 0.748 0.705
Epoch: 34 Task: 2 auc:0.981 0.915 0.902 sn:0.961 0.897 0.885 sp:0.894 0.826 0.826 acc:0.933 0.867 0.860 mcc:0.863 0.727 0.713
Epoch: 34 Task: 3 auc:1.000 0.985 0.985 sn:0.986 0.989 0.933 sp:0.996 0.965 0.943 acc:0.992 0.974 0.939 mcc:0.984 0.946 0.872
Epoch: 34 Task: 4 auc:0.959 0.844 0.865 sn:0.346 0.462 0.280 sp:0.998 0.975 0.961 acc:0.969 0.952 0.932 mcc:0.540 0.437 0.224
Epoch: 34 Task: 5 auc:0.932 0.822 0.860 sn:0.436 0.517 0.506 sp:0.986 0.938 0.927 acc:0.902 0.873 0.862 mcc:0.564 0.484 0.450
Epoch: 34 Task: 6 auc:0.971 0.928 0.926 sn:0.808 0.875 0.791 sp:0.965 0.892 0.886 acc:0.942 0.889 0.872 mcc:0.772 0.657 0.592
Epoch: 35 Task: 1 auc:0.970 0.905 0.912 sn:0.968 0.909 0.894 sp:0.787 0.802 0.789 acc:0.895 0.865 0.851 mcc:0.783 0.720 0.690
Epoch: 35 Task: 2 auc:0.973 0.910 0.892 sn:0.935 0.909 0.885 sp:0.889 0.810 0.760 acc:0.915 0.867 0.832 mcc:0.827 0.72

Epoch: 45 Task: 1 auc:0.982 0.906 0.891 sn:0.926 0.864 0.818 sp:0.948 0.879 0.822 acc:0.935 0.870 0.820 mcc:0.867 0.736 0.633
Epoch: 45 Task: 2 auc:0.991 0.909 0.885 sn:0.951 0.933 0.921 sp:0.951 0.752 0.686 acc:0.951 0.857 0.822 mcc:0.899 0.707 0.635
Epoch: 45 Task: 3 auc:1.000 0.988 0.989 sn:0.994 1.000 0.955 sp:0.993 0.965 0.957 acc:0.993 0.978 0.957 mcc:0.986 0.955 0.909
Epoch: 45 Task: 4 auc:0.961 0.843 0.818 sn:0.400 0.423 0.200 sp:0.996 0.991 0.989 acc:0.970 0.966 0.956 mcc:0.564 0.523 0.282
Epoch: 45 Task: 5 auc:0.957 0.827 0.850 sn:0.657 0.770 0.816 sp:0.980 0.783 0.732 acc:0.931 0.781 0.745 mcc:0.712 0.435 0.414
Epoch: 45 Task: 6 auc:0.979 0.916 0.924 sn:0.755 0.795 0.780 sp:0.985 0.914 0.884 acc:0.950 0.896 0.868 mcc:0.795 0.642 0.581
EarlyStopping counter: 4 out of 100
Epoch: 46 Task: 1 auc:0.991 0.905 0.906 sn:0.977 0.871 0.902 sp:0.886 0.857 0.789 acc:0.940 0.865 0.856 mcc:0.877 0.724 0.699
Epoch: 46 Task: 2 auc:0.989 0.918 0.898 sn:0.936 0.897 0.897 sp:0.963 0.802 0.752 

Epoch: 56 Task: 1 auc:0.992 0.902 0.912 sn:0.989 0.833 0.864 sp:0.842 0.868 0.844 acc:0.929 0.848 0.856 mcc:0.856 0.693 0.704
Epoch: 56 Task: 2 auc:0.994 0.908 0.890 sn:0.973 0.855 0.848 sp:0.956 0.851 0.752 acc:0.966 0.853 0.808 mcc:0.929 0.702 0.604
Epoch: 56 Task: 3 auc:1.000 0.988 0.990 sn:0.997 0.977 0.910 sp:0.997 0.972 0.965 acc:0.997 0.974 0.943 mcc:0.994 0.945 0.880
Epoch: 56 Task: 4 auc:0.982 0.799 0.840 sn:0.561 0.308 0.120 sp:0.998 0.998 0.988 acc:0.979 0.968 0.951 mcc:0.715 0.512 0.168
Epoch: 56 Task: 5 auc:0.965 0.839 0.806 sn:0.697 0.713 0.655 sp:0.977 0.850 0.792 acc:0.935 0.829 0.771 mcc:0.732 0.477 0.361
Epoch: 56 Task: 6 auc:0.984 0.892 0.930 sn:0.825 0.795 0.780 sp:0.984 0.920 0.906 acc:0.960 0.901 0.887 mcc:0.840 0.654 0.620
EarlyStopping counter: 9 out of 100
Epoch: 57 Task: 1 auc:0.985 0.881 0.890 sn:0.960 0.826 0.826 sp:0.955 0.890 0.767 acc:0.958 0.852 0.802 mcc:0.913 0.705 0.590
Epoch: 57 Task: 2 auc:0.996 0.908 0.892 sn:0.987 0.903 0.885 sp:0.940 0.802 0.744 

Epoch: 67 Task: 1 auc:0.992 0.896 0.899 sn:0.956 0.780 0.811 sp:0.961 0.901 0.889 acc:0.958 0.830 0.842 mcc:0.913 0.670 0.688
Epoch: 67 Task: 2 auc:0.991 0.900 0.882 sn:0.949 0.939 0.927 sp:0.976 0.760 0.686 acc:0.961 0.864 0.825 mcc:0.921 0.722 0.643
Epoch: 67 Task: 3 auc:1.000 0.990 0.988 sn:0.997 1.000 0.955 sp:0.996 0.972 0.950 acc:0.997 0.983 0.952 mcc:0.993 0.964 0.900
Epoch: 67 Task: 4 auc:0.975 0.780 0.849 sn:0.727 0.538 0.440 sp:0.990 0.947 0.927 acc:0.979 0.928 0.906 mcc:0.740 0.379 0.261
Epoch: 67 Task: 5 auc:0.981 0.829 0.843 sn:0.845 0.598 0.540 sp:0.968 0.915 0.912 acc:0.949 0.866 0.855 mcc:0.806 0.499 0.448
Epoch: 67 Task: 6 auc:0.992 0.904 0.924 sn:0.938 0.739 0.736 sp:0.974 0.946 0.938 acc:0.969 0.915 0.907 mcc:0.884 0.672 0.654
EarlyStopping counter: 3 out of 100
Epoch: 68 Task: 1 auc:0.994 0.896 0.900 sn:0.984 0.841 0.833 sp:0.927 0.890 0.833 acc:0.961 0.861 0.833 mcc:0.919 0.721 0.660
Epoch: 68 Task: 2 auc:0.998 0.906 0.886 sn:0.990 0.903 0.921 sp:0.955 0.793 0.702 

Epoch: 78 Task: 1 auc:0.998 0.897 0.906 sn:0.996 0.818 0.818 sp:0.937 0.868 0.811 acc:0.972 0.839 0.815 mcc:0.943 0.677 0.623
Epoch: 78 Task: 2 auc:0.998 0.912 0.875 sn:0.989 0.879 0.861 sp:0.974 0.818 0.736 acc:0.983 0.853 0.808 mcc:0.964 0.699 0.603
Epoch: 78 Task: 3 auc:1.000 0.990 0.984 sn:0.999 1.000 0.955 sp:0.999 0.958 0.943 acc:0.999 0.974 0.948 mcc:0.998 0.947 0.892
Epoch: 78 Task: 4 auc:0.992 0.716 0.748 sn:0.732 0.385 0.200 sp:0.999 0.986 0.968 acc:0.987 0.959 0.935 mcc:0.838 0.442 0.175
Epoch: 78 Task: 5 auc:0.992 0.822 0.819 sn:0.910 0.632 0.644 sp:0.979 0.892 0.866 acc:0.968 0.852 0.832 mcc:0.878 0.483 0.450
Epoch: 78 Task: 6 auc:0.996 0.904 0.938 sn:0.907 0.830 0.879 sp:0.993 0.932 0.874 acc:0.980 0.916 0.875 mcc:0.921 0.704 0.635
EarlyStopping counter: 14 out of 100
Epoch: 79 Task: 1 auc:0.996 0.887 0.906 sn:0.983 0.788 0.833 sp:0.942 0.901 0.844 acc:0.967 0.834 0.838 mcc:0.931 0.677 0.670
Epoch: 79 Task: 2 auc:0.998 0.920 0.874 sn:0.994 0.945 0.885 sp:0.969 0.777 0.669

Epoch: 89 Task: 1 auc:0.997 0.877 0.882 sn:0.961 0.833 0.803 sp:0.985 0.868 0.767 acc:0.971 0.848 0.788 mcc:0.940 0.693 0.565
Epoch: 89 Task: 2 auc:0.998 0.912 0.863 sn:0.983 0.891 0.885 sp:0.978 0.793 0.702 acc:0.981 0.850 0.808 mcc:0.961 0.690 0.603
Epoch: 89 Task: 3 auc:1.000 0.990 0.986 sn:0.999 1.000 0.933 sp:0.999 0.972 0.957 acc:0.999 0.983 0.948 mcc:0.998 0.964 0.890
Epoch: 89 Task: 4 auc:0.995 0.778 0.788 sn:0.751 0.577 0.360 sp:0.997 0.955 0.941 acc:0.987 0.939 0.917 mcc:0.828 0.435 0.236
Epoch: 89 Task: 5 auc:0.989 0.812 0.827 sn:0.823 0.540 0.621 sp:0.988 0.915 0.906 acc:0.963 0.857 0.862 mcc:0.852 0.453 0.500
Epoch: 89 Task: 6 auc:0.976 0.876 0.897 sn:0.795 0.716 0.692 sp:0.996 0.954 0.952 acc:0.966 0.918 0.912 mcc:0.861 0.676 0.656
EarlyStopping counter: 25 out of 100
Epoch: 90 Task: 1 auc:0.998 0.895 0.889 sn:0.997 0.841 0.902 sp:0.954 0.879 0.733 acc:0.980 0.857 0.833 mcc:0.958 0.711 0.651
Epoch: 90 Task: 2 auc:0.999 0.908 0.880 sn:0.992 0.927 0.933 sp:0.979 0.744 0.579

Epoch: 100 Task: 1 auc:0.999 0.909 0.900 sn:0.981 0.841 0.856 sp:0.989 0.868 0.811 acc:0.984 0.852 0.838 mcc:0.967 0.701 0.665
Epoch: 100 Task: 2 auc:0.997 0.905 0.883 sn:0.970 0.855 0.879 sp:0.986 0.826 0.711 acc:0.977 0.843 0.808 mcc:0.953 0.679 0.603
Epoch: 100 Task: 3 auc:0.996 0.976 0.969 sn:0.839 0.864 0.865 sp:1.000 0.965 0.936 acc:0.938 0.926 0.909 mcc:0.873 0.843 0.807
Epoch: 100 Task: 4 auc:0.987 0.768 0.761 sn:0.644 0.385 0.240 sp:1.000 0.979 0.954 acc:0.984 0.952 0.923 mcc:0.793 0.393 0.172
Epoch: 100 Task: 5 auc:0.990 0.809 0.839 sn:0.867 0.609 0.575 sp:0.993 0.898 0.874 acc:0.973 0.854 0.828 mcc:0.895 0.476 0.409
Epoch: 100 Task: 6 auc:0.979 0.908 0.885 sn:0.916 0.830 0.802 sp:0.984 0.904 0.860 acc:0.973 0.892 0.851 mcc:0.896 0.647 0.558
EarlyStopping counter: 36 out of 100
Epoch: 101 Task: 1 auc:0.999 0.901 0.896 sn:0.994 0.818 0.856 sp:0.938 0.879 0.789 acc:0.972 0.843 0.829 mcc:0.942 0.687 0.645
Epoch: 101 Task: 2 auc:0.999 0.921 0.890 sn:0.989 0.903 0.897 sp:0.985 0.8

Epoch: 111 Task: 1 auc:0.999 0.909 0.899 sn:0.992 0.848 0.848 sp:0.996 0.868 0.800 acc:0.994 0.857 0.829 mcc:0.987 0.709 0.646
Epoch: 111 Task: 2 auc:1.000 0.918 0.875 sn:0.997 0.867 0.891 sp:0.978 0.851 0.736 acc:0.989 0.860 0.825 mcc:0.978 0.715 0.640
Epoch: 111 Task: 3 auc:1.000 0.991 0.983 sn:0.999 0.989 0.955 sp:0.999 0.972 0.950 acc:0.999 0.978 0.952 mcc:0.998 0.955 0.900
Epoch: 111 Task: 4 auc:0.998 0.758 0.790 sn:0.829 0.423 0.240 sp:1.000 0.984 0.968 acc:0.992 0.959 0.937 mcc:0.904 0.462 0.212
Epoch: 111 Task: 5 auc:0.997 0.824 0.855 sn:0.961 0.575 0.575 sp:0.991 0.917 0.912 acc:0.986 0.864 0.860 mcc:0.947 0.485 0.476
Epoch: 111 Task: 6 auc:0.999 0.910 0.918 sn:0.985 0.807 0.769 sp:0.992 0.940 0.902 acc:0.991 0.920 0.882 mcc:0.966 0.706 0.604
EarlyStopping counter: 47 out of 100
Epoch: 112 Task: 1 auc:0.998 0.904 0.893 sn:0.991 0.894 0.917 sp:0.982 0.791 0.733 acc:0.988 0.852 0.842 mcc:0.974 0.692 0.670
Epoch: 112 Task: 2 auc:0.998 0.907 0.868 sn:0.982 0.885 0.879 sp:0.978 0.8

Epoch: 122 Task: 1 auc:0.999 0.894 0.894 sn:0.992 0.818 0.833 sp:0.996 0.868 0.844 acc:0.994 0.839 0.838 mcc:0.987 0.677 0.670
Epoch: 122 Task: 2 auc:1.000 0.906 0.852 sn:0.992 0.903 0.897 sp:0.991 0.760 0.603 acc:0.991 0.843 0.773 mcc:0.982 0.676 0.532
Epoch: 122 Task: 3 auc:1.000 0.993 0.982 sn:0.997 0.989 0.955 sp:1.000 0.972 0.943 acc:0.999 0.978 0.948 mcc:0.998 0.955 0.892
Epoch: 122 Task: 4 auc:0.999 0.700 0.836 sn:0.946 0.462 0.240 sp:0.999 0.968 0.947 acc:0.997 0.945 0.917 mcc:0.959 0.401 0.157
Epoch: 122 Task: 5 auc:0.998 0.814 0.813 sn:0.965 0.575 0.598 sp:0.993 0.917 0.891 acc:0.989 0.864 0.846 mcc:0.957 0.485 0.455
Epoch: 122 Task: 6 auc:0.999 0.899 0.906 sn:0.983 0.784 0.769 sp:0.994 0.946 0.904 acc:0.992 0.922 0.883 mcc:0.969 0.704 0.608
EarlyStopping counter: 58 out of 100
Epoch: 123 Task: 1 auc:1.000 0.900 0.901 sn:0.986 0.773 0.803 sp:1.000 0.923 0.833 acc:0.991 0.834 0.815 mcc:0.983 0.684 0.627
Epoch: 123 Task: 2 auc:0.999 0.910 0.858 sn:0.988 0.897 0.921 sp:0.990 0.7

Epoch: 133 Task: 1 auc:0.999 0.902 0.896 sn:0.996 0.826 0.848 sp:0.983 0.868 0.767 acc:0.991 0.843 0.815 mcc:0.981 0.685 0.616
Epoch: 133 Task: 2 auc:1.000 0.908 0.851 sn:0.998 0.885 0.867 sp:0.973 0.793 0.678 acc:0.987 0.846 0.787 mcc:0.974 0.683 0.559
Epoch: 133 Task: 3 auc:1.000 0.991 0.982 sn:1.000 1.000 0.955 sp:0.997 0.958 0.936 acc:0.998 0.974 0.943 mcc:0.997 0.947 0.883
Epoch: 133 Task: 4 auc:0.999 0.711 0.730 sn:0.932 0.423 0.200 sp:0.998 0.980 0.948 acc:0.995 0.956 0.917 mcc:0.943 0.437 0.128
Epoch: 133 Task: 5 auc:0.999 0.819 0.848 sn:0.968 0.598 0.609 sp:0.995 0.944 0.895 acc:0.991 0.891 0.851 mcc:0.964 0.564 0.472
Epoch: 133 Task: 6 auc:0.998 0.894 0.902 sn:0.940 0.750 0.758 sp:0.999 0.946 0.926 acc:0.990 0.916 0.900 mcc:0.961 0.680 0.644
EarlyStopping counter: 69 out of 100
Epoch: 134 Task: 1 auc:0.999 0.903 0.889 sn:0.984 0.841 0.864 sp:0.992 0.879 0.800 acc:0.987 0.857 0.838 mcc:0.973 0.711 0.664
Epoch: 134 Task: 2 auc:0.999 0.908 0.847 sn:0.989 0.873 0.891 sp:0.987 0.8

Epoch: 144 Task: 1 auc:0.999 0.883 0.901 sn:0.990 0.765 0.848 sp:0.979 0.879 0.822 acc:0.985 0.812 0.838 mcc:0.969 0.633 0.666
Epoch: 144 Task: 2 auc:0.996 0.887 0.857 sn:0.976 0.830 0.855 sp:0.988 0.818 0.702 acc:0.981 0.825 0.790 mcc:0.962 0.645 0.567
Epoch: 144 Task: 3 auc:1.000 0.992 0.982 sn:1.000 0.989 0.921 sp:0.998 0.965 0.943 acc:0.999 0.974 0.935 mcc:0.998 0.946 0.863
Epoch: 144 Task: 4 auc:0.998 0.784 0.714 sn:0.922 0.538 0.280 sp:1.000 0.955 0.909 acc:0.996 0.937 0.883 mcc:0.953 0.408 0.128
Epoch: 144 Task: 5 auc:0.998 0.785 0.842 sn:0.938 0.414 0.483 sp:0.995 0.963 0.941 acc:0.986 0.878 0.871 mcc:0.946 0.462 0.465
Epoch: 144 Task: 6 auc:0.999 0.870 0.890 sn:0.931 0.648 0.692 sp:0.999 0.958 0.922 acc:0.988 0.911 0.887 mcc:0.954 0.637 0.587
EarlyStopping counter: 80 out of 100
Epoch: 145 Task: 1 auc:0.999 0.898 0.897 sn:0.996 0.879 0.902 sp:0.983 0.835 0.700 acc:0.991 0.861 0.820 mcc:0.981 0.713 0.622
Epoch: 145 Task: 2 auc:1.000 0.913 0.853 sn:0.994 0.891 0.891 sp:0.992 0.7

Epoch: 155 Task: 1 auc:1.000 0.899 0.906 sn:0.996 0.788 0.864 sp:0.989 0.923 0.822 acc:0.993 0.843 0.847 mcc:0.986 0.699 0.684
Epoch: 155 Task: 2 auc:0.999 0.906 0.862 sn:0.995 0.945 0.952 sp:0.983 0.719 0.587 acc:0.990 0.850 0.797 mcc:0.979 0.695 0.595
Epoch: 155 Task: 3 auc:1.000 0.992 0.974 sn:0.997 1.000 0.910 sp:1.000 0.972 0.943 acc:0.999 0.983 0.930 mcc:0.998 0.964 0.853
Epoch: 155 Task: 4 auc:1.000 0.744 0.742 sn:0.883 0.423 0.160 sp:1.000 0.984 0.959 acc:0.995 0.959 0.925 mcc:0.934 0.462 0.115
Epoch: 155 Task: 5 auc:0.999 0.826 0.825 sn:0.957 0.586 0.563 sp:0.996 0.944 0.918 acc:0.990 0.889 0.863 mcc:0.961 0.555 0.479
Epoch: 155 Task: 6 auc:0.999 0.897 0.880 sn:0.951 0.750 0.758 sp:0.997 0.950 0.904 acc:0.990 0.920 0.882 mcc:0.962 0.690 0.600
EarlyStopping counter: 91 out of 100
Epoch: 156 Task: 1 auc:1.000 0.895 0.878 sn:0.995 0.856 0.856 sp:0.987 0.846 0.711 acc:0.992 0.852 0.797 mcc:0.984 0.697 0.576
Epoch: 156 Task: 2 auc:1.000 0.895 0.861 sn:0.998 0.867 0.897 sp:0.982 0.7

Epoch: 166 Task: 1 auc:1.000 0.902 0.891 sn:0.996 0.803 0.795 sp:0.997 0.890 0.811 acc:0.997 0.839 0.802 mcc:0.993 0.682 0.599
Epoch: 166 Task: 2 auc:1.000 0.912 0.856 sn:0.997 0.885 0.909 sp:0.995 0.777 0.595 acc:0.996 0.839 0.776 mcc:0.992 0.669 0.541
Epoch: 166 Task: 3 auc:1.000 0.995 0.979 sn:1.000 1.000 0.944 sp:0.998 0.972 0.936 acc:0.999 0.983 0.939 mcc:0.998 0.964 0.873
Epoch: 166 Task: 4 auc:0.998 0.677 0.713 sn:0.941 0.500 0.320 sp:0.999 0.963 0.918 acc:0.996 0.942 0.893 mcc:0.956 0.407 0.166
Epoch: 166 Task: 5 auc:0.998 0.822 0.857 sn:0.980 0.563 0.667 sp:0.996 0.917 0.870 acc:0.994 0.862 0.839 mcc:0.975 0.475 0.474
Epoch: 166 Task: 6 auc:1.000 0.889 0.899 sn:0.975 0.670 0.659 sp:0.999 0.968 0.952 acc:0.995 0.923 0.907 mcc:0.982 0.683 0.632
EarlyStopping counter: 102 out of 100
Epoch: 167 Task: 1 auc:1.000 0.900 0.886 sn:0.990 0.780 0.811 sp:0.999 0.934 0.800 acc:0.993 0.843 0.806 mcc:0.986 0.702 0.604
Epoch: 167 Task: 2 auc:1.000 0.909 0.845 sn:0.997 0.848 0.855 sp:0.990 0.

Epoch: 177 Task: 1 auc:1.000 0.890 0.880 sn:0.994 0.788 0.856 sp:0.996 0.890 0.733 acc:0.995 0.830 0.806 mcc:0.989 0.667 0.595
Epoch: 177 Task: 2 auc:1.000 0.906 0.858 sn:0.992 0.909 0.927 sp:0.995 0.760 0.603 acc:0.993 0.846 0.790 mcc:0.986 0.683 0.574
Epoch: 177 Task: 3 auc:1.000 0.992 0.974 sn:0.996 0.989 0.910 sp:1.000 0.965 0.929 acc:0.998 0.974 0.922 mcc:0.997 0.946 0.836
Epoch: 177 Task: 4 auc:1.000 0.735 0.693 sn:0.966 0.462 0.200 sp:0.997 0.973 0.938 acc:0.995 0.951 0.906 mcc:0.945 0.427 0.110
Epoch: 177 Task: 5 auc:0.999 0.849 0.837 sn:0.958 0.690 0.632 sp:0.997 0.904 0.847 acc:0.991 0.871 0.814 mcc:0.964 0.549 0.413
Epoch: 177 Task: 6 auc:1.000 0.883 0.879 sn:0.986 0.716 0.758 sp:0.999 0.948 0.914 acc:0.997 0.913 0.890 mcc:0.988 0.661 0.619
EarlyStopping counter: 113 out of 100
Epoch: 178 Task: 1 auc:0.999 0.891 0.882 sn:0.990 0.864 0.894 sp:0.990 0.824 0.678 acc:0.990 0.848 0.806 mcc:0.980 0.686 0.593
Epoch: 178 Task: 2 auc:0.998 0.904 0.871 sn:0.985 0.836 0.867 sp:0.994 0.

Epoch: 188 Task: 1 auc:1.000 0.903 0.897 sn:0.995 0.833 0.856 sp:0.996 0.890 0.811 acc:0.995 0.857 0.838 mcc:0.991 0.713 0.665
Epoch: 188 Task: 2 auc:1.000 0.906 0.862 sn:0.984 0.861 0.891 sp:0.996 0.810 0.702 acc:0.989 0.839 0.811 mcc:0.978 0.671 0.611
Epoch: 188 Task: 3 auc:1.000 0.992 0.976 sn:0.999 0.977 0.944 sp:1.000 0.965 0.929 acc:0.999 0.970 0.935 mcc:0.999 0.936 0.865
Epoch: 188 Task: 4 auc:1.000 0.691 0.705 sn:0.961 0.423 0.200 sp:0.998 0.984 0.963 acc:0.997 0.959 0.930 mcc:0.959 0.462 0.160
Epoch: 188 Task: 5 auc:0.994 0.797 0.811 sn:0.900 0.586 0.552 sp:0.996 0.933 0.895 acc:0.981 0.880 0.842 mcc:0.927 0.530 0.426
Epoch: 188 Task: 6 auc:1.000 0.886 0.869 sn:0.978 0.739 0.714 sp:0.996 0.956 0.926 acc:0.993 0.923 0.894 mcc:0.973 0.698 0.612
EarlyStopping counter: 124 out of 100
Epoch: 189 Task: 1 auc:0.998 0.898 0.885 sn:0.985 0.856 0.886 sp:0.990 0.857 0.767 acc:0.987 0.857 0.838 mcc:0.973 0.707 0.661
Epoch: 189 Task: 2 auc:0.999 0.907 0.834 sn:0.989 0.776 0.782 sp:0.995 0.

Epoch: 199 Task: 1 auc:1.000 0.882 0.892 sn:0.997 0.811 0.894 sp:0.986 0.802 0.722 acc:0.993 0.807 0.824 mcc:0.985 0.607 0.632
Epoch: 199 Task: 2 auc:1.000 0.910 0.842 sn:0.997 0.933 0.927 sp:0.987 0.752 0.570 acc:0.993 0.857 0.776 mcc:0.985 0.707 0.546
Epoch: 199 Task: 3 auc:1.000 0.995 0.976 sn:0.999 0.989 0.944 sp:0.997 0.979 0.943 acc:0.998 0.983 0.943 mcc:0.995 0.964 0.882
Epoch: 199 Task: 4 auc:1.000 0.703 0.754 sn:0.917 0.385 0.160 sp:1.000 0.989 0.972 acc:0.996 0.963 0.937 mcc:0.953 0.472 0.146
Epoch: 199 Task: 5 auc:0.997 0.835 0.834 sn:0.964 0.632 0.678 sp:0.996 0.908 0.836 acc:0.991 0.866 0.812 mcc:0.966 0.513 0.433
Epoch: 199 Task: 6 auc:0.998 0.876 0.883 sn:0.985 0.716 0.736 sp:0.999 0.950 0.908 acc:0.997 0.915 0.882 mcc:0.987 0.666 0.591
EarlyStopping counter: 135 out of 100
Epoch: 200 Task: 1 auc:1.000 0.889 0.903 sn:0.998 0.818 0.833 sp:0.994 0.890 0.811 acc:0.997 0.848 0.824 mcc:0.993 0.697 0.639
Epoch: 200 Task: 2 auc:0.999 0.904 0.856 sn:0.997 0.867 0.873 sp:0.993 0.

Epoch: 210 Task: 1 auc:0.996 0.874 0.890 sn:0.994 0.788 0.826 sp:0.969 0.890 0.778 acc:0.984 0.830 0.806 mcc:0.967 0.667 0.601
Epoch: 210 Task: 2 auc:1.000 0.907 0.864 sn:0.996 0.903 0.945 sp:0.997 0.744 0.587 acc:0.997 0.836 0.794 mcc:0.993 0.662 0.586
Epoch: 210 Task: 3 auc:1.000 0.992 0.973 sn:0.999 1.000 0.955 sp:1.000 0.972 0.908 acc:0.999 0.983 0.926 mcc:0.999 0.964 0.850
Epoch: 210 Task: 4 auc:0.997 0.721 0.700 sn:0.951 0.385 0.240 sp:0.996 0.975 0.934 acc:0.994 0.949 0.905 mcc:0.935 0.374 0.135
Epoch: 210 Task: 5 auc:0.998 0.829 0.820 sn:0.951 0.690 0.805 sp:0.997 0.844 0.727 acc:0.990 0.820 0.739 mcc:0.959 0.451 0.402
Epoch: 210 Task: 6 auc:1.000 0.885 0.895 sn:0.993 0.773 0.791 sp:0.996 0.936 0.860 acc:0.996 0.911 0.850 mcc:0.983 0.673 0.550
EarlyStopping counter: 146 out of 100
Epoch: 211 Task: 1 auc:1.000 0.893 0.901 sn:1.000 0.833 0.909 sp:0.980 0.846 0.744 acc:0.992 0.839 0.842 mcc:0.984 0.672 0.670
Epoch: 211 Task: 2 auc:1.000 0.904 0.860 sn:0.997 0.830 0.836 sp:0.990 0.

Epoch: 222 Task: 1 auc:1.000 0.898 0.885 sn:0.993 0.795 0.803 sp:1.000 0.912 0.833 acc:0.996 0.843 0.815 mcc:0.992 0.696 0.627
Epoch: 222 Task: 2 auc:0.999 0.907 0.877 sn:0.997 0.855 0.879 sp:0.997 0.835 0.752 acc:0.997 0.846 0.825 mcc:0.994 0.687 0.640
Epoch: 222 Task: 3 auc:1.000 0.995 0.981 sn:0.999 1.000 0.944 sp:1.000 0.965 0.915 acc:0.999 0.978 0.926 mcc:0.999 0.955 0.848
Epoch: 222 Task: 4 auc:1.000 0.770 0.765 sn:0.966 0.538 0.280 sp:0.998 0.954 0.897 acc:0.997 0.935 0.871 mcc:0.962 0.402 0.114
Epoch: 222 Task: 5 auc:0.999 0.833 0.789 sn:0.974 0.494 0.437 sp:0.999 0.969 0.916 acc:0.995 0.896 0.842 mcc:0.981 0.551 0.369
Epoch: 222 Task: 6 auc:0.999 0.841 0.872 sn:0.980 0.636 0.626 sp:0.999 0.958 0.936 acc:0.996 0.910 0.889 mcc:0.985 0.628 0.568
EarlyStopping counter: 158 out of 100
Epoch: 223 Task: 1 auc:1.000 0.886 0.890 sn:0.998 0.864 0.909 sp:0.985 0.769 0.622 acc:0.993 0.825 0.793 mcc:0.985 0.636 0.566
Epoch: 223 Task: 2 auc:1.000 0.901 0.865 sn:0.996 0.830 0.879 sp:0.996 0.

Epoch: 233 Task: 1 auc:0.999 0.895 0.895 sn:0.997 0.864 0.848 sp:0.983 0.846 0.767 acc:0.991 0.857 0.815 mcc:0.982 0.705 0.616
Epoch: 233 Task: 2 auc:1.000 0.906 0.849 sn:0.998 0.909 0.909 sp:0.996 0.777 0.603 acc:0.997 0.853 0.780 mcc:0.995 0.698 0.548
Epoch: 233 Task: 3 auc:1.000 0.993 0.977 sn:0.999 0.989 0.944 sp:1.000 0.972 0.943 acc:0.999 0.978 0.943 mcc:0.999 0.955 0.882
Epoch: 233 Task: 4 auc:1.000 0.750 0.724 sn:0.951 0.385 0.160 sp:0.999 0.991 0.980 acc:0.997 0.964 0.946 mcc:0.959 0.490 0.180
Epoch: 233 Task: 5 auc:1.000 0.797 0.828 sn:0.980 0.598 0.701 sp:0.999 0.923 0.803 acc:0.996 0.873 0.787 mcc:0.985 0.516 0.408
Epoch: 233 Task: 6 auc:0.999 0.885 0.899 sn:0.997 0.739 0.791 sp:0.999 0.944 0.882 acc:0.998 0.913 0.868 mcc:0.993 0.667 0.585
EarlyStopping counter: 169 out of 100
Epoch: 234 Task: 1 auc:0.999 0.893 0.889 sn:0.996 0.811 0.848 sp:0.993 0.890 0.833 acc:0.995 0.843 0.842 mcc:0.989 0.690 0.677
Epoch: 234 Task: 2 auc:1.000 0.903 0.846 sn:0.992 0.842 0.915 sp:0.995 0.

Epoch: 244 Task: 1 auc:1.000 0.907 0.894 sn:1.000 0.871 0.886 sp:0.996 0.868 0.744 acc:0.998 0.870 0.829 mcc:0.996 0.734 0.642
Epoch: 244 Task: 2 auc:1.000 0.914 0.871 sn:0.989 0.861 0.867 sp:1.000 0.868 0.736 acc:0.993 0.864 0.811 mcc:0.987 0.724 0.611
Epoch: 244 Task: 3 auc:1.000 0.995 0.986 sn:0.997 1.000 0.955 sp:0.999 0.972 0.957 acc:0.998 0.983 0.957 mcc:0.997 0.964 0.909
Epoch: 244 Task: 4 auc:1.000 0.736 0.782 sn:0.990 0.423 0.080 sp:0.999 0.988 0.963 acc:0.999 0.963 0.925 mcc:0.987 0.490 0.044
Epoch: 244 Task: 5 auc:0.999 0.802 0.838 sn:0.975 0.540 0.713 sp:0.998 0.921 0.828 acc:0.995 0.862 0.810 mcc:0.980 0.465 0.448
Epoch: 244 Task: 6 auc:1.000 0.896 0.899 sn:0.993 0.784 0.824 sp:0.999 0.932 0.840 acc:0.998 0.910 0.838 mcc:0.991 0.672 0.545
EarlyStopping counter: 180 out of 100
Epoch: 245 Task: 1 auc:1.000 0.888 0.901 sn:0.994 0.833 0.902 sp:0.997 0.868 0.733 acc:0.995 0.848 0.833 mcc:0.991 0.693 0.651
Epoch: 245 Task: 2 auc:1.000 0.917 0.866 sn:0.995 0.921 0.933 sp:0.995 0.

Epoch: 256 Task: 1 auc:0.999 0.893 0.893 sn:0.999 0.818 0.818 sp:0.993 0.868 0.844 acc:0.997 0.839 0.829 mcc:0.993 0.677 0.654
Epoch: 256 Task: 2 auc:1.000 0.898 0.853 sn:0.998 0.921 0.921 sp:0.998 0.760 0.562 acc:0.998 0.853 0.769 mcc:0.996 0.699 0.530
Epoch: 256 Task: 3 auc:1.000 0.993 0.982 sn:0.997 1.000 0.955 sp:0.999 0.958 0.901 acc:0.998 0.974 0.922 mcc:0.997 0.947 0.842
Epoch: 256 Task: 4 auc:1.000 0.676 0.676 sn:0.951 0.462 0.240 sp:1.000 0.975 0.947 acc:0.998 0.952 0.917 mcc:0.972 0.437 0.157
Epoch: 256 Task: 5 auc:1.000 0.820 0.818 sn:0.986 0.563 0.678 sp:0.999 0.938 0.849 acc:0.997 0.880 0.823 mcc:0.989 0.521 0.451
Epoch: 256 Task: 6 auc:1.000 0.861 0.857 sn:0.992 0.705 0.703 sp:0.997 0.952 0.878 acc:0.996 0.915 0.851 mcc:0.985 0.663 0.514
EarlyStopping counter: 192 out of 100
Epoch: 257 Task: 1 auc:1.000 0.881 0.891 sn:0.991 0.826 0.864 sp:1.000 0.868 0.700 acc:0.995 0.843 0.797 mcc:0.989 0.685 0.575
Epoch: 257 Task: 2 auc:1.000 0.908 0.849 sn:0.999 0.873 0.873 sp:0.995 0.

Epoch: 267 Task: 1 auc:1.000 0.883 0.893 sn:0.998 0.795 0.818 sp:0.996 0.879 0.800 acc:0.997 0.830 0.811 mcc:0.994 0.664 0.613
Epoch: 267 Task: 2 auc:1.000 0.917 0.855 sn:0.992 0.921 0.939 sp:0.999 0.777 0.612 acc:0.995 0.860 0.801 mcc:0.990 0.713 0.598
Epoch: 267 Task: 3 auc:1.000 0.994 0.977 sn:1.000 1.000 0.955 sp:0.998 0.958 0.915 acc:0.999 0.974 0.930 mcc:0.998 0.947 0.858
Epoch: 267 Task: 4 auc:1.000 0.747 0.705 sn:0.971 0.462 0.200 sp:1.000 0.973 0.895 acc:0.999 0.951 0.866 mcc:0.985 0.427 0.062
Epoch: 267 Task: 5 auc:1.000 0.786 0.806 sn:0.977 0.517 0.586 sp:0.999 0.931 0.857 acc:0.996 0.868 0.816 mcc:0.983 0.469 0.393
Epoch: 267 Task: 6 auc:0.999 0.876 0.894 sn:0.996 0.705 0.648 sp:0.996 0.970 0.946 acc:0.996 0.930 0.900 mcc:0.985 0.713 0.608
EarlyStopping counter: 203 out of 100
Epoch: 268 Task: 1 auc:0.998 0.878 0.887 sn:0.995 0.871 0.871 sp:0.959 0.791 0.689 acc:0.981 0.839 0.797 mcc:0.960 0.665 0.574
Epoch: 268 Task: 2 auc:0.999 0.913 0.850 sn:0.991 0.818 0.830 sp:0.990 0.

Epoch: 278 Task: 1 auc:0.999 0.903 0.890 sn:1.000 0.856 0.826 sp:0.996 0.868 0.744 acc:0.998 0.861 0.793 mcc:0.996 0.717 0.570
Epoch: 278 Task: 2 auc:1.000 0.910 0.846 sn:0.998 0.933 0.952 sp:0.997 0.711 0.471 acc:0.998 0.839 0.748 mcc:0.996 0.672 0.498
Epoch: 278 Task: 3 auc:1.000 0.994 0.976 sn:1.000 1.000 0.944 sp:0.999 0.972 0.936 acc:0.999 0.983 0.939 mcc:0.999 0.964 0.873
Epoch: 278 Task: 4 auc:1.000 0.745 0.711 sn:0.980 0.385 0.120 sp:1.000 0.991 0.975 acc:0.999 0.964 0.939 mcc:0.987 0.490 0.115
Epoch: 278 Task: 5 auc:0.997 0.828 0.837 sn:0.980 0.540 0.609 sp:0.999 0.956 0.889 acc:0.996 0.892 0.846 mcc:0.985 0.551 0.461
Epoch: 278 Task: 6 auc:0.999 0.840 0.877 sn:0.997 0.670 0.659 sp:0.997 0.962 0.908 acc:0.997 0.918 0.870 mcc:0.989 0.665 0.534
EarlyStopping counter: 214 out of 100
Epoch: 279 Task: 1 auc:0.998 0.881 0.886 sn:0.979 0.818 0.841 sp:0.999 0.857 0.767 acc:0.987 0.834 0.811 mcc:0.973 0.666 0.608
Epoch: 279 Task: 2 auc:1.000 0.910 0.868 sn:0.998 0.891 0.903 sp:0.992 0.

Epoch: 289 Task: 1 auc:0.999 0.890 0.896 sn:0.985 0.795 0.856 sp:0.993 0.890 0.789 acc:0.988 0.834 0.829 mcc:0.975 0.674 0.645
Epoch: 289 Task: 2 auc:1.000 0.915 0.832 sn:0.990 0.909 0.897 sp:0.995 0.785 0.587 acc:0.992 0.857 0.766 mcc:0.984 0.705 0.518
Epoch: 289 Task: 3 auc:1.000 0.989 0.979 sn:0.993 0.977 0.955 sp:0.996 0.958 0.908 acc:0.995 0.965 0.926 mcc:0.990 0.928 0.850
Epoch: 289 Task: 4 auc:1.000 0.729 0.667 sn:0.922 0.385 0.120 sp:1.000 0.989 0.972 acc:0.996 0.963 0.935 mcc:0.953 0.472 0.104
Epoch: 289 Task: 5 auc:0.997 0.816 0.802 sn:0.971 0.575 0.632 sp:0.988 0.908 0.807 acc:0.986 0.857 0.780 mcc:0.945 0.468 0.361
Epoch: 289 Task: 6 auc:0.999 0.855 0.861 sn:0.976 0.659 0.670 sp:0.998 0.966 0.932 acc:0.995 0.920 0.892 mcc:0.980 0.668 0.592
EarlyStopping counter: 225 out of 100
Epoch: 290 Task: 1 auc:0.999 0.876 0.880 sn:0.994 0.818 0.833 sp:0.986 0.857 0.756 acc:0.991 0.834 0.802 mcc:0.981 0.666 0.589
Epoch: 290 Task: 2 auc:0.997 0.915 0.843 sn:0.991 0.891 0.879 sp:0.989 0.

Epoch: 300 Task: 1 auc:1.000 0.881 0.895 sn:0.999 0.803 0.871 sp:0.992 0.846 0.744 acc:0.996 0.821 0.820 mcc:0.992 0.640 0.623
Epoch: 300 Task: 2 auc:1.000 0.905 0.864 sn:0.998 0.921 0.909 sp:0.994 0.752 0.545 acc:0.997 0.850 0.755 mcc:0.993 0.691 0.498
Epoch: 300 Task: 3 auc:1.000 0.991 0.979 sn:0.999 0.989 0.944 sp:0.999 0.972 0.929 acc:0.999 0.978 0.935 mcc:0.998 0.955 0.865
Epoch: 300 Task: 4 auc:0.999 0.742 0.729 sn:0.951 0.423 0.160 sp:0.998 0.979 0.970 acc:0.996 0.954 0.935 mcc:0.951 0.426 0.141
Epoch: 300 Task: 5 auc:1.000 0.797 0.863 sn:0.975 0.448 0.563 sp:0.997 0.975 0.922 acc:0.994 0.894 0.867 mcc:0.976 0.533 0.488
Epoch: 300 Task: 6 auc:1.000 0.882 0.888 sn:0.989 0.727 0.692 sp:0.999 0.960 0.910 acc:0.997 0.925 0.877 mcc:0.988 0.700 0.563
EarlyStopping counter: 236 out of 100
Epoch: 301 Task: 1 auc:1.000 0.881 0.894 sn:0.996 0.795 0.833 sp:0.994 0.879 0.844 acc:0.995 0.830 0.838 mcc:0.991 0.664 0.670
Epoch: 301 Task: 2 auc:1.000 0.911 0.867 sn:0.997 0.861 0.848 sp:0.995 0.

Epoch: 311 Task: 1 auc:0.996 0.870 0.884 sn:0.980 0.848 0.879 sp:0.992 0.780 0.678 acc:0.985 0.821 0.797 mcc:0.968 0.629 0.574
Epoch: 311 Task: 2 auc:1.000 0.926 0.872 sn:0.992 0.897 0.915 sp:0.996 0.818 0.645 acc:0.994 0.864 0.801 mcc:0.988 0.720 0.592
Epoch: 311 Task: 3 auc:1.000 0.993 0.981 sn:1.000 0.989 0.955 sp:0.999 0.972 0.922 acc:0.999 0.978 0.935 mcc:0.999 0.955 0.866
Epoch: 311 Task: 4 auc:1.000 0.757 0.689 sn:0.995 0.500 0.240 sp:0.999 0.970 0.931 acc:0.999 0.949 0.901 mcc:0.985 0.439 0.130
Epoch: 311 Task: 5 auc:0.997 0.792 0.834 sn:0.971 0.563 0.655 sp:0.995 0.921 0.826 acc:0.992 0.866 0.800 mcc:0.967 0.484 0.402
Epoch: 311 Task: 6 auc:0.999 0.863 0.873 sn:0.989 0.693 0.736 sp:0.999 0.962 0.890 acc:0.997 0.922 0.867 mcc:0.989 0.682 0.559
EarlyStopping counter: 247 out of 100
Epoch: 312 Task: 1 auc:1.000 0.877 0.884 sn:0.991 0.758 0.811 sp:0.996 0.912 0.833 acc:0.993 0.821 0.820 mcc:0.986 0.659 0.635
Epoch: 312 Task: 2 auc:1.000 0.905 0.861 sn:0.997 0.921 0.927 sp:0.994 0.

Epoch: 322 Task: 1 auc:1.000 0.889 0.897 sn:0.999 0.833 0.909 sp:0.999 0.835 0.700 acc:0.999 0.834 0.824 mcc:0.998 0.662 0.632
Epoch: 322 Task: 2 auc:1.000 0.916 0.862 sn:0.999 0.867 0.897 sp:0.997 0.835 0.645 acc:0.998 0.853 0.790 mcc:0.996 0.700 0.568
Epoch: 322 Task: 3 auc:1.000 0.992 0.984 sn:1.000 0.989 0.944 sp:0.999 0.972 0.936 acc:0.999 0.978 0.939 mcc:0.999 0.955 0.873
Epoch: 322 Task: 4 auc:1.000 0.699 0.670 sn:0.956 0.385 0.120 sp:1.000 0.989 0.975 acc:0.998 0.963 0.939 mcc:0.977 0.472 0.115
Epoch: 322 Task: 5 auc:0.999 0.801 0.845 sn:0.988 0.563 0.609 sp:0.997 0.946 0.864 acc:0.996 0.887 0.824 mcc:0.984 0.542 0.420
Epoch: 322 Task: 6 auc:1.000 0.871 0.869 sn:1.000 0.716 0.769 sp:0.999 0.952 0.888 acc:0.999 0.916 0.870 mcc:0.997 0.671 0.579
EarlyStopping counter: 258 out of 100
Epoch: 323 Task: 1 auc:1.000 0.891 0.891 sn:0.998 0.826 0.879 sp:0.997 0.835 0.767 acc:0.998 0.830 0.833 mcc:0.995 0.654 0.652
Epoch: 323 Task: 2 auc:1.000 0.912 0.855 sn:0.998 0.873 0.909 sp:0.999 0.

Epoch: 333 Task: 1 auc:0.999 0.882 0.895 sn:0.998 0.841 0.894 sp:0.985 0.813 0.711 acc:0.993 0.830 0.820 mcc:0.985 0.650 0.622
Epoch: 333 Task: 2 auc:0.999 0.901 0.845 sn:0.995 0.909 0.933 sp:0.986 0.744 0.455 acc:0.991 0.839 0.731 mcc:0.982 0.669 0.455
Epoch: 333 Task: 3 auc:1.000 0.994 0.986 sn:0.999 1.000 0.944 sp:0.999 0.965 0.936 acc:0.999 0.978 0.939 mcc:0.998 0.955 0.873
Epoch: 333 Task: 4 auc:1.000 0.809 0.690 sn:0.990 0.577 0.400 sp:0.998 0.938 0.831 acc:0.998 0.922 0.813 mcc:0.977 0.379 0.122
Epoch: 333 Task: 5 auc:0.999 0.801 0.845 sn:0.983 0.517 0.529 sp:0.997 0.958 0.901 acc:0.995 0.891 0.844 mcc:0.979 0.538 0.419
Epoch: 333 Task: 6 auc:1.000 0.874 0.901 sn:0.993 0.739 0.780 sp:0.998 0.940 0.890 acc:0.997 0.910 0.873 mcc:0.988 0.658 0.591
EarlyStopping counter: 269 out of 100
Epoch: 334 Task: 1 auc:1.000 0.891 0.901 sn:0.998 0.879 0.939 sp:1.000 0.802 0.578 acc:0.999 0.848 0.793 mcc:0.998 0.683 0.572
Epoch: 334 Task: 2 auc:1.000 0.912 0.850 sn:0.997 0.873 0.909 sp:0.997 0.

Epoch: 344 Task: 1 auc:1.000 0.880 0.875 sn:0.998 0.833 0.856 sp:1.000 0.846 0.733 acc:0.999 0.839 0.806 mcc:0.998 0.672 0.595
Epoch: 344 Task: 2 auc:1.000 0.923 0.881 sn:1.000 0.885 0.921 sp:0.996 0.826 0.686 acc:0.998 0.860 0.822 mcc:0.996 0.713 0.635
Epoch: 344 Task: 3 auc:1.000 0.993 0.980 sn:0.999 0.989 0.955 sp:0.999 0.958 0.908 acc:0.999 0.970 0.926 mcc:0.998 0.937 0.850
Epoch: 344 Task: 4 auc:1.000 0.709 0.746 sn:0.985 0.423 0.320 sp:0.999 0.971 0.899 acc:0.999 0.947 0.874 mcc:0.985 0.388 0.141
Epoch: 344 Task: 5 auc:0.999 0.810 0.810 sn:0.948 0.586 0.690 sp:0.998 0.908 0.769 acc:0.990 0.859 0.757 mcc:0.962 0.477 0.361
Epoch: 344 Task: 6 auc:0.999 0.879 0.883 sn:0.985 0.705 0.780 sp:0.999 0.958 0.874 acc:0.997 0.920 0.860 mcc:0.988 0.679 0.564
EarlyStopping counter: 280 out of 100
Epoch: 345 Task: 1 auc:0.999 0.865 0.878 sn:0.992 0.803 0.848 sp:0.992 0.857 0.733 acc:0.992 0.825 0.802 mcc:0.984 0.650 0.586
Epoch: 345 Task: 2 auc:0.999 0.906 0.850 sn:0.994 0.933 0.952 sp:0.970 0.

Epoch: 356 Task: 1 auc:1.000 0.868 0.875 sn:0.994 0.788 0.917 sp:0.997 0.835 0.689 acc:0.995 0.807 0.824 mcc:0.991 0.614 0.633
Epoch: 356 Task: 2 auc:1.000 0.907 0.857 sn:0.998 0.921 0.939 sp:0.994 0.744 0.496 acc:0.996 0.846 0.752 mcc:0.992 0.684 0.500
Epoch: 356 Task: 3 auc:1.000 0.992 0.978 sn:1.000 1.000 0.966 sp:0.997 0.965 0.922 acc:0.998 0.978 0.939 mcc:0.997 0.955 0.876
Epoch: 356 Task: 4 auc:1.000 0.635 0.601 sn:0.976 0.500 0.440 sp:1.000 0.952 0.842 acc:0.999 0.932 0.825 mcc:0.985 0.369 0.151
Epoch: 356 Task: 5 auc:0.998 0.814 0.844 sn:0.986 0.540 0.690 sp:0.998 0.954 0.876 acc:0.996 0.891 0.848 mcc:0.985 0.545 0.501
Epoch: 356 Task: 6 auc:0.999 0.853 0.863 sn:0.987 0.670 0.670 sp:0.999 0.968 0.952 acc:0.997 0.923 0.909 mcc:0.988 0.683 0.640
EarlyStopping counter: 292 out of 100
Epoch: 357 Task: 1 auc:1.000 0.870 0.880 sn:0.996 0.803 0.894 sp:1.000 0.824 0.711 acc:0.998 0.812 0.820 mcc:0.995 0.619 0.622
Epoch: 357 Task: 2 auc:1.000 0.918 0.877 sn:0.986 0.909 0.915 sp:0.998 0.

Epoch: 367 Task: 1 auc:1.000 0.894 0.891 sn:1.000 0.833 0.848 sp:0.994 0.868 0.789 acc:0.998 0.848 0.824 mcc:0.995 0.693 0.636
Epoch: 367 Task: 2 auc:1.000 0.917 0.874 sn:0.997 0.891 0.952 sp:0.997 0.785 0.529 acc:0.997 0.846 0.773 mcc:0.994 0.683 0.547
Epoch: 367 Task: 3 auc:1.000 0.992 0.971 sn:0.999 1.000 0.944 sp:0.999 0.965 0.901 acc:0.999 0.978 0.917 mcc:0.998 0.955 0.832
Epoch: 367 Task: 4 auc:1.000 0.687 0.609 sn:0.971 0.423 0.200 sp:0.999 0.979 0.938 acc:0.997 0.954 0.906 mcc:0.969 0.426 0.110
Epoch: 367 Task: 5 auc:1.000 0.812 0.847 sn:0.991 0.471 0.632 sp:0.998 0.965 0.885 acc:0.997 0.889 0.846 mcc:0.989 0.518 0.471
Epoch: 367 Task: 6 auc:0.999 0.859 0.860 sn:0.986 0.705 0.681 sp:0.999 0.950 0.890 acc:0.997 0.913 0.858 mcc:0.988 0.657 0.518
EarlyStopping counter: 303 out of 100
Epoch: 368 Task: 1 auc:1.000 0.871 0.888 sn:0.999 0.811 0.841 sp:0.997 0.846 0.756 acc:0.998 0.825 0.806 mcc:0.996 0.648 0.598
Epoch: 368 Task: 2 auc:0.999 0.919 0.874 sn:0.995 0.915 0.927 sp:0.999 0.

Epoch: 378 Task: 1 auc:1.000 0.878 0.880 sn:0.999 0.894 0.962 sp:0.987 0.703 0.589 acc:0.994 0.816 0.811 mcc:0.988 0.615 0.616
Epoch: 378 Task: 2 auc:1.000 0.907 0.866 sn:1.000 0.933 0.958 sp:0.998 0.744 0.479 acc:0.999 0.853 0.755 mcc:0.998 0.700 0.515
Epoch: 378 Task: 3 auc:1.000 0.989 0.975 sn:1.000 0.989 0.944 sp:1.000 0.951 0.936 acc:1.000 0.965 0.939 mcc:1.000 0.929 0.873
Epoch: 378 Task: 4 auc:1.000 0.705 0.695 sn:0.966 0.423 0.240 sp:1.000 0.982 0.956 acc:0.998 0.957 0.925 mcc:0.977 0.449 0.177
Epoch: 378 Task: 5 auc:0.999 0.808 0.824 sn:0.980 0.506 0.690 sp:0.993 0.942 0.832 acc:0.991 0.875 0.810 mcc:0.967 0.484 0.436
Epoch: 378 Task: 6 auc:0.998 0.871 0.880 sn:0.986 0.739 0.736 sp:0.999 0.954 0.884 acc:0.997 0.922 0.861 mcc:0.988 0.692 0.548
EarlyStopping counter: 314 out of 100
Epoch: 379 Task: 1 auc:1.000 0.882 0.885 sn:1.000 0.856 0.924 sp:0.997 0.802 0.689 acc:0.999 0.834 0.829 mcc:0.998 0.657 0.643
Epoch: 379 Task: 2 auc:1.000 0.895 0.860 sn:0.997 0.897 0.933 sp:0.994 0.

Epoch: 390 Task: 1 auc:1.000 0.875 0.896 sn:0.999 0.795 0.886 sp:0.997 0.824 0.756 acc:0.998 0.807 0.833 mcc:0.996 0.611 0.651
Epoch: 390 Task: 2 auc:1.000 0.908 0.863 sn:1.000 0.903 0.909 sp:0.991 0.736 0.496 acc:0.996 0.832 0.734 mcc:0.992 0.654 0.455
Epoch: 390 Task: 3 auc:1.000 0.991 0.978 sn:0.999 0.989 0.933 sp:1.000 0.965 0.915 acc:0.999 0.974 0.922 mcc:0.999 0.946 0.838
Epoch: 390 Task: 4 auc:1.000 0.772 0.750 sn:0.995 0.462 0.240 sp:0.999 0.973 0.934 acc:0.999 0.951 0.905 mcc:0.987 0.427 0.135
Epoch: 390 Task: 5 auc:0.999 0.812 0.852 sn:0.984 0.552 0.736 sp:0.998 0.925 0.820 acc:0.996 0.868 0.807 mcc:0.983 0.484 0.454
Epoch: 390 Task: 6 auc:1.000 0.870 0.863 sn:0.999 0.705 0.714 sp:0.999 0.960 0.904 acc:0.999 0.922 0.875 mcc:0.996 0.684 0.568
EarlyStopping counter: 326 out of 100
Epoch: 391 Task: 1 auc:1.000 0.887 0.895 sn:0.999 0.841 0.864 sp:0.996 0.846 0.778 acc:0.998 0.843 0.829 mcc:0.995 0.680 0.644
Epoch: 391 Task: 2 auc:1.000 0.915 0.870 sn:0.999 0.903 0.909 sp:0.984 0.

Epoch: 401 Task: 1 auc:1.000 0.888 0.913 sn:1.000 0.864 0.917 sp:0.999 0.813 0.667 acc:0.999 0.843 0.815 mcc:0.999 0.676 0.614
Epoch: 401 Task: 2 auc:1.000 0.899 0.853 sn:0.998 0.867 0.879 sp:0.997 0.818 0.620 acc:0.997 0.846 0.769 mcc:0.995 0.685 0.523
Epoch: 401 Task: 3 auc:1.000 0.991 0.990 sn:1.000 0.989 0.955 sp:0.999 0.958 0.936 acc:0.999 0.970 0.943 mcc:0.999 0.937 0.883
Epoch: 401 Task: 4 auc:1.000 0.741 0.740 sn:0.990 0.500 0.360 sp:1.000 0.955 0.883 acc:0.999 0.935 0.861 mcc:0.990 0.381 0.147
Epoch: 401 Task: 5 auc:1.000 0.806 0.830 sn:0.996 0.540 0.644 sp:0.998 0.956 0.866 acc:0.998 0.892 0.832 mcc:0.991 0.551 0.450
Epoch: 401 Task: 6 auc:1.000 0.881 0.890 sn:0.985 0.705 0.769 sp:0.998 0.962 0.872 acc:0.996 0.923 0.856 mcc:0.984 0.690 0.553
EarlyStopping counter: 337 out of 100
Epoch: 402 Task: 1 auc:1.000 0.875 0.886 sn:0.993 0.803 0.864 sp:1.000 0.835 0.789 acc:0.996 0.816 0.833 mcc:0.992 0.630 0.654
Epoch: 402 Task: 2 auc:1.000 0.911 0.860 sn:0.989 0.903 0.915 sp:1.000 0.

Epoch: 412 Task: 1 auc:1.000 0.884 0.891 sn:0.997 0.833 0.924 sp:1.000 0.846 0.656 acc:0.998 0.839 0.815 mcc:0.996 0.672 0.615
Epoch: 412 Task: 2 auc:1.000 0.902 0.878 sn:0.998 0.842 0.861 sp:0.998 0.835 0.769 acc:0.998 0.839 0.822 mcc:0.996 0.673 0.633
Epoch: 412 Task: 3 auc:1.000 0.995 0.985 sn:1.000 1.000 0.966 sp:0.997 0.951 0.908 acc:0.998 0.970 0.930 mcc:0.997 0.938 0.860
Epoch: 412 Task: 4 auc:0.999 0.700 0.618 sn:0.966 0.423 0.280 sp:1.000 0.966 0.876 acc:0.998 0.942 0.850 mcc:0.979 0.364 0.093
Epoch: 412 Task: 5 auc:1.000 0.829 0.866 sn:0.993 0.540 0.713 sp:0.999 0.954 0.872 acc:0.998 0.891 0.848 mcc:0.994 0.545 0.511
Epoch: 412 Task: 6 auc:1.000 0.881 0.891 sn:0.999 0.784 0.868 sp:0.999 0.908 0.792 acc:0.999 0.889 0.804 mcc:0.994 0.622 0.516
EarlyStopping counter: 348 out of 100
Epoch: 413 Task: 1 auc:1.000 0.886 0.878 sn:0.998 0.818 0.833 sp:0.997 0.857 0.722 acc:0.998 0.834 0.788 mcc:0.995 0.666 0.559
Epoch: 413 Task: 2 auc:1.000 0.909 0.869 sn:0.998 0.885 0.915 sp:0.995 0.

Epoch: 423 Task: 1 auc:1.000 0.862 0.881 sn:0.999 0.803 0.879 sp:0.986 0.813 0.700 acc:0.994 0.807 0.806 mcc:0.987 0.609 0.594
Epoch: 423 Task: 2 auc:1.000 0.908 0.866 sn:0.996 0.873 0.861 sp:0.994 0.851 0.669 acc:0.995 0.864 0.780 mcc:0.990 0.722 0.544
Epoch: 423 Task: 3 auc:0.999 0.992 0.984 sn:0.993 0.989 0.989 sp:1.000 0.937 0.865 acc:0.997 0.957 0.913 mcc:0.994 0.912 0.834
Epoch: 423 Task: 4 auc:1.000 0.742 0.712 sn:0.980 0.423 0.320 sp:0.998 0.973 0.924 acc:0.998 0.949 0.898 mcc:0.972 0.396 0.175
Epoch: 423 Task: 5 auc:1.000 0.808 0.855 sn:0.987 0.517 0.655 sp:0.998 0.965 0.881 acc:0.997 0.896 0.846 mcc:0.987 0.556 0.482
Epoch: 423 Task: 6 auc:0.997 0.834 0.863 sn:0.980 0.670 0.747 sp:0.999 0.948 0.876 acc:0.996 0.906 0.856 mcc:0.986 0.627 0.543
EarlyStopping counter: 359 out of 100
Epoch: 424 Task: 1 auc:1.000 0.864 0.897 sn:0.999 0.856 0.932 sp:0.994 0.769 0.578 acc:0.997 0.821 0.788 mcc:0.994 0.628 0.560
Epoch: 424 Task: 2 auc:1.000 0.916 0.876 sn:0.999 0.915 0.958 sp:0.997 0.

Epoch: 434 Task: 1 auc:1.000 0.874 0.899 sn:0.998 0.811 0.909 sp:0.996 0.824 0.756 acc:0.997 0.816 0.847 mcc:0.994 0.627 0.680
Epoch: 434 Task: 2 auc:1.000 0.910 0.855 sn:0.995 0.867 0.903 sp:0.997 0.835 0.653 acc:0.996 0.853 0.797 mcc:0.992 0.700 0.583
Epoch: 434 Task: 3 auc:1.000 0.995 0.975 sn:1.000 1.000 0.955 sp:0.999 0.965 0.908 acc:0.999 0.978 0.926 mcc:0.999 0.955 0.850
Epoch: 434 Task: 4 auc:0.998 0.668 0.691 sn:0.980 0.385 0.080 sp:1.000 0.980 0.943 acc:0.999 0.954 0.906 mcc:0.987 0.404 0.020
Epoch: 434 Task: 5 auc:1.000 0.798 0.853 sn:0.993 0.517 0.713 sp:0.995 0.942 0.847 acc:0.994 0.877 0.826 mcc:0.978 0.494 0.474
Epoch: 434 Task: 6 auc:1.000 0.888 0.880 sn:0.999 0.727 0.769 sp:0.999 0.944 0.880 acc:0.999 0.911 0.863 mcc:0.996 0.659 0.566
EarlyStopping counter: 370 out of 100
Epoch: 435 Task: 1 auc:1.000 0.878 0.893 sn:1.000 0.833 0.917 sp:0.999 0.824 0.700 acc:0.999 0.830 0.829 mcc:0.999 0.652 0.642
Epoch: 435 Task: 2 auc:1.000 0.913 0.861 sn:0.995 0.909 0.952 sp:1.000 0.

Epoch: 445 Task: 1 auc:1.000 0.875 0.904 sn:1.000 0.826 0.871 sp:0.999 0.879 0.789 acc:0.999 0.848 0.838 mcc:0.999 0.695 0.663
Epoch: 445 Task: 2 auc:1.000 0.913 0.878 sn:0.997 0.836 0.848 sp:0.999 0.860 0.777 acc:0.998 0.846 0.818 mcc:0.996 0.690 0.627
Epoch: 445 Task: 3 auc:1.000 0.992 0.980 sn:1.000 1.000 0.888 sp:0.999 0.979 0.943 acc:0.999 0.987 0.922 mcc:0.999 0.973 0.835
Epoch: 445 Task: 4 auc:1.000 0.658 0.627 sn:0.990 0.385 0.120 sp:1.000 0.986 0.925 acc:1.000 0.959 0.891 mcc:0.995 0.442 0.034
Epoch: 445 Task: 5 auc:1.000 0.783 0.827 sn:0.993 0.540 0.724 sp:0.998 0.925 0.784 acc:0.997 0.866 0.775 mcc:0.990 0.474 0.403
Epoch: 445 Task: 6 auc:1.000 0.860 0.879 sn:0.993 0.739 0.824 sp:1.000 0.916 0.814 acc:0.999 0.889 0.816 mcc:0.996 0.605 0.511
EarlyStopping counter: 381 out of 100
Epoch: 446 Task: 1 auc:1.000 0.868 0.896 sn:0.999 0.871 0.939 sp:0.994 0.780 0.611 acc:0.997 0.834 0.806 mcc:0.994 0.655 0.600
Epoch: 446 Task: 2 auc:1.000 0.910 0.873 sn:0.999 0.897 0.915 sp:0.995 0.

Epoch: 457 Task: 1 auc:0.999 0.869 0.904 sn:1.000 0.826 0.909 sp:0.994 0.846 0.767 acc:0.998 0.834 0.851 mcc:0.995 0.664 0.689
Epoch: 457 Task: 2 auc:1.000 0.916 0.859 sn:0.998 0.903 0.909 sp:1.000 0.769 0.545 acc:0.999 0.846 0.755 mcc:0.997 0.683 0.498
Epoch: 457 Task: 3 auc:1.000 0.989 0.980 sn:1.000 1.000 0.978 sp:1.000 0.944 0.865 acc:1.000 0.965 0.909 mcc:1.000 0.930 0.823
Epoch: 457 Task: 4 auc:1.000 0.717 0.778 sn:0.990 0.385 0.200 sp:0.999 0.975 0.929 acc:0.999 0.949 0.898 mcc:0.987 0.374 0.098
Epoch: 457 Task: 5 auc:1.000 0.817 0.822 sn:0.997 0.678 0.816 sp:0.998 0.863 0.696 acc:0.998 0.834 0.715 mcc:0.992 0.470 0.381
Epoch: 457 Task: 6 auc:0.999 0.858 0.887 sn:0.992 0.716 0.758 sp:0.998 0.950 0.868 acc:0.997 0.915 0.851 mcc:0.989 0.666 0.539
EarlyStopping counter: 393 out of 100
Epoch: 458 Task: 1 auc:1.000 0.868 0.887 sn:0.998 0.818 0.871 sp:1.000 0.824 0.767 acc:0.999 0.821 0.829 mcc:0.998 0.635 0.643
Epoch: 458 Task: 2 auc:1.000 0.909 0.846 sn:0.997 0.897 0.915 sp:0.996 0.

Epoch: 468 Task: 1 auc:1.000 0.870 0.885 sn:0.999 0.803 0.902 sp:0.997 0.857 0.733 acc:0.998 0.825 0.833 mcc:0.996 0.650 0.651
Epoch: 468 Task: 2 auc:1.000 0.899 0.860 sn:0.998 0.885 0.909 sp:1.000 0.802 0.595 acc:0.999 0.850 0.776 mcc:0.997 0.691 0.541
Epoch: 468 Task: 3 auc:1.000 0.994 0.975 sn:1.000 1.000 0.899 sp:0.999 0.958 0.922 acc:0.999 0.974 0.913 mcc:0.999 0.947 0.818
Epoch: 468 Task: 4 auc:0.999 0.611 0.736 sn:0.966 0.500 0.600 sp:0.998 0.932 0.783 acc:0.996 0.913 0.776 mcc:0.955 0.316 0.183
Epoch: 468 Task: 5 auc:1.000 0.797 0.832 sn:0.990 0.598 0.690 sp:0.998 0.923 0.792 acc:0.997 0.873 0.777 mcc:0.988 0.516 0.387
Epoch: 468 Task: 6 auc:1.000 0.880 0.880 sn:0.996 0.670 0.725 sp:0.999 0.964 0.896 acc:0.998 0.920 0.870 mcc:0.993 0.671 0.561
EarlyStopping counter: 404 out of 100
Epoch: 469 Task: 1 auc:0.999 0.865 0.871 sn:0.995 0.803 0.833 sp:0.999 0.890 0.789 acc:0.997 0.839 0.815 mcc:0.993 0.682 0.619
Epoch: 469 Task: 2 auc:0.999 0.897 0.852 sn:0.988 0.842 0.873 sp:0.998 0.

Epoch: 479 Task: 1 auc:1.000 0.868 0.878 sn:0.999 0.780 0.803 sp:1.000 0.901 0.789 acc:0.999 0.830 0.797 mcc:0.999 0.670 0.586
Epoch: 479 Task: 2 auc:1.000 0.914 0.862 sn:0.998 0.885 0.915 sp:1.000 0.802 0.570 acc:0.999 0.850 0.769 mcc:0.998 0.691 0.528
Epoch: 479 Task: 3 auc:1.000 0.991 0.979 sn:0.999 0.989 0.910 sp:0.999 0.958 0.929 acc:0.999 0.970 0.922 mcc:0.998 0.937 0.836
Epoch: 479 Task: 4 auc:1.000 0.728 0.714 sn:0.990 0.346 0.240 sp:1.000 0.979 0.952 acc:0.999 0.951 0.922 mcc:0.992 0.360 0.168
Epoch: 479 Task: 5 auc:1.000 0.810 0.841 sn:0.997 0.632 0.782 sp:0.993 0.908 0.746 acc:0.994 0.866 0.752 mcc:0.977 0.513 0.404
Epoch: 479 Task: 6 auc:1.000 0.867 0.894 sn:0.999 0.705 0.714 sp:0.999 0.958 0.928 acc:0.999 0.920 0.895 mcc:0.995 0.679 0.616
EarlyStopping counter: 415 out of 100
Epoch: 480 Task: 1 auc:1.000 0.874 0.881 sn:0.995 0.795 0.833 sp:0.996 0.857 0.756 acc:0.995 0.821 0.802 mcc:0.991 0.643 0.589
Epoch: 480 Task: 2 auc:1.000 0.911 0.846 sn:0.992 0.903 0.939 sp:0.996 0.

Epoch: 490 Task: 1 auc:0.999 0.879 0.877 sn:0.994 0.803 0.886 sp:0.994 0.846 0.722 acc:0.994 0.821 0.820 mcc:0.988 0.640 0.622
Epoch: 490 Task: 2 auc:1.000 0.919 0.875 sn:0.994 0.915 0.933 sp:0.994 0.802 0.595 acc:0.994 0.867 0.790 mcc:0.988 0.727 0.575
Epoch: 490 Task: 3 auc:1.000 0.989 0.979 sn:0.999 0.989 0.955 sp:0.996 0.965 0.908 acc:0.997 0.974 0.926 mcc:0.994 0.946 0.850
Epoch: 490 Task: 4 auc:1.000 0.730 0.752 sn:0.985 0.423 0.280 sp:0.999 0.966 0.899 acc:0.998 0.942 0.872 mcc:0.980 0.364 0.116
Epoch: 490 Task: 5 auc:0.999 0.793 0.848 sn:0.975 0.609 0.713 sp:0.996 0.927 0.822 acc:0.992 0.878 0.805 mcc:0.971 0.534 0.440
Epoch: 490 Task: 6 auc:0.999 0.848 0.879 sn:0.976 0.602 0.626 sp:0.997 0.966 0.932 acc:0.993 0.911 0.885 mcc:0.974 0.626 0.559
EarlyStopping counter: 426 out of 100
Epoch: 491 Task: 1 auc:1.000 0.878 0.875 sn:0.999 0.803 0.864 sp:0.999 0.824 0.667 acc:0.999 0.812 0.784 mcc:0.998 0.619 0.545
Epoch: 491 Task: 2 auc:1.000 0.916 0.871 sn:0.997 0.903 0.939 sp:0.998 0.

Epoch: 501 Task: 1 auc:0.999 0.872 0.851 sn:0.980 0.864 0.917 sp:0.996 0.780 0.533 acc:0.986 0.830 0.761 mcc:0.972 0.646 0.500
Epoch: 501 Task: 2 auc:0.999 0.903 0.861 sn:0.998 0.927 0.970 sp:0.983 0.769 0.438 acc:0.991 0.860 0.745 mcc:0.982 0.713 0.501
Epoch: 501 Task: 3 auc:1.000 0.989 0.979 sn:1.000 0.989 0.955 sp:0.999 0.958 0.915 acc:0.999 0.970 0.930 mcc:0.999 0.937 0.858
Epoch: 501 Task: 4 auc:0.995 0.619 0.496 sn:0.888 0.346 0.080 sp:0.999 0.980 0.947 acc:0.994 0.952 0.910 mcc:0.927 0.370 0.024
Epoch: 501 Task: 5 auc:0.999 0.808 0.850 sn:0.961 0.517 0.598 sp:0.992 0.948 0.868 acc:0.987 0.882 0.826 mcc:0.951 0.510 0.417
Epoch: 501 Task: 6 auc:1.000 0.857 0.888 sn:0.962 0.716 0.780 sp:0.998 0.950 0.884 acc:0.992 0.915 0.868 mcc:0.970 0.666 0.581
EarlyStopping counter: 437 out of 100
Epoch: 502 Task: 1 auc:1.000 0.876 0.904 sn:0.999 0.848 0.939 sp:0.993 0.824 0.711 acc:0.997 0.839 0.847 mcc:0.993 0.669 0.682
Epoch: 502 Task: 2 auc:1.000 0.917 0.885 sn:0.996 0.885 0.933 sp:0.999 0.

Epoch: 512 Task: 1 auc:1.000 0.875 0.900 sn:0.995 0.841 0.902 sp:1.000 0.846 0.667 acc:0.997 0.843 0.806 mcc:0.994 0.680 0.594
Epoch: 512 Task: 2 auc:1.000 0.903 0.854 sn:0.996 0.897 0.915 sp:0.992 0.851 0.612 acc:0.994 0.878 0.787 mcc:0.988 0.749 0.564
Epoch: 512 Task: 3 auc:1.000 0.991 0.974 sn:1.000 0.989 0.955 sp:1.000 0.944 0.887 acc:1.000 0.961 0.913 mcc:1.000 0.920 0.826
Epoch: 512 Task: 4 auc:1.000 0.721 0.674 sn:0.961 0.423 0.200 sp:1.000 0.973 0.938 acc:0.998 0.949 0.906 mcc:0.979 0.396 0.110
Epoch: 512 Task: 5 auc:0.996 0.785 0.809 sn:0.970 0.586 0.644 sp:0.997 0.912 0.788 acc:0.993 0.862 0.766 mcc:0.973 0.485 0.348
Epoch: 512 Task: 6 auc:0.998 0.860 0.865 sn:0.978 0.682 0.736 sp:0.998 0.946 0.888 acc:0.995 0.906 0.865 mcc:0.978 0.631 0.555
EarlyStopping counter: 448 out of 100
Epoch: 513 Task: 1 auc:1.000 0.876 0.895 sn:0.999 0.818 0.894 sp:1.000 0.857 0.711 acc:0.999 0.834 0.820 mcc:0.999 0.666 0.622
Epoch: 513 Task: 2 auc:1.000 0.917 0.859 sn:0.998 0.915 0.952 sp:0.999 0.

Epoch: 523 Task: 1 auc:1.000 0.876 0.882 sn:0.998 0.864 0.917 sp:1.000 0.824 0.600 acc:0.999 0.848 0.788 mcc:0.998 0.686 0.557
Epoch: 523 Task: 2 auc:1.000 0.910 0.862 sn:0.999 0.903 0.927 sp:0.994 0.793 0.496 acc:0.997 0.857 0.745 mcc:0.994 0.705 0.482
Epoch: 523 Task: 3 auc:1.000 0.990 0.981 sn:1.000 0.977 0.966 sp:0.998 0.951 0.936 acc:0.999 0.961 0.948 mcc:0.998 0.919 0.893
Epoch: 523 Task: 4 auc:1.000 0.718 0.681 sn:0.980 0.423 0.160 sp:1.000 0.966 0.906 acc:0.999 0.942 0.874 mcc:0.987 0.364 0.045
Epoch: 523 Task: 5 auc:1.000 0.810 0.840 sn:0.990 0.517 0.586 sp:0.999 0.952 0.889 acc:0.998 0.885 0.842 mcc:0.991 0.521 0.442
Epoch: 523 Task: 6 auc:1.000 0.834 0.877 sn:0.982 0.625 0.604 sp:1.000 0.962 0.936 acc:0.997 0.911 0.885 mcc:0.988 0.631 0.551
EarlyStopping counter: 459 out of 100
Epoch: 524 Task: 1 auc:1.000 0.867 0.853 sn:0.992 0.879 0.924 sp:0.997 0.725 0.489 acc:0.994 0.816 0.748 mcc:0.988 0.616 0.473
Epoch: 524 Task: 2 auc:0.998 0.913 0.883 sn:0.998 0.885 0.952 sp:0.985 0.

In [None]:
stopper.load_checkpoint(model)
stopper_afse.load_checkpoint(amodel)
stopper_generate.load_checkpoint(gmodel)
    
test_auc, test_sn, test_sp, test_acc, test_mcc, test_predict = eval(model, amodel, gmodel, test_df)

In [None]:
for i in range(task_num):
    print('Epoch:',epoch, 'Task:', i+1,
      'auc:%.3f'%train_auc[i],'%.3f'%val_auc[i],'%.3f'%test_auc[i], 
      'sn:%.3f'%train_sn[i],'%.3f'%val_sn[i],'%.3f'%test_sn[i], 
      'sp:%.3f'%train_sp[i], '%.3f'%val_sp[i], '%.3f'%test_sp[i], 
      'acc:%.3f'%train_acc[i], '%.3f'%val_acc[i], '%.3f'%test_acc[i], 
      'mcc:%.3f'%train_mcc[i],'%.3f'%val_mcc[i],'%.3f'%test_mcc[i])

In [None]:
# print('target_file:',raw_filename[0])
# print('inactive_file:',test_filename)
# np.savez(result_dir, epoch_list, train_f_list, train_d_list, 
#          train_predict_list, train_y_list, val_f_list, val_d_list, val_predict_list, val_y_list, test_f_list, 
#          test_d_list, test_predict_list, test_y_list)
# sim_space = np.load(result_dir+'.npz')
# print(sim_space['arr_10'].shape)

In [None]:
# Task-specific AFSE
# Dynamic cth
# loss =  regression_loss + 0.08 * (vat_loss + test_vat_loss)
# r=3 t=2 200 100 100 1 all_lr=3e-4