In [None]:
%run Include.ipynb
%run Net.ipynb
%run Data.ipynb
%run viewer.ipynb
%run Medical_Utility.ipynb
%run Medical_IO.ipynb

from torch.autograd import Variable
from sklearn.metrics import roc_auc_score
from medcam import medcam

class CNN_Attention(object):
    
    def __init__(self, general, arch_list):
        
        lr        = general["learning_rate"]
        beta1     = general["beta1"]
        beta2     = general["beta2"]
        loss_mode = general["loss"]
        reduction = general["reduction"]
        gamma     = general['gamma']
        alpha     = general['alpha']
        attention = general['attention']
        theta     = general['theta']

        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
        self.Attention = attention
        

        assert((self.Attention != 0 or theta == 0))
        if self.Attention == 2:
            self.input_dims_arch1, layers_arch1 = Net.parse_layers(arch_list[0])
            self.input_dims_arch2, layers_arch2 = Net.parse_layers(arch_list[1])
            self.input_dims_arch3, layers_arch3 = Net.parse_layers(arch_list[2])
            self.net_s0 = Network_template(gpu_num, layers_arch1).to(self.device)
            self.net_s1 = Network_template(gpu_num, layers_arch2, self.Attention).to(self.device)
            self.net_s2 = Network_template(gpu_num, layers_arch2, self.Attention).to(self.device)
            self.net_s3 = Network_template(gpu_num, layers_arch3).to(self.device)
            Net.init_weights(self.net_s0, "normal")
            Net.init_weights(self.net_s1, "normal")
            Net.init_weights(self.net_s2, "normal")
            Net.init_weights(self.net_s3, "normal")
            network_params =list(self.net_s0.parameters())+list(self.net_s1.parameters())+list(self.net_s2.parameters())+list(self.net_s3.parameters())
        elif self.Attention == 4 or self.Attention == 5 or self.Attention == 6:
            self.input_dims_arch1, layers_arch1 = Net.parse_layers(arch_list[0])
            self.input_dims_arch2, layers_arch2 = Net.parse_layers(arch_list[1])
            self.net_s1 = Network_template(gpu_num, layers_arch1, self.Attention).to(self.device)
            self.net_s3 = Network_template(gpu_num, layers_arch2).to(self.device)
            Net.init_weights(self.net_s1, "normal")
            Net.init_weights(self.net_s3, "normal")
            network_params = list(self.net_s1.parameters())+list(self.net_s3.parameters())
        else:
            self.input_dims_arch1, layers_arch1 = Net.parse_layers(arch_list[0])
            self.input_dims_arch2, layers_arch2 = Net.parse_layers(arch_list[1])
            self.net_s1 = Network_template(gpu_num, layers_arch1, self.Attention).to(self.device)
            self.net_s2 = Network_template(gpu_num, layers_arch1, self.Attention).to(self.device)
            self.net_s3 = Network_template(gpu_num, layers_arch2).to(self.device)
            Net.init_weights(self.net_s1, "normal")
            Net.init_weights(self.net_s2, "normal")
            Net.init_weights(self.net_s3, "normal")
            network_params = list(self.net_s1.parameters())+list(self.net_s2.parameters())+list(self.net_s3.parameters())
        
        self.criterion = StandardLoss(loss_mode, reduction, gamma, alpha, theta).to(self.device)
        # network_params = list(self.net_s1.parameters())+list(self.net_s2.parameters())+list(self.net_s3.parameters())
        self.optimizer = optim.Adam(network_params, lr=lr, betas=(beta1,beta2))
        # self.optimizer = optim.SGD(network_params, lr=lr, momentum=0.9, weight_decay=1e-2)
        
    def optimize_step(self, Dtrain, labels):
        self.net_s1.zero_grad()
        if self.Attention != 5 and self.Attention != 6:
            self.net_s2.zero_grad()
        self.net_s3.zero_grad()
        loss = self.criterion([self.net_s3, Dtrain, labels])
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def optimize_step_attention(self, Dtrain, labels, msk1, msk2, vm1,vm2):
        if self.Attention == 2:
            self.net_s0.zero_grad()
        self.net_s1.zero_grad()
        self.net_s2.zero_grad()
        self.net_s3.zero_grad()
        loss = self.criterion([self.net_s3, Dtrain, labels, msk1, msk2, vm1, vm2])
        loss.backward()
        self.optimizer.step() 
        return loss.item()
    
    
    def optimize_step_attention_multi(self, Dtrain, labels, msk11,msk12,mask13,mask14,maks21,mask22,mask23,mask24,vm1,vm2):
        self.net_s1.zero_grad()
        self.net_s2.zero_grad()
        self.net_s3.zero_grad()
        loss = self.criterion([self.net_s3, Dtrain, labels, msk11,msk12,mask13,mask14,maks21,mask22,mask23,mask24,vm1,vm2])
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def optimize_step_attention_onestream(self, Dtrain, labels, msk, vm):
        self.net_s1.zero_grad()
        self.net_s3.zero_grad()
        loss = self.criterion([self.net_s3, Dtrain, labels, msk, vm])
        loss.backward()
        self.optimizer.step() 
        return loss.item()
        
    def train(self, data_params, branch_name="Undefined Here"):
        
        epochs           = data_params["epochs"]
        batch_size       = data_params["batch_size"]
        batch_workers    = data_params["batch_workers"]
        shuffle          = data_params["shuffle"]
        drop_last        = data_params["drop_last"]
        transform        = data_params['transform']
        binary_mask      = data_params['binary_mask']
        datasplit_scheme = data_params["datasplit_scheme"]
        test_split       = data_params["test_split"]
        xfold            = data_params["xfold"]
        fold_idx         = data_params["fold_idx"]
        random_seed      = data_params["random_seed"]
    
        if transform:
            train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            train_loader3, test_loader3 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.ori_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        else:
            train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            train_loader3, test_loader3 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.ori_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        log = open(FLAGS.log_path, "a")
        log.write('Branch: %s  Fold ID: %d\n\n' % (branch_name, fold_idx))
        log.flush()
        
        step = 0
        if FLAGS.continue_model:
            if self.Attention == 2:
                self.net_s0.load_state_dict(torch.load('%s/net_s0_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
                self.net_s2.load_state_dict(torch.load('%s/net_s2_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.net_s1.load_state_dict(torch.load('%s/net_s1_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.net_s3.load_state_dict(torch.load('%s/net_s3_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            step = FLAGS.model_step + 1
            
        lrec = []
        best_f1 = -1.0
        best_accuracy = 0.0
        best_step = 0
        stablize_step = FLAGS.save_step
        for epoch in range(epochs):
            for i, data in enumerate(zip(train_loader1,train_loader2,train_loader3), 0):
                #################set train
                if self.Attention == 2:
                    self.net_s0.train()
                if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
                    self.net_s2.train()
                self.net_s1.train()
                self.net_s3.train()
                ####################
                ##############################################################data load
                if transform:
                    vol1 = data[0][0].to(self.device)
                    labels1 = data[0][1].to(self.device)
                    vol2 = data[1][0].to(self.device)
                    vol0 = data[2][0].to(self.device)
                else:
                    vol1 = data[0]['vol'].unsqueeze(1).to(self.device)
                    labels1 = data[0]['label'].to(self.device)
                    vol2 = data[1]['vol'].unsqueeze(1).to(self.device)
                    vol0 = data[2]['vol'].unsqueeze(1).to(self.device)
                #########################################binary
                if binary_mask:
                    vol1[vol1>vol1.min()] = 1.0
                    vol1[vol1==vol1.min()] = 0.0
                    vol2[vol2>vol2.min()] = 1.0
                    vol2[vol2==vol2.min()] = 0.0
                ###############################################
                ###############################################################
                if labels1.shape[0] == 1:
                    continue
                ##############################################################
                if self.Attention == 1:
                    out1,mask1 = self.net_s1(vol0)
                    out2,mask2 = self.net_s2(vol0)
                    out = torch.cat((out1, out2), dim=1)

                    lrec.append(self.optimize_step_attention(out,labels1,mask1,mask2,vol1,vol2))
                #####################################################################
                elif self.Attention == 2:
                    out0 = self.net_s0(vol0)
                    out1,mask1 = self.net_s1(out0)
                    out2,mask2 = self.net_s2(out0)
                    out = torch.cat((out1, out2), dim=1)
                    lrec.append(self.optimize_step_attention(out,labels1,mask1,mask2,vol1,vol2))
                ############################################################
                elif self.Attention == 3:
                    out1,mask11,mask12,mask13,mask14 = self.net_s1(vol0)
                    out2,mask21,mask22,mask23,mask24 = self.net_s2(vol0)
                    out = torch.cat((out1, out2), dim=1)
                    lrec.append(self.optimize_step_attention_multi(out,labels1,mask11,mask12,mask13,mask14,mask21,mask22,mask23,mask24,vol1,vol2))
                #####################################################################
                elif self.Attention == 4:
                    out,mask = self.net_s1(vol0)
                    ##################change the vol mask: vol1/vol2
                    lrec.append(self.optimize_step_attention_onestream(out,labels1,mask,vol1))
                #####################################################################
                elif self.Attention == 5:
                    out,mask = self.net_s1(vol0)
                    lrec.append(self.optimize_step(out,labels1))
                #####################################################################
                elif self.Attention == 6:
                    out = self.net_s1(vol0)
                    lrec.append(self.optimize_step(out,labels1))
                #####################################################################
                else:
                    out1 = self.net_s1(vol1)
                    out2 = self.net_s2(vol2)
                    out = torch.cat((out1, out2), dim=1)
                    lrec.append(self.optimize_step(out, labels1))

                step = step + 1
                
                
                if step % FLAGS.print_step == 0:
                    msg = ('[%d/%d][%d/%d] loss: %.4f Step: %d'
                      %(epoch, epochs, i, len(train_loader1), np.mean(np.asarray(lrec)), step))
                    lrec[:] = []
                    print(msg)
                    log.write(msg+"\n")
                    log.flush()
                    #################set eval
                    if self.Attention == 2:
                        self.net_s0.eval()
                    if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
                        self.net_s2.eval()
                    self.net_s1.eval()
                    self.net_s3.eval()
                    ####################
                    path_test_cum = []
                    label_pred_cum = np.array((),dtype=np.int32)
                    label_test_cum = np.array((),dtype=np.int32)
                    for j, data_test in enumerate(zip(test_loader1,test_loader2,test_loader3), 0):
                        if transform:
                            vol_test1 = data_test[0][0].to(self.device)
                            label_test = data_test[0][1]
                            vol_test2 = data_test[1][0].to(self.device)
                            vol_test0 = data_test[2][0].to(self.device)
                        else:
                            vol_test1  = data_test[0]['vol'].unsqueeze(1).to(self.device)
                            vol_test2  = data_test[1]['vol'].unsqueeze(1).to(self.device)
                            label_test = data_test[0]['label']
                            vol_test0 = data_test[2]['vol'].unsqueeze(1).to(self.device)
                            path_test = data_test[0]['path']
                            
                        ############################binary
                        if binary_mask:
                            vol_test1[vol_test1>vol_test1.min()] = 1.0
                            vol_test1[vol_test1==vol_test1.min()] = 0.0
                            vol_test2[vol_test2>vol_test2.min()] = 1.0
                            vol_test2[vol_test2==vol_test2.min()] = 0.0
                        #############################
                        # if label_test.shape[0] == 1:
                        #     continue
                       
                    ##############################################################
                        if self.Attention == 1:
                            pred_test0 = self.net_s1(vol_test0)
                            pred_test1 = self.net_s2(vol_test0)

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 2:
                            out_test0 = self.net_s0(vol_test0)
                            pred_test0 = self.net_s1(out_test0)
                            pred_test1 = self.net_s2(out_test0)

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))
                        
                        ##############################################################
                        elif self.Attention == 3:
                            pred_test0 = self.net_s1(vol_test0)
                            pred_test1 = self.net_s2(vol_test0)

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 4:
                            pred_test0 = self.net_s1(vol_test0)

                            pred_test  = self.net_s3(pred_test0)  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 5:
                            pred_test0 = self.net_s1(vol_test0)

                            pred_test  = self.net_s3(pred_test0)  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 6:
                            pred_test0 = self.net_s1(vol_test0)

                            pred_test  = self.net_s3(pred_test0)  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        else:
                            pred_test0 = self.net_s1(vol_test1)
                            pred_test1 = self.net_s2(vol_test2)  

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))
                        
                        if not transform:
                            path_test_cum += path_test
                    ############################################################
                    
                    accuracy = Utility_MEDICAL.compute_accuracy(label_test_cum,label_pred_cum)
                    balanced_accuracy = Utility_MEDICAL.binary_balanced_evaluation(label_test_cum,label_pred_cum)
                    specificity, sensitivity = Utility_MEDICAL.compute_specificity_sensitivity(label_test_cum,label_pred_cum)
                    f1_score = Utility_MEDICAL.compute_F1(label_test_cum, label_pred_cum)
                    auc = roc_auc_score(label_test_cum, label_pred_cum)
                    
                    # if step > stablize_step:
                    if accuracy > best_accuracy:
                        best_accuracy = accuracy
                        best_f1 = f1_score
                        best_step = step
                        if step > stablize_step:
                            if self.Attention == 2:
                                torch.save(self.net_s0.state_dict(), '%s/net_s0_step_%d.pth' % (FLAGS.model_save, step))
                            if self.Attention != 4 and self.Attention!= 5 and self.Attention!= 6:
                                torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))

                        #########check the results
                        if not transform:
                            path_label = ("path_label: {}".format(path_test_cum))
                            log.write(path_label+"\n")
                        test_label = 'labels: ' + str(label_test_cum)
                        prediction = 'prediction: ' + str(label_pred_cum)
                        log.write(test_label+"\n")
                        log.write(prediction+"\n")
                        #########

                    elif accuracy == best_accuracy:
                        if f1_score > best_f1:
                            best_f1 = f1_score
                            best_step = step
                            if step > stablize_step:
                                if self.Attention == 2:
                                    torch.save(self.net_s0.state_dict(), '%s/net_s0_step_%d.pth' % (FLAGS.model_save, step))
                                elif self.Attention != 4 and self.Attention != 5 and self.Attention!= 6:
                                    torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                                torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                                torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))

                            #########check the results
                            if not transform:
                                path_label = ("path_label: {}".format(path_test_cum))
                                log.write(path_label+"\n")
                            test_label = 'labels: ' + str(label_test_cum)
                            prediction = 'prediction: ' + str(label_pred_cum)
                            log.write(test_label+"\n")
                            log.write(prediction+"\n")
                            #########
                        ###########just for visual#########
                        # else:
                        #     if step % stablize_step == 0:
                        #         if self.Attention == 2:
                        #                 torch.save(self.net_s0.state_dict(), '%s/net_s0_step_%d.pth' % (FLAGS.model_save, step))
                        #         elif self.Attention != 4 and self.Attention != 5 and self.Attention!= 6:
                        #             torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                        #         torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                        #         torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))
                        #################################
                    
                    msg0 = ('Test accuracy: %.4f  Balanced_accuracy: %.4f'%(accuracy, balanced_accuracy))
                    msg1 = ('Specificity: %.4f  Sensitivity: %.4f'%(specificity, sensitivity))
                    msg2 = ('AUC: %.4f\nF1 score: %.4f'%(auc, f1_score))
                    msg3 = ('Best F1 score: %.4f  At Step: %d' %(best_f1, best_step))
                    print(msg0)
                    print(msg1)
                    print(msg2)
                    print(msg3)
                    log.write(msg0+"\n")
                    log.write(msg1+"\n")
                    log.write(msg2+"\n")
                    log.write(msg3+"\n")
                    log.flush()
                        
                # if step % FLAGS.save_step == 0:
                #     # ===== Save models ====
                #     torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                #     torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                #     torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))

        log.close()
        print("Training complete.")

    def train_val(self, data_params, branch_name="Undefined Here"):
        
        epochs           = data_params["epochs"]
        batch_size       = data_params["batch_size"]
        batch_workers    = data_params["batch_workers"]
        shuffle          = data_params["shuffle"]
        drop_last        = data_params["drop_last"]
        transform        = data_params['transform']
        binary_mask      = data_params['binary_mask']
        datasplit_scheme = data_params["datasplit_scheme"]
        test_split       = data_params["test_split"]
        xfold            = data_params["xfold"]
        fold_idx         = data_params["fold_idx"]
        random_seed      = data_params["random_seed"]
    
        if transform:
            train_loader1, val_loader1 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.data_path, batch_size, 
                                                                                     batch_workers, shuffle, drop_last, 0.5, 
                                                                                     datasplit_scheme, test_split, xfold, 
                                                                                     fold_idx, random_seed)
            train_loader2, val_loader2 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.data_path2, batch_size, 
                                                                                     batch_workers, shuffle, drop_last, 0.5, 
                                                                                     datasplit_scheme, test_split, xfold, 
                                                                                     fold_idx, random_seed)
            train_loader3, val_loader3 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.ori_path, batch_size, 
                                                                                     batch_workers, shuffle, drop_last, 0.5, 
                                                                                     datasplit_scheme, test_split, xfold, 
                                                                                     fold_idx, random_seed)
            
            test_loader1 = Data_fetcher.fetch_dataset_wValidation_aug(name =FLAGS.dataset, data_path = FLAGS.extra_path_dim1, batch_size = batch_size, 
                                                                      batch_workers = batch_workers, shuffle = shuffle, drop_last=drop_last, scalor =0.5, 
                                                                      datasplit_scheme = 'All', test_split =test_split, xfold =xfold, 
                                                                      fold_idx = fold_idx, random_seed = random_seed)
            test_loader2 = Data_fetcher.fetch_dataset_wValidation_aug(name =FLAGS.dataset, data_path = FLAGS.extra_path_dim2, batch_size = batch_size, 
                                                                      batch_workers = batch_workers, shuffle = shuffle, drop_last=drop_last, scalor =0.5, 
                                                                      datasplit_scheme = 'All', test_split =test_split, xfold =xfold, 
                                                                      fold_idx = fold_idx, random_seed = random_seed)
            test_loader3 = Data_fetcher.fetch_dataset_wValidation_aug(name =FLAGS.dataset, data_path = FLAGS.extra_path_ori, batch_size = batch_size, 
                                                                      batch_workers = batch_workers, shuffle = shuffle, drop_last=drop_last, scalor =0.5, 
                                                                      datasplit_scheme = 'All', test_split =test_split, xfold =xfold, 
                                                                      fold_idx = fold_idx, random_seed = random_seed)
            
        else:
            train_loader1, val_loader1 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path, batch_size, 
                                                                                 batch_workers, shuffle, drop_last, 0.5, 
                                                                                 datasplit_scheme, test_split, xfold, 
                                                                                 fold_idx, random_seed)
            train_loader2, val_loader2 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path2, batch_size, 
                                                                                 batch_workers, shuffle, drop_last, 0.5, 
                                                                                 datasplit_scheme, test_split, xfold, 
                                                                                 fold_idx, random_seed)
            train_loader3, val_loader3 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.ori_path, batch_size, 
                                                                                 batch_workers, shuffle, drop_last, 0.5, 
                                                                                 datasplit_scheme, test_split, xfold, 
                                                                                 fold_idx, random_seed)
            
            test_loader1 = Data_fetcher.fetch_dataset_wValidation(name =FLAGS.dataset, data_path = FLAGS.extra_path_dim1, batch_size = batch_size, 
                                                                      batch_workers = batch_workers, shuffle = shuffle, drop_last=drop_last, scalor =0.5, 
                                                                      datasplit_scheme = 'All', test_split =test_split, xfold =xfold, 
                                                                      fold_idx = fold_idx, random_seed = random_seed)
            test_loader2 = Data_fetcher.fetch_dataset_wValidation(name =FLAGS.dataset, data_path = FLAGS.extra_path_dim2, batch_size = batch_size, 
                                                                      batch_workers = batch_workers, shuffle = shuffle, drop_last=drop_last, scalor =0.5, 
                                                                      datasplit_scheme = 'All', test_split =test_split, xfold =xfold, 
                                                                      fold_idx = fold_idx, random_seed = random_seed)
            test_loader3 = Data_fetcher.fetch_dataset_wValidation(name =FLAGS.dataset, data_path = FLAGS.extra_path_ori, batch_size = batch_size, 
                                                                      batch_workers = batch_workers, shuffle = shuffle, drop_last=drop_last, scalor =0.5, 
                                                                      datasplit_scheme = 'All', test_split =test_split, xfold =xfold, 
                                                                      fold_idx = fold_idx, random_seed = random_seed)

        log = open(FLAGS.log_path, "a")
        log.write('Branch: %s  Fold ID: %d\n\n' % (branch_name, fold_idx))
        log.flush()
        
        step = 0
        if FLAGS.continue_model:
            if self.Attention == 2:
                self.net_s0.load_state_dict(torch.load('%s/net_s0_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
                self.net_s2.load_state_dict(torch.load('%s/net_s2_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.net_s1.load_state_dict(torch.load('%s/net_s1_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.net_s3.load_state_dict(torch.load('%s/net_s3_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            step = FLAGS.model_step + 1
            
        lrec = []
        # best_f1 = -1.0
        best_accuracy = 0.0
        best_step = 0
        ###
        best_accuracy_val = 0.0
        best_accuracy_all = 0.0
        best_step_all = 0
        best_step_val = 0

        stablize_step = FLAGS.save_step
        for epoch in range(epochs):
            for i, data in enumerate(zip(train_loader1,train_loader2,train_loader3), 0):
                #################set train
                if self.Attention == 2:
                    self.net_s0.train()
                if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
                    self.net_s2.train()
                self.net_s1.train()
                self.net_s3.train()
                ####################
                ##############################################################data load
                if transform:
                    vol1 = data[0][0].to(self.device)
                    labels1 = data[0][1].to(self.device)
                    vol2 = data[1][0].to(self.device)
                    vol0 = data[2][0].to(self.device)
                else:
                    vol1 = data[0]['vol'].unsqueeze(1).to(self.device)
                    labels1 = data[0]['label'].to(self.device)
                    vol2 = data[1]['vol'].unsqueeze(1).to(self.device)
                    vol0 = data[2]['vol'].unsqueeze(1).to(self.device)
                #########################################binary
                if binary_mask:
                    vol1[vol1>vol1.min()] = 1.0
                    vol1[vol1==vol1.min()] = 0.0
                    vol2[vol2>vol2.min()] = 1.0
                    vol2[vol2==vol2.min()] = 0.0
                ###############################################
                ###############################################################
                if labels1.shape[0] == 1:
                    continue
                ##############################################################
                if self.Attention == 1:
                    out1,mask1 = self.net_s1(vol0)
                    out2,mask2 = self.net_s2(vol0)
                    out = torch.cat((out1, out2), dim=1)

                    lrec.append(self.optimize_step_attention(out,labels1,mask1,mask2,vol1,vol2))
                #####################################################################
                elif self.Attention == 2:
                    out0 = self.net_s0(vol0)
                    out1,mask1 = self.net_s1(out0)
                    out2,mask2 = self.net_s2(out0)
                    out = torch.cat((out1, out2), dim=1)
                    lrec.append(self.optimize_step_attention(out,labels1,mask1,mask2,vol1,vol2))
                ############################################################
                elif self.Attention == 3:
                    out1,mask11,mask12,mask13,mask14 = self.net_s1(vol0)
                    out2,mask21,mask22,mask23,mask24 = self.net_s2(vol0)
                    out = torch.cat((out1, out2), dim=1)
                    lrec.append(self.optimize_step_attention_multi(out,labels1,mask11,mask12,mask13,mask14,mask21,mask22,mask23,mask24,vol1,vol2))
                #####################################################################
                elif self.Attention == 4:
                    out,mask = self.net_s1(vol0)
                    ##################change the vol mask: vol1/vol2
                    lrec.append(self.optimize_step_attention_onestream(out,labels1,mask,vol1))
                #####################################################################
                elif self.Attention == 5:
                    out,mask = self.net_s1(vol0)
                    lrec.append(self.optimize_step(out,labels1))
                #####################################################################
                elif self.Attention == 6:
                    out = self.net_s1(vol0)
                    lrec.append(self.optimize_step(out,labels1))
                #####################################################################
                else:
                    out1 = self.net_s1(vol1)
                    out2 = self.net_s2(vol2)
                    out = torch.cat((out1, out2), dim=1)
                    lrec.append(self.optimize_step(out, labels1))

                step = step + 1
                
                
                if step % FLAGS.print_step == 0:
                    msg = ('[%d/%d][%d/%d] loss: %.4f Step: %d'
                      %(epoch, epochs, i, len(train_loader1), np.mean(np.asarray(lrec)), step))
                    lrec[:] = []
                    print(msg)
                    log.write(msg+"\n")
                    log.flush()
                    #################set eval
                    if self.Attention == 2:
                        self.net_s0.eval()
                    if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
                        self.net_s2.eval()
                    self.net_s1.eval()
                    self.net_s3.eval()
                    ####################

                    label_pred_cum = np.array((),dtype=np.int32)
                    label_test_cum = np.array((),dtype=np.int32)
                    ###
                    label_pred_cum_val = np.array((),dtype=np.int32)
                    label_test_cum_val = np.array((),dtype=np.int32)
                    
                    for j, data_val in enumerate(zip(val_loader1,val_loader2,val_loader3), 0):
                        if transform:

                            vol_val1 = data_val[0][0].to(self.device)
                            label_val = data_val[0][1]
                            vol_val2 = data_val[1][0].to(self.device)
                            vol_val0 = data_val[2][0].to(self.device)
                        else:

                            vol_val1  = data_val[0]['vol'].unsqueeze(1).to(self.device)
                            vol_val2  = data_val[1]['vol'].unsqueeze(1).to(self.device)
                            label_val = data_val[0]['label']
                            vol_val0 = data_val[2]['vol'].unsqueeze(1).to(self.device)
                            
                        ############################binary
                        if binary_mask:
                            vol_val1[vol_val1>vol_val1.min()] = 1.0
                            vol_val1[vol_val1==vol_val1.min()] = 0.0
                            vol_val2[vol_val2>vol_val2.min()] = 1.0
                            vol_val2[vol_val2==vol_val2.min()] = 0.0
                        #############################
                        # if label_test.shape[0] == 1:
                        #     continue
                       
                    ##############################################################
                        if self.Attention == 1:

                            ###
                            pred_val0 = self.net_s1(vol_val0)
                            pred_val1 = self.net_s2(vol_val0)

                            pred_val  = self.net_s3(torch.cat((pred_val0, pred_val1), dim=1))  
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))

                        ######################################################################
                        elif self.Attention == 2:

                            ###
                            out_val0 = self.net_s0(vol_val0)
                            pred_val0 = self.net_s1(out_val0)
                            pred_val1 = self.net_s2(out_val0)

                            pred_val  = self.net_s3(torch.cat((pred_val0, pred_val1), dim=1))  
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))
                        
                        ##############################################################
                        elif self.Attention == 3:

                            ###
                            pred_val0 = self.net_s1(vol_val0)
                            pred_val1 = self.net_s2(vol_val0)

                            pred_val  = self.net_s3(torch.cat((pred_val0, pred_val1), dim=1))  
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))

                        ######################################################################
                        elif self.Attention == 4:

                            ###
                            pred_val0 = self.net_s1(vol_val0)

                            pred_val  = self.net_s3(pred_val0)  
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))

                        ######################################################################
                        elif self.Attention == 5:

                            pred_val0 = self.net_s1(vol_val0)

                            pred_val  = self.net_s3(pred_val0)  
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))

                        ######################################################################
                        elif self.Attention == 6:

                            ###
                            pred_val0 = self.net_s1(vol_val0)

                            pred_val  = self.net_s3(pred_val0)  
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))

                        ######################################################################
                        else:

                            pred_val0 = self.net_s1(vol_val1)
                            pred_val1 = self.net_s2(vol_val2)  

                            pred_val  = self.net_s3(torch.cat((pred_val0, pred_val1), dim=1))
                            predicted_val  = torch.max(pred_val.data, 1)[1]
                            label_pred_cum_val = np.concatenate((label_pred_cum_val, predicted_val.detach().cpu()))
                            label_test_cum_val = np.concatenate((label_test_cum_val, label_val))


                    ##########################################test
                    for k, data_test in enumerate(zip(test_loader1,test_loader2,test_loader3), 0):
                        if transform:
                            vol_test1 = data_test[0][0].to(self.device)
                            label_test = data_test[0][1]
                            vol_test2 = data_test[1][0].to(self.device)
                            vol_test0 = data_test[2][0].to(self.device)
                        else:
                            vol_test1  = data_test[0]['vol'].unsqueeze(1).to(self.device)
                            vol_test2  = data_test[1]['vol'].unsqueeze(1).to(self.device)
                            label_test = data_test[0]['label']
                            vol_test0 = data_test[2]['vol'].unsqueeze(1).to(self.device)
                            
                        ############################binary
                        if binary_mask:
                            vol_test1[vol_test1>vol_test1.min()] = 1.0
                            vol_test1[vol_test1==vol_test1.min()] = 0.0
                            vol_test2[vol_test2>vol_test2.min()] = 1.0
                            vol_test2[vol_test2==vol_test2.min()] = 0.0
                        #############################
                        # if label_test.shape[0] == 1:
                        #     continue
                       
                    ##############################################################
                        if self.Attention == 1:
                            pred_test0 = self.net_s1(vol_test0)
                            pred_test1 = self.net_s2(vol_test0)

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 2:
                            out_test0 = self.net_s0(vol_test0)
                            pred_test0 = self.net_s1(out_test0)
                            pred_test1 = self.net_s2(out_test0)

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))
                        
                        ##############################################################
                        elif self.Attention == 3:
                            pred_test0 = self.net_s1(vol_test0)
                            pred_test1 = self.net_s2(vol_test0)

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 4:
                            pred_test0 = self.net_s1(vol_test0)

                            pred_test  = self.net_s3(pred_test0)  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 5:
                            pred_test0 = self.net_s1(vol_test0)

                            pred_test  = self.net_s3(pred_test0)  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        elif self.Attention == 6:
                            pred_test0 = self.net_s1(vol_test0)

                            pred_test  = self.net_s3(pred_test0)  
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                        ######################################################################
                        else:
                            pred_test0 = self.net_s1(vol_test1)
                            pred_test1 = self.net_s2(vol_test2)  

                            pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))
                            predicted  = torch.max(pred_test.data, 1)[1]
                            label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                            label_test_cum = np.concatenate((label_test_cum, label_test))

                    ############################################################
                    
                    accuracy = Utility_MEDICAL.compute_accuracy(label_test_cum,label_pred_cum)
                    balanced_accuracy = Utility_MEDICAL.binary_balanced_evaluation(label_test_cum,label_pred_cum)
                    specificity, sensitivity = Utility_MEDICAL.compute_specificity_sensitivity(label_test_cum,label_pred_cum)
                    f1_score = Utility_MEDICAL.compute_F1(label_test_cum, label_pred_cum)
                    auc = roc_auc_score(label_test_cum, label_pred_cum)

                    ###
                    accuracy_val = Utility_MEDICAL.compute_accuracy(label_test_cum_val,label_pred_cum_val)
                    balanced_accuracy_val = Utility_MEDICAL.binary_balanced_evaluation(label_test_cum_val,label_pred_cum_val)
                    specificity_val, sensitivity_val = Utility_MEDICAL.compute_specificity_sensitivity(label_test_cum_val,label_pred_cum_val)
                    f1_score_val = Utility_MEDICAL.compute_F1(label_test_cum_val, label_pred_cum_val)
                    auc_val = roc_auc_score(label_test_cum_val, label_pred_cum_val)

                    #############################################################
                    
                    if (accuracy_val) >= (best_accuracy_val):###
                        best_accuracy_val = accuracy_val
                        best_step_val = step
                        if accuracy >= best_accuracy:
                            best_accuracy = accuracy
                            best_step = step
                        if (accuracy+accuracy_val) >= best_accuracy_all:
                            best_accuracy_all = accuracy+accuracy_val
                            best_step_all = step
                        if (step >= stablize_step and 
                            (accuracy>=0.9 or accuracy_val>=0.9 or (accuracy_val+accuracy)>=1.5)):
                            if self.Attention == 2:
                                torch.save(self.net_s0.state_dict(), '%s/net_s0_step_%d.pth' % (FLAGS.model_save, step))
                            if self.Attention != 4 and self.Attention!= 5 and self.Attention!= 6:
                                torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))

                        #########check the results
                        print('best validation results===================')
                        test_label = 'labels: ' + str(label_test_cum)
                        prediction = 'prediction: ' + str(label_pred_cum)

                        ###
                        val_label = 'val labels: ' + str(label_test_cum_val)
                        prediction_val = 'val prediction: ' + str(label_pred_cum_val)


                        log.write(test_label+"\n")
                        log.write(prediction+"\n")

                        ###
                        log.write(val_label+"\n")
                        log.write(prediction_val+"\n")
                        #########
                    elif (accuracy) >= (best_accuracy):###
                        best_accuracy = accuracy
                        best_step = step
                        if (accuracy+accuracy_val) >= best_accuracy_all:
                            best_accuracy_all = accuracy+accuracy_val
                            best_step_all = step
                        if (step >= stablize_step and 
                            (accuracy>=0.9 or accuracy_val>=0.9 or (accuracy_val+accuracy)>=1.5)):
                            if self.Attention == 2:
                                torch.save(self.net_s0.state_dict(), '%s/net_s0_step_%d.pth' % (FLAGS.model_save, step))
                            if self.Attention != 4 and self.Attention!= 5 and self.Attention!= 6:
                                torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))

                        #########check the results
                        print('best test results===================')
                        test_label = 'labels: ' + str(label_test_cum)
                        prediction = 'prediction: ' + str(label_pred_cum)

                        ###
                        val_label = 'val labels: ' + str(label_test_cum_val)
                        prediction_val = 'val prediction: ' + str(label_pred_cum_val)


                        log.write(test_label+"\n")
                        log.write(prediction+"\n")

                        ###
                        log.write(val_label+"\n")
                        log.write(prediction_val+"\n")
                        #########

                    elif (accuracy+accuracy_val) >= (best_accuracy+best_accuracy_val):###
                        best_accuracy_all = accuracy+accuracy_val###
                        best_step_all = step
                        if (step >= stablize_step and 
                            (accuracy>=0.9 or accuracy_val>=0.9 or (accuracy_val+accuracy)>=1.5)):
                            if self.Attention == 2:
                                torch.save(self.net_s0.state_dict(), '%s/net_s0_step_%d.pth' % (FLAGS.model_save, step))
                            if self.Attention != 4 and self.Attention!= 5 and self.Attention!= 6:
                                torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                            torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))

                        #########check the results
                        print('best summation results===================')
                        test_label = 'labels: ' + str(label_test_cum)
                        prediction = 'prediction: ' + str(label_pred_cum)

                        ###
                        val_label = 'val labels: ' + str(label_test_cum_val)
                        prediction_val = 'val prediction: ' + str(label_pred_cum_val)


                        log.write(test_label+"\n")
                        log.write(prediction+"\n")

                        ###
                        log.write(val_label+"\n")
                        log.write(prediction_val+"\n")
                        #########
                            
                    
                    msg0 = ('Test accuracy: %.4f  Balanced_accuracy: %.4f'%(accuracy, balanced_accuracy))
                    msg1 = ('Specificity: %.4f  Sensitivity: %.4f'%(specificity, sensitivity))
                    msg2 = ('AUC: %.4f\nF1 score: %.4f'%(auc, f1_score))
                    msg3 = ('Best test score: %.4f  At Step: %d' %(best_accuracy, best_step))

                    ###
                    msg4 = ('Val accuracy: %.4f  Balanced_accuracy_val: %.4f'%(accuracy_val, balanced_accuracy_val))
                    msg5 = ('Specificity_val: %.4f  Sensitivity_val: %.4f'%(specificity_val, sensitivity_val))
                    msg6 = ('AUC: %.4f\nF1 score: %.4f'%(auc_val, f1_score_val))
                    msg7 = ('Best validation score: %.4f  At Step: %d' %(best_accuracy_val, best_step_val))

                    msg8 = ('Sum accuracy: %.4f'%(accuracy+accuracy_val))
                    msg9 = ('Best summation score: %.4f  At Step: %d' %(best_accuracy_all, best_step_all))
                    msg10 = ('######################################')

                    print(msg0)
                    print(msg1)
                    print(msg2)
                    print(msg3)
                    print(msg4)
                    print(msg5)
                    print(msg6)
                    print(msg7)
                    print(msg8)
                    print(msg9)
                    print(msg10)

                    log.write(msg0+"\n")
                    log.write(msg1+"\n")
                    log.write(msg2+"\n")
                    log.write(msg3+"\n")
                    log.write(msg4+"\n")
                    log.write(msg5+"\n")
                    log.write(msg6+"\n")
                    log.write(msg7+"\n")
                    log.write(msg8+"\n")
                    log.write(msg9+"\n")
                    log.write(msg10+"\n")
                    log.flush()
                        
        log.close()
        print("Training complete.")

    def eval(self, data_params, data_type, ori_path, data_path1, data_path2):

        # epochs           = data_params["epochs"]
        # batch_size       = data_params["batch_size"]
        batch_size       = 1
        batch_workers    = data_params["batch_workers"]
        # shuffle          = data_params["shuffle"]
        shuffle          = False
        drop_last        = data_params["drop_last"]
        transform        = data_params['transform']
        
        binary_mask      = data_params['binary_mask']
        datasplit_scheme = 'All'
        # datasplit_scheme = data_params["datasplit_scheme"]
        test_split       = data_params["test_split"]
        xfold            = data_params["xfold"]
        fold_idx         = data_params["fold_idx"]
        random_seed      = data_params["random_seed"]

        if transform:
            if datasplit_scheme == 'Test':
                train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation_aug(data_type, data_path1, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation_aug(data_type, data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                train_loader3, test_loader3 = Data_fetcher.fetch_dataset_wValidation_aug(data_type, ori_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            else:
                test_loader1 = Data_fetcher.fetch_dataset_wValidation_aug(data_type, data_path1, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                test_loader2 = Data_fetcher.fetch_dataset_wValidation_aug(data_type, data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                test_loader3 = Data_fetcher.fetch_dataset_wValidation_aug(data_type, ori_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        else:   
            if datasplit_scheme == 'Test':
                train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation(data_type, data_path1, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation(data_type, data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                train_loader3, test_loader3 = Data_fetcher.fetch_dataset_wValidation(data_type, ori_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            else:
                test_loader1 = Data_fetcher.fetch_dataset_wValidation(data_type, data_path1, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                test_loader2 = Data_fetcher.fetch_dataset_wValidation(data_type, data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
                test_loader3 = Data_fetcher.fetch_dataset_wValidation(data_type, ori_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        log = open(FLAGS.log_path, "a")
        log.write('Start evaluation>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>'+'\n')
        log.flush()
        
        if self.Attention == 2:
            self.net_s0.load_state_dict(torch.load('%s/net_s0_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
        self.net_s1.load_state_dict(torch.load('%s/net_s1_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
        if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
            self.net_s2.load_state_dict(torch.load('%s/net_s2_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
        self.net_s3.load_state_dict(torch.load('%s/net_s3_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
        step = FLAGS.model_step + 1  

        msg = ('Step: %d'%(step))
        log.write(msg+"\n")
        log.flush()

        ################
        label_pred_cum = np.array((),dtype=np.int32)
        label_test_cum = np.array((),dtype=np.int32)
        #################set eval
        if self.Attention == 2:
            self.net_s0.eval()
        if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
            self.net_s2.eval()
        self.net_s1.eval()
        self.net_s3.eval()
        
        ####################
        # print(medcam.get_layers(self.net_s1))
        
        # #############attention map
        # # self.net_s1 = medcam.inject(self.net_s1, output_dir= FLAGS.image_save +'/attention_maps_dim1_'+ str(FLAGS.model_step) , backend='gcam', layer = 'attention1.1'  ,label= 'best', save_maps=True)
        # self.net_s1 = medcam.inject(self.net_s1, output_dir= FLAGS.image_save +'/attention_maps_dim1_'+ str(FLAGS.model_step) , backend='gcam', label= None, save_maps=True)
        # if self.Attention != 4 and self.Attention != 5 and self.Attention != 6:
        #     self.net_s2 = medcam.inject(self.net_s2, output_dir= FLAGS.image_save +'/attention_maps_dim2_'+ str(FLAGS.model_step) , backend='gcam', label= None, save_maps=True)
        #     # self.net_s2 = medcam.inject(self.net_s2, output_dir= FLAGS.image_save +'/attention_maps_dim2_'+ str(FLAGS.model_step) , backend='gcam', layer = 'attention1.1'  ,label= 'best', save_maps=True)
        # #################

        ####################count layers and architecture
        if self.Attention == 1:
            print('s1:',self.net_s1.parameters())
        #################################################

        for j, data_test in enumerate(zip(test_loader1,test_loader2,test_loader3), 0):
            ########################################data load
            if transform:
                    vol_test1 = data_test[0][0].to(self.device)
                    label_test = data_test[0][1]
                    vol_test2 = data_test[1][0].to(self.device)
                    vol_test0 = data_test[2][0].to(self.device)
            else:
                vol_test1  = data_test[0]['vol'].unsqueeze(1).to(self.device)
                vol_test2  = data_test[1]['vol'].unsqueeze(1).to(self.device)
                label_test = data_test[0]['label']
                vol_test0 = data_test[2]['vol'].unsqueeze(1).to(self.device)
            
            ############################binary
            if binary_mask:
                vol_test1[vol_test1>vol_test1.min()] = 1.0
                vol_test1[vol_test1==vol_test1.min()] = 0.0
                vol_test2[vol_test2>vol_test2.min()] = 1.0
                vol_test2[vol_test2==vol_test2.min()] = 0.0
            ##################################################

        ##############################################################
            if self.Attention == 1:
                pred_test0 = self.net_s1(vol_test0)
                pred_test1= self.net_s2(vol_test0)

                pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))
                 
                predicted  = torch.max(pred_test.data, 1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

            ######################################################################
            elif self.Attention == 2:
                out_test0 = self.net_s0(vol_test0)
                pred_test0 = self.net_s1(out_test0)
                pred_test1 = self.net_s2(out_test0)

                pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                predicted  = torch.max(pred_test.data,1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

            ##############################################################
            elif self.Attention == 3:
                pred_test0 = self.net_s1(vol_test0)
                pred_test1 = self.net_s2(vol_test0)

                pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))  
                predicted  = torch.max(pred_test.data, 1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

            ######################################################################
            elif self.Attention == 4:
                pred_test0 = self.net_s1(vol_test0)

                pred_test  = self.net_s3(pred_test0)  
                predicted  = torch.max(pred_test.data, 1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

            ######################################################################
            elif self.Attention == 5:
                pred_test0 = self.net_s1(vol_test0)

                pred_test  = self.net_s3(pred_test0)  
                predicted  = torch.max(pred_test.data, 1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

            ######################################################################
            elif self.Attention == 6:
                pred_test0 = self.net_s1(vol_test0)

                pred_test  = self.net_s3(pred_test0)  
                predicted  = torch.max(pred_test.data, 1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

            ######################################################################
            else:
                pred_test0 = self.net_s1(vol_test1)
                pred_test1 = self.net_s2(vol_test2)

                pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))
                predicted  = torch.max(pred_test.data, 1)[1]
                label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                label_test_cum = np.concatenate((label_test_cum, label_test))

    ############################################################
        accuracy = Utility_MEDICAL.compute_accuracy(label_test_cum,label_pred_cum)
        balanced_accuracy = Utility_MEDICAL.binary_balanced_evaluation(label_test_cum,label_pred_cum)
        specificity, sensitivity = Utility_MEDICAL.compute_specificity_sensitivity(label_test_cum,label_pred_cum)
        f1_score = Utility_MEDICAL.compute_F1(label_test_cum, label_pred_cum)
        auc = roc_auc_score(label_test_cum, label_pred_cum)  
        print('predictions:',label_pred_cum)
        print('labels:', label_test_cum)

        msg0 = ('Test accuracy: %.4f  Balanced_accuracy: %.4f'%(accuracy, balanced_accuracy))
        msg1 = ('Specificity: %.4f  Sensitivity: %.4f'%(specificity, sensitivity))
        msg2 = ('AUC: %.4f\nF1 score: %.4f'%(auc, f1_score))
        print(msg0)
        print(msg1)
        print(msg2)
        log.write(msg0+"\n")
        log.write(msg1+"\n")
        log.write(msg2+"\n")
        log.flush()