In [None]:
!pip install tensorflow==2.2-rc3

In [1]:
import numpy as np
import math
import cv2
import pywt
import os
from PIL import Image
from tensorflow.keras.utils import to_categorical, Sequence
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model, Sequential
import seaborn as sb
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dense, Flatten, BatchNormalization, Activation, Dropout, Lambda, GlobalAveragePooling2D
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.resnet50 import ResNet50
from sklearn.metrics import classification_report,confusion_matrix
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.layers import Layer

In [2]:
K.set_image_data_format('channels_first')

In [18]:
wavename = pywt.Wavelet('haar')

class DWT_Pooling(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super(DWT_Pooling, self).__init__(**kwargs)
        
    def build(self, input_shape):
        super(DWT_Pooling, self).build(input_shape) 
    
    @tf.function
    def call(self, inputs):
        band_low = wavename.rec_lo
        band_high = wavename.rec_hi
        assert len(band_low) == len(band_high)
        band_length = len(band_low)
        assert band_length % 2 == 0
        band_length_half = math.floor(band_length / 2)

        input_height = inputs.shape[2]
        input_width = inputs.shape[3]

        L1 = input_height
        L = math.floor(L1 / 2)
        matrix_h = np.zeros( ( L,      L1 + band_length - 2 ), dtype=np.float32)
        matrix_g = np.zeros( ( L1 - L, L1 + band_length - 2 ), dtype=np.float32)
        end = None if band_length_half == 1 else (-band_length_half+1)
        
        index = 0
        for i in range(L):
            for j in range(band_length):
                matrix_h[i, index+j] = band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(input_height / 2)), 0:(input_height + band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(input_width / 2)), 0:(input_width + band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(band_length):
                matrix_g[i, index+j] = band_high[j]
            index += 2

        matrix_g_0 = matrix_g[0:(input_height - math.floor(input_height / 2)),0:(input_height + band_length - 2)]
        matrix_g_1 = matrix_g[0:(input_width - math.floor(input_width / 2)),0:(input_width + band_length - 2)]

        matrix_h_0 = matrix_h_0[:,(band_length_half-1):end]
        matrix_h_1 = matrix_h_1[:,(band_length_half-1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_g_0 = matrix_g_0[:,(band_length_half-1):end]
        matrix_g_1 = matrix_g_1[:,(band_length_half-1):end]
        matrix_g_1 = np.transpose(matrix_g_1)

        matrix_low_0 = tf.convert_to_tensor(matrix_h_0,dtype=tf.float32)
        matrix_low_1 = tf.convert_to_tensor(matrix_h_1,dtype=tf.float32)
        matrix_high_0 = tf.convert_to_tensor(matrix_g_0,dtype=tf.float32)
        matrix_high_1 = tf.convert_to_tensor(matrix_g_1,dtype=tf.float32)
        
        L = tf.matmul(matrix_low_0, inputs)
        H = tf.matmul(matrix_high_0, inputs)
        LL = tf.matmul(L, matrix_low_1)
        LH = tf.matmul(L, matrix_high_1)
        HL = tf.matmul(H, matrix_low_1)
        HH = tf.matmul(H, matrix_high_1)
        return LL    
    
    def get_config(self):
        config = super(DWT_Pooling, self).get_config()
        return config

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2]//2, input_shape[3]//2)

In [21]:
def create_model(input_shape=(1,28,28), num_classes = 1, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)

  inputs = Input(shape=input_shape)
  
  output = Conv2D(16,(3,3),padding='same',use_bias=False)(inputs)
  output = BatchNormalization(scale=False,center=True)(output)
  output = Activation('relu')(output)
  #output = MaxPooling2D()(output)
  output = DWT_Pooling()(output)

  output = Conv2D(32,(3,3),padding='same',use_bias=False)(output)
  output = BatchNormalization(scale=False,center=True)(output)
  output = Activation('relu')(output)
  #output = MaxPooling2D()(output)
  output = DWT_Pooling()(output)

  output = Flatten()(output)
  output = Dense(256,activation='relu')(output)
  output = Dropout(0.3)(output)
  if num_classes == 1:
    activation = 'sigmoid'
  else:
    activation = 'softmax'
  output = Dense(num_classes,activation=activation,bias_initializer=output_bias)(output)
  model = Model(inputs,output)
  return model

In [22]:
model = create_model(input_shape=(1,28,28),num_classes=10)
model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 1, 28, 28)]       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 28, 28)        144       
_________________________________________________________________
batch_normalization_4 (Batch (None, 16, 28, 28)        84        
_________________________________________________________________
activation_4 (Activation)    (None, 16, 28, 28)        0         
_________________________________________________________________
dwt__pooling_3 (DWT_Pooling) (None, 16, 14, 14)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 32, 14, 14)        4608      
_________________________________________________________________
batch_normalization_5 (Batch (None, 32, 14, 14)        42  

In [9]:
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()

In [10]:
x_train = x_train/255.
x_test = x_test/255.

In [11]:
from tensorflow.keras.utils import to_categorical
y_train_oh = to_categorical(y_train,10)
y_test_oh = to_categorical(y_test,10)

In [12]:
num_train = x_train.shape[0]
num_test = x_test.shape[0]
img_height = x_train.shape[1]
img_width = x_train.shape[2]
num_channels = 1
x_train = x_train.reshape(num_train,1,img_height,img_width)
x_test = x_test.reshape(num_test,1,img_height,img_width)

In [23]:
opt= Adam(learning_rate=0.01)
model.compile(optimizer = opt,loss='categorical_crossentropy',metrics=['accuracy'])

In [24]:
def lr_decay(epoch):
  return 0.01*math.pow(0.666,epoch)
lr_decay_cb = LearningRateScheduler(lr_decay,verbose=True)
model_check_cb = ModelCheckpoint('mnist_dwt.h5',save_best_only=True,monitor='val_loss')

In [25]:
history = model.fit(x_train,y_train_oh,validation_data=(x_test,y_test_oh),epochs=10,batch_size=64,
                    callbacks=[lr_decay_cb,model_check_cb])


Epoch 00001: LearningRateScheduler reducing learning rate to 0.01.
Epoch 1/10

Epoch 00002: LearningRateScheduler reducing learning rate to 0.00666.
Epoch 2/10

Epoch 00003: LearningRateScheduler reducing learning rate to 0.004435560000000001.
Epoch 3/10

Epoch 00004: LearningRateScheduler reducing learning rate to 0.0029540829600000007.
Epoch 4/10

Epoch 00005: LearningRateScheduler reducing learning rate to 0.0019674192513600007.
Epoch 5/10

Epoch 00006: LearningRateScheduler reducing learning rate to 0.0013103012214057605.
Epoch 6/10

Epoch 00007: LearningRateScheduler reducing learning rate to 0.0008726606134562365.
Epoch 7/10

Epoch 00008: LearningRateScheduler reducing learning rate to 0.0005811919685618535.
Epoch 8/10

Epoch 00009: LearningRateScheduler reducing learning rate to 0.0003870738510621945.
Epoch 9/10

Epoch 00010: LearningRateScheduler reducing learning rate to 0.00025779118480742154.
Epoch 10/10


In [26]:
best_model = tf.keras.models.load_model('mnist_dwt.h5',custom_objects={'DWT_Pooling':DWT_Pooling})

In [28]:
best_model.evaluate(x_test,y_test_oh)
y_preds = np.argmax(m.predict(x_test),axis=1)



In [30]:
from sklearn.metrics import classification_report
print(classification_report(y_test,y_preds))

              precision    recall  f1-score   support

           0       0.99      1.00      1.00       980
           1       0.99      1.00      1.00      1135
           2       1.00      1.00      1.00      1032
           3       0.99      0.99      0.99      1010
           4       0.99      0.99      0.99       982
           5       0.99      0.99      0.99       892
           6       1.00      0.99      0.99       958
           7       0.99      1.00      0.99      1028
           8       0.99      0.99      0.99       974
           9       0.99      0.99      0.99      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000

