In [1]:
import pickle
import sys
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dropout, MaxPool2D, Flatten, Add, Dense, Activation, BatchNormalization, Lambda, ReLU, PReLU, LeakyReLU, GlobalAveragePooling2D
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.metrics import TopKCategoricalAccuracy
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D, AveragePooling2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Dense
from tensorflow.keras.layers import Flatten, Input, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
"""
The CIFAR10 dataset was downloaded through the official website, with each training batch being unpickled and then appended to each other
to create one large training set. The images were preprocessed to convert the initial row vector to shape (32, 32, 3) through reshaping and transposing.
The class output data was one hot encoded. My preliminary attempt used my model for the MNIST dataset, with an alteration for the input size. This
resulted in a test accuracy of 67% after 10 epochs. My next attempt was a VGG with fractional max pooling, based on a paper by Benjamin Graham. While this definitely 
outperformed the previous model, the computational time was far too high. I tried reducing training time by using only one of the 5 training batches and doubling the 
batch size to 512, but the tradeoff with accuracy was way too high. 

The next model I tried was another VGG type convolutional network, which was shallower and converged much faster. This got me to 80% test accuracy.   
"""
# Fractional max pooling
# - https://arxiv.org/abs/1412.6071
# - https://github.com/laplacetw/vgg-like-cifar10/blob/master/fmp_cifar10.py
# https://www.binarystudy.com/2021/09/how-to-load-preprocess-visualize-CIFAR-10-and-CIFAR-100.html#routine

BTEST = '../data/test_batch'
meta_file = '../CIFAR10-data/batches.meta'

NUM_TRAINING_BATCHES = 5
BATCH_SIZE = 128 #128
LAMBDA = 1e-5
EPOCHS = 200
IMG_SIDE_LEN = 32
LR = 5e-3
DATASET = "CIFAR10"
MODEL = "model3x"

def unpickle(file):
    with open(file, 'rb') as fo:
        u = pickle._Unpickler( fo )
        u.encoding = 'latin1'
        dict = u.load()
    return dict

def load_training_data():
    # The whole data_batch_1 has 10,000 images. And each image is a 1-D array having 3,072 entries. 
    # First 1024 entries for Red, the next 1024 entries for Green and last 1024 entries for Blue channels. 
    print("Loading Data:")
    features, classes = np.empty((0,32,32,3)), np.empty((0,10))
    for i in range(NUM_TRAINING_BATCHES):
        print(f"Batch {i+1}")
        batch_path = f'../data/data_batch_{i+1}'
        x, y = reshape_features(batch_path)
        features = np.append(features, x, axis=0)
        classes = np.append(classes, y, axis=0)
    return features, classes

def reshape_features(feat_path, CIFAR100=False):
    labels = 'coarse_labels' if CIFAR100 else 'labels'
    unpickled_data = unpickle(feat_path)
    return (unpickled_data['data'].reshape(len(unpickled_data['data']),3,32,32).transpose(0,2,3,1) / 255,
            tf.keras.utils.to_categorical(unpickled_data[labels]))

def frac_max_pool(x):
    return tf.nn.fractional_max_pool(x, [1.0, 1.41, 1.41, 1.0], pseudo_random=True, overlapping=True)[0]

def poly_decay(epoch):
  maxEpochs = EPOCHS
  baseLR = LR
  power = 1.0
  alpha = baseLR * (1 - (epoch   / float(maxEpochs))) ** power
  return alpha

datagen = ImageDataGenerator(
    rotation_range=15,
    horizontal_flip=True,
    width_shift_range=0.1,
    height_shift_range=0.1
    #zoom_range=0.3
)

aug = ImageDataGenerator(width_shift_range=0.1,height_shift_range=0.1, horizontal_flip=True,fill_mode="nearest")

stopping = tf.keras.callbacks.EarlyStopping(
          monitor="val_accuracy",
          min_delta=0,
          patience=25,
          verbose=1,
          mode="max",
          baseline=None,
          restore_best_weights=True)

def normalize_x_data(x_train, x_test):
    eps = 1e-7
    mean = np.mean(x_train,axis = (0, 1, 2, 3))
    std = np.std(x_train,axis = (0, 1, 2, 3))
    x_train = (x_train - mean)/(std + eps)
    x_test = (x_test - mean)/(std + eps)
    return x_train, x_test

