In [1]:
from __future__ import division, print_function
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets
import tensorflow as tf
import math
from tensorflow.python.framework.ops import reset_default_graph

from sklearn.utils import shuffle

In [2]:
# Load data (download if you haven't already)
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('MNIST_data', 
                                       one_hot=True,   # Convert the labels into one hot encoding
                                       dtype='float32', # rescale images to `[0, 1]`
                                       reshape=False, # Don't flatten the images to vectors
                                      )

gpu_opts = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)


Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz


In [3]:
from tensorflow import layers
from tensorflow.contrib.layers import fully_connected, convolution2d, convolution2d_transpose, batch_norm, max_pool2d, dropout
from tensorflow.python.ops.nn import relu, elu, relu6, sigmoid, tanh, softmax, softplus

In [19]:
# reset graph
reset_default_graph()

# -- THE MODEL --#
num_channels = 1; #Black and white for MNIST
num_classes = 1
k = 16;
height = width = 28

# Layer definitions
def layer(x, units):
    with tf.name_scope('layer_' + str(units)):
        x = fully_connected(x, num_outputs=units, activation_fn=relu,
                             normalizer_fn=batch_norm)
        x = convolution2d(x, num_outputs=units, kernel_size=(3, 3),
                                 stride=1)
        return dropout(x, is_training=is_training_pl)
    
def dense_block(x, num_layers):
    with tf.name_scope('dense_' + str(num_layers)):
        for i in range(num_layers):
            layer_output = layer(x, k)
            x = tf.concat([x, layer_output], axis=-1)
            if i == 0:
                res = layer_output
            else:
                res = tf.concat([res, layer_output], axis=-1)
        return res
    

def transition_up(x, units):
    return convolution2d_transpose(x, num_outputs=units, kernel_size=(3, 3), stride=2)
    
    
def transition_down(x, units):
    with tf.name_scope('transition_down_' + str(units)):
        x = batch_norm(x) #Batch norm should be included in fully_connected layer below
        x = relu(x)
        x = convolution2d(x, num_outputs=units, kernel_size=(1, 1),
                             stride=1)
        x = dropout(x, is_training=is_training_pl)
        x = max_pool2d(x, kernel_size=(2, 2))
        return x

# - Tiramisu Architecture - #
# Input placeholder
x_pl = tf.placeholder(tf.float32, [None, height, width, num_channels], 'x_pl')
y_pl = tf.placeholder(tf.float32, [None, height, width, num_classes], 'y_pl')
is_training_pl = tf.placeholder(tf.bool, name="is-training_pl")
print('x_pl', x_pl.shape)
print('y_pl', y_pl.shape)

def upsample(x, skip, num_dense, skip_up=False):
    x = transition_up(x, x.shape[-1].value)
    if skip_up:
        skip = tf.concat([x, skip], axis=-1)
    x = tf.concat([x, skip], axis=-1)
    x = dense_block(x, num_dense)
    print('DB ({} layers) + TU'.format(num_dense), '\t', x.shape)
    return x

def downsample(x, num_dense):
    skip = dense_block(x, num_dense)
    skip = tf.concat([x, skip], axis=-1)
    x = transition_down(skip, num_dense*k + x.shape[-1].value)
    print('DB ({} layers) + TD'.format(num_dense), '\t', x.shape)
    return x, skip

with tf.name_scope('tiramisu'):
    # DOWN SAMPLING
    x = convolution2d(x_pl, num_outputs=k, kernel_size=(3, 3),
                             stride=1, scope="pre-convolution")
    print('pre_conv', '\t\t', x.shape)
    
    x, skip1 = downsample(x, 4)
    x, skip2 = downsample(x, 5)

    # BOTTLENECK
    x = dense_block(x, 15)
    print('Bottleneck (15 layers)', '\t', x.shape)

    # UPSAMPLING
    x = upsample(x, skip2, 5)
    x = upsample(x, skip1, 4, skip_up=True)

    # Output layers
    x = convolution2d(x, num_outputs=num_classes, kernel_size=(1, 1),
                             stride=1, scope="post-convolution")
    print('post-convolution', '\t', x.shape)
    y = fully_connected(x, num_outputs=num_classes, activation_fn=softmax, scope="SoftMax")
    print('SoftMax output', '\t\t', y.shape)

print("Model built")

