# Main script

In [None]:
import tensorflow as tf
import numpy as np
import math
import tqdm
import IPython.display
import matplotlib.pyplot as plt
from matplotlib import gridspec
import time
%matplotlib inline  

def get_params(name, shape):
    w = tf.get_variable(name=name + "_w", shape=shape,
                       #initializer=tf.contrib.layers.xavier_initializer(),
                       initializer=tf.contrib.layers.variance_scaling_initializer(),
                       dtype=tf.float32)
    n_filter = shape[-1]
    b = tf.get_variable(name=name + "_b", shape=n_filter, initializer = tf.constant_initializer(0), dtype=tf.float32)
    
    return w, b

def FC(x, w, b):
    fc = tf.matmul(x, w)
    return tf.nn.bias_add(fc, b)

def siamese_model(data, keep_prob, phase_train=True):
    with tf.variable_scope('siamese'):
        
        with tf.variable_scope('rnn_block_1'):
            
            B, R, T = data.get_shape()
            B, R, T = int(B), int(R), int(T)
            data_re = tf.reshape(data, [B*R, T])
            data_re = tf.reshape(data_re, [B*R, T, 1])
            layers = [tf.nn.rnn_cell.BasicLSTMCell(size) for size in [64, 64, 64, 64, 64, 64]]
            layers_drop = [tf.nn.rnn_cell.DropoutWrapper(layer, input_keep_prob = keep_prob) for layer in layers]
            cell = tf.nn.rnn_cell.MultiRNNCell(layers_drop, state_is_tuple=True)
            
            batch = data_re.get_shape().as_list()[0]
            time_length = data_re.get_shape().as_list()[1]
            initial_state = cell.zero_state(batch, tf.float32)
            
            outputs, final_state = tf.nn.dynamic_rnn(cell, data_re, initial_state=initial_state, dtype=tf.float32)
            last_outputs = outputs[:,99,:]
            last_outputs = tf.squeeze(last_outputs)
            last_output = tf.reshape(last_outputs, [B, R, 64])
            
            glob_avg_pool = tf.reduce_mean(last_output, 2)
            w1, b1 = get_params(name='1_fc', shape=(glob_avg_pool.get_shape()[1], 100))
            FC1 = FC(x=glob_avg_pool, w=w1, b=b1)
            FC1_tanh = tf.nn.tanh(FC1, name="FC1_tanh")
            
        return outputs, final_state, last_output, glob_avg_pool, w1, FC1, FC1_tanh

def contrastive_loss(y, d, batch_size):
    tmp= y *tf.square(d)
    tmp2 = (1-y) *tf.square(tf.maximum((1 - d),0))
    return tf.reduce_sum(tmp +tmp2)/batch_size/2
        
#REST1_data = np.load("./norm_REST1_S900_FIND_BOLD_raw.npy")
#REST2_data = np.load("./norm_REST2_S900_FIND_BOLD_raw.npy")