class Data10(object):
    def __init__(self):
        self.x_train, self.y_train = load_training_data()
        self.x_test, self.y_test = reshape_features(BTEST)
        self.x_train, self.x_test = normalize_x_data(self.x_train, self.x_test)
        self.x_train, self.x_val, self.y_train, self.y_val = train_test_split(self.x_train, self.y_train, test_size=0.2, random_state=31415)

class Data100(object):
    def __init__(self):
      self.x_train, self.y_train = reshape_features('../data/train', CIFAR100=True)
      self.x_test, self.y_test = reshape_features('../data/test', CIFAR100=True)
      self.x_train, self.x_val, self.y_train, self.y_val = train_test_split(self.x_train, self.y_train, test_size=0.2, random_state=31415)

def double_conv_module(input, num_filters, activation, kern_reg, dropout, padding="same"):
    input = Conv2D(filters = num_filters, kernel_size = (3, 3), activation = activation, padding = padding, kernel_regularizer = kern_reg)(input)
    input = BatchNormalization(axis=-1)(input)
    input = Conv2D(filters = num_filters, kernel_size = (3, 3), activation = activation, padding = padding, kernel_regularizer = kern_reg)(input)
    input = BatchNormalization(axis=-1)(input)
    input = MaxPooling2D(pool_size = (2, 2))(input)
    input = Dropout(dropout)(input)

    return input

def rav_model(width, height, depth, classes):
    inputShape=(height, width, depth)
    weight_decay = 0.001

    # (Step 1) Define the model input
    inputs = Input(shape=inputShape)
    KR = None #l2(weight_decay)
    x = double_conv_module(inputs, 32, activation='relu', kern_reg=KR, dropout = 0.1, padding='same')
    x = double_conv_module(x, 64, activation='relu', kern_reg=KR, dropout = 0.2, padding='same')
    x = double_conv_module(x, 128, activation='relu', kern_reg=KR, dropout = 0.3, padding='same')
    x = double_conv_module(x, 128, activation='relu', kern_reg=KR, dropout = 0.4, padding='same')
   
    x = Flatten()(x)
    x = Dense(512, activation='relu',kernel_regularizer=None)(x)
    x = BatchNormalization(axis=-1)(x)
    x = Dropout(0.5)(x)
    x = Dense(classes)(x)
    x = Activation("softmax")(x)

    model = Model(inputs, x, name="rav_net")
    return model

def fmp_unit(input, num_filters, dropout, padding="same", frac_pool=True):
    input = Conv2D(filters = num_filters, kernel_size = (3, 3), padding = padding, kernel_initializer='he_uniform')(input)
    input = LeakyReLU()(input)
    input = BatchNormalization(axis=-1)(input)
    input = Conv2D(filters = num_filters, kernel_size = (3, 3), padding = padding, kernel_initializer='he_uniform')(input)
    input = LeakyReLU()(input)
    input = BatchNormalization(axis=-1)(input)
    input = Lambda(frac_max_pool)(input) if frac_pool else input
    input = Dropout(dropout)(input)

    return input

def fmp_model(width, height, depth, classes):
    inputShape=(height, width, depth)
    inputs = Input(shape=inputShape)

    x = fmp_unit(inputs, 32, dropout = 0.3, padding='same', frac_pool=False)
    x = fmp_unit(x, 64, dropout = 0.35, padding='same')
    x = fmp_unit(x, 96, dropout = 0.35, padding='same')
    x = fmp_unit(x, 128, dropout = 0.4, padding='same')
    x = fmp_unit(x, 160, dropout = 0.45, padding='same')
    x = fmp_unit(x, 192, dropout = 0.5, padding='same')

    x = Conv2D(filters=192, kernel_size=(1, 1), padding='same', kernel_initializer='he_uniform')(x)
    x = LeakyReLU()(x)
    x = BatchNormalization()(x)
    x = GlobalAveragePooling2D()(x)
    x = Dense(units=classes, kernel_initializer='he_uniform')(x)
    x = Activation("softmax")(x)

    model = Model(inputs, x, name="fmp_rav_net")
    return model

