# Training a wake word detection pipeline from scratch

Wake word detection is a technology used in voice recognition systems, such as virtual assistants and smart speakers, to trigger the system to start listening when a specific word or phrase is spoken.

In this tutorial, you will learn how to train and evaluate a wake word pipeline from scratch.

## Work with predefined wake words
In this configuration, audios containing **wake words have been previously collected**.

In particular, we will a dataset containing two set of audios:
* positive samples represented by audio samples of the word/phrases: "alexa"
* negative samples consisting of **audio where the wakeword/phrase is not present**

### Data preparation
1. We download a repository containing audio samples of the wakeword
2. We download a sample from the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) that will be used as *source* of negative sample

In [1]:
!git clone https://github.com/Picovoice/wake-word-benchmark.git &> /dev/null
!wget https://groups.inf.ed.ac.uk/ami/AMICorpusMirror/amicorpus/TS3003a/audio/TS3003a.Mix-Headset.wav &> /dev/null

^C


In [2]:
import os
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import lightning.pytorch as pl

import torchaudio as ta
import torchaudio.transforms as T

from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader
from IPython.display import Audio

RuntimeError: operator torchvision::nms does not exist

In [3]:
def plot_waveform(waveform, sr, title="Waveform", ax=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sr

    if ax is None:
        _, ax = plt.subplots(num_channels, 1)
    ax.plot(time_axis, waveform[0], linewidth=1)
    ax.grid(True)
    ax.set_xlim([0, time_axis[-1]])
    ax.set_title(title)

def plot_mfcc(fbank, title=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "MFCC")
    axs.imshow(fbank, aspect="auto")
    axs.set_ylabel("mfcc bin")
    axs.set_xlabel("time frame")

### Pytorch dataset handling
1.   We define our dataset for loading and processing 3-seconds audio samples
2.   We implement the training dataloader



In [4]:
def is_valid_file(filename, extensions=[".wav", ".flac"]):
  """Checks if a file is an allowed extension.

  Args:
      filename (string): path to a file
      extensions (tuple of strings): extensions to consider (lowercase)

  Returns:
      bool: True if the filename ends with one of given extensions
  """
  return filename.lower().endswith(extensions if isinstance(extensions, str) \
                                   else tuple(extensions))

def list_directory(target_dir):
  instances = []
  for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
      for fname in sorted(fnames):
        path = os.path.join(root, fname)
        if is_valid_file(path):
          instances.append(path)
  return instances


class WakeWordDataset(Dataset):
  def __init__(self, pos_path, neg_path, max_length=3, sr=16000, evalmode=False):
    pos = [[f, 1] for f in list_directory(pos_path)]
    neg = [[neg_path, 0]] * len(pos)
    self.data = pos + neg
    random.shuffle(self.data)

    self.max_length = max_length * sr
    self.sr = sr
    self.evalmode = evalmode

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    fn, label = self.data[idx]
    audio, sr = ta.load(fn)
    audio = ta.functional.resample(audio, sr, self.sr)

    # Handle too short audio samples
    audiosize = audio.size(1)
    if audiosize <= self.max_length:
      shortage  = self.max_length - audiosize + 1
      audio = F.pad(audio, (0, shortage), mode='replicate')
      audiosize = audio.size(1)

    if self.evalmode:
      startframe = torch.linspace(0, audiosize - self.max_length, 5, dtype=int)
    else:
      startframe = torch.randint(audiosize - self.max_length, (1,))

    audios = []
    if self.evalmode and self.max_length == 0:
        audios.append(audio)
    else:
        for asf in startframe:
            audios.append(audio[:, asf:(asf+self.max_length)])

    audios = torch.stack(audios)
    return audios, label

NameError: name 'Dataset' is not defined

In [None]:
dataset = WakeWordDataset("./wake-word-benchmark/audio/alexa",
                          "./TS3003a.Mix-Headset.wav")
audios_pos, label = dataset[0]
audios_neg, label = dataset[500]

Let's listen to the audio of a positive sample from the dataset...

In [None]:
Audio(audios_pos[0].numpy(), rate=16000)

... and the audio of a negative sample from the dataset

In [None]:
Audio(audios_neg[0].numpy(), rate=16000)

In [None]:
dataloader = DataLoader(dataset, shuffle=False, batch_size=64,
                        num_workers=2, pin_memory=True)

## Model definition

We design a wake word detector using a Convolutional Neural Network (CNN) that takes audio samples represented by Mel-frequency cepstral coefficients (MFCC).

Specifically, we adapt the original ResNet18 pretrained on ImageNet to encode a 40-dimensional MFCC and output a probability distribution over the two classes wake_word/generic_content.


Let's visualize the MFCC representation for a 3-second audio sample

In [None]:
mfcc = T.MFCC(sample_rate=16000, n_mfcc=40, melkwargs={"n_fft": 400,
                                                       "hop_length": 160,
                                                       "n_mels": 40})
mfcc_sample = mfcc(audios_pos[0])
plot_waveform(audios_pos[0], 16000)
plot_mfcc(mfcc_sample[0])
print("Shape of the MFCC representation ", mfcc_sample.shape)

Let's take a look at the ResNet18 architecture for understanding which parts might be changed

In [None]:
resnet18()

Now is the right time to define our model for wake word detection!

In [None]:
class WakeWordModel(nn.Module):
  def __init__(self):
    super(WakeWordModel, self).__init__()

    self.mfcc = T.MFCC(sample_rate=16000, n_mfcc=40,
                       melkwargs={"n_fft": 400,
                                  "hop_length": 160,
                                  "n_mels": 40})
    self.instancenorm = nn.InstanceNorm1d(40)
    self.resnet = resnet18()
    self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 2),
                                  padding=3, bias=False)
    self.resnet.maxpool = nn.Identity()
    self.resnet.avgpool = nn.AdaptiveAvgPool2d((None, 1))
    self.resnet.fc = nn.Linear(5*512, 2)

  def forward(self, x):
    with torch.no_grad():
      x = self.mfcc(x) + 1e-6
      x = self.instancenorm(x).unsqueeze(1).detach()
    x = self.resnet(x)
    return x

In [None]:
model = WakeWordModel()
model(audios_pos[0])

## Training
Now that the data and the model are ready, let's train with `pytorch-ligthning`!

In [None]:
device = 'cpu'
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for e in range(2): # we run the training for two epochs
    # switch to train mode
    model.train()

    for i, (audio, target) in enumerate(dataloader):
        # move data to the same device as model
        audio = audio.to(device, non_blocking=True).view(audio.size(0), -1)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(audio)
        loss = F.cross_entropy(output, target)
        acc = (output.argmax(1) == target).sum() / target.size(0)

        # measure accuracy and record loss
        print("Epoch [%d/2], iter [%d/%d], loss %.4f, acc %.2f" % (e, i,
                                                                   len(dataloader),
                                                                   loss.item(),
                                                                   acc.item()
                                                                   )
        )

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

## Inference

Once trained, the model can be applied to an audio sample...

In [None]:
model.eval()

sample, label = dataset[101]
model(sample[0].to(device)).argmax()
print(model(sample[0].to(device)).argmax(), label)

...or on our own recording

In [None]:
# all imports
from io import BytesIO
from base64 import b64decode
from google.colab import output
from IPython.display import Javascript

RECORD = """
const sleep  = time => new Promise(resolve => setTimeout(resolve, time))
const b2text = blob => new Promise(resolve => {
  const reader = new FileReader()
  reader.onloadend = e => resolve(e.srcElement.result)
  reader.readAsDataURL(blob)
})
var record = time => new Promise(async resolve => {
  stream = await navigator.mediaDevices.getUserMedia({ audio: true })
  recorder = new MediaRecorder(stream)
  chunks = []
  recorder.ondataavailable = e => chunks.push(e.data)
  recorder.start()
  await sleep(time)
  recorder.onstop = async ()=>{
    blob = new Blob(chunks)
    text = await b2text(blob)
    resolve(text)
  }
  recorder.stop()
})
"""

def record(sec=3):
  print("")
  print("Speak Now...")
  display(Javascript(RECORD))
  sec += 1
  s = output.eval_js('record(%d)' % (sec*1000))
  print("Done Recording !")
  b = b64decode(s.split(',')[1])
  return b #byte stream

In [None]:
audio = record(3)

In [None]:
Audio(audio)

In [None]:
import numpy as np


def preprocess_audio(audio, max_length=3*16000):
  # Handle too short audio samples
  audiosize = audio.size(1)
  if audiosize <= max_length:
    shortage  = max_length - audiosize + 1
    audio = F.pad(audio, (0, shortage), mode='replicate')
    audiosize = audio.size(1)

  startframe = torch.linspace(0, audiosize - max_length, 5, dtype=int)

  audios = []
  for asf in startframe:
    audios.append(audio[:, asf:(asf+max_length)])

  audios = torch.cat(audios)
  return audios

np_array = np.frombuffer(audio, dtype=np.int8)
pt_audio = torch.from_numpy(np_array).float()
pt_audio = preprocess_audio(pt_audio[None])
model(pt_audio)

: 

## Improve previous model
*   Add data augmentation to input samples: noise from the RIR dataset, reverberation, and so on.
*   Tune training parameters
*   Improve model architecture

## How to train a model on a custom wake word?