In [1]:
"""

@author: daehyeon
"""

from __future__ import division, print_function, absolute_import

In [2]:
import numpy as np
import tensorflow as tf
from acrnn1 import acrnn
from extract_mel import next_batch, test_dataset
import cPickle
import matplotlib.pyplot as plt

from sklearn.metrics import precision_score as precision
from sklearn.metrics import recall_score as recall
from sklearn.metrics import confusion_matrix as confusion
import os

In [3]:
tf.app.flags.DEFINE_integer('num_epoch', 30000, 'The number of epoches for training.')
tf.app.flags.DEFINE_integer('num_classes', 7, 'The number of emotion classes.')
tf.app.flags.DEFINE_integer('batch_size', 45, 'The number of samples in each batch.')
tf.app.flags.DEFINE_boolean('is_adam', True, 'whether to use adam optimizer.')
tf.app.flags.DEFINE_float('learning_rate', 0.00001, 'learning rate of Adam optimizer')
tf.app.flags.DEFINE_float('dropout_keep_prob', 1, 'the prob of every unit keep in dropout layer')
tf.app.flags.DEFINE_integer('image_height', 300, 'image height')
tf.app.flags.DEFINE_integer('image_width', 40, 'image width')
tf.app.flags.DEFINE_integer('image_channel', 3, 'image channels as input')
tf.app.flags.DEFINE_integer('val_iter', 5, 'the number of validation test')

tf.app.flags.DEFINE_string('checkpoint', './checkpoint/', 'the checkpoint dir')
tf.app.flags.DEFINE_string('model_name', 'model4.ckpt', 'model name')
tf.app.flags.DEFINE_string('f', '', 'kernel')

FLAGS = tf.app.flags.FLAGS


In [4]:
history_acc = []
history_cost = []