def relu_bn(input):
    relu = ReLU()(input)
    bn = BatchNormalization()(relu)
    return bn

def residual_block(x, downsample: bool, filters: int, kernel_size: int = 3):
    y = Conv2D(kernel_size=kernel_size, strides= (1 if not downsample else 2), filters=filters, padding="same")(x)
    y = relu_bn(y)
    y = Conv2D(kernel_size=kernel_size, strides=1, filters=filters, padding="same")(y)

    x = Conv2D(kernel_size=1, strides=2, filters=filters, padding="same")(x) if downsample else x
    
    out = Add()([x, y])
    out = relu_bn(out)
    return out

def create_res_net(width, height, depth, classes):
    num_filters = 64
    inputShape=(height, width, depth)
    inputs = Input(shape=inputShape)

    t = BatchNormalization()(inputs)
    t = Conv2D(kernel_size=3, strides=1, filters=num_filters, padding="same")(t)
    t = relu_bn(t)
    
    num_blocks_list = [2, 5, 5, 2]
    for i in range(len(num_blocks_list)):
        num_blocks = num_blocks_list[i]
        for j in range(num_blocks):
            t = residual_block(t, downsample=(j==0 and i!=0), filters=num_filters)
        num_filters *= 2
    
    t = AveragePooling2D(4)(t)
    x = Flatten()(t)
    #x = Dense(512, activation='relu',kernel_regularizer=None)(x)
    #x = BatchNormalization(axis=-1)(x)
    #x = Dropout(0.5)(x)
    x = Dense(classes)(x)
    x = Activation("softmax")(x)
    
    model = Model(inputs, x, name="rav_net")
    return model

if __name__ == "__main__":
  if DATASET == 'CIFAR10':
    data = Data10()
    NUM_CLASSES = 10
  else:
    data = Data100()
    NUM_CLASSES = 20
      
  x_train, y_train = data.x_train, data.y_train
  x_test, y_test = data.x_test, data.y_test
  x_val, y_val = data.x_val, data.y_val

  lr_scheduler = LearningRateScheduler(poly_decay)
  variable_learning_rate = ReduceLROnPlateau(monitor='val_loss', factor = 0.2, patience = 2)

  if MODEL == "model3x":
    model = rav_model(width=32, height=32, depth=3, classes=NUM_CLASSES)
    ac='relu'
    adm=Adam(learning_rate=0.001,decay=0, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    opt=adm
  
  elif MODEL == "fmp":
    model = fmp_model(width=32, height=32, depth=3, classes=NUM_CLASSES)
    opt = RMSprop(decay=1e-6)
  
  elif MODEL == "resnet":
    model = create_res_net(width=32, height=32, depth=3, classes=NUM_CLASSES) 
    adm=Adam(learning_rate=0.001,decay=0, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    opt=adm

  model.compile(loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy'],optimizer=opt)
  #model.compile(loss=tf.keras.losses.categorical_crossentropy, metrics=[TopKCategoricalAccuracy(k = 5)],optimizer=opt)
  model.summary()
  history=model.fit(datagen.flow(x_train, y_train, batch_size=BATCH_SIZE), 
                    batch_size=BATCH_SIZE, 
                    epochs=EPOCHS, 
                    callbacks=[variable_learning_rate, lr_scheduler, stopping], 
                    validation_data=(x_val, y_val), 
                    verbose=1, 
                    steps_per_epoch = len(x_train) // BATCH_SIZE)
  score = model.evaluate(x_test, y_test, verbose=0)
  

  print('Test loss:', score[0])
  print('Test accuracy:', score[1])

  # 30 epochs 82%

Loading Data:
Batch 1
Batch 2
Batch 3
Batch 4
Batch 5
Model: "rav_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d (Conv2D)             (None, 32, 32, 32)        896       
                                                                 
 batch_normalization (BatchN  (None, 32, 32, 32)       128       
 ormalization)                                                   
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 32)        9248      
                                                                 
 batch_normalization_1 (Batc  (None, 32, 32, 32)       128       
 hNormalization)                                                 
                                                                 
 max_