# Keyword Transformer (KWT)

<img src="media/keyword_transformer/kwt.png" alt="kwt" width="500"/>

https://arxiv.org/pdf/2104.00769v2.pdf

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf

from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras import layers
from tensorflow.keras import models
from IPython import display

from utils import mel_features

from tqdm.notebook import tqdm


# Set seed for experiment reproducibility
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
# set some paths and variables
data_dir = pathlib.Path('/home/vitto/tensorflow_datasets/speech_commands_v2') # change your path accordingly

labels = ['house', 'tree', 'dog','nine', 'sheila', 'four','seven','backward','wow','stop','eight','on',
 'down','bed','zero','off','six','one','five','two','marvin','forward','up','right','three','cat', 'learn',
 'bird','yes','no','left', 'follow', 'go', 'happy']

# Import the Dataset

Dowload the dataset with 'wget' and copying the download link from ----> http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz

TensorFlow Speech Command dataset is a set of one-second .wav audio files, each containing a single spoken English word. These words are from a small set of commands, and are spoken by a variety of different speakers. 20 of the words are core words, while 10 words are auxiliary words that could act as tests for algorithms in ignoring speeches that do not contain triggers. Included along with the 30 words is a collection of background noise audio files. The dataset was originally designed for limited vocabulary speech recognition tasks. The audio clips were originally collected by Google, and recorded by volunteers in uncontrolled locations around the world.

In [None]:
data_dir.joinpath('README.md').open('r').readlines()

In [None]:
filenames = tf.io.gfile.glob(str(data_dir) + '/*/*')
filenames = tf.random.shuffle(filenames)
num_samples = len(filenames)
print('Number of total examples:', num_samples)
print('Number of examples per label (average):',
      len(tf.io.gfile.listdir(str(data_dir/labels[0]))))
print('Example file tensor:', filenames[0])

## Play random samples

In [None]:
def decode_audio(audio_binary):                                                                                                                                                                                                                                                                                                                                                                                                                                                     
    audio, s_rate = tf.audio.decode_wav(audio_binary)
    return tf.squeeze(audio, axis=-1)

In [None]:
def play_random(folder, n=10):
    files = os.listdir(data_dir.joinpath(folder))
    ch = np.random.randint(0, len(files), n)
    for i in ch:
        audio_binary = tf.io.read_file(data_dir.joinpath(folder, files[i]).as_posix())  # binary
        waveform = decode_audio(audio_binary) # decoded
        display.display(display.Audio(waveform, rate=16000))                   

In [None]:
play_random('zero')

## Load dataset

In [None]:
def get_label(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    return parts[-2]

def get_waveform_and_label(file_path):
    label = get_label(file_path)
    audio_binary = tf.io.read_file(file_path)  # binary
    waveform = decode_audio(audio_binary) # decoded
    return waveform, label

def import_dataset(file_paths):
    waveforms = []
    labels = []
    for file_path in tqdm(file_paths):
        waveform, label = get_waveform_and_label(file_path)
        waveforms.append(waveform)
        labels.append(label)
    return np.array(waveforms), np.array(labels)

In [None]:
X, y = import_dataset(filenames)

In [None]:
print(X.shape, y.shape)

## Visualize the dataset

In [None]:
rows = 3
cols = 3
n = rows*cols
fig, axes = plt.subplots(rows, cols, figsize=(10, 12))
for i, (audio, label) in enumerate(zip(X[:n], y[:n])):
    r = i // cols
    c = i % cols
    ax = axes[r][c]
    ax.plot(audio.numpy())
    label = label.numpy().decode()
    ax.set_title(label)

plt.show()

# Prepare the Dataset

## Extract spectrograms

You'll convert the waveform into a spectrogram, which shows frequency changes over time and can be represented as a 2D image. This can be done by applying the short-time Fourier transform (STFT) to convert the audio into the time-frequency domain.

A Fourier transform (tf.signal.fft) converts a signal to its component frequencies, but loses all time information. The STFT (tf.signal.stft) splits the signal into windows of time and runs a Fourier transform on each window, preserving some time information, and returning a 2D tensor that you can run standard convolutions on.

STFT produces an array of complex numbers representing magnitude and phase. However, you'll only need the magnitude for this tutorial, which can be derived by applying tf.abs on the output of tf.signal.stft.

Choose frame_length and frame_step parameters such that the generated spectrogram "image" is almost square. For more information on STFT parameters choice, you can refer to this video on audio signal processing.

You also want the waveforms to have the same length, so that when you convert it to a spectrogram image, the results will have similar dimensions. This can be done by simply zero padding the audio clips that are shorter than one second.

In [None]:
def get_spectrogram(waveform):
    # Padding for files with less than 16000 samples
    zero_padding = tf.zeros([16000] - tf.shape(waveform), dtype=tf.float32)

    # Concatenate audio with padding so that all audio clips will be of the 
    # same length
    waveform = tf.cast(waveform, tf.float32)

    equal_length = tf.concat([waveform, zero_padding], 0)
    equal_length = equal_length.numpy().flatten()
    spectrogram = 30 * (
        mel_features.log_mel_spectrogram(
        equal_length,
            16000,
            log_offset=0.001,
            window_length_secs=0.025,
            hop_length_secs=0.010,
            num_mel_bins=32,
            lower_edge_hertz=60,
            upper_edge_hertz=3800) - np.log(1e-3))

    return spectrogram

In [None]:
def get_spectrogram_df(X, y):
    audios = []
    for audio in tqdm(X):
        spectrogram = get_spectrogram(audio)
        audios.append(tf.expand_dims(spectrogram, -1))
    return np.array(audios), y

In [None]:
X_pre, y_pre = get_spectrogram_df(X.copy(), y.copy())

### Visualize spectrograms

In [None]:
rows = 3
cols = 3
n = rows*cols
fig, axes = plt.subplots(rows, cols, figsize=(15, 8))
for i, (audio, label) in enumerate(zip(X_pre[:9], y_pre[:9])):
    r = i // cols
    c = i % cols
    ax = axes[r][c]
    ax.imshow(audio.T[0,...,None])
    ax.set_title(label.numpy().decode())

plt.show()

## One-hot encoding

In [None]:
y_enc = np.array([labels.index(l) for l in y_pre])
y_enc = tf.one_hot(y_enc, len(labels))

In [None]:
print(y_enc.shape)

## Split the dataset

In [None]:
def train_test_split(X, y, test_size):
    n_test = int(X.shape[0] * test_size)
    X_test = X[:n_test]
    y_test = y[:n_test]
    X_train = X[n_test:]
    y_train = y[n_test:]
    return X_train, X_test, y_train, y_test

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_pre, y_enc, 0.1)