x_pl (?, 28, 28, 1)
y_pl (?, 28, 28, 1)
pre_conv 		 (?, 28, 28, 16)
DB (4 layers) + TD 	 (?, 14, 14, 80)
DB (5 layers) + TD 	 (?, 7, 7, 160)
Bottleneck (15 layers) 	 (?, 7, 7, 240)
DB (5 layers) + TU 	 (?, 14, 14, 80)
DB (4 layers) + TU 	 (?, 28, 28, 64)
post-convolution 	 (?, 28, 28, 1)
SoftMax output 		 (?, 28, 28, 1)
Model built


In [17]:
with tf.variable_scope('loss'):
    # computing cross entropy per sample
    cross_entropy = -tf.reduce_sum(y_pl * tf.log(y+1e-8), reduction_indices=[1])

    # averaging over samples
    cross_entropy = tf.reduce_mean(cross_entropy)

    
with tf.variable_scope('training'):
    # defining our optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

    # applying the gradients
    train_op = optimizer.minimize(cross_entropy)

    
with tf.variable_scope('performance'):
    # making a one-hot encoded vector of correct (1) and incorrect (0) predictions
    correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pl, axis=1))

    # averaging the one-hot encoded vector
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [None]:
#Test the forward pass
x_batch, y_batch = mnist_data.train.next_batch(32)

with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
    sess.run(tf.global_variables_initializer())
    y_pred = sess.run(fetches=y, feed_dict={x_pl: x_batch, is_training_pl: True})

assert y_pred.shape == x_batch.shape, "ERROR the output shape is not as expected!" \
        + " Output shape should be " + str(x_batch.shape) + ' but was ' + str(y_pred.shape)

print('Forward pass successful!')

Forward pass successful!


In [None]:
#Training Loop
batch_size = 10
max_epochs = 10


valid_loss, valid_accuracy = [], []
train_loss, train_accuracy = [], []
test_loss, test_accuracy = [], []


with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
    summary_writer = tf.summary.FileWriter('./logs', graph=sess.graph)
    sess.run(tf.global_variables_initializer())
    print('Begin training loop')

    i = 0
    try:
        while mnist_data.train.epochs_completed < max_epochs:
            _train_loss, _train_accuracy = [], []
            
            print("epoch", mnist_data.train.epochs_completed, i)
            i += 1
            ## Run train op
            x_batch, y_batch = mnist_data.train.next_batch(batch_size)
            fetches_train = [train_op, cross_entropy, accuracy]
            feed_dict_train = {x_pl: x_batch, y_pl: x_batch, is_training_pl: True}
            _, _loss, _acc = sess.run(fetches_train, feed_dict_train)
            
            _train_loss.append(_loss)
            _train_accuracy.append(_acc)
            

            ## Compute validation loss and accuracy
            if mnist_data.train.epochs_completed % 1 == 0 \
                    and mnist_data.train._index_in_epoch <= batch_size:
                print("validation")
                train_loss.append(np.mean(_train_loss))
                train_accuracy.append(np.mean(_train_accuracy))

                fetches_valid = [cross_entropy, accuracy]
                
                images = mnist_data.validation.images
                feed_dict_valid = {x_pl: images, y_pl: images, is_training_pl: False}
                _loss, _acc = sess.run(fetches_valid, feed_dict_valid)
                
                valid_loss.append(_loss)
                valid_accuracy.append(_acc)
                print("Epoch {} : Train Loss {:6.3f}, Train acc {:6.3f},  Valid loss {:6.3f},  Valid acc {:6.3f}".format(
                    mnist_data.train.epochs_completed, train_loss[-1], train_accuracy[-1], valid_loss[-1], valid_accuracy[-1]))
        
        
        test_epoch = mnist_data.test.epochs_completed
        while mnist_data.test.epochs_completed == test_epoch:
            x_batch, y_batch = mnist_data.test.next_batch(batch_size)
            feed_dict_test = {x_pl: x_batch, y_pl: x_batch}
            _loss, _acc = sess.run(fetches_valid, feed_dict_test)
            test_loss.append(_loss)
            test_accuracy.append(_acc)
        print('Test Loss {:6.3f}, Test acc {:6.3f}'.format(
                    np.mean(test_loss), np.mean(test_accuracy)))


    except KeyboardInterrupt:
        pass

