In [1]:
from __future__ import absolute_import
import random
import numpy as np
from tensorflow.keras.layers import *
import tensorflow as tf

class FractionalPooling2D(Layer):
	def __init__(self, pool_ratio = None, pseudo_random = True, overlap = False, name ='FractionPooling2D', **kwargs):
		self.pool_ratio = pool_ratio
		self.input_spec = [InputSpec(ndim=4)]
		self.pseudo_random = pseudo_random
		self.overlap = overlap
		super(FractionalPooling2D, self).__init__(**kwargs)
		
	def call(self, input):
		[batch_tensor,row_pooling,col_pooling] = tf.nn.fractional_max_pool(input, pooling_ratio = self.pool_ratio, pseudo_random = self.pseudo_random, overlapping = self.overlap)
		return(batch_tensor)
		
	def compute_output_shape(self, input_shape):
	
			if(input_shape[0] != None):
				batch_size = int(input_shape[0]/self.pool_ratio[0])
			else:
				batch_size = input_shape[0]
			width = int(input_shape[1]/self.pool_ratio[1])
			height = int(input_shape[2]/self.pool_ratio[2])
			channels = int(input_shape[3]/self.pool_ratio[3])
			return(batch_size, width, height, channels)
			

		
	def get_config(self):
		config = {'pooling_ratio': self.pool_ratio, 'pseudo_random': self.pseudo_random, 'overlap': self.overlap, 'name':self.name}
		base_config = super(FractionalPooling2D, self).get_config()
		return dict(list(base_config.items()) + list(config.items()))
		
	def build(self, input_shape):
		self.input_spec = [InputSpec(shape=input_shape)]

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import datetime as dt
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).batch(64).shuffle(50000)
train_dataset = train_dataset.map(
    lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
train_dataset = train_dataset.repeat()
valid_dataset = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(5000).shuffle(10000)
valid_dataset = valid_dataset.map(
    lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
valid_dataset = valid_dataset.repeat()



def res_net_block(input_data, filters, conv_size):
    x = layers.Conv2D(filters, conv_size, activation='relu',
                      padding='same')(input_data)
    x = layers.BatchNormalization(axis = -1)(x)
    x = layers.Conv2D(filters, conv_size, activation=None, padding='same')(x)
    x = layers.BatchNormalization(axis = -1)(x)
    x = layers.Add()([x, input_data])
    x = layers.Activation('relu')(x)
    return x

def non_res_block(input_data, filters, conv_size):
    x = layers.Conv2D(filters, conv_size, activation='relu',
                      padding='same')(input_data)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(filters, conv_size,
                      activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    return x

inputs = keras.Input(shape=(32, 32, 3))
x = layers.Conv2D(32, 3, activation='relu')(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = FractionalPooling2D(3)(x)
num_res_net_blocks = 10
for i in range(num_res_net_blocks):
    x = res_net_block(x, 64, 3)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = FractionalPooling2D(3)(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
hybrid_net_model = keras.Model(inputs, outputs)


callbacks = [
    # Write TensorBoard logs to `./logs` directory
    keras.callbacks.TensorBoard(
        log_dir='./log/{}'.format(dt.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")), write_images=True),
]
hybrid_net_model.compile(optimizer=keras.optimizers.Adam(),
                        loss='sparse_categorical_crossentropy',
                        metrics=['acc'])

print(hybrid_net_model.summary())

checkpoint = ModelCheckpoint('Model.hdf5', monitor='val_loss', save_best_only = True, verbose=1, mode='min')

callbacks_list = [checkpoint]
#model.load_weights('Model.hdf5')
epochs = 10
results = hybrid_net_model.fit(train_dataset, epochs=30, steps_per_epoch=195,
                    validation_data=valid_dataset,
                    validation_steps=3, callbacks=callbacks_list)

# plot epoch vs accuracy


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
