# Import Libraries

In [None]:
import os
import pathlib
import glob

import librosa.util
import librosa

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf

from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras import layers
from tensorflow.keras import models
from IPython import display
import tensorflow_datasets as tfds
import datetime, os
from wandb.keras import WandbCallback
import tensorflow_addons as tfa


# Set seed for experiment reproducibility
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

In [None]:
import wandb

# Insert your wandb login key to track metrics
# wandb.login(key='Your login key')


# 35 words speech command dataset

In [None]:
def get_train_val_test_split(root: str, val_file: str, test_file: str):
    """Creates train, val, and test split according to provided val and test files.
    Args:
        root (str): Path to base directory of the dataset.
        val_file (str): Path to file containing list of validation data files.
        test_file (str): Path to file containing list of test data files.

    Returns:
        train_list (list): List of paths to training data items.
        val_list (list): List of paths to validation data items.
        test_list (list): List of paths to test data items.
        train_label (list): List of train labels
        val_label (list): List of val labels
        test_label (list): List of test labels

    """

    ####################
    # Labels
    ####################

    label_list = [label for label in sorted(os.listdir(root)) if
                  os.path.isdir(os.path.join(root, label)) and label[0] != "_"]
    label_map = {idx: label for idx, label in enumerate(label_list)}
    label_to_idx = {v: int(k) for k, v in label_map.items()}

    ###################
    # Split
    ###################

    all_files_set = set()
    for label in label_list:
        all_files_set.update(set(glob.glob(os.path.join(root, label, "*.wav"))))

    with open(val_file, "r") as f:
        val_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip("\n").split("\n")))

    with open(test_file, "r") as f:
        test_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip("\n").split("\n")))

    assert len(val_files_set.intersection(
        test_files_set)) == 0, "Sanity check: No files should be common between val and test."

    all_files_set -= val_files_set
    all_files_set -= test_files_set

    train_list, val_list, test_list = list(all_files_set), list(val_files_set), list(test_files_set)

    print(f"Number of training samples: {len(train_list)}")
    print(f"Number of validation samples: {len(val_list)}")
    print(f"Number of test samples: {len(test_list)}")

    train_label = []
    val_label = []
    test_label = []

    for path in train_list:
        train_label.append(label_to_idx[path.split('/')[-2]])

    for path in val_list:
        val_label.append(label_to_idx[path.split('/')[-2]])

    for path in test_list:
        test_label.append(label_to_idx[path.split('/')[-2]])

    return train_list, val_list, test_list, train_label, val_label, test_label

In [None]:
def time_shift(wav: np.ndarray, sr: int, s_min: float, s_max: float) -> np.ndarray:
    """Time shift augmentation.
    Refer to https://www.kaggle.com/haqishen/augmentation-methods-for-audio#1.-Time-shifting.
    Changed np.r_ to np.hstack for numba support.
    Args:
        wav (np.ndarray): Waveform array of shape (n_samples,).
        sr (int): Sampling rate.
        s_min (float): Minimum fraction of a second by which to shift.
        s_max (float): Maximum fraction of a second by which to shift.
    
    Returns:
        wav_time_shift (np.ndarray): Time-shifted waveform array.
    """

    start = int(np.random.uniform(sr * s_min, sr * s_max))
    if start >= 0:
        wav_time_shift = np.hstack((wav[start:], np.random.uniform(-0.001, 0.001, start)))
    else:
        wav_time_shift = np.hstack((np.random.uniform(-0.001, 0.001, -start), wav[:start]))
    
    return wav_time_shift


def resample(x: np.ndarray, sr: int, r_min: float, r_max: float) -> np.ndarray:
    """Resamples waveform.
    Args:
        x (np.ndarray): Input waveform, array of shape (n_samples, ).
        sr (int): Sampling rate.
        r_min (float): Minimum percentage of resampling.
        r_max (float): Maximum percentage of resampling.
    """

    sr_new = sr * np.random.uniform(r_min, r_max)
    x = librosa.resample(x, sr, sr_new)
    return x, sr_new



