<a href="https://colab.research.google.com/github/koshygeoji/Texture-Classification-using-Wavelet-CNN/blob/master/TextureClassificationCNNWavelet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Checking tensorflow version

In [0]:
import tensorflow as tf
print(tf.__version__)
print(tf.test.gpu_device_name())

2.2.0-rc3
/device:GPU:0


In [0]:
import numpy as np
from matplotlib import pyplot as plt

from keras import backend as K
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Lambda
from keras.layers import Flatten
from keras.layers import Reshape
from keras.layers import Dropout
from keras.layers import Activation
from keras.layers import AveragePooling2D
from keras.layers import BatchNormalization
from keras.layers.merge import add, concatenate
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import plot_model

Using TensorFlow backend.


defining wavelet transform

In [0]:
# batch operation usng tensor slice
def WaveletTransformAxisY(batch_img):
    odd_img  = batch_img[:,0::2]
    even_img = batch_img[:,1::2]
    L = (odd_img + even_img) / 2.0
    H = K.abs(odd_img - even_img)
    return L, H

def WaveletTransformAxisX(batch_img):
    # transpose + fliplr
    tmp_batch = K.permute_dimensions(batch_img, [0, 2, 1])[:,:,::-1]
    _dst_L, _dst_H = WaveletTransformAxisY(tmp_batch)
    # transpose + flipud
    dst_L = K.permute_dimensions(_dst_L, [0, 2, 1])[:,::-1,...]
    dst_H = K.permute_dimensions(_dst_H, [0, 2, 1])[:,::-1,...]
    return dst_L, dst_H

