In [None]:
%run Include.ipynb
%run Net.ipynb
%run Data.ipynb
%run viewer.ipynb
%run Medical_Utility.ipynb

from torch.autograd import Variable
from sklearn.metrics import roc_auc_score


class CNN(object):
    
    def __init__(self, general, arch_list):
        
        lr        = general["learning_rate"]
        beta1     = general["beta1"]
        beta2     = general["beta2"]
        loss_mode = general["loss"]
        reduction = general["reduction"]
        gamma     = general['gamma']
        alpha     = general['alpha']
        
        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
    
        
    def train(self, data_params, branch_name="Undefined Here"):
        
        epochs           = data_params["epochs"]
        batch_size       = data_params["batch_size"]
        batch_workers    = data_params["batch_workers"]
        shuffle          = data_params["shuffle"]
        drop_last        = data_params["drop_last"]
        transform        = data_params['transform']
        extra_aug        = data_params['extra_aug']
        datasplit_scheme = data_params["datasplit_scheme"]
        test_split       = data_params["test_split"]
        xfold            = data_params["xfold"]
        fold_idx         = data_params["fold_idx"]
        random_seed      = data_params["random_seed"]

        
        if transform:
            train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation_aug(FLAGS.dataset, FLAGS.data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        else:
            train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
            train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        
        
        log = open(FLAGS.log_path, "a")
        log.write('Branch: %s  Fold ID: %d\n\n' % (branch_name, fold_idx))
        log.flush()
        
        step = 0
           
        lrec = []
        best_f1 = -1.0
        best_accuracy = 0.0
        best_step = 0
        stablize_step = 1000
        

        for epoch in range(epochs):
            for i, data in enumerate(zip(train_loader1,train_loader2), 0):
                if transform or extra_aug:
                    vol1 = data[0][0].to(self.device)
                    labels1 = data[0][1].to(self.device)
                    vol2 = data[1][0].to(self.device)
                else:
                    vol1 = data[0]['vol'].unsqueeze(1).to(self.device)
                    labels1 = data[0]['label'].to(self.device)
                    vol2 = data[1]['vol'].unsqueeze(1).to(self.device)

                if labels1.shape[0] == 1:
                    continue

                step = step + 1
                
                if step % FLAGS.print_step == 0:
                    msg = ('[%d/%d][%d/%d] Step: %d'
                      %(epoch, epochs, i, len(train_loader1), step))
                    print(msg)
                    log.write(msg+"\n")
                    log.flush()
                    
                    path_test_cum = []
                    label_test_cum = np.array((),dtype=np.int32)
                    for j, data_test in enumerate(zip(test_loader1,test_loader2), 0):
                        if transform or extra_aug:
                            vol_test1 = data_test[0][0].to(self.device)
                            label_test = data_test[0][1]
                            vol_test2 = data_test[1][0].to(self.device)
                            
                        else:
                            vol_test1  = data_test[0]['vol'].unsqueeze(1).to(self.device)
                            vol_test2  = data_test[1]['vol'].unsqueeze(1).to(self.device)
                            label_test = data_test[0]['label']
                            path_test = data_test[0]['path']
                            

                        if label_test.shape[0] == 1:
                            continue
                        
                        path_test_cum += path_test
                        label_test_cum = np.concatenate((label_test_cum, label_test))

                    msg_label = ("test_label: {}".format(label_test_cum))
                    msg0 = ("test_set: {}".format(path_test_cum))
                    print(msg_label)
                    print(msg0)
                    log.write(msg_label+"\n")
                    log.write(msg0+"\n")
                    log.flush()
                        
        log.close()
        print("Training complete.")