# MNIST Odd-Even classification with TensorFlow (CNN)

## Data loading

Let's load the dataset using the Keras API implementation included in TensorFlow.

As we only want to classify the images between **odd** and **even** numbers, we will map labels to *1* if odd or *0* if even.

In [None]:
import tensorflow as tf
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Keep original test numbers for displaying misslcasificationsx
y_numbers = y_test

# Set Even or Odd labels for each sample
y_train = np.array(list(map(lambda x: x%2, y_train)), dtype=np.float32)
y_test = np.array(list(map(lambda x: x%2, y_test)), dtype=np.float32)

# Normalize images dividing by max pixel value (255)
x_train = (x_train / 255.0).astype(np.float32)
x_test = (x_test / 255.0).astype(np.float32)

# Reshape to TF API (#img, rows, cols, channels)
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

## CNN Model

### Neural Network architecture

We will use the classical architecture for a CNN (Convolution + Pooling > Convolution + Pooling > Flatten > Dense > Dense)

Instead of the common Convolution > ReLU > Pooling, we can apply the ReLU after Pooling step beacuse ReLU(MaxPooling(x)) == MaxPooling(ReLU(x)) and it will be more efficient, as the tensor size will be 75% smaller.

In [None]:
class CNNClassifier:
    def __init__(self, train_data=None):
        data, labels = train_data

        # labels = self._transform_labels(labels)
        # data = self._flatten_input(data)
        
        self.train_data = (data, labels)

        self.assemble_graph()

        self._open_session()
        
        self._draw_graph_tensorboard()

        if train_data:
            self.train()     

   
    def assemble_graph(self, learning_rate = 0.02):
        #### Placeholders/Variables/Constants
        
        self.X = tf.placeholder(name='X', dtype=tf.float32, shape=(None, 28, 28, 1))
        self.Y = tf.placeholder(name='Y', dtype=tf.float32, shape=(None,))
        self.L = tf.reshape(self.Y, shape=(-1, 1))
        
        #### Layers
        
        # 1st Convolutional Layer 3x5x5 Kernel + MaxPooling 2x2 + ReLu
        conv1 = tf.layers.conv2d(self.X, filters=3, kernel_size=(5, 5), padding='same')
        pool1 = tf.layers.max_pooling2d(conv1, pool_size=(2, 2), strides=2)
        relu1 = tf.nn.relu(pool1)
        
        # 2nd Convolutional Layer 3x5x5 Kernel + MaxPooling 2x2 + ReLu
        conv2 = tf.layers.conv2d(relu1, filters=3, kernel_size=(5, 5), padding='same')
        pool2 = tf.layers.max_pooling2d(conv2, pool_size=(2, 2), strides=2)
        relu2 = tf.nn.relu(pool2)
        
        # Flatten data into (-1, 7*7*3 tensor)
        flatten = tf.reshape(relu2, shape=(-1, 7*7*3))
        
        # Dense Layer 1 with 4 neurons
        dense1 = tf.layers.dense(flatten, units=4, activation=tf.nn.relu)
        
        # Dense Layer 2 with 1 neurons (our final output)
        output_layer = tf.layers.dense(dense1, units=1, activation=None)
        
        #### Optimizer, Loss, Predictions and Accuracy
        
        # Cross-Entropy as loss function
        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.L, logits=output_layer)
        self.cost = tf.reduce_mean(cross_entropy)
        self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(self.cost)
        
        # Apply sigmoid and round to get the predicted class
        self.predicted = tf.nn.sigmoid(output_layer)
        correct_pred = tf.equal(tf.round(self.predicted), self.L)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

 
    def train(self, epochs=20, minibatch_size=256):
        for epoch in range(1, epochs + 1):
            for batch_idx, (features, labels) in enumerate(self._create_minibatches(minibatch_size)):
                d = { self.X: features, self.Y: labels }
                loss, _, acc = self.sess.run([self.cost, self.optimizer, self.accuracy], feed_dict=d)
                
                if (batch_idx + 1) % 50 == 0:
                    n_treated = batch_idx * len(features)
                    p_treated = 100.0 * batch_idx / len(self.train_data[0])
                    msg = 'Epoch: {:>2} [{:>5}/{:>5} ({:2.0f}%)]\tLoss: {:2.6f}\tAccuracy: {:5.2f}%'
                    print(msg.format(epoch, n_treated, len(features), p_treated, loss, acc * 100))
            

    def predict(self, data):
        predictions = []
        for features in data:
            d = { self.X: features }
            pred = self.sess.run(self.predicted, feed_dict=d)
            predictions.append(pred)
        return predictions

    def _create_minibatches(self, minibatch_size):
        pos = 0

        data, labels = self.train_data
        n_samples = len(labels)

        batches = []
        while pos + minibatch_size < n_samples:
            batches.append((data[pos:pos+minibatch_size,:], labels[pos:pos+minibatch_size]))
            pos += minibatch_size

        if pos < n_samples:
            batches.append((data[pos:n_samples,:], labels[pos:n_samples]))

        return batches

    # def _transform_labels(self, labels):
    #    raise NotImplementedError()
        

    # def _flatten_input(self, data):
    #     raise NotImplementedError()

    def _open_session(self):
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        
    def _draw_graph_tensorboard(self):
        writer = tf.summary.FileWriter('./tmp', self.sess.graph)
        writer.close()

In [None]:
svm = CNNClassifier((x_train, y_train))
predictions = svm.predict(x_test)
print('Testing score f1: {}'.format(f1_score(y_test, predictions)))

## Missclassifications

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

%matplotlib inline

def plot_missclasifications(preds):
    count = 0
    plt.figure(figsize=(10, 10))
    for i, pred in enumerate(preds):
        if pred != np.argmax(y_test[i]):
            msg = '{} ({})'
            msg = msg.format('Even' if preds[i] else 'Odd', str(y_numbers[i]))
        
            plt.subplot(2, 5, count + 1)
            plt.title(msg)
            plt.axis('off')
            plt.imshow(x_test[i].reshape(28, 28), cmap=cm.binary)
            count += 1
            if count == 10:
                break

Some missclassification examples:

In [None]:
plot_missclasifications(predictions)