<a href="https://colab.research.google.com/github/bbi-yggy/keras-jumpnet/blob/main/jumpnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing JumpNet and training with CIFAR-10
_Deep Learning for Computer Vision (Cohort #5)_  
_Yossarian King / Blackbird Interactive / October 2020_

The JumpNet implementation is based on the description from [keras-idiomatic-programmer/zoo/jumpnet](https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/tree/master/zoo/jumpnet).

The implementation is in the idiomatic style - stem > learner > classifier.
The learner consists of groups, each composed of blocks, as illustrated here:

![JumpNet diagram](https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/blob/master/zoo/jumpnet/macro.jpg?raw=true)

For clarity of implementation, the code has been built as a class with a fluent API, allowing all model metaparameters to be fed into fluent method calls as named parameters. This makes the model structure explicit and the parameters clear, and enables flexible use to build different models in the JumpNet style.

Metaparameters used in this notebook precisely mimic the model built by [jumpnet_c.py](https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/tree/master/zoo/jumpnet/jumpnet_c.py) (see the comment "Example of JumpNet for CIFAR-10").

The model is trained on [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), from [Learning Multiple Layers of Features from Tiny Images](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf), Alex Krizhevsky, 2009.


In [1]:
# JumpNet class implementation. This class implements the fluent model-building API.

from tensorflow.keras import Input, Model
import tensorflow.keras.layers as layers

class JumpNet():

	def __init__(self, shape):
		self.inputs = Input(shape)
		self.layers = None
		self.model = None

	def stem(self, filters1=16, filters2=32, stride1=1, stride2=1):
		self.layers = layers.Conv2D(filters1, (3, 3), strides=stride1, padding='same', use_bias=False)(self.inputs)
		self.layers = layers.BatchNormalization()(self.layers)
		self.layers = layers.ReLU()(self.layers)

		self.layers = layers.Conv2D(filters2, (3, 3), strides=stride2, padding='same', use_bias=False)(self.layers)
		self.layers = layers.BatchNormalization()(self.layers)
		self.layers = layers.ReLU()(self.layers)
		return self

	def group(self, filters, blocks, blockfilters=None):
		shortcut = layers.BatchNormalization()(self.layers)
		shortcut = layers.Conv2D(filters, (1,1), strides=(2,2), use_bias=False)(shortcut)

		for _ in range(blocks):
			self.block(filters, blockfilters)

		self.layers = layers.BatchNormalization()(self.layers)
		self.layers = layers.ReLU()(self.layers)
		self.layers = layers.Conv2D(filters, (1,1), strides=(2,2), use_bias=False)(self.layers)

		self.layers = layers.Concatenate()([shortcut, self.layers])
		return self

	def block(self, filters, blockfilters=None):
		shortcut = self.layers

		blockfilters = blockfilters or filters
		self.layers = layers.BatchNormalization()(self.layers)
		self.layers = layers.ReLU()(self.layers)
		self.layers = layers.Conv2D(blockfilters, (1,1), strides=(1,1), use_bias=False)(self.layers)

		self.layers = layers.BatchNormalization()(self.layers)
		self.layers = layers.ReLU()(self.layers)
		self.layers = layers.Conv2D(blockfilters, (3,3), strides=(1,1), padding='same', use_bias=False)(self.layers)

		self.layers = layers.BatchNormalization()(self.layers)
		self.layers = layers.ReLU()(self.layers)
		self.layers = layers.Conv2D(filters, (1,1), strides=(1,1), use_bias=False)(self.layers)

		self.layers = layers.Add()([shortcut, self.layers])
		return self
	
	def classifier(self, classes):
		self.layers = layers.GlobalAveragePooling2D()(self.layers)
		self.layers = layers.Dense(classes)(self.layers)
		self.layers = layers.Activation('softmax')(self.layers)
		self.outputs = self.layers
		self.model = Model(self.inputs, self.outputs)
		return self


With this class, we can build a JumpNet model to be trained with the CIFAR-10 dataset.

Here we see the fluent API in action, and metaparameters exposed as named arguments. (The enclosing parentheses are to enable the multi-line statement with having to end lines with backslashes.)

In [2]:
jumpnet = (
	JumpNet(shape=(32, 32, 3))
	.stem()
	.group(filters=32, blocks=3)
	.group(filters=64, blocks=4)
	.group(filters=128, blocks=3)
	.classifier(classes=10)
)

So, what does that model look like?

In [4]:
jumpnet.model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 16)   432         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 32, 32, 16)   64          conv2d[0][0]                     
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 32, 32, 16)   0           batch_normalization[0][0]        
_______________________________________________________________________________________

Now let's get the CIFAR-10 dataset.

In [6]:
import tensorflow as tf
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


Now we can compile and train the model. The training program is naive, we'll just run 200 epochs and hope the model stablizes. This is going to take a while - as of this writing it's about 30 seconds per epoch, times 200 epochs is an hour and forty minutes. Now would be a good time to go get a coffee. 
☕

In [None]:
jumpnet.model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
history = jumpnet.model.fit(train_images, train_labels, epochs=200, validation_data=(test_images, test_labels))
test_loss, test_acc = jumpnet.model.evaluate(test_images,  test_labels, verbose=2)

print("test accuracy", test_acc)
print("test loss", test_loss)

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200