In [0]:
def Wavelet(batch_image):

  def WaveletTransformAxisY(batch_img):
      odd_img  = batch_img[:,0::2]
      even_img = batch_img[:,1::2]
      L = (odd_img + even_img) / 2.0
      H = K.abs(odd_img - even_img)
      return L, H

  def WaveletTransformAxisX(batch_img):
      # transpose + fliplr
      tmp_batch = K.permute_dimensions(batch_img, [0, 2, 1])[:,:,::-1]
      _dst_L, _dst_H = WaveletTransformAxisY(tmp_batch)
      # transpose + flipud
      dst_L = K.permute_dimensions(_dst_L, [0, 2, 1])[:,::-1,...]
      dst_H = K.permute_dimensions(_dst_H, [0, 2, 1])[:,::-1,...]
      return dst_L, dst_H

  # make channel first image
  batch_image = K.permute_dimensions(batch_image, [0, 3, 1, 2])
  r = batch_image[:,0]
  g = batch_image[:,1]
  b = batch_image[:,2]

  # level 1 decomposition
  wavelet_L, wavelet_H = WaveletTransformAxisY(r)
  r_wavelet_LL, r_wavelet_LH = WaveletTransformAxisX(wavelet_L)
  r_wavelet_HL, r_wavelet_HH = WaveletTransformAxisX(wavelet_H)

  wavelet_L, wavelet_H = WaveletTransformAxisY(g)
  g_wavelet_LL, g_wavelet_LH = WaveletTransformAxisX(wavelet_L)
  g_wavelet_HL, g_wavelet_HH = WaveletTransformAxisX(wavelet_H)

  wavelet_L, wavelet_H = WaveletTransformAxisY(b)
  b_wavelet_LL, b_wavelet_LH = WaveletTransformAxisX(wavelet_L)
  b_wavelet_HL, b_wavelet_HH = WaveletTransformAxisX(wavelet_H)

  wavelet_data = [r_wavelet_LL, r_wavelet_LH, r_wavelet_HL, r_wavelet_HH, 
                  g_wavelet_LL, g_wavelet_LH, g_wavelet_HL, g_wavelet_HH,
                  b_wavelet_LL, b_wavelet_LH, b_wavelet_HL, b_wavelet_HH]
  transform_batch = K.stack(wavelet_data, axis=1)

  # level 2 decomposition
  wavelet_L2, wavelet_H2 = WaveletTransformAxisY(r_wavelet_LL)
  r_wavelet_LL2, r_wavelet_LH2 = WaveletTransformAxisX(wavelet_L2)
  r_wavelet_HL2, r_wavelet_HH2 = WaveletTransformAxisX(wavelet_H2)

  wavelet_L2, wavelet_H2 = WaveletTransformAxisY(g_wavelet_LL)
  g_wavelet_LL2, g_wavelet_LH2 = WaveletTransformAxisX(wavelet_L2)
  g_wavelet_HL2, g_wavelet_HH2 = WaveletTransformAxisX(wavelet_H2)

  wavelet_L2, wavelet_H2 = WaveletTransformAxisY(b_wavelet_LL)
  b_wavelet_LL2, b_wavelet_LH2 = WaveletTransformAxisX(wavelet_L2)
  b_wavelet_HL2, b_wavelet_HH2 = WaveletTransformAxisX(wavelet_H2)


  wavelet_data_l2 = [r_wavelet_LL2, r_wavelet_LH2, r_wavelet_HL2, r_wavelet_HH2, 
                  g_wavelet_LL2, g_wavelet_LH2, g_wavelet_HL2, g_wavelet_HH2,
                  b_wavelet_LL2, b_wavelet_LH2, b_wavelet_HL2, b_wavelet_HH2]
  transform_batch_l2 = K.stack(wavelet_data_l2, axis=1)

  # level 3 decomposition
  wavelet_L3, wavelet_H3 = WaveletTransformAxisY(r_wavelet_LL2)
  r_wavelet_LL3, r_wavelet_LH3 = WaveletTransformAxisX(wavelet_L3)
  r_wavelet_HL3, r_wavelet_HH3 = WaveletTransformAxisX(wavelet_H3)

  wavelet_L3, wavelet_H3 = WaveletTransformAxisY(g_wavelet_LL2)
  g_wavelet_LL3, g_wavelet_LH3 = WaveletTransformAxisX(wavelet_L3)
  g_wavelet_HL3, g_wavelet_HH3 = WaveletTransformAxisX(wavelet_H3)

  wavelet_L3, wavelet_H3 = WaveletTransformAxisY(b_wavelet_LL2)
  b_wavelet_LL3, b_wavelet_LH3 = WaveletTransformAxisX(wavelet_L3)
  b_wavelet_HL3, b_wavelet_HH3 = WaveletTransformAxisX(wavelet_H3)

  wavelet_data_l3 = [r_wavelet_LL3, r_wavelet_LH3, r_wavelet_HL3, r_wavelet_HH3, 
                  g_wavelet_LL3, g_wavelet_LH3, g_wavelet_HL3, g_wavelet_HH3,
                  b_wavelet_LL3, b_wavelet_LH3, b_wavelet_HL3, b_wavelet_HH3]
  transform_batch_l3 = K.stack(wavelet_data_l3, axis=1)

  # level 4 decomposition
  wavelet_L4, wavelet_H4 = WaveletTransformAxisY(r_wavelet_LL3)
  r_wavelet_LL4, r_wavelet_LH4 = WaveletTransformAxisX(wavelet_L4)
  r_wavelet_HL4, r_wavelet_HH4 = WaveletTransformAxisX(wavelet_H4)

  wavelet_L4, wavelet_H4 = WaveletTransformAxisY(g_wavelet_LL3)
  g_wavelet_LL4, g_wavelet_LH4 = WaveletTransformAxisX(wavelet_L4)
  g_wavelet_HL4, g_wavelet_HH4 = WaveletTransformAxisX(wavelet_H4)

  wavelet_L3, wavelet_H3 = WaveletTransformAxisY(b_wavelet_LL3)
  b_wavelet_LL4, b_wavelet_LH4 = WaveletTransformAxisX(wavelet_L4)
  b_wavelet_HL4, b_wavelet_HH4 = WaveletTransformAxisX(wavelet_H4)


  wavelet_data_l4 = [r_wavelet_LL4, r_wavelet_LH4, r_wavelet_HL4, r_wavelet_HH4, 
                  g_wavelet_LL4, g_wavelet_LH4, g_wavelet_HL4, g_wavelet_HH4,
                  b_wavelet_LL4, b_wavelet_LH4, b_wavelet_HL4, b_wavelet_HH4]
  transform_batch_l4 = K.stack(wavelet_data_l4, axis=1)

  # print('shape before')
  # print(transform_batch.shape)
  # print(transform_batch_l2.shape)
  # print(transform_batch_l3.shape)
  # print(transform_batch_l4.shape)

  decom_level_1 = K.permute_dimensions(transform_batch, [0, 2, 3, 1])
  decom_level_2 = K.permute_dimensions(transform_batch_l2, [0, 2, 3, 1])
  decom_level_3 = K.permute_dimensions(transform_batch_l3, [0, 2, 3, 1])
  decom_level_4 = K.permute_dimensions(transform_batch_l4, [0, 2, 3, 1])
  
  # print('shape after')
  # print(decom_level_1.shape)
  # print(decom_level_2.shape)
  # print(decom_level_3.shape)
  # print(decom_level_4.shape)
  return [decom_level_1, decom_level_2, decom_level_3, decom_level_4]


