In [6]:
%pip install PyWavelets

Defaulting to user installation because normal site-packages is not writeable
Collecting PyWavelets
  Downloading pywavelets-1.6.0-cp39-cp39-macosx_11_0_arm64.whl (4.3 MB)
[K     |████████████████████████████████| 4.3 MB 2.2 MB/s eta 0:00:01
Installing collected packages: PyWavelets
Successfully installed PyWavelets-1.6.0
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [37]:
import tensorflow as tf
import pywt
import numpy as np
import tensorflow.keras.backend as K
from typing import Tuple, Optional
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, GlobalAveragePooling2D, Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2

In [35]:
class WaveletLayerTransformation(tf.keras.layers.Layer):
    ''' Wavelet transformation layer'''
    def __init__(self, wavelet='haar', level=1, **kwargs):
        super(WaveletLayerTransformation, self).__init__(**kwargs)
        self.wavelet = wavelet
        self.level = level

    def call(self, inputs):
        return tf.signal.dct(inputs, type=2, norm='ortho')

class WaveletNet():
    ''' Wavelenet model'''
    def __init__(self, num_classes: Optional[int] = 100, input_shape: Optional[Tuple[np.float32, np.float32, np.float32]] = (28, 28, 1)):
        self.num_classes = num_classes
        self.inputs = Input(shape=input_shape)
    
    def main(self):
        return Model(self.inputs, self.wavelet_net())

    def wavelet_net(self):
        self.tensor = WaveletLayerTransformation()(self.inputs)

        # First convolution block
        self.tensor = Conv2D(64, (3, 3), activation='relu', padding='same')(self.tensor)
        self.tensor = MaxPooling2D(pool_size=(2, 2))(self.tensor)
        self.tensor = Dropout(0.6)(self.tensor)
        
        # Second convolution block
        self.tensor = Conv2D(128, (3, 3), activation='relu', padding='same')(self.tensor)
        self.tensor = Conv2D(128, (3, 3), activation='relu', padding='same')(self.tensor)
        self.tensor = MaxPooling2D(pool_size=(2, 2))(self.tensor)
        self.tensor = Dropout(0.4)(self.tensor)
        
        # Global average pooling and output block
        self.tensor = GlobalAveragePooling2D()(self.tensor)
        self.tensor = Dropout(0.4)(self.tensor)
        return Dense(self.num_classes, activation='softmax', kernel_regularizer=l2(0.01))(self.tensor)


In [30]:
model = WaveletNet().main()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.build((None, 32, 32, 3))
model.summary()