def spec_augment(mel_spec: np.ndarray, n_time_masks: int, time_mask_width: int, n_freq_masks: int, freq_mask_width: int):
    """Numpy implementation of spectral augmentation.
    Args:
        mel_spec (np.ndarray): Mel spectrogram, array of shape (n_mels, T).
        n_time_masks (int): Number of time bands.   
        time_mask_width (int): Max width of each time band.
        n_freq_masks (int): Number of frequency bands.
        freq_mask_width (int): Max width of each frequency band.
    Returns:
        mel_spec (np.ndarray): Spectrogram with random time bands and freq bands masked out.
    """
    
    offset, begin = 0, 0

    for _ in range(n_time_masks):
        offset = np.random.randint(0, time_mask_width)
        begin = np.random.randint(0, mel_spec.shape[1] - offset)
        mel_spec[:, begin: begin + offset] = 0.0
    
    for _ in range(n_freq_masks):
        offset = np.random.randint(0, freq_mask_width)
        begin = np.random.randint(0, mel_spec.shape[0] - offset)
        mel_spec[begin: begin + offset, :] = 0.0

    return mel_spec

In [None]:
def sample_generator(data_list: list, label_list: list, augment: bool):
    """
    Generator function to create samples
    :param data: Data list
    :param label_list: Label list
    :return:
    """


    def transform(x, augment=True):
        sr = 16000
        x = librosa.util.fix_length(x, sr)
        x = librosa.feature.melspectrogram(y=x, n_fft=480, win_length=480, hop_length=160, center=False)
        x = librosa.feature.mfcc(S=librosa.power_to_db(x), n_mfcc=40)
        if augment:
            x = spec_augment(mel_spec=x, n_time_masks=2, time_mask_width=25, n_freq_masks=2, freq_mask_width=7)
        x = tf.expand_dims(x, axis=-1)
        #x = tf.tile(x, [1, 1, 3])
        return x

    for audio_file, label in zip(data_list, label_list):
        
        x = librosa.load(audio_file, sr=16000)[0]
        x = transform(x, augment)
        label = tf.convert_to_tensor(label, dtype=tf.int32)
        label = tf.one_hot(label, depth=35)
        yield x, label


# Model training

## Model architecture

In [None]:
import tensorflow as tf


class Patches(tf.keras.layers.Layer):
    """
    Extract patches from images
    """
    def __init__(self, patch_size_w, patch_size_h):
        super(Patches, self).__init__()
        
        self.w = patch_size_w
        self.h = patch_size_h

    def call(self, images):
    
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.w, self.h, 1],
            strides=[1, self.w, self.h, 1],
            rates=[1, 1, 1, 1],
            padding='SAME',
        )
     
        dim = patches.shape[-1]

        patches = tf.reshape(patches, (batch_size, -1, dim))
        return patches


class Mixer(tf.keras.layers.Layer):
    def __init__(self, S, C, DS, DC):
        """
        Mixer layer for the MLP mixer
        :param S: Number of patches
        :param C: Hidden dimension for projection
        :param DS: tunable token mixing hidden width
        :param DC: tunable channel mixing hidden width
        """
        super(Mixer, self).__init__()
        self.layer_norm = tf.keras.layers.LayerNormalization()
        self.S = S
        self.C = C
        self.DS = DS
        self.DC = DC
        
        w_init = tf.random_normal_initializer()

        self.W1 = tf.Variable(initial_value=w_init(shape=(S, DS), dtype="float32"),trainable=True, name='W1')
        self.W2 = tf.Variable(initial_value=w_init(shape=(DS, S), dtype="float32"),trainable=True, name='W2')
        self.W3 = tf.Variable(initial_value=w_init(shape=(C, DC), dtype="float32"),trainable=True, name='W3')
        self.W4 = tf.Variable(initial_value=w_init(shape=(DC, C), dtype="float32"),trainable=True, name='W4')

    def call(self, X):
        """
        Call function of mixer layer
        :param X: Input
        :return:
        """
        X_T = tf.transpose(self.layer_norm(X), perm=(0, 2, 1))
        
        W1X = tf.matmul(X_T, self.W1)
        
        U = X + tf.transpose(tf.matmul(tf.nn.gelu(W1X), self.W2), perm=(0, 2, 1))

        W3U = tf.matmul(self.layer_norm(U), self.W3)
        Y = U + tf.matmul(tf.nn.gelu(W3U), self.W4)

        return Y


