In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from attention_unet import *
from deeplab_model import *
from data_pipeline import *
import numpy as np
import os as os

# Select single gpu
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 1.0
sess=tf.Session(config = config)

In [None]:
import time
tf.reset_default_graph()
lr =  2e-4
model = attention_gating_unet2D(num_classes=1)
# TRAINING MODEL
# forward propagation
def forward(model, data_generator):
    
    images = data_generator[0]
    labels = data_generator[1]
    logits = model(images) # computes logits with variables
    
    return logits, labels

def compute_loss(model, data_generator, loss_fn='default'):
    
    logits, labels = forward(model,data_generator)
    if loss_fn == 'default':
        loss_fn = tf.keras.losses.BinaryCrossentropy()
    losses = loss_fn(labels, logits)
    
    return tf.reduce_mean(losses), logits, labels
    


def start_training(sess, num_epoch, loss, train_logit, train_label, train_op, train_acc,
                   test_logit, test_label, test_acc, batch_size):
       
    variables = (train_acc.variables + test_acc.variables)
    # inits
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    sess.run([v.initializer for v in variables])
    
    # MAKE TIME LIST
    train_times = []
    test_times = []

    for e in range(num_epoch):

        start_time = time.time()

        num_steps = len(imgs_train)// batch_size
        for i in tqdm(range(num_steps)):
            
            loss_value, logits, labels, _  = sess.run([loss, train_logit, train_label, train_op])
            train_acc(labels, logits)
            
            if i % 20 == 0:
                print("--------------------------- Current epoch: {} --------------------------- ".format(e))
                print("Training loss at {}th iteration is {}".format(i, loss_value))

        end_time = time.time()
        time_diff = end_time - start_time
        train_times.append(time_diff)
        
        train_acc_value = sess.run(train_acc.result())
        print("Training binary accuracy for last epoch: {}".format(train_acc_value))
        train_acc.reset_states()
        print("Time taken for last training epoch is {}s".format(time_diff))
        print("Avg time for last training epoch is {}s".format(time_diff/num_steps))
                                
        
        # VALIDATION STEPS
        start_test_time = time.time()
        
        num_steps_test = len(imgs_test)//batch_size
        for i in range(num_steps_test):
            logits, labels = sess.run([test_logit, test_label])
            test_acc(labels, logits)
            
        end_test_time = time.time()
        test_time_diff = end_test_time - start_test_time
        test_times.append(test_time_diff)
        
        print("--------------------------- Validation Step --------------------------- ")
        print("Time taken for last validation epoch is {}s".format(test_time_diff))
        print("Avg time for last validation epoch is {}s".format(test_time_diff/num_steps_test))
        valid_acc_value = sess.run(test_acc.result())
        print("Mean validation binary accuracy for last epoch: {}".format(valid_acc_value))
        valid_acc_value.reset_states()
               
    mean_train_time = np.mean(train_times)
    max_train_time = np.max(train_times)
    min_train_time = np.min(train_times)
    print("Time taken for complete training cycle: mean: {}[{},{}] /epoch".format(mean_train_time,
                                                                   min_train_time,
                                                                   max_train_time))
    
    mean_test_time = np.mean(test_times)
    max_test_time = np.max(test_times)
    min_test_time = np.min(test_times)
    
    print("Time for complete validations: mean: {}[{},{}] /epoch".format(mean_test_time,
                                                                   min_test_time,
                                                                   max_test_time))
    

def train_model(model, batch_size, lr):
    
    # ESTABLISH DATA PIPELINE
    # SPLIT: PATH DIR -DATA(TRAIN/TEST)[0.8/0.2]
    imgs_train, label_train, imgs_test, label_test = split_datalist(train_test_ratio=0.8)
    
    # MAKE GENERATOR
    train_generator = input_fn(imgs_train, label_train, batch_size,num_parallel_calls= 18).make_one_shot_iterator()
    test_generator = input_fn(imgs_test, label_test, batch_size, num_parallel_calls= 18).make_one_shot_iterator()
    
    def input_train_fn():
        with tf.device(None):
            return train_generator.get_next()
        
    # SET OPTIMIZER
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    
    # EVALUATING METRICS
    train_acc_metric = tf.keras.metrics.BinaryAccuracy()
    valid_acc_metric = tf.keras.metrics.BinaryAccuracy()
    
    # COMPUTE LOSS
    train_loss, train_logits, train_labels = compute_loss(model, input_train_fn())
    
    grads = optimizer.compute_gradients(train_loss)
    apply_gradient_op = optimizer.apply_gradients(grads)
    
    # DEFINE TEST VARIABLES
    logits_test, labels_test = forward(model, test_generator.get_next())
    
    with tf.Session() as sess:        
        start_training(sess, 15, train_loss, train_logits, train_labels, apply_gradient_op, train_acc_metric,
                   logits_test, labels_test, valid_acc_metric, batch_size)
        
    
train_model(model, 15,lr) 