Begin training loop
epoch 0 0
epoch 0 1
epoch 0 2
epoch 0 3
epoch 0 4
epoch 0 5
epoch 0 6
epoch 0 7
epoch 0 8
epoch 0 9
epoch 0 10
epoch 0 11
epoch 0 12
epoch 0 13
epoch 0 14
epoch 0 15
epoch 0 16
epoch 0 17
epoch 0 18
epoch 0 19
epoch 0 20
epoch 0 21
epoch 0 22
epoch 0 23
epoch 0 24
epoch 0 25
epoch 0 26
epoch 0 27
epoch 0 28
epoch 0 29
epoch 0 30
epoch 0 31
epoch 0 32
epoch 0 33
epoch 0 34
epoch 0 35
epoch 0 36
epoch 0 37
epoch 0 38
epoch 0 39
epoch 0 40
epoch 0 41
epoch 0 42
epoch 0 43
epoch 0 44
epoch 0 45
epoch 0 46
epoch 0 47
epoch 0 48
epoch 0 49
epoch 0 50
epoch 0 51
epoch 0 52
epoch 0 53
epoch 0 54
epoch 0 55
epoch 0 56
epoch 0 57
epoch 0 58
epoch 0 59
epoch 0 60
epoch 0 61
epoch 0 62
epoch 0 63
epoch 0 64
epoch 0 65
epoch 0 66
epoch 0 67
epoch 0 68
epoch 0 69
epoch 0 70
epoch 0 71
epoch 0 72
epoch 0 73
epoch 0 74
epoch 0 75
epoch 0 76
epoch 0 77
epoch 0 78
epoch 0 79
epoch 0 80
epoch 0 81
epoch 0 82
epoch 0 83
epoch 0 84
epoch 0 85
epoch 0 86
epoch 0 87
epoch 0 88
epoch 0 89


epoch 0 691
epoch 0 692
epoch 0 693
epoch 0 694
epoch 0 695
epoch 0 696
epoch 0 697
epoch 0 698
epoch 0 699
epoch 0 700
epoch 0 701
epoch 0 702
epoch 0 703
epoch 0 704
epoch 0 705
epoch 0 706
epoch 0 707
epoch 0 708
epoch 0 709
epoch 0 710
epoch 0 711
epoch 0 712
epoch 0 713
epoch 0 714
epoch 0 715
epoch 0 716
epoch 0 717
epoch 0 718
epoch 0 719
epoch 0 720
epoch 0 721
epoch 0 722
epoch 0 723
epoch 0 724
epoch 0 725
epoch 0 726
epoch 0 727
epoch 0 728
epoch 0 729
epoch 0 730
epoch 0 731
epoch 0 732
epoch 0 733
epoch 0 734
epoch 0 735
epoch 0 736
epoch 0 737
epoch 0 738
epoch 0 739
epoch 0 740
epoch 0 741
epoch 0 742
epoch 0 743
epoch 0 744
epoch 0 745
epoch 0 746
epoch 0 747
epoch 0 748
epoch 0 749
epoch 0 750
epoch 0 751
epoch 0 752
epoch 0 753
epoch 0 754
epoch 0 755
epoch 0 756
epoch 0 757
epoch 0 758
epoch 0 759
epoch 0 760
epoch 0 761
epoch 0 762
epoch 0 763
epoch 0 764
epoch 0 765
epoch 0 766
epoch 0 767
epoch 0 768
epoch 0 769
epoch 0 770
epoch 0 771
epoch 0 772
epoch 0 773
epoc

epoch 0 1345
epoch 0 1346
epoch 0 1347
epoch 0 1348
epoch 0 1349
epoch 0 1350
epoch 0 1351
epoch 0 1352
epoch 0 1353
epoch 0 1354
epoch 0 1355
epoch 0 1356
epoch 0 1357
epoch 0 1358
epoch 0 1359
epoch 0 1360
epoch 0 1361
epoch 0 1362
epoch 0 1363
epoch 0 1364
epoch 0 1365
epoch 0 1366
epoch 0 1367
epoch 0 1368
epoch 0 1369
epoch 0 1370
epoch 0 1371
epoch 0 1372
epoch 0 1373
epoch 0 1374
epoch 0 1375
epoch 0 1376
epoch 0 1377
epoch 0 1378
epoch 0 1379
epoch 0 1380
epoch 0 1381
epoch 0 1382
epoch 0 1383
epoch 0 1384
epoch 0 1385