def train():
    best_valid_uw = 0
    best_valid_ac = 0
    
    X = tf.placeholder(tf.float32, shape=[None, FLAGS.image_height,FLAGS.image_width,FLAGS.image_channel])
    Y = tf.placeholder(tf.int32, shape=[None, FLAGS.num_classes])
    
    is_training = tf.placeholder(tf.bool)
    lr = tf.placeholder(tf.float32) # learning rate
    
    keep_prob = tf.placeholder(tf.float32) # drop out rate
    
    Ylogits = acrnn(X, is_training=is_training, dropout_keep_prob=keep_prob) # result
    
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels =  Y, logits =  Ylogits)

    cost = tf.reduce_mean(cross_entropy)
    var_trainable_op = tf.trainable_variables()
    if FLAGS.is_adam:
        # not apply gradient clipping
        train_op = tf.train.AdamOptimizer(lr).minimize(cost)            
    else:
        # apply gradient clipping
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, var_trainable_op), 5)
        opti = tf.train.AdamOptimizer(lr)
        train_op = opti.apply_gradients(zip(grads, var_trainable_op))
    
    correct_pred = tf.equal(tf.argmax(Ylogits, 1), tf.argmax(Y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    saver=tf.train.Saver(tf.global_variables())
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        for i in range(FLAGS.num_epoch):
            data, label, _ = next_batch(FLAGS.batch_size, train=True)
            
            #train
            [_,tcost,tracc] = sess.run([train_op,cost,accuracy], feed_dict={X:data, Y:label,
                                            is_training:True, keep_prob:FLAGS.dropout_keep_prob, lr:FLAGS.learning_rate})
            
            # ploting. check and stop when it overfits.
            if(i%300 == 0):
                plt.plot(range(0,len(history_acc)), history_acc)
                plt.savefig("./checkpoint/acc.png")
                plt.plot(range(0,len(history_cost)), history_cost)
                plt.savefig("./checkpoint/cost.png")
                
            #validation
            if( i% 5 == 0):
                whole_y_valid = []
                whole_y_pred = []
                
                cost_valid = 0
                valid_data_size = 0
                for v in range(FLAGS.val_iter):
                    valid_data, valid_label, pernums_valid = next_batch(FLAGS.batch_size, train=False)
                    valid_data_size += valid_label.shape[0]
                    loss, y_pred_valid = sess.run([cross_entropy,Ylogits],feed_dict = {X:valid_data, Y:valid_label,is_training:False, keep_prob:1})
                    
                    index = 0
                    y_valid_list = []
                    valid_label_list = []
                    for s in range(pernums_valid.shape[0]):
                        y_valid = np.max(y_pred_valid[index : index+ pernums_valid[s]], 0)
                        
                        y_valid_list.append(y_valid)
                        valid_label_list.append(valid_label[index+pernums_valid[s]-1])
                        
                        index = index+ pernums_valid[s]
                    
                    cost_valid += np.sum(loss) # cost sum of validation
                    
                    whole_y_valid += valid_label_list
                    whole_y_pred += y_valid_list
                    
                cost_valid = cost_valid/valid_data_size
                
                valid_acc = np.equal(np.argmax(whole_y_valid,1),np.argmax(whole_y_pred,1))
                valid_acc = np.mean(valid_acc)
                
                history_acc.append(valid_acc)
                history_cost.append(cost_valid)
                
                valid_conf = confusion(np.argmax(whole_y_valid, 1),np.argmax(whole_y_pred,1))
                
                if valid_acc > best_valid_ac:
                    best_valid_ac = valid_acc
                    best_valid_conf = valid_conf
                    saver.save(sess, os.path.join(FLAGS.checkpoint, FLAGS.model_name), global_step = i+1)
                
                print ("*****************************************************************")
                print ("Epoch: %05d" %(i+1))
                print ("Training cost: %2.3g" %tcost)   
                print ("Training accuracy: %3.4g" %tracc) 
                print ("Valid acc: %3.4g" %valid_acc)
                print ("Best valid acc: %3.4g" %best_valid_ac)
                print ("")
                print ("Valid cost: %2.3g" %cost_valid)
                #print ('Valid Confusion Matrix:["ang","sad","hap","neu"]')
                #print (valid_conf)
                #print ('Best Valid Confusion Matrix:["ang","sad","hap","neu"]')
                #print (best_valid_conf)
                print ("*****************************************************************" )
                
                
                    

In [None]:
train()

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

*****************************************************************
Epoch: 00001
Training cost: 2.06
Training accuracy: 0.1237
Valid acc: 0.1022
Best valid acc: 0.1022

Valid cost:  2
*****************************************************************
*****************************************************************
Epoch: 00006
Training cost: 1.94
Training accuracy: 0.1296
Valid acc: 0.1244
Best valid acc: 0.1244

Valid cost: 1.96
*****************************************************************
*****************************************************************
Epoch: 00011
Training cost: 1.96
Training accuracy: 0.1509
Valid acc: 0.16
Best valid acc: 0.16

Valid cost: 1.94
*****************************************************************
*****************************************************************
E

*****************************************************************
Epoch: 00161
Training cost: 1.93
Training accuracy: 0.1616
Valid acc: 0.2667
Best valid acc: 0.2667

Valid cost: 1.9
*****************************************************************
*****************************************************************
Epoch: 00166
Training cost: 1.87
Training accuracy: 0.2736
Valid acc: 0.2356
Best valid acc: 0.2667

Valid cost: 1.91
*****************************************************************
*****************************************************************
Epoch: 00171
Training cost: 1.91
Training accuracy: 0.2056
Valid acc: 0.1956
Best valid acc: 0.2667

Valid cost: 1.91
*****************************************************************
*****************************************************************
Epoch: 00176
Training cost: 1.85
Training accuracy: 0.2925
Valid acc: 0.2711
Best valid acc: 0.2711

Valid cost: 1.91
*****************************************************************
*

*****************************************************************
Epoch: 00326
Training cost: 1.9
Training accuracy: 0.2286
Valid acc: 0.2844
Best valid acc: 0.32

Valid cost: 1.88
*****************************************************************
*****************************************************************
Epoch: 00331
Training cost: 1.87
Training accuracy: 0.1714
Valid acc: 0.2756
Best valid acc: 0.32

Valid cost: 1.87
*****************************************************************
*****************************************************************
Epoch: 00336
Training cost: 1.87
Training accuracy: 0.2381
Valid acc: 0.2356
Best valid acc: 0.32

Valid cost: 1.91
*****************************************************************
*****************************************************************
Epoch: 00341
Training cost: 1.91
Training accuracy: 0.1754
Valid acc: 0.2489
Best valid acc: 0.32

Valid cost: 1.89
*****************************************************************
*********

*****************************************************************
Epoch: 00496
Training cost: 1.84
Training accuracy: 0.25
Valid acc: 0.2622
Best valid acc: 0.36

Valid cost: 1.89
*****************************************************************
*****************************************************************
Epoch: 00501
Training cost: 1.84
Training accuracy: 0.2569
Valid acc: 0.2844
Best valid acc: 0.36

Valid cost: 1.85
*****************************************************************
*****************************************************************
Epoch: 00506
Training cost: 1.86
Training accuracy: 0.2762
Valid acc: 0.2356
Best valid acc: 0.36

Valid cost: 1.9
*****************************************************************
*****************************************************************
Epoch: 00511
Training cost: 1.79
Training accuracy: 0.2963
Valid acc: 0.2578
Best valid acc: 0.36

Valid cost: 1.88
*****************************************************************
***********

*****************************************************************
Epoch: 00666
Training cost: 1.88
Training accuracy: 0.22
Valid acc: 0.3156
Best valid acc: 0.3778

Valid cost: 1.83
*****************************************************************
*****************************************************************
Epoch: 00671
Training cost: 1.87
Training accuracy: 0.1887
Valid acc: 0.3022
Best valid acc: 0.3778

Valid cost: 1.86
*****************************************************************
*****************************************************************
Epoch: 00676
Training cost: 1.82
Training accuracy: 0.2569
Valid acc: 0.32
Best valid acc: 0.3778

Valid cost: 1.82
*****************************************************************
*****************************************************************
Epoch: 00681
Training cost: 1.82
Training accuracy: 0.2323
Valid acc: 0.2889
Best valid acc: 0.3778

Valid cost: 1.86
*****************************************************************
****

*****************************************************************
Epoch: 00831
Training cost: 1.83
Training accuracy: 0.2772
Valid acc: 0.28
Best valid acc: 0.3778

Valid cost: 1.85
*****************************************************************
*****************************************************************
Epoch: 00836
Training cost: 1.76
Training accuracy: 0.3925
Valid acc: 0.3556
Best valid acc: 0.3778

Valid cost: 1.83
*****************************************************************
*****************************************************************
Epoch: 00841
Training cost: 1.84
Training accuracy: 0.2385
Valid acc: 0.2889
Best valid acc: 0.3778

Valid cost: 1.83
*****************************************************************
*****************************************************************
Epoch: 00846
Training cost: 1.83
Training accuracy: 0.2432
Valid acc: 0.2933
Best valid acc: 0.3778

Valid cost: 1.87
*****************************************************************
**

*****************************************************************
Epoch: 00996
Training cost: 1.84
Training accuracy: 0.2411
Valid acc: 0.2933
Best valid acc: 0.3778

Valid cost: 1.82
*****************************************************************
*****************************************************************
Epoch: 01001
Training cost: 1.82
Training accuracy: 0.1909
Valid acc: 0.2711
Best valid acc: 0.3778

Valid cost: 1.85
*****************************************************************
*****************************************************************
Epoch: 01006
Training cost: 1.81
Training accuracy: 0.2667
Valid acc: 0.2844
Best valid acc: 0.3778

Valid cost: 1.86
*****************************************************************
*****************************************************************
Epoch: 01011
Training cost: 1.81
Training accuracy: 0.3097
Valid acc: 0.2756
Best valid acc: 0.3778

Valid cost: 1.85
*****************************************************************
