In [11]:
!pip install -q tensorflow-io
!pip install -q pydub

Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub
Successfully installed pydub-0.25.1


In [143]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_io as tfio

from tensorflow import keras
from tensorflow import lite as tflite
from keras import layers

import matplotlib.pyplot as plt

In [8]:
tfds.disable_progress_bar()

In [15]:
split = ['train', 'validation', 'test']

training_data, validation_date, test_data = tfds.load('crema_d', as_supervised=True, split=split)

In [121]:
input_len = 20000

def preprocess(audio, label):
  audio = tf.cast(audio, tf.float32)

  audio = audio[:input_len]
  zero_padding = tf.zeros(
      [input_len] - tf.shape(audio),
      dtype=tf.float32)
  
  equal_length = tf.concat([audio, zero_padding], 0)
  spectrogram = tfio.audio.spectrogram(
    audio, nfft=512, window=256, stride=256)
  
  spectrogram = tf.abs(spectrogram)
  spectrogram = spectrogram[..., tf.newaxis]

  return spectrogram, label

In [122]:
processed = validation_date.map(preprocess)
for spec, label in processed:
  if spec.shape != (79, 257, 1):
    print(spec.shape)

In [128]:
training_data = training_data.map(preprocess).shuffle(1000).batch(32).cache().prefetch(1)
validation_date = validation_date.map(preprocess).batch(32).cache().prefetch(1)
test_data = test_data.map(preprocess).batch(32).cache().prefetch(1)

In [137]:
input_shape = (79, 257, 1)

model = keras.models.Sequential([
  layers.Input(shape=input_shape, name='input_layer'),
  layers.BatchNormalization(name='bn_0'),
  layers.SeparableConv2D(filters=8, kernel_size=3, padding='same', strides=2, activation='relu', name='conv_1'),
  layers.BatchNormalization(name='bn_1'),
  layers.SeparableConv2D(filters=16, kernel_size=3, padding='same', strides=2, activation='relu', name='conv_2'),
  layers.BatchNormalization(name='bn_2'),
  layers.SeparableConv2D(filters=32, kernel_size=3, padding='same', strides=2, activation='relu', name='conv_3'),
  layers.BatchNormalization(name='bn_3'),
  layers.SeparableConv2D(filters=32, kernel_size=3, padding='same', strides=2, activation='relu', name='conv_4'),

  layers.GlobalMaxPooling2D(name='global_average_pooling'),

  layers.Dense(6, activation='softmax')
], name='audio_model')

model.summary()

Model: "audio_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 bn_0 (BatchNormalization)   (None, 79, 257, 1)        4         
                                                                 
 conv_1 (SeparableConv2D)    (None, 40, 129, 8)        25        
                                                                 
 bn_1 (BatchNormalization)   (None, 40, 129, 8)        32        
                                                                 
 conv_2 (SeparableConv2D)    (None, 20, 65, 16)        216       
                                                                 
 bn_2 (BatchNormalization)   (None, 20, 65, 16)        64        
                                                                 
 conv_3 (SeparableConv2D)    (None, 10, 33, 32)        688       
                                                                 
 bn_3 (BatchNormalization)   (None, 10, 33, 32)        

In [138]:
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    metrics=['accuracy']
)

In [139]:
history = model.fit(training_data, validation_data=validation_date, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f40b9574f90>

In [141]:
model.evaluate(test_data)



[1.473469614982605, 0.39524421095848083]

In [142]:
tf.saved_model.save(model, 'model_v0')

INFO:tensorflow:Assets written to: model_v0/assets


INFO:tensorflow:Assets written to: model_v0/assets


In [151]:
converter = tflite.TFLiteConverter.from_saved_model('model_v0')

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]

# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.representative_dataset = test_data
# converter.inference_input_type = tf.int8  # or tf.uint8
# converter.inference_output_type = tf.int8  # or tf.uint8

tflite_model = converter.convert()



In [152]:
print(len(tflite_model))

15680


In [153]:
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)