In [1]:
!rm -rf ./mixnet_bin_10

In [2]:
%config IPCompleter.greedy=True
import numpy as np
import tensorflow as tf
import time

In [3]:
%run ./datagen.py
datagen, (x_train, y_train), (x_test, y_test) = data_preparation()

In [4]:
save_dir = './mixnet_bin_10/'
batch_size = 100
iterations = x_train.shape[0] // batch_size
epochs = 1200
old_acc = 0
start_lr = 2.
end_lr = 0.5
decay_rate = (end_lr / start_lr) ** (1 / epochs)
k = 10

In [5]:
%run ./binary_layer.py

# resnet layer
def res_layer(inputs, filter_num, filter_size, stride, is_train,
              binarized=False, batch_norm=True, activation=True):

    x = inputs

    if binarized:
        if batch_norm:
            x = tf.layers.batch_normalization(x, training=is_train)
        if activation:
            x = tf.square(x)
        x = conv2d(inputs=x, filters=filter_num, 
                   kernel_size=filter_size, strides=stride, padding='same')
    else:
        if batch_norm:
            x = tf.layers.batch_normalization(x, training=is_train)
        if activation:
            x = tf.square(x)
        x = tf.layers.conv2d(inputs=x, filters=filter_num, 
                             kernel_size=filter_size, strides=stride, padding='same')
        
    return x

In [6]:
def wide_resnet(inputs, k, is_train, binarized):

    with tf.variable_scope("1st_Conv"):
        x = tf.layers.conv2d(inputs=inputs, filters=16, 
                             kernel_size=3, strides=1, padding='same')
        x = tf.layers.batch_normalization(x, training=is_train)
        x = tf.nn.relu(x)
        tf.summary.histogram('activation', x)
    
    x_temp_0 = x
    
    with tf.variable_scope('ResBlock_%d_%d' % (1, 1)):
        
        with tf.variable_scope('conv1'):
            x = res_layer(x, 16*k, 3, 1, is_train, binarized=binarized, 
                          batch_norm=False, activation=False)
                    
        x = tf.layers.dropout(x, 0.1)
                
        with tf.variable_scope('conv2'):
            x = res_layer(x, 16*k, 3, 1, is_train, binarized=binarized)
    
        with tf.variable_scope('x_plus_shortcut'):
            shortcut = res_layer(x_temp_0, 16*k, 1, 1, is_train=False, 
                                 binarized=False, batch_norm=False, activation=False)
            x = x + shortcut 
            
        tf.summary.histogram('output', x)
        
    x_temp_1 = x
    
    with tf.variable_scope('ResBlock_%d_%d' % (2, 1)):
        
        with tf.variable_scope('conv1'):
            x = res_layer(x, 32*k, 3, 2, is_train, binarized=binarized)
                    
        x = tf.layers.dropout(x, 0.1)
                
        with tf.variable_scope('conv2'):
            x = res_layer(x, 32*k, 3, 1, is_train, binarized=binarized)
            
        with tf.variable_scope('x_plus_shortcut'):
            shortcut = res_layer(x_temp_0, 32*k, 1, 2, is_train=False, 
                                 binarized=False, batch_norm=False, activation=False)
            x = x + shortcut
            
            shortcut = res_layer(x_temp_1, 32*k, 1, 2, is_train=False, 
                                 binarized=False, batch_norm=False, activation=False)
            x = x + shortcut
            
        tf.summary.histogram('output', x)
        
    x_temp_2 = x    
    
    with tf.variable_scope('ResBlock_%d_%d' % (3, 1)):
        
        with tf.variable_scope('conv1'):
            x = res_layer(x, 64*k, 3, 2, is_train, binarized=binarized)
                    
        x = tf.layers.dropout(x, 0.1)
                
        with tf.variable_scope('conv2'):
            x = res_layer(x, 64*k, 3, 1, is_train, binarized=binarized)
            
        with tf.variable_scope('x_plus_shortcut'):
            
            shortcut = res_layer(x_temp_0, 64*k, 1, 4, is_train=False, 
                                     binarized=False, batch_norm=False, activation=False)
            x = x + shortcut

            shortcut = res_layer(x_temp_1, 64*k, 1, 4, is_train=False, 
                                     binarized=False, batch_norm=False, activation=False)
            x = x + shortcut

            shortcut = res_layer(x_temp_2, 64*k, 1, 2, is_train=False, 
                                     binarized=False, batch_norm=False, activation=False)
            x = x + shortcut
            
        tf.summary.histogram('output', x)

    with tf.variable_scope("AfterResBlock"):
        x = tf.layers.batch_normalization(x, training=is_train)
        x = tf.square(x)
        x = tf.layers.average_pooling2d(x, pool_size=8, strides=8, 
                                        padding='SAME', name='ave_pool')
        tf.summary.histogram('bn_relu_pooling', x)
    
    ######## current x.shape = (?, 1, 1, N) ##########

    with tf.variable_scope("Flatten"):
        x = tf.transpose(x, perm=[0, 3, 1, 2])
        x = tf.layers.flatten(x)

    with tf.variable_scope("Prediction"):
        pred = tf.layers.dense(x, units=10)
        tf.summary.histogram('prediction', pred)
        
    return pred

In [7]:
tf.reset_default_graph()

