# Transfer Learning in the Audio Domain with Model Maker on Golf Shot Classification 

In this notebook, I use Model Maker for the Audio Domain.

It is part of the [Codelab to Customize an Audio model and deploy on Android](https://codelabs.developers.google.com/codelabs/tflite-audio-classification-custom-model-android).

I use a custom golf shot dataset and export a TFLite model that can be used on a phone, a TensorFlow.JS model that can be used for inference in the browser and also a SavedModel version that you can use for serving.


In [11]:
# Prerequisites
# Model Maker for the Audio domain needs TensorFlow 2.5 to work.
! pip install tflite-model-maker tensorflow==2.6
# While upgrading the numpy version would often solve the issue, it's not always viable. Good example is the case when you're using tensorflow==2.5 which isn't compatible with the newest numpy version (it requires ~=1.19.2).
! pip uninstall pycocotools --yes
! pip install pycocotools --no-binary pycocotools
# Restart the kernel  manually

# Note: If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
# You can find the compatibility matrix in TensorFlow Addon's readme:
# https://github.com/tensorflow/addons

Collecting tflite-model-maker
  Using cached tflite_model_maker-0.4.0-py3-none-any.whl (642 kB)
Collecting tensorflow==2.6
  Using cached tensorflow-2.6.0-cp38-cp38-win_amd64.whl (423.2 MB)
Collecting absl-py>=0.10.0
  Using cached absl_py-1.0.0-py3-none-any.whl (126 kB)
Collecting tensorflow-hub<0.13,>=0.7.0; python_version >= "3"
  Using cached tensorflow_hub-0.12.0-py2.py3-none-any.whl (108 kB)
Collecting sentencepiece>=0.1.91
  Using cached sentencepiece-0.1.96-cp38-cp38-win_amd64.whl (1.1 MB)
Collecting numba==0.53
  Using cached numba-0.53.0-cp38-cp38-win_amd64.whl (2.3 MB)
Collecting tensorflow-datasets>=2.1.0
  Using cached tensorflow_datasets-4.5.2-py3-none-any.whl (4.2 MB)
Collecting flatbuffers==1.12
  Using cached flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting neural-structured-learning>=1.3.1
  Using cached neural_structured_learning-1.3.1-py2.py3-none-any.whl (120 kB)
Collecting tensorflowjs>=2.4.0
  Using cached tensorflowjs-3.15.0-py3-none-any.whl (77 kB)
Coll

ERROR: Could not find a version that satisfies the requirement scann>=1.2.6 (from tflite-model-maker) (from versions: none)
ERROR: No matching distribution found for scann>=1.2.6 (from tflite-model-maker)


Collecting pycocotools
  Using cached pycocotools-2.0.4.tar.gz (106 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
    Preparing wheel metadata: started
    Preparing wheel metadata: finished with status 'done'
Building wheels for collected packages: pycocotools
  Building wheel for pycocotools (PEP 517): started
  Building wheel for pycocotools (PEP 517): finished with status 'error'
Failed to build pycocotools


  ERROR: Command errored out with exit status 1:
   command: 'C:\ProgramData\Anaconda3\python.exe' 'C:\ProgramData\Anaconda3\lib\site-packages\pip\_vendor\pep517\_in_process.py' build_wheel 'C:\Users\MAXIME~1.CAR\AppData\Local\Temp\tmpiz_hvk86'
       cwd: C:\Users\maxime.carpentier\AppData\Local\Temp\pip-install-ctzuksb_\pycocotools
  Complete output (14 lines):
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build\lib.win-amd64-cpython-38
  creating build\lib.win-amd64-cpython-38\pycocotools
  copying pycocotools\coco.py -> build\lib.win-amd64-cpython-38\pycocotools
  copying pycocotools\cocoeval.py -> build\lib.win-amd64-cpython-38\pycocotools
  copying pycocotools\mask.py -> build\lib.win-amd64-cpython-38\pycocotools
  copying pycocotools\__init__.py -> build\lib.win-amd64-cpython-38\pycocotools
  running build_ext
  skipping 'pycocotools\_mask.c' Cython extension (up-to-date)
  building 'pycocotools._mask' extension
  error: Microsoft Visual C+

In [12]:
import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random

from IPython.display import Audio, Image
from scipy.io import wavfile

print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")

ModuleNotFoundError: No module named 'tensorflow'

In [6]:
birds_dataset_folder = tf.keras.utils.get_file('golf_dataset.zip',
                                                'https://www.dropbox.com/s/4qedjfyzghxfzds/golf_dataset.zip?dl=1',
                                                cache_dir='./',
                                                cache_subdir='dataset',
                                                extract=True)
                                                

NameError: name 'tf' is not defined

In [7]:
# @title [Run this] Util functions and data structures.

data_dir = './dataset/golf_dataset/small_golf_dataset'

bird_code_to_name = {
  'driver': 'Driver Strike',
  'inthehole': 'The ball is in the hole',
  'Iron': 'Iron shot',  
  'putter': 'Put hit',
  'wood': "Wood Shot",   
  'Driver': 'Driver Strike',
  'Inthehole': 'The ball is in the hole',
  'iron': 'Iron shot',  
  'Putter': 'Put hit',
  'Wood': "Wood Shot", 
}

birds_images = {
  'driver': 'https://golfworkoutprogram.com/wp-content/uploads/2018/05/golf-driving-tips-943x600.jpg', #  
  'Driver': 'https://golfworkoutprogram.com/wp-content/uploads/2018/05/golf-driving-tips-943x600.jpg', #  
  'inthehole': 'https://cdn.golfmagic.com/field/image/hfd.jpg', # 
  'Iron': 'https://golfswingremedy.com/wp-content/uploads/2020/03/How-to-Swing-a-Golf-Iron-golfswingremedy.com_-1024x683-1.jpeg', #  Elaine R. Wilson, www.naturespicsonline.com
  'putter': 'https://www.golfibiza.com/wp-content/uploads/putt-golf-1.jpg', # 
  'Wood': 'https://golf.com/wp-content/uploads/2021/10/GettyImages-1132705407rr.jpg', # 
  'wood': 'https://golf.com/wp-content/uploads/2021/10/GettyImages-1132705407rr.jpg', # 

}

test_files = os.path.join('/content', data_dir, 'test/*/*.wav')

print(test_files)

def get_random_audio_file():
  test_list = glob.glob(test_files)
  print(test_list)
  random_audio_path = random.choice(test_list)
  print(random_audio_path)
  return random_audio_path


def show_bird_data(audio_path):
  sample_rate, audio_data = wavfile.read(audio_path, 'rb')
  print('======= show_bird_data =======')
  bird_code = audio_path.split('/')[-2]
  print('Code: ' + bird_code)
  print('Name: '+bird_code_to_name[bird_code])
  print(f'Bird name: {bird_code_to_name[bird_code]}')
  print(f'Bird code: {bird_code}')
  display(Image(birds_images[bird_code]))

  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
  plt.title(plttitle)
  plt.plot(audio_data)
  display(Audio(audio_data, rate=sample_rate))

print('functions and data structures created')
                                                

NameError: name 'os' is not defined

### Playing some audio

To have a better understanding about the data, lets listen to a random audio files from the test split.

Note: later in this notebook you'll run inference on this audio for testing

In [8]:
random_audio = get_random_audio_file()
show_bird_data(random_audio)

NameError: name 'get_random_audio_file' is not defined

## Training the Model

When using Model Maker for audio, you have to start with a model spec. This is the base model that your new model will extract information to learn about the new classes. It also affects how the dataset will be transformed to respect the models spec parameters like: sample rate, number of channels.

[YAMNet](https://tfhub.dev/google/yamnet/1) is an audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology.

It's input is expected to be at 16kHz and with 1 channel.

You don't need to do any resampling yourself. Model Maker takes care of that for you.

- `frame_length` is to decide how long each traininng sample is. in this caase EXPECTED_WAVEFORM_LENGTH * 3s

- `frame_steps` is to decide how far appart are the training samples. In this case, the ith sample will start at EXPECTED_WAVEFORM_LENGTH * 6s after the (i-1)th sample.

The reason to set these values is to work around some limitation in real world dataset.

For example, in the bird dataset, birds don't sing all the time. They sing, rest and sing again, with noises in between. Having a long frame would help capture the singing, but setting it too long will reduce the number of samples for training.


In [9]:
spec = audio_classifier.YamNetSpec(
    keep_yamnet_and_custom_heads=True,
    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)

NameError: name 'audio_classifier' is not defined

## Loading the data

Model Maker has the API to load the data from a folder and have it in the expected format for the model spec.

The train and test split are based on the folders. The validation dataset will be created as 20% of the train split.

Note: The `cache=True` is important to make training later faster but it will also require more RAM to hold the data. For the birds dataset that is not a problem since it's only 300MB, but if you use your own data you have to pay attention to it.


In [10]:
train_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
    spec, os.path.join(data_dir, 'test'), cache=True)

NameError: name 'audio_classifier' is not defined

## Training the model

the audio_classifier has the [`create`](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/audio_classifier/create) method that creates a model and already start training it. 

You can customize many parameterss, for more information you can read more details in the documentation.

On this first try you'll use all the default configurations and train for 100 epochs.

Note: The first epoch takes longer than all the other ones because it's when the cache is created. After that each epoch takes close to 1 second.