In [None]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

## Standardize data

In [None]:
def standardize(X, y):
    X -= np.mean(X, axis=1, keepdims=True)
    X /= np.std(X, axis=1, keepdims=True)
    return X, y

In [None]:
X_train_norm, y_train_norm = standardize(X_train, y_train)
X_test_norm, y_test_norm = standardize(X_test, y_test)

print(X_train_norm.shape, y_train_norm.shape)
print(X_test_norm.shape, y_test_norm.shape)

# Build KWT Model

In [None]:
from utils.transformer import TransformerEncoder, PatchClassEmbedding
from utils.tools import CustomSchedule

In [None]:
# model configurations
d_model = 64
d_ff = d_model * 4
n_heads = 1
mlp_head_size = 256
dropout = 0.1
activation = tf.nn.gelu
n_layers = 12

In [None]:
def build_kwt(transformer, input_size):
    # Input
    inputs = tf.keras.layers.Input(shape=input_size)
    
    # Linear Projection of Flattened Patches
    x = tf.keras.layers.Dense(d_model)(inputs)
    
    # Position Embedding + Extra learnable class embedding
    x = PatchClassEmbedding(d_model, input_size[0])(x)
    
    # Transformer Model
    x = transformer(x)
    
    # Take only the Extra Learnable Class
    x = tf.keras.layers.Lambda(lambda x: x[:,0,:])(x)
    
    # MLP Head
    x = tf.keras.layers.Dense(mlp_head_size)(x)
    outputs = tf.keras.layers.Dense(len(labels), activation='softmax')(x)
    
    return tf.keras.models.Model(inputs, outputs)


In [None]:
transformer = TransformerEncoder(d_model, n_heads, d_ff, dropout, activation, n_layers)
model = build_kwt(transformer, input_size=X_train_norm.shape[1:-1])

In [None]:
model.summary()

# Train the Model

In [None]:
# set some variables
batch_size = 64
n_epochs = 50

In [None]:
lr = CustomSchedule(d_model, warmup_steps=20000.0)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0.1),
    metrics=['accuracy'])

In [None]:
history = model.fit(
    X_train_norm, y_train_norm, 
    validation_data=(X_test_norm, y_test_norm),  
    epochs=n_epochs, initial_epoch=0)

In [None]:
metrics = history.history
plt.plot(history.epoch, metrics['loss'], metrics['val_loss'])
plt.legend(['loss', 'val_loss'])
plt.show()

# Test the model

In [None]:
model.evaluate(X_test, y_test)

# Convert to TensorFlow-Lite

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)

In [None]:
# start conversion
tflite_model = converter.convert()

In [None]:
# save model
tflite_model_file = pathlib.Path('bin/model_fp32_kwt.tflite')
tflite_model_file.write_bytes(tflite_model)