with tf.device('/GPU:7'):

    inputs = tf.placeholder(tf.float32, [None, 32, 32, 3], name='input')
    outputs = tf.placeholder(tf.float32, [None, 10], name='output')
    is_train = tf.placeholder(tf.bool, name='is_train')

    global_step = tf.Variable(0, trainable=False)

    l_r = tf.train.exponential_decay(
        start_lr, global_step, iterations, decay_rate, staircase=True)
    tf.summary.scalar('learning_rate', l_r)

    opt = tf.train.MomentumOptimizer(learning_rate=l_r, momentum=0.9)

    pred = wide_resnet(inputs, k, is_train, binarized=True)

    loss = tf.losses.softmax_cross_entropy(outputs, pred)
    grads = opt.compute_gradients(loss)

    ########################## LARS ##########################
    grads_norm = [(tf.norm(grad, ord=2), tf.norm(var, ord=2)) for grad, var in grads]
    eta = 1e-4
    local_lr = [tf.where(var_norm < 1e-3, 1e-3, eta*var_norm / (grad_norm + 1e-8)) 
                for grad_norm, var_norm in grads_norm]
    new_grads = [(local_lr[i]*grad, var) for i, (grad, var) in enumerate(grads)]
    ##########################################################

    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(outputs, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    for grad, var in new_grads:
        if grad is not None:
            tf.summary.histogram(var.name.split(":")[0] + '/gradients', grad)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    train_op = opt.apply_gradients(new_grads, global_step=global_step)

    kernel_vars = []

    for i in tf.trainable_variables():
        if 'bin/kernel' in i.name:
            kernel_vars.append(i)

    with tf.control_dependencies(update_ops):
        with tf.control_dependencies([train_op]):
            kernel_clip_op = [tf.clip_by_value(var, -1, 1) for var in kernel_vars]

saver = tf.train.Saver(tf.global_variables())

def add_hist(train_vars):
    for i in train_vars:
        name = i.name.split(":")[0] + '/value'
        value = i.value()
        tf.summary.histogram(name, value)

add_hist(tf.trainable_variables())

tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
merged = tf.summary.merge_all()

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Use keras.layers.batch_normalization instead.
Instructions for updating:
Use keras.layers.dropout instead.
Instructions for updating:
Use keras.layers.average_pooling2d instead.
Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use tf.cast instead.


In [None]:
config = tf.ConfigProto(allow_soft_placement=True,
                        log_device_placement=True)
config.gpu_options.allow_growth = True

with tf.Session(config=config) as sess:

    print('*****************Training Start!*****************')
    train_writer = tf.summary.FileWriter(save_dir+'train', sess.graph)
    sess.run(tf.global_variables_initializer())

    for m in range(epochs):
        start = time.time()
        batch_gen = datagen.flow(
            x_train, y_train, batch_size=batch_size)

        for i in range(iterations):
            x_batch, y_batch = next(batch_gen)
            _, loss_train = sess.run([kernel_clip_op, loss], 
                                     {inputs: x_batch, outputs: y_batch, is_train: True})
        
        summary = sess.run(merged, {inputs: x_batch, outputs: y_batch, is_train: False})
        train_writer.add_summary(summary, m*iterations + i + 1)

        val_accs = []
        for i in range(5000//(batch_size*5)):
            val_acc = sess.run(accuracy, {inputs: x_test[i*batch_size*5: (i+1)*batch_size*5],
                                          outputs: y_test[i*batch_size*5: (i+1)*batch_size*5],
                                          is_train: False})
            val_accs.append(val_acc)

        if np.mean(val_accs) > old_acc:
            old_acc = np.mean(val_accs)
            saver.save(sess, save_dir+'cifar10.ckpt')

        end = time.time()
        print('Epoch: {}'.format(m + 1),
              'Train_loss: {:.3f}'.format(loss_train),
              'Val_acc: {:.3f}'.format(np.mean(val_accs)),
              'Time consumed: {:.4f} s'.format(end - start))

    print('*****************Training End!*****************')

*****************Training Start!*****************
Epoch: 1 Train_loss: 2.335 Val_acc: 0.148 Time consumed: 67.0520 s
Epoch: 2 Train_loss: 2.811 Val_acc: 0.134 Time consumed: 59.6729 s
Epoch: 3 Train_loss: 3.095 Val_acc: 0.141 Time consumed: 59.8925 s
Epoch: 4 Train_loss: 2.192 Val_acc: 0.155 Time consumed: 59.8577 s
Epoch: 5 Train_loss: 2.433 Val_acc: 0.185 Time consumed: 60.3968 s
Epoch: 6 Train_loss: 2.421 Val_acc: 0.190 Time consumed: 60.2005 s
Epoch: 7 Train_loss: 2.340 Val_acc: 0.225 Time consumed: 60.6280 s
Epoch: 8 Train_loss: 2.238 Val_acc: 0.254 Time consumed: 60.6137 s
Epoch: 9 Train_loss: 2.683 Val_acc: 0.258 Time consumed: 60.5375 s
Epoch: 10 Train_loss: 2.047 Val_acc: 0.255 Time consumed: 60.0600 s
Epoch: 11 Train_loss: 2.379 Val_acc: 0.257 Time consumed: 59.7153 s
Epoch: 12 Train_loss: 2.032 Val_acc: 0.294 Time consumed: 60.2921 s
Epoch: 13 Train_loss: 2.078 Val_acc: 0.305 Time consumed: 60.5159 s
Epoch: 14 Train_loss: 2.225 Val_acc: 0.327 Time consumed: 59.8766 s
Epoch: 