graph =tf.Graph()
with graph.as_default():
    global_step = tf.get_variable('global_step', [], initializer = tf.constant_initializer(0), trainable=False)
    
    batch_size = 16
    data_size = REST1_corr.shape[1]
    num_ROI = 141
    crop_size = 100
    
    X = tf.placeholder(dtype=tf.float32, shape=(batch_size*2, num_ROI, crop_size), name="X")
    Y = tf.placeholder(dtype=tf.float32, shape=(batch_size), name="Y")
    lr = tf.placeholder(tf.float32, shape=[], name="lr")
    
    phase_train = tf.placeholder(dtype=tf.bool, name="phase_train")
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    
    LSTM_out, LSTM_final, Last_o, GAP, weight, FC, FC_tanh = siamese_model(X, keep_prob, phase_train)
    
    latent_a = FC_tanh[:batch_size]
    latent_b = FC_tanh[batch_size:]
    distance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(latent_a,latent_b)),1,keep_dims=True))
    distance = tf.div(distance, tf.add(tf.sqrt(tf.reduce_sum(tf.square(latent_a),1,keep_dims=True)),tf.sqrt(tf.reduce_sum(tf.square(latent_b),1,keep_dims=True))))
    distance = tf.reshape(distance, [-1], name="distance")


    CR_loss = contrastive_loss(Y, distance, batch_size)
    
    temp_sim = tf.subtract(tf.ones_like(distance),tf.rint(distance), name="temp_sim") #auto threshold 0.5
    correct_predictions = tf.equal(temp_sim, Y)
    accuracy=tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
    
    # Call parameters
    tvars = tf.trainable_variables()
    w_vars = [var for var in tvars if 'w' in var.name]
    
    regularizer = 0
    w_vars_shape = np.shape(w_vars)
    
    for i in range(w_vars_shape[0]):
        regularizer_one = tf.nn.l2_loss(w_vars[i])
        regularizer += regularizer_one
    
    loss = CR_loss + 0.1*regularizer
    
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    Grads = optimizer.compute_gradients(loss)
    train = optimizer.apply_gradients(Grads)
    
    # Call model saver
    saver = tf.train.Saver(keep_checkpoint_every_n_hours=8, max_to_keep=100)
    ## Open Tensorflow session
    sess = tf.InteractiveSession(graph=graph, config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
    # Initialize all variables
    sess.run(tf.global_variables_initializer())

## Data preparation step

In [None]:
def prepare_pairs(X1, X2):
    N ,R, T= X1.shape
    M = X2.shape[0]
    n_pair = 2*N
    Pair_X = np.empty((n_pair, R, 100, 2))
    Pair_Y = np.ones(n_pair)

    for i in range(N):
        Pair_X[i,:,:,0]= X1[i,:,0:100]
        Pair_X[i,:,:,1]= X2[i,:,0:100]

    for i in range(N,2*N):
        Pair_X[i,:,:,0]= X1[i-N,:,0:100]
        if i-N == 0:
            Pair_X[i,:,:,1]= X2[i-1,:,0:100]
        else:
            Pair_X[i,:,:,1]= X2[i-N-1,:,0:100]
        Pair_Y[i] = 0
            
    return Pair_X, Pair_Y

def prepare_test(X1, X2, order):
    N ,R, T= X1.shape
    M = X2.shape[0]
    n_pair = 100
    Pair_X = np.empty((n_pair, R, 100, 2))
    Pair_Y = np.ones(n_pair)

    Pair_X[0,:,:,0]= X1[order,:,0:100]
    Pair_X[0,:,:,1]= X2[order,:,0:100]

    for i in range(1,100): 
        Pair_X[i,:,:,0]= X1[order,:,0:100]
        
        if order+i < 100:
            Pair_X[i,:,:,1]= X2[order+i,:,0:100]
        else:
            Pair_X[i,:,:,1]= X2[order+i-100,:,0:100]
        Pair_Y[i] = 0
            
    return Pair_X, Pair_Y

In [None]:
import pickle
with open('./HCP_ID_1_fold_test_idx.txt','rb') as f:
    test_index = pickle.load(f)
with open('./HCP_ID_1_fold_train_idx.txt','rb') as f:
    train_index = pickle.load(f)
with open('./HCP_ID_1_fold_valid_idx.txt','rb') as f:
    valid_index = pickle.load(f)
    
Train_REST1_data = REST1_data[train_index,:,:]
Train_REST2_data = REST2_data[train_index,:,:]

Valid_REST1_data = REST1_data[valid_index,:,:]
Valid_REST2_data = REST2_data[valid_index,:,:]

Test_REST1_data = REST1_data[test_index,:,:]
Test_REST2_data = REST2_data[test_index,:,:]

train_data, train_label = prepare_pairs(Train_REST1_data, Train_REST2_data)
num_train_trials = train_data.shape[0]
print(num_train_trials)

In [None]:
def plot_network_output():
        
    fig, ax1 = plt.subplots(figsize=(10, 5))
    ax1.plot(Accuracy_list, 'b.-', alpha = .3)
    ax1.set_ylabel('Training accuracy',  color='b')
    ax1.set_xlabel('Epoch')
    ax1.set_ylim((0.1,1.0))
    ax2 = ax1.twinx()
    ax2.plot(Accuracy_val_list, 'r.-', alpha = .3)
    ax2.set_ylabel('Validation accuracy', color='r')
    ax2.set_ylim((0.1,1.0))
    plt.title('Training & Validation accuracy')
    plt.show()
    fig.savefig('./1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_acc_last.png',dpi=100)
    
    fig, ax1 = plt.subplots(figsize=(10, 5))
    ax1.plot(Loss_list, 'b.-', alpha = .3)
    ax1.set_ylabel('Training loss',  color='b')
    ax1.set_xlabel('Epoch')
    ax1.set_ylim((0,1))
    ax2 = ax1.twinx()
    ax2.plot(Loss_val_list, 'r.-', alpha = .3)
    ax2.set_ylabel('Validation loss', color='r')
    ax2.set_ylim((0,1))
    plt.title('Training & Validation loss')

    plt.show()
    fig.savefig('./1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_loss_last.png',dpi=100)

## Start training

In [None]:
Loss_list = []
CR_Loss_list = []
Accuracy_list = []

Loss_val_list = []
Accuracy_val_list = []

Loss_test_list = []
Accuracy_test_list = []

pre_val_acc = 0.1
pre_val_loss = 1
total_epoch = 1000
print("Begin Training")
total_batch = int((num_train_trials) / batch_size)
epoch = 0
learning_rate = 1e-5

while epoch < total_epoch:
    
    Acc_avg = []
    Loss_avg = []
    CR_Loss_avg = []
    Acc_val_avg = []
    Loss_val_avg = []
    CR_Loss_val_avg = []
    Acc_test_avg = []
    Loss_test_avg = []
    CR_Loss_test_avg = []
    
    rand_idx = np.random.permutation(num_train_trials)
    for batch in tqdm.tqdm(range(total_batch)):
        batch_x = np.empty(shape=(batch_size*2, num_ROI, crop_size))
        batch_y = np.empty(shape=(batch_size))
        
        position = rand_idx[batch*batch_size : (batch+1)*batch_size]
                
        batch_x[:batch_size] = train_data[position, :, :, 0]
        batch_x[batch_size:] = train_data[position, :, :, 1]
        batch_y = train_label[position]

        _, loss_, CR_loss_ = sess.run([train, loss, CR_loss], {X: batch_x, Y: batch_y, lr: learning_rate, keep_prob: 0.8, phase_train: True})
        Acc_value= sess.run(accuracy, {X:batch_x, Y: batch_y, keep_prob:1.0, phase_train: False})
        
        Acc_avg.append(Acc_value)
        Loss_avg.append(loss_)
        CR_Loss_avg.append(CR_loss_)
    
    Accuracy_list.append(np.mean(Acc_avg))
    Loss_list.append(np.mean(Loss_avg))
    CR_Loss_list.append(np.mean(CR_Loss_avg))
    IPython.display.clear_output()
    print("%dth Epoch Training Accuracy: %f" %(epoch + 1, np.mean(Acc_avg)))
    print("%dth Epoch Training Loss: %f" %(epoch + 1, np.mean(Loss_avg)))
    print("%dth Epoch Training Sigmoid Loss: %f" %(epoch + 1, np.mean(CR_Loss_avg)))
    
    print("Model Saving")
    saver.save(sess, "./1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_last.tfmod")
    f = open(''.join(['1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_last',str(epoch).zfill(5)]), 'w')
    f.close()
    
    Total_loss_valid = []
    count = 0
    for i in range(100):
        Acc_valid_data, Acc_valid_label = prepare_test(Valid_REST1_data, Valid_REST2_data, i)
        valid_idx =np.arange(100)
        
        Total_dist_valid = []
        for batch in range(int(100/batch_size)):
            #print(batch)
            batch_x_valid = np.empty(shape=(batch_size*2, num_ROI, crop_size))
            batch_y_valid = np.empty(shape=(batch_size))
        
            position = valid_idx[batch*batch_size : (batch+1)*batch_size]
        
            batch_x_valid[:batch_size] = Acc_valid_data[position, :, :, 0]
            batch_x_valid[batch_size:] = Acc_valid_data[position, :, :, 1]
            batch_y_valid = Acc_valid_label[position]
            temp_dist_valid, loss_valid = sess.run([distance,loss], {X:batch_x_valid, Y:batch_y_valid, keep_prob:1.0, phase_train: False})
                                                               
            Total_dist_valid.extend(temp_dist_valid)
            Total_loss_valid.append(loss_valid)
        batch_x_valid = np.empty(shape=(batch_size*2, num_ROI, crop_size))
        batch_y_valid = np.empty(shape=(batch_size))
        batch_x_valid[:batch_size] = Acc_valid_data[100-batch_size:100, :, :, 0]
        batch_x_valid[batch_size:] = Acc_valid_data[100-batch_size:100, :, :, 1]
        batch_y_valid = Acc_valid_label[100-batch_size:100]
        temp_dist_valid, loss_valid = sess.run([distance,loss], {X:batch_x_valid, Y:batch_y_valid, keep_prob:1.0, phase_train: False})
    
        Total_dist_valid.extend(temp_dist_valid[12:]) 
        Total_loss_valid.append(loss_valid)
        
        ans = np.argmin(Total_dist_valid)
        if ans == 0:
            count += 1
    Acc_valid = count/100.0
    Accuracy_val_list.append(Acc_valid)
    Loss_val_list.append(np.mean(Total_loss_valid))
    print("Total Valid Accuracy: %4f" % (Acc_valid))
    print("Total Valid Max Accuracy: %4f" % (max(Accuracy_val_list)))
    print("TotalValid Loss: %4f" % (np.mean(Total_loss_valid)))
    
    validation_acc = Acc_valid
    if validation_acc >= pre_val_acc:
        pre_val_acc = validation_acc
        print("save max weight at %s acc"%(validation_acc))
        saver.save(sess, ''.join(['max_weight/1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_model.tfmod']))
        f = open(''.join(['max_weight/1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_last_model_',str(epoch).zfill(5)]), 'w')
        f.close()
 
    validation_loss = np.mean(Total_loss_valid)
    if validation_loss <= pre_val_loss:
        pre_val_loss = validation_loss
        print("save min weight at %s loss"%(validation_loss))
        saver.save(sess, ''.join(['min_weight/1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_model.tfmod']))
        f = open(''.join(['min_weight/1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_last_model_',str(epoch).zfill(5)]), 'w')
        f.close()
    #print("End Training\n")
    plot_network_output()
    epoch += 1

## Verify test_data with weight of maximum accuracy of validation_data

In [None]:
#saver = tf.train.import_meta_graph('max_weight/1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_model.tfmod.meta')
#tf.train.Saver.restore(saver, sess, 'max_weight/1fold_LSTM_64x6_GAP_FC_no_drop_0.8_lr1e-5_model.tfmod')
saver = tf.train.import_meta_graph('./HCP_ID_1fold_trained_weight.tfmod.meta')
tf.train.Saver.restore(saver, sess, './HCP_ID_1fold_trained_weight.tfmod')

Test_REST1_data = np.load("./norm_REST1_S900_test.npy")
Test_REST1_data = np.load("./norm_REST2_S900_test.npy")

Total_loss_test =[]
count = 0
for i in range(100):
    Acc_test_data, Acc_test_label = prepare_test(Test_REST1_data, Test_REST2_data, i)
    test_idx =np.arange(100)
    
    Total_dist_test = []
    for batch in range(int(100/batch_size)):
        batch_x_test = np.empty(shape=(batch_size*2, num_ROI, crop_size))
        batch_y_test = np.empty(shape=(batch_size))
        
        position = test_idx[batch*batch_size : (batch+1)*batch_size]
        
        batch_x_test[:batch_size] = Acc_test_data[position, :, :, 0]
        batch_x_test[batch_size:] = Acc_test_data[position, :, :, 1]
        batch_y_test = Acc_test_label[position]
        temp_dist_test, loss_test = sess.run([distance, loss], {X:batch_x_test, Y:batch_y_test, keep_prob:1.0, phase_train: False})
                                                               
        Total_dist_test.extend(temp_dist_test)
        Total_loss_test.append(loss_test)
                
    batch_x_test = np.empty(shape=(batch_size*2, num_ROI, crop_size))
    batch_y_test = np.empty(shape=(batch_size))
    batch_x_test[:batch_size] = Acc_test_data[100-batch_size:100, :, :, 0]
    batch_x_test[batch_size:] = Acc_test_data[100-batch_size:100, :, :, 1]
    batch_y_test = Acc_test_label[100-batch_size:100]
    temp_dist_test, loss_test = sess.run([distance, loss], {X:batch_x_test, Y:batch_y_test, keep_prob:1.0, phase_train: False})
    
    Total_dist_test.extend(temp_dist_test[12:]) 
    Total_loss_test.append(loss_test)
            
    ans = np.argmin(Total_dist_test)

    if ans == 0:
        count += 1
                
print("Total Test Accuracy: %4f" % (count/100.0))
print("Total Test Loss: %4f" % (np.mean(Total_loss_test)))