# MNIST-Azure
## Train Model
### By: Sebastian Goodfellow

In [None]:
# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Import 3rd party libraries
import os
import cv2
import sys
import json
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt

# Import local Libraries
sys.path.insert(0, '/home/sebastiangoodfellow/Documents/Code/mnist-azure')
from mnistazure.config import DATA_PATH, TENSORBOARD_PATH
from mnistazure.generator import DataGenerator
from mnistazure.graph import Graph
from mnistazure.network import Network

# Set Model Parameters

In [None]:
# Image shape
image_shape = (28, 28, 1)

# Maximum number of checkpoints to keep
max_to_keep = 1

# Random seed
seed = 0

# Number of unique labels
num_labels = 10

# Test Data Generator

In [None]:
# Initialize generator
generator = DataGenerator(path=DATA_PATH, mode='train', shape=image_shape, batch_size=32, 
                          prefetch_buffer=100, seed=0, num_parallel_calls=24)

# View dataset
generator.dataset

# Test Graph

In [None]:
# Initialize network
network = Network(height=image_shape[0], width=image_shape[1], 
                  channels=image_shape[2], num_labels=num_labels, seed=0)

# Initialize graph
graph = Graph(network=network, save_path=TENSORBOARD_PATH, 
              data_path=DATA_PATH, max_to_keep=max_to_keep)

# Learning rate
learning_rate = 1e-3

# Number of epochs
epochs = 5

# Batch size
batch_size = 128

with tf.Session() as sess:
    
    # Initialize variables
    sess.run(graph.init_global)
    
    # Get number of training batches
    num_train_batches = graph.generator_train.num_batches.eval(
        feed_dict={graph.batch_size: batch_size})
    
    # Get number of batch steps per epoch
    steps_per_epoch = int(np.ceil(num_train_batches / 1))
    
    # Get mode handle for training
    handle_train = sess.run(graph.generator_train.iterator.string_handle())
    
    # Initialize the train dataset iterator at the beginning of each epoch
    sess.run(fetches=[graph.generator_train.iterator.initializer],
             feed_dict={graph.batch_size: batch_size})
    
    # Loop through epochs
    for epoch in range(epochs):

        # Initialize metrics
        sess.run(fetches=[graph.init_metrics_op])

        # Loop through train dataset batches
        for batch in range(steps_per_epoch):
            
            loss, accuracy, _, _, _, _ = sess.run(fetches=[graph.loss, graph.accuracy, graph.train_op, 
                                        graph.update_metrics_op, graph.train_summary_metrics_op, 
                                        graph.global_step],
                               feed_dict={graph.batch_size: batch_size, graph.is_training: True,
                                          graph.learning_rate: learning_rate,
                                          graph.mode_handle: handle_train})
            print(loss, accuracy)

        # Initialize the train dataset iterator at the end of each epoch
        sess.run(fetches=[graph.generator_train.iterator.initializer],
                 feed_dict={graph.batch_size: batch_size})