class MLPMixer(tf.keras.models.Model):
    def __init__(self, patch_size_w, patch_size_h,  C, DS, DC, num_of_mixer_blocks,  num_classes):
        """
        Creates the Mixer model
        :param patch_size: Patch size
        :param S: number of patches
        :param C: dimension of projection layer
        :param DS: tunable token mixing hidden width
        :param DC: tunable channel mixing hidden width
        :param num_of_mixer_blocks: number of mixer layers
        :param num_classes: number of classes
        """
        super(MLPMixer, self).__init__()
        self.projection = tf.keras.layers.Dense(C)
        self.S = int((40*98)/(patch_size_w * patch_size_h))
        self.mixer = [Mixer(self.S, C, DS, DC,) for _ in range(num_of_mixer_blocks)]
        self.C = C
        self.DS = DS
        self.DC = DC
        self.num_classes = num_classes
        self.patch_w = patch_size_w
        self.patch_h = patch_size_h

        
        self.classification_layer = tf.keras.models.Sequential([
            tf.keras.layers.GlobalAveragePooling1D(),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(self.num_classes, activation='softmax')
        ])

    def call(self, images):
        """
        Call function for MLPMixer model
        :param images: input image
        :return:
        """
        patcher = Patches(self.patch_w, self.patch_h)

        X = patcher(images)
        
        X = self.projection(X)
        
        for block in self.mixer:
            X = block(X)
        
        out = self.classification_layer(X)
        return out

## Learning rate scheduler 


