In [1]:
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense,Dropout
import keras
import numpy as np
import sys
import matplotlib.pyplot as plt
from keras import utils
from keras.optimizers import Adam
from keras.datasets import mnist
from tensorflow.keras.models import load_model

In [2]:
dataset = mnist

In [3]:
def build_alexnet(inputShape):
    # Define the input layer
    inputs = keras.Input(shape = inputShape)

    #converlutional layer 1
    conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)    
    pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
    stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)

    #converlutional layer 2
    conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
    pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
    
    stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)

    #converlutional layer 3
    conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
    stand3 = keras.layers.BatchNormalization(axis=1)(conv3)

    #converlutional layer 4
    conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
    stand4 = keras.layers.BatchNormalization(axis=1)(conv4)

    #converlutional layer 5
    conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
    pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
    stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)

    # fully connected layer
    flatten = keras.layers.Flatten()(stand5)
    fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
    drop1 = keras.layers.Dropout(0.5)(fc1)

    fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
    drop2 = keras.layers.Dropout(0.5)(fc2)

    fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
    model = keras.Model(inputs= inputs, outputs = fc3)
    
    model.compile(optimizer= Adam(0.001),
              loss= keras.losses.categorical_crossentropy,
              metrics= ['accuracy'])
    return model

In [4]:
def train_mnist(dataset):
    
    (x_train, y_train), (x_test, y_test) = dataset.load_data() #original size train(xxx, 28, 28),test(xxx)
        
    #translate the data to image
    x_train = x_train.reshape(x_train.shape[0], -1).astype("float64")
    x_test = x_test.reshape(x_test.shape[0], -1).astype("float64")

    #normalization
    x_train /= 255
    x_test /= 255

    #translate y_train and y_test to “one hot” form
    y_train = utils.to_categorical(y_train) 
    y_test = utils.to_categorical(y_test)
        
    x_train = x_train.reshape(-1, 28, 28, 1)
    x_test = x_test.reshape(-1, 28, 28, 1)
    model = build_alexnet((28,28,1))
    batch_size = 64
    print("\nTraining：")
    model.fit(x_train, y_train, batch_size, epochs=4)

    # Evaluation
    print("\nEvaluation：")
    final_loss, final_accuracy = model.evaluate(x_test, y_test)
    print("loss= ", final_loss)
    print("accuracy= ", final_accuracy)
    
        
    print("success!")
    model.save('alexnet_mnist.h5')
        
    outcome = model.predict(x_test[:1])
    print(outcome)
        

In [5]:
train_mnist(dataset)


Training：
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4

Evaluation：
loss=  0.03555772081017494
accuracy=  0.991599977016449
success!
[[2.2360580e-11 2.0816005e-08 1.0460592e-09 5.0800466e-08 3.6237293e-09
  2.0222701e-10 1.0478430e-12 9.9999833e-01 5.8541888e-10 1.6446148e-06]]