def Wavelet_out_shape(input_shapes):
    # print('in to shape')
    return [tuple([None, 112, 112, 12]), tuple([None, 56, 56, 12]), 
            tuple([None, 28, 28, 12]), tuple([None, 14, 14, 12])]

In [0]:
img_batch = K.zeros(shape=(8, 224, 224, 3), dtype='float32')
Wavelet(img_batch)

[<tf.Tensor: shape=(8, 112, 112, 12), dtype=float32, numpy=
 array([[[[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
         

model creation

In [0]:
def get_wavelet_cnn_model():

    input_shape = 224, 224, 3

    input_ = Input(input_shape, name='the_input')
    # wavelet = Lambda(Wavelet, name='wavelet')
    wavelet = Lambda(Wavelet, Wavelet_out_shape, name='wavelet')
    input_l1, input_l2, input_l3, input_l4 = wavelet(input_)
    # print(input_l1)
    # print(input_l2)
    # print(input_l3)
    # print(input_l4)
    # level one decomposition starts
    conv_1 = Conv2D(64, kernel_size=(3, 3), padding='same', name='conv_1')(input_l1)
    norm_1 = BatchNormalization(name='norm_1')(conv_1)
    relu_1 = Activation('relu', name='relu_1')(norm_1)

    conv_1_2 = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same', name='conv_1_2')(relu_1)
    norm_1_2 = BatchNormalization(name='norm_1_2')(conv_1_2)
    relu_1_2 = Activation('relu', name='relu_1_2')(norm_1_2)

    # level two decomposition starts
    conv_a = Conv2D(filters=64, kernel_size=(3, 3), padding='same', name='conv_a')(input_l2)
    norm_a = BatchNormalization(name='norm_a')(conv_a)
    relu_a = Activation('relu', name='relu_a')(norm_a)

    # concate level one and level two decomposition
    concate_level_2 = concatenate([relu_1_2, relu_a])
    conv_2 = Conv2D(128, kernel_size=(3, 3), padding='same', name='conv_2')(concate_level_2)
    norm_2 = BatchNormalization(name='norm_2')(conv_2)
    relu_2 = Activation('relu', name='relu_2')(norm_2)

    conv_2_2 = Conv2D(128, kernel_size=(3, 3), strides=(2, 2), padding='same', name='conv_2_2')(relu_2)
    norm_2_2 = BatchNormalization(name='norm_2_2')(conv_2_2)
    relu_2_2 = Activation('relu', name='relu_2_2')(norm_2_2)

    # level three decomposition starts 
    conv_b = Conv2D(filters=64, kernel_size=(3, 3), padding='same', name='conv_b')(input_l3)
    norm_b = BatchNormalization(name='norm_b')(conv_b)
    relu_b = Activation('relu', name='relu_b')(norm_b)

    conv_b_2 = Conv2D(128, kernel_size=(3, 3), padding='same', name='conv_b_2')(relu_b)
    norm_b_2 = BatchNormalization(name='norm_b_2')(conv_b_2)
    relu_b_2 = Activation('relu', name='relu_b_2')(norm_b_2)

    # concate level two and level three decomposition 
    concate_level_3 = concatenate([relu_2_2, relu_b_2])
    conv_3 = Conv2D(256, kernel_size=(3, 3), padding='same', name='conv_3')(concate_level_3)
    norm_3 = BatchNormalization(name='nomr_3')(conv_3)
    relu_3 = Activation('relu', name='relu_3')(norm_3)

    conv_3_2 = Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same', name='conv_3_2')(relu_3)
    norm_3_2 = BatchNormalization(name='norm_3_2')(conv_3_2)
    relu_3_2 = Activation('relu', name='relu_3_2')(norm_3_2)

    # level four decomposition start
    conv_c = Conv2D(64, kernel_size=(3, 3), padding='same', name='conv_c')(input_l4)
    norm_c = BatchNormalization(name='norm_c')(conv_c)
    relu_c = Activation('relu', name='relu_c')(norm_c)

    conv_c_2 = Conv2D(256, kernel_size=(3, 3), padding='same', name='conv_c_2')(relu_c)
    norm_c_2 = BatchNormalization(name='norm_c_2')(conv_c_2)
    relu_c_2 = Activation('relu', name='relu_c_2')(norm_c_2)

    conv_c_3 = Conv2D(256, kernel_size=(3, 3), padding='same', name='conv_c_3')(relu_c_2)
    norm_c_3 = BatchNormalization(name='norm_c_3')(conv_c_3)
    relu_c_3 = Activation('relu', name='relu_c_3')(norm_c_3)

    # concate level level three and level four decomposition
    concate_level_4 = concatenate([relu_3_2, relu_c_3])
    conv_4 = Conv2D(256, kernel_size=(3, 3), padding='same', name='conv_4')(concate_level_4)
    norm_4 = BatchNormalization(name='norm_4')(conv_4)
    relu_4 = Activation('relu', name='relu_4')(norm_4)

    conv_4_2 = Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding='same', name='conv_4_2')(relu_4)
    norm_4_2 = BatchNormalization(name='norm_4_2')(conv_4_2)
    relu_4_2 = Activation('relu', name='relu_4_2')(norm_4_2)

    conv_5_1 = Conv2D(128, kernel_size=(3, 3), padding='same', name='conv_5_1')(relu_4_2)
    norm_5_1 = BatchNormalization(name='norm_5_1')(conv_5_1)
    relu_5_1 = Activation('relu', name='relu_5_1')(norm_5_1)

    pool_5_1 = AveragePooling2D(pool_size=(7, 7), strides=1, padding='same', name='avg_pool_5_1')(relu_5_1)
    flat_5_1 = Flatten(name='flat_5_1')(pool_5_1) 

    fc_5 = Dense(2048, name='fc_5')(flat_5_1)
    norm_5 = BatchNormalization(name='norm_5')(fc_5)
    relu_5 = Activation('relu', name='relu_5')(norm_5)
    drop_5 = Dropout(0.5, name='drop_5')(relu_5)

    fc_6 = Dense(2048, name='fc_6')(drop_5)
    norm_6 = BatchNormalization(name='norm_6')(fc_6)
    relu_6 = Activation('relu', name='relu_6')(norm_6)
    drop_6 = Dropout(0.5, name='drop_6')(relu_6)

    output = Dense(6, activation='softmax', name='fc_7')(drop_6)

    model = Model(inputs=input_, outputs=output)
    model.summary()
    plot_model(model, to_file='wavelet_cnn_0.5.png')

    return model

In [0]:

model = get_wavelet_cnn_model()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
the_input (InputLayer)          (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
wavelet (Lambda)                [(None, 112, 112, 12 0           the_input[0][0]                  
__________________________________________________________________________________________________
conv_1 (Conv2D)                 (None, 112, 112, 64) 6976        wavelet[0][0]                    
__________________________________________________________________________________________________
norm_1 (BatchNormalization)     (None, 112, 112, 64) 256         conv_1[0][0]                     
____________________________________________________________________________________________

compile the model

In [0]:

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

to get dataset

In [0]:
!git clone https://github.com/namndh/texture-classification

Cloning into 'texture-classification'...
remote: Enumerating objects: 329, done.[K
remote: Total 329 (delta 0), reused 0 (delta 0), pack-reused 329[K
Receiving objects: 100% (329/329), 51.47 MiB | 27.84 MiB/s, done.
Resolving deltas: 100% (43/43), done.


In [0]:
!cp -r /content/texture-classification/data/subdataset /content/

In [0]:
!rm -r '/content/texture-classification'

slpit into train and validation

In [0]:
!pip install split-folders

Collecting split-folders
  Downloading https://files.pythonhosted.org/packages/20/67/29dda743e6d23ac1ea3d16704d8bbb48d65faf3f1b1eaf53153b3da56c56/split_folders-0.3.1-py3-none-any.whl
Installing collected packages: split-folders
Successfully installed split-folders-0.3.1


In [0]:
import split_folders

# Split with a ratio.
# To only split into training and validation set, set a tuple to `ratio`, i.e, `(.8, .2)`.
split_folders.ratio('/content/subdataset', output="data", seed=1337, ratio=(.7, .3))

Copying files: 240 files [00:00, 3328.01 files/s]


In [0]:

train_data_gen = ImageDataGenerator(
	#rescale=1./255,
	shear_range=0.1,
	zoom_range=0.1,
	horizontal_flip=True
)
val_data_gen = ImageDataGenerator(
	#rescale=1./255
)

train_img_dir = '/content/subdataset'
val_img_dir   = '/content/subdataset'

# def train & test generators
train_generator = train_data_gen.flow_from_directory(
	train_img_dir,
	target_size=(224, 224),
	batch_size=8,
	class_mode='categorical')
val_generator = val_data_gen.flow_from_directory(
	val_img_dir,
	target_size=(224, 224),
	batch_size=8,
	class_mode='categorical')

Found 240 images belonging to 6 classes.
Found 240 images belonging to 6 classes.


checkpoinu

In [0]:
from keras.callbacks import ModelCheckpoint, CSVLogger
checkpointer = ModelCheckpoint(filepath='best_model_3class.hdf5', verbose=1, save_best_only=True)
csv_logger = CSVLogger('history.log')

In [0]:
history = model1.fit_generator(
	train_generator,
	steps_per_epoch=240//8,
	epochs=10,
	validation_data=val_generator,
    validation_steps=240//8,
	verbose=True,
  callbacks=[csv_logger, checkpointer])

Epoch 1/10

Epoch 00001: val_loss did not improve from 1.69155
Epoch 2/10

Epoch 00002: val_loss did not improve from 1.69155
Epoch 3/10

Epoch 00003: val_loss did not improve from 1.69155
Epoch 4/10

Epoch 00004: val_loss did not improve from 1.69155
Epoch 5/10

Epoch 00005: val_loss did not improve from 1.69155
Epoch 6/10

Epoch 00006: val_loss improved from 1.69155 to 0.04478, saving model to best_model_3class.hdf5
Epoch 7/10

Epoch 00007: val_loss improved from 0.04478 to 0.00417, saving model to best_model_3class.hdf5
Epoch 8/10

Epoch 00008: val_loss did not improve from 0.00417
Epoch 9/10

Epoch 00009: val_loss did not improve from 0.00417
Epoch 10/10

Epoch 00010: val_loss did not improve from 0.00417


plotting

In [0]:

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
epochs = range(len(acc))

plt.plot(acc, label='training accuracy')
plt.plot(val_acc, label='validation accuracy')
plt.title('Accuracy curve')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend()

In [0]:

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(loss, label='training loss')
plt.plot(val_loss, label='validation loss')
plt.title('Loss curve')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()

In [0]:
model1.save('mymodel_final.hdf5')

In [0]:
from keras import models
model1 = models.load_model('mymodel_final.hdf5',custom_objects={'wavelet': Wavelet})

loading the model

In [0]:
from keras.preprocessing import image
imgpath='/content/subdataset/train/cushion1/cushion1-a-p009.png'
img = image.load_img(imgpath, target_size=(224, 224))
img = image.img_to_array(img)                    
img = np.expand_dims(img, axis=0)         
#img /= 255.                                      

pred = model1.predict(img)
index = np.argmax(pred)
class_list=['canvas','cushion','linseeds','sand','seat','stone']
class_list.sort()


pred_value = class_list[index]
print(pred_value)

cushion


In [0]:
model1.evaluate_generator(train_generator)

[0.43022438883781433, 0.8999999761581421]

In [0]:
model1.metrics_names

['loss', 'accuracy']

In [0]:
model1.predict_generator(train_generator)

array([[3.7603822e-06, 3.6413550e-10, 9.9995303e-01, 3.0406926e-09,
        2.8634000e-05, 1.4534423e-05],
       [7.5721847e-09, 1.3147319e-13, 9.9999964e-01, 1.1096569e-12,
        8.0415834e-09, 3.0453037e-07],
       [2.4640954e-01, 2.5931178e-05, 6.7332844e-06, 7.5345069e-01,
        7.3888418e-06, 9.9638899e-05],
       ...,
       [9.9372619e-01, 4.2251859e-08, 2.5691909e-11, 6.2735379e-03,
        2.1082580e-07, 3.1135292e-10],
       [9.9921227e-01, 2.9356825e-08, 5.7336704e-09, 7.8487687e-04,
        2.8401319e-06, 3.8433150e-09],
       [3.4706753e-01, 3.1768715e-05, 7.6289721e-06, 6.5235561e-01,
        9.8601606e-07, 5.3650944e-04]], dtype=float32)