# ResNet Train

This notebook uses the ResNet model implemented in `resnet.py`.

The dataset used is the MNIST dataset. The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.

In [1]:
import tensorflow as tf
import numpy as np

from resnet import resnet

  from ._conv import register_converters as _register_converters


In [2]:
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
# Creating placeholders
x = tf.placeholder(tf.float32, shape = [None, 784])
y_true = tf.placeholder(tf.float32, shape = [None, 10])

In [5]:
x_image = tf.reshape(x, [-1, 28, 28, 1])

In [6]:
# Create the ResNet model
model = resnet(x = x_image, n = 20, num_classes = 10)

#define activation of last layer as score
score = model.out

In [7]:
# Loss Function
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y_true, logits = score))

In [8]:
# Optimizer
optimizer = tf.train.AdamOptimizer(learning_rate = 0.01)
train = optimizer.minimize(cross_entropy)

In [9]:
init = tf.global_variables_initializer()

# To measure accuracy
correct_prediction = tf.equal(tf.argmax(score, 1), tf.argmax(y_true, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [10]:
epochs = 10
batch_size = 128

with tf.Session() as sess:
    
    sess.run(init)
    
    for j in range(epochs):
        for i in range(0, 60000, batch_size):
        
            batch_x, batch_y = mnist.train.next_batch(batch_size)                
            sess.run(train, feed_dict = {x : batch_x, y_true : batch_y})
            
        
        test_accuracy = accuracy.eval(feed_dict = {x : mnist.test.images[: 5000], 
                                                           y_true : mnist.test.labels[: 5000]})
        print('After %d epochs, accuracy: %g' % (j + 1, test_accuracy)) 
                         

After 1 epochs, accuracy: 0.6014
After 2 epochs, accuracy: 0.651
After 3 epochs, accuracy: 0.6856
After 4 epochs, accuracy: 0.7222
After 5 epochs, accuracy: 0.7268
After 6 epochs, accuracy: 0.7498
After 7 epochs, accuracy: 0.7648
After 8 epochs, accuracy: 0.7754
After 9 epochs, accuracy: 0.7754
After 10 epochs, accuracy: 0.7904
