In [None]:
%run include.ipynb
%run Net.ipynb
%run Data.ipynb
%run viewer.ipynb
%run Medical_Utility.ipynb

from torch.autograd import Variable
from sklearn.metrics import roc_auc_score

class CNN(object):
    
    def __init__(self, general, arch_list):
        
        lr        = general["learning_rate"]
        beta1     = general["beta1"]
        beta2     = general["beta2"]
        loss_mode = general["loss"]
        reduction = general["reduction"]
        
        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
        torch.manual_seed(random.randint(1, 10000))
        
        assert(len(arch_list) == 2)
        self.input_dims_arch1, layers_arch1 = Net.parse_layers(arch_list[0])
        self.input_dims_arch2, layers_arch2 = Net.parse_layers(arch_list[1])
        self.net_s1 = Network_template(gpu_num, layers_arch1).to(self.device)
        self.net_s2 = Network_template(gpu_num, layers_arch1).to(self.device)
        self.net_s3 = Network_template(gpu_num, layers_arch2).to(self.device)
        Net.init_weights(self.net_s1, "normal")
        Net.init_weights(self.net_s2, "normal")
        Net.init_weights(self.net_s3, "normal")
        
        self.criterion = StandardLoss(loss_mode, reduction).to(self.device)
        #self.optimizer = optim.Adam(self.net.parameters(), lr=lr, betas=(beta1,beta2))
        network_params = list(self.net_s1.parameters())+list(self.net_s2.parameters())+list(self.net_s3.parameters())
        self.optimizer = optim.SGD(network_params, lr=0.01, momentum=0.9, weight_decay=1e-2)
        
    def optimize_step(self, Dtrain, labels):
        self.net_s1.zero_grad()
        self.net_s2.zero_grad()
        self.net_s3.zero_grad()
        loss = self.criterion([self.net_s3, Dtrain, labels])
        loss.backward()
        self.optimizer.step()
        return loss.item()
        
    def train(self, data_params, branch_name="Undefined Here"):
        
        epochs           = data_params["epochs"]
        batch_size       = data_params["batch_size"]
        batch_workers    = data_params["batch_workers"]
        shuffle          = data_params["shuffle"]
        drop_last        = data_params["drop_last"]
        datasplit_scheme = data_params["datasplit_scheme"]
        test_split       = data_params["test_split"]
        xfold            = data_params["xfold"]
        fold_idx         = data_params["fold_idx"]
        random_seed      = data_params["random_seed"]
        train_loader1, test_loader1 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        train_loader2, test_loader2 = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path2, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        log           = open(FLAGS.log_path, "a")
        log.write('Branch: %s  Fold ID: %d\n\n' % (branch_name, fold_idx))
        log.flush()
        
        step = 0
        if FLAGS.continue_model:
            self.net_s1.load_state_dict(torch.load('%s/net_s1_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.net_s2.load_state_dict(torch.load('%s/net_s2_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.net_s3.load_state_dict(torch.load('%s/net_s3_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            step = FLAGS.model_step + 1
            
        lrec = []
        best_f1 = -1.0
        best_step = 0
        stablize_step = 1000
        for epoch in range(epochs):
            for i, data in enumerate(zip(train_loader1,train_loader2), 0):
                vol1 = data[0]['vol'].unsqueeze(1).to(self.device)
                labels1 = data[0]['label'].to(self.device)
                vol2 = data[1]['vol'].unsqueeze(1).to(self.device)
                         
                if labels1.shape[0] == 1:
                    continue
                out1 = self.net_s1(vol1)
                out2 = self.net_s2(vol2)
                out = torch.cat((out1, out2), dim=1)
                lrec.append(self.optimize_step(out, labels1))
                step = step + 1
                
                if step % FLAGS.print_step == 0:
                    msg = ('[%d/%d][%d/%d] loss: %.4f Step: %d'
                      %(epoch, epochs, i, len(train_loader1), np.mean(np.asarray(lrec)), step))
                    lrec[:] = []
                    print(msg)
                    log.write(msg+"\n")
                    log.flush()
                    
                    label_pred_cum = np.array((),dtype=np.int32)
                    label_test_cum = np.array((),dtype=np.int32)
                    for j, data_test in enumerate(zip(test_loader1,test_loader2), 0):
                        vol_test1  = data_test[0]['vol'].unsqueeze(1).to(self.device)
                        vol_test2  = data_test[1]['vol'].unsqueeze(1).to(self.device)
                        label_test = data_test[0]['label']
                        if label_test.shape[0] == 1:
                            continue
                        pred_test0 = self.net_s1(vol_test1)
                        pred_test1 = self.net_s2(vol_test1)
                        pred_test  = self.net_s3(torch.cat((pred_test0, pred_test1), dim=1))
                        predicted  = torch.max(pred_test.data, 1)[1]
                        label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
                        label_test_cum = np.concatenate((label_test_cum, label_test))
                        
                    print(label_pred_cum.shape)
                    
                    accuracy = Utility_MEDICAL.compute_accuracy(label_test_cum,label_pred_cum)
                    balanced_accuracy = Utility_MEDICAL.binary_balanced_evaluation(label_test_cum,label_pred_cum)
                    specificity, sensitivity = Utility_MEDICAL.compute_specificity_sensitivity(label_test_cum,label_pred_cum)
                    f1_score = Utility_MEDICAL.compute_F1(label_test_cum, label_pred_cum)
                    auc = roc_auc_score(label_test_cum, label_pred_cum)
                    
                    if step > stablize_step:
                        if f1_score > best_f1:
                            best_f1 = f1_score
                            best_step = step
                    
                    msg0 = ('Test accurady: %.4f  Balanced_accuracy: %.4f'%(accuracy, balanced_accuracy))
                    msg1 = ('Specificity: %.4f  Sensitivity: %.4f'%(specificity, sensitivity))
                    msg2 = ('AUC: %.4f\nF1 score: %.4f'%(auc, f1_score))
                    msg3 = ('Best F1 score: %.4f  Step: %d' %(best_f1, best_step))
                    print(msg0)
                    print(msg1)
                    print(msg2)
                    print(msg3)
                    log.write(msg0+"\n")
                    log.write(msg1+"\n")
                    log.write(msg2+"\n")
                    log.write(msg3+"\n")
                    log.flush()
                        
                if step % FLAGS.save_step == 0:
                    # ===== Save models ====
                    torch.save(self.net_s1.state_dict(), '%s/net_s1_step_%d.pth' % (FLAGS.model_save, step))
                    torch.save(self.net_s2.state_dict(), '%s/net_s2_step_%d.pth' % (FLAGS.model_save, step))
                    torch.save(self.net_s3.state_dict(), '%s/net_s3_step_%d.pth' % (FLAGS.model_save, step))
        log.close()
        print("Training complete.")