In [None]:
class CosineScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,
                learning_rate_base,
                total_steps,
                warmup_learning_rate=0.0,
                warmup_steps=0):
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.warmup_learning_rate =warmup_learning_rate
        self.warmup_steps = warmup_steps
    
    def __call__(self,step):
        learning_rate = 0.5 * self.learning_rate_base * (1 + tf.cos(
            np.pi * 
            (tf.cast(step, tf.float32) - self.warmup_steps)/ float(self.total_steps-self.warmup_steps)))
        if self.warmup_steps > 0:
            slope = (self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
            warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
            learning_rate = tf.where(step < self.warmup_steps, warmup_rate, learning_rate)
        lr = tf.where(step > self.total_steps, 0.0, learning_rate, name='learning_rate')
        wandb.log({"lr": lr})
        return lr



## Model training function for a single epoch

In [None]:
# Train the model

def model_train(features, labels, model, loss_func,optimizer,train_acc,train_loss):
    # Define the GradientTape context
    with tf.GradientTape() as tape:
        # Get the probabilities
        predictions = model(features)
        # Calculate the loss
        loss = loss_func(labels, predictions)
    # Get the gradients
    gradients = tape.gradient(loss, model.trainable_variables)
    # Update the weights
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # Update the loss and accuracy
    train_loss(loss)
    train_acc(labels, predictions)
    loss = train_loss.result()
    wandb.log({"train_loss": loss.numpy()})

## Model validation function for a single epoch

In [None]:

def model_validate(features, labels,model,loss_func,valid_loss,val_acc):
    predictions = model(features)
    v_loss = loss_func(labels, predictions)

    valid_loss(v_loss)
    val_acc(labels, predictions)
    (val_loss, val_acc) = valid_loss.result(), val_acc.result()
    wandb.log({
               "val_loss": val_loss.numpy()
    })

## Training function

In [None]:
def train(train_dataset,
          val_dataset,
          model,
          optimizer,
          loss_func,
          train_loss,
          train_acc,
          valid_loss,
          valid_acc,
          epochs
         ):
    max_val_acc = 0
    checkpoint_path = "training_1/cp.ckpt"

    for epoch in range(epochs):
        # Run the model through train and test sets respectively
        for (features, labels) in train_dataset:
            model_train(features, labels, model, loss_func,optimizer,train_acc,train_loss)

        for val_features, val_labels in val_dataset:
            model_validate(val_features, val_labels,model,loss_func,valid_loss,valid_acc)

        # Grab the results
        (loss, acc) = train_loss.result(), train_acc.result()
        (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
        if val_acc > max_val_acc:
            max_val_acc= val_acc
            model.save_weights(checkpoint_path)

        # Clear the current state of the metrics
        train_loss.reset_states(), train_acc.reset_states()
        valid_loss.reset_states(), valid_acc.reset_states()

        # Local logging
        template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
        print (template.format(epoch+1,
                             loss,
                             acc,
                             val_loss,
                             val_acc))
        wandb.log({"train_accuracy": acc,
                  "val_accuracy": val_acc})
    wandb.log({"best_val_acc": max_val_acc})


## Initiate training

In [None]:

config_defaults= {
    "C": 256,
    "DS" :128,
    "DC" : 1024,
    "num_of_mixer_blocks" : 8,
    'learning_rate' : 0.0003
}
wandb.init(config=config_defaults,project="MLP-mixer-audio")
train_list, val_list, test_list, train_label, val_label, test_label =  get_train_val_test_split('../input/google-speech-v2', 
                                                                                            '../input/google-speech-v2/validation_list.txt', 
                                                                                            '../input/google-speech-v2/testing_list.txt')
train_dataset = tf.data.Dataset.from_generator(
        sample_generator,
        args=(train_list, train_label,True),
        output_types=(tf.float32, tf.float32),
        output_shapes=((40, 98, 1), (35,))
    )
val_dataset = tf.data.Dataset.from_generator(
        sample_generator,
        args=(val_list, val_label,False),
        output_types=(tf.float32, tf.float32),
        output_shapes=((40, 98, 1), (35,))
    )
test_dataset = tf.data.Dataset.from_generator(
        sample_generator,
        args=(test_list, test_label,False),
        output_types=(tf.float32, tf.float32),
        output_shapes=((40, 98, 1), (35,))
    )
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.shuffle(1024).batch(64).cache()
val_dataset = val_dataset.shuffle(1024).batch(64).cache()


model = MLPMixer(40, 1, C=wandb.config.C, DS=wandb.config.DS, DC=wandb.config.DC, num_of_mixer_blocks=wandb.config.num_of_mixer_blocks, num_classes=35)
learning_rate = CosineScheduler(learning_rate_base=wandb.config.learning_rate, 
                            total_steps=23000, 
                            warmup_learning_rate=0.0, 
                            warmup_steps=1660)
loss_func = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")

# Specify the performance metric
train_acc = tf.keras.metrics.CategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.CategoricalAccuracy(name="valid_acc")

train(train_dataset,
     val_dataset,
      model,
      optimizer,
      loss_func,
      train_loss,
      train_acc,
      valid_loss,
      valid_acc,
     epochs = 50)
test_audio = []
test_labels = []

for audio, label in test_dataset:
    test_audio.append(audio.numpy())
    test_labels.append(label.numpy())

test_audio = np.array(test_audio)
test_labels = np.array(test_labels)

y_pred = np.argmax(model.predict(test_audio), axis=1)
y_true = np.argmax(test_labels,axis=1)

test_acc = sum(y_pred == y_true) / len(y_true)
print(f'Test set accuracy: {test_acc:.0%}')
wandb.log({'test_acc': test_acc})

tf.keras.backend.clear_session()
checkpoint_path = "training_1/cp.ckpt"
new_model = MLPMixer(40, 1, C=wandb.config.C, DS=wandb.config.DS, DC=wandb.config.DC, num_of_mixer_blocks=wandb.config.num_of_mixer_blocks, num_classes=35)
new_model.load_weights(checkpoint_path)

y_pred = np.argmax(new_model.predict(test_audio), axis=1)
y_true = np.argmax(test_labels,axis=1)

test_acc = sum(y_pred == y_true) / len(y_true)
print(f'Test set accuracy: {test_acc:.0%}')
wandb.log({'best_test_acc': test_acc})