In this demonstration we'll be using a speech recognition data creatred from the Free spoken Digit dataset.

The original dataset contains recordings of spoken digits.

We concatenate random recordings to make the audio and the target text longer.
We treat the text as a sequence of characters and train our transformer model to predict
them with the Log Mel spectrograms of the audio is the input.

We have all the data preparation code inside `speech_recognition_data.py` so here,
we just need to import dataset objects.

In [1]:
from speech_recognition_data import get_data

ds, val_ds = get_data()

This is what's present in one batch of our dataset:

Each batch is a dictionary containing two keys.

- 'src' - contains a tensor of shape (batch_size, max_src_len, num_feats)
  where batch_size = 4, max_src_len = 59 (2 seconds of audio) and num_feats = 80
- 'trg' - contains a tensor of shape (batch_size, max_trg_len) where max_trg_len = 13
  (13 characters). The characters have been replaced by their index in a vocabulary.
  
The target character vocabulary is as follows:
- 0 -> pad token ("-")
- 1,2,...,26 -> "a","b",...,"z"
- 27 -> start token ("<")
- 28 -> end token (">")
- 29 -> space token (" ")

To run this on your own dataset, you need to create a `tf.data.Dataset` object which generates batches
in the same format.

In [4]:
for i in ds.take(1):
    print(i['src'].shape)
    print(i['trg'].shape)
    print(i['trg'])

(4, 59, 80)
(4, 13)
tf.Tensor(
[[27 20  8 18  5  5 29  5  9  7  8 20 28]
 [27  5  9  7  8 20 29 14  9 14  5 28  0]
 [27 14  9 14  5 29  6 15 21 18 28  0  0]
 [27  6 15 21 18 29 20 23 15 28  0  0  0]], shape=(4, 13), dtype=int32)


Now we create a transformer model and train it in few lines of code.

The `DisplayOutputs` callback prints a batch of outputs from the validation data
after every epoch. We can manually verify the quality of outputs improving as the training progresses.

In [5]:
from xformer.tf import Transformer
from xformer.tf.callbacks import DisplayOutputs

model = Transformer(
    input_type="feats",
    nvocab=1000,
    ninp=80,
    nhid=64,
    nhead=2,
    nff=128,
    src_maxlen=59,
    trg_maxlen=12,
    nlayers=2,
    nclasses=30,
)
for i in val_ds.take(1):
    batch = i  # Use the first batch of validation set to display outputs

# vocabulary to convert preedicted indices to characters
idx_to_char = ["-"] + [chr(i + 96) for i in range(1, 27)] + ["<", ">", " "]
cb = DisplayOutputs(batch, idx_to_char)
model.compile(optimizer="adam")
model.fit(ds, callbacks=[cb], epochs=10)

Epoch 1/10
target:     <one six>----
prediction: <four four>--

target:     <six seven>--
prediction: <three three>

target:     <seven three>
prediction: <three three>

target:     <three nine>-
prediction: <five three>-

Epoch 2/10
target:     <one six>----
prediction: <five five>--

target:     <six seven>--
prediction: <seven seven>

target:     <seven three>
prediction: <seven seven>

target:     <three nine>-
prediction: <nine nine>--

Epoch 3/10
target:     <one six>----
prediction: <four four>--

target:     <six seven>--
prediction: <seven seven>

target:     <seven three>
prediction: <seven zero>-

target:     <three nine>-
prediction: <five zero>--

Epoch 4/10
target:     <one six>----
prediction: <one nine>---

target:     <six seven>--
prediction: <zero seven>-

target:     <seven three>
prediction: <seven zero>-

target:     <three nine>-
prediction: <five five>--

Epoch 5/10
target:     <one six>----
prediction: <one nine>---

target:     <six seven>--
prediction: <seven

<tensorflow.python.keras.callbacks.History at 0x15bc2b090>

In [6]:
model.fit(ds, callbacks=[cb], epochs=5)

Epoch 1/5
target:     <one six>----
prediction: <one eight>--

target:     <six seven>--
prediction: <eight seven>

target:     <seven three>
prediction: <zero two>---

target:     <three nine>-
prediction: <three nine>-

Epoch 2/5
target:     <one six>----
prediction: <one eight>--

target:     <six seven>--
prediction: <six seven>--

target:     <seven three>
prediction: <seven two>--

target:     <three nine>-
prediction: <two nine>---

Epoch 3/5
target:     <one six>----
prediction: <one eight>--

target:     <six seven>--
prediction: <six seven>--

target:     <seven three>
prediction: <seven two>--

target:     <three nine>-
prediction: <three nine>-

Epoch 4/5
target:     <one six>----
prediction: <one eight>--

target:     <six seven>--
prediction: <eight seven>

target:     <seven three>
prediction: <zero three>-

target:     <three nine>-
prediction: <three nine>-

Epoch 5/5
target:     <one six>----
prediction: <one six>----

target:     <six seven>--
prediction: <six seven>

<tensorflow.python.keras.callbacks.History at 0x14abff950>