In [1]:
try:
  # Notebook specific code

  %matplotlib inline
  %tensorflow_version 2.x 
  # !pip install tensorflow-gpu==2.0.0-beta1 -q
  
  !pip install gast==0.2.2 --force-reinstall
  
  try:
    import tensorflow as tf
    import warprnnt_tensorflow
  except:
    !git clone https://github.com/HawkAaron/warp-transducer
    !export CUDA_HOME=/usr/local/cuda; cd warp-transducer; mkdir build; cd build; cmake -DCUDA_TOOLKIT_ROOT_DIR=$CUDA_HOME ..; make
    !cd warp-transducer/tensorflow_binding; pip install -e .
    
    import os
    os.kill(os.getpid(), 9)
except:
  pass

TensorFlow 2.x selected.
Collecting gast==0.2.2
Installing collected packages: gast
  Found existing installation: gast 0.2.2
    Uninstalling gast-0.2.2:
      Successfully uninstalled gast-0.2.2
Successfully installed gast-0.2.2


In [0]:
import tensorflow as tf
import warprnnt_tensorflow

def gpu_config(mem_limit=None, print_gpus=True, log_device_placement=False):
  tf.debugging.set_log_device_placement(log_device_placement)

  gpus = tf.config.experimental.list_physical_devices('GPU')
  if print_gpus:
    print(gpus)
  if mem_limit is not None and gpus:
    try:
      tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*10)])
    except RuntimeError as e:
      # Virtual devices must be set before GPUs have been initialized
      print(e)

This is the code for the On Device Transcription Model, based of [this blog post](https://ai.googleblog.com/2019/03/an-all-neural-on-device-speech.html).

Preparing the data

In [0]:
import os

import urllib.request
import shutil

import tarfile

def download_data():
  url = "http://www.openslr.org/resources/60/dev-clean.tar.gz"
  file_name = "dev-clean.tar.gz"
  dir_name = "LibriTTS/dev-clean"

  if not os.path.isdir(dir_name):
    with urllib.request.urlopen(url) as response, open(file_name, 'wb') as out_file:
      shutil.copyfileobj(response, out_file)
    tar = tarfile.open(file_name)
    tar.extractall()
    tar.close()
    
    os.remove(file_name)
    
  files = []
  for (dirpath, dirnames, filenames) in os.walk(dir_name):
      for filename in filenames:
        filepath = dirpath + "/" + filename
        if filepath.endswith(".wav"):
          files.append(filepath[:-4])
  
  return files

Data processing

In [0]:
import string

labels = [''] + [chr(i) for i in range(ord('a'), ord('z') + 1)] + [' '] + list(string.digits + string.punctuation)
label_lookup = {}
for i in range(len(labels)):
  label_lookup[labels[i]] = i
label_size = len(labels)

def text_normalise(text_data):
  labels_data = [label_lookup[x] for x in text_data.lower()]

  return labels_data

import librosa, librosa.display
import numpy as np

def get_data(filepath, sr=24000):
  filepath = bytes.decode(filepath)
  y, _ = librosa.load(filepath + ".wav", mono=True, sr=sr)
  
  y = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=int(25*sr/1000), hop_length=int(10*sr/1000), n_mels=80)
  y = librosa.power_to_db(y, ref=np.max)
  
  with open(filepath + ".original.txt", 'r') as f:
    text = f.read()
  return y.astype(np.float32), text_normalise(text), len(text)

In [0]:
def get_dataset(files, audio_batches, text_batches, rnnt_batches, caching=False):
  @tf.function
  def data_shapping(audio, text, text_length):
    text = tf.cast(text, tf.int32)
    text_length = tf.cast(text_length, tf.int32)

    audio_shape = tf.shape(audio)
    audio = tf.pad(audio, [[0, 0], [0, 3 - tf.math.floormod(audio_shape[1], 3)]], constant_values = -80.0)
    audio_shape = tf.shape(audio)
    audio = tf.reshape(tf.transpose(audio), (audio_shape[1]//3, 3 * audio_shape[0]))
    return audio, tf.one_hot(tf.concat([[0], text], axis=-1), label_size, dtype=tf.float32), tf.reshape(tf.concat([[tf.cast(tf.shape(audio)[0], tf.int32)], [text_length], text], axis=-1), (-1, 1))
    #(audio, text_embedded, [audio_length, text_length, text])

  dataset = tf.data.Dataset.from_tensor_slices(files)
  dataset = dataset.shuffle(buffer_size=10000)

  # dataset to be both
  dataset = dataset.map(lambda x: tf.numpy_function(get_data, [x], [tf.float32, tf.int64, tf.int64]))
  dataset = dataset.map(data_shapping, 16)
  if caching: dataset.cache()

  @tf.function
  def audio_crop(audio, text_embedded, text_and_length):
    return audio, text_and_length

  # Split audio
  audio_dataset = dataset.map(audio_crop, 16)

  @tf.function
  def text_crop(audio, text_embedded, text_and_length):
    return text_embedded[:-1], text_embedded[1:]

  # Split text
  text_dataset = dataset.map(text_crop, 16)

  # Test dataset split

  test_dataset = dataset.take(1000)

  # Batching split datasets

  text_dataset = text_dataset.padded_batch(text_batches, padded_shapes=((512, label_size), (512, label_size))).prefetch(20)
  audio_dataset = audio_dataset.padded_batch(audio_batches, padded_shapes=((1024, 80*3), (512, 1)), drop_remainder=True).prefetch(20)

  # Batching RNNT train & test datasets

  @tf.function
  def data_grouping(audio, text_embedded, targets):
    return ((audio, text_embedded), targets)

  def dataset_shapping(dataset, batches=1):
  # dataset = dataset.padded_batch(1, padded_shapes=((None, 80*3), (None, ), (), ())).prefetch(2) # 10 loops, best of 3: 393 ms per loop
    dataset = dataset.padded_batch(batches, padded_shapes=((1024, 80*3), (512, label_size), (512, 1))).prefetch(20) # 10 loops, best of 3: 461 ms per loop

    dataset = dataset.map(data_grouping, 2)

    return dataset

  dataset = dataset_shapping(dataset, rnnt_batches)
  test_dataset = dataset_shapping(test_dataset, 1).repeat()
  
  return audio_dataset, text_dataset, dataset, test_dataset

# next(iter(text_dataset))

Max frame size: 1296

Max text length: 419

In [0]:
import math

class Sequence():
  def __init__(self, seq=None, hidden=None, blank=0):
    if seq is None:
      self.g = [] # predictions of phoneme language model
      self.k = [blank] # prediction phoneme label
      self.orig = []
      # self.h = [None] # input hidden vector to phoneme model
      self.h = hidden
      self.logp = 0 # probability of this sequence, in log scale
    else:
      self.g = seq.g[:] # save for prefixsum
      self.k = seq.k[:]
      self.orig = seq.orig[:]
      self.h = seq.h
      self.logp = seq.logp
  
  def resetOrig(self):
    self.orig = self.k

  def __str__(self):
    return 'Prediction: {} ({:.2f})'.format(self.string(), -self.logp)
  
  def string(self):
    return ''.join([labels[i] for i in self.k])

In [0]:
import tqdm

def AudioModel():
  audio_model = tf.keras.Sequential()
  audio_model.add(tf.keras.layers.Masking(mask_value=0.))
  for i in range(8):
    audio_model.add(tf.keras.layers.LSTM(1024,#2048,
                                         return_sequences=True))#, recurrent_activation='hard_sigmoid'))
    audio_model.add(tf.keras.layers.LayerNormalization(2))
  
  return audio_model

def TextModel():
  text_model = tf.keras.Sequential()
  text_model.add(tf.keras.layers.Masking(mask_value=0.))
  for _ in range(5):
    text_model.add(tf.keras.layers.LSTM(1024,
                                        return_sequences=True))#, recurrent_activation='hard_sigmoid'))
    text_model.add(tf.keras.layers.LayerNormalization(2))
    
  return text_model

class EncapsulationModel(tf.keras.Model):
  def __init__(self, sub_model, output_units, softmax=True):
    super(EncapsulationModel, self).__init__()
    
    self.model = tf.keras.Sequential()
    self.model.add(sub_model)
    for _ in range(2):
      self.model.add(tf.keras.layers.LSTM(output_units,
                                         return_sequences=True))
    self.softmax = softmax
    
  @tf.function
  def call(self, x):
    if self.softmax: return tf.nn.softmax(self.model(x))
    return self.model(x)

# def EncapsulationModel(sub_model, output_units, softmax=True):
#   model = tf.keras.Sequential()
#   model.add(sub_model)
#   for _ in range(2):
#     model.add(tf.keras.layers.LSTM(output_units, 
#                                    return_sequences=True))
#   if softmax: model.add(tf.keras.layers.Softmax())

#   return model

class JointModel(tf.keras.Model):
  def __init__(self, output_units):
    super(JointModel, self).__init__()
    
    self.post_audio = tf.keras.Sequential()
    for i in range(3):
      self.post_audio.add(tf.keras.layers.SeparableConv1D(1024, 2, 2))
    self.post_audio.add(tf.keras.layers.Dense(512))
    
    self.post_text = tf.keras.Sequential()
    self.post_text.add(tf.keras.layers.Dense(512))
    
    self.joint_model = tf.keras.Sequential([
          tf.keras.layers.Dense(512),
          tf.keras.layers.Dense(output_units)
    ])
    
  @tf.function
  def call(self, x_audio, x_text):
    x_audio = self.post_audio(x_audio)
    x_text = self.post_text(x_text)
    
    x_audio = tf.expand_dims(x_audio, axis=2) # Batch, Time-Steps, 1, Channels
    x_text = tf.expand_dims(x_text, axis=1) # Batch 1, Graphemes, Channels
    broadcast_shape = tf.broadcast_dynamic_shape(tf.shape(x_audio)[:-1], tf.shape(x_text)[:-1])
    x_audio = tf.broadcast_to(x_audio, tf.concat([broadcast_shape, tf.shape(x_audio)[-1:]], axis=0))
    x_text = tf.broadcast_to(x_text, tf.concat([broadcast_shape, tf.shape(x_text)[-1:]], axis=0))
    x = tf.concat([x_audio, x_text], axis=-1) # Batch, Time-Steps, Graphemes, Channels (Audio_Channels + Text_Channels)
    return tf.nn.log_softmax(self.joint_model(x))
    
class RNNT(tf.keras.Model):
  def __init__(self, audio_model, text_model, joint_model):
    super(RNNT, self).__init__()
    
    self.audio_model = audio_model
    self.text_model = text_model
    self.joint = joint_model
    
  @tf.function
  def call(self, x):
    x_audio, x_text = x
    x_audio = self.audio_model(x_audio) # Batch, Time-Steps, Channels
    x_text = self.text_model(x_text) # Batch, Graphemes, Channels
    return self.joint(x_audio, x_text)
  
  def beam_search(self, audio, width=10, prefix=True, debug=False):
    def forward_step(labels, hidden=None):
      # All labels
      labels = tf.reshape(tf.one_hot([labels], label_size), (1, 1, label_size))
      return self.text_model(labels)[:, -1, :], hidden

    def is_prefix(pref, seq):
      if pref == seq or len(pref) >= len(seq): return False
      for i in range(len(pref)):
        if pref[i] != seq[i]: return False
      return True
    
    def log_aplusb(a, b):
      return max(a, b) + math.log1p(math.exp(-math.fabs(a-b)))

    audio = self.audio_model(audio) # Batch, Time-Steps, Channels
    B = [Sequence()]
    if debug: pbar = tqdm.tqdm(total = audio.shape[1], position=0)
    if debug: print()
    for i in range(audio.shape[1]):
      x = audio[:, i, :]
      if prefix: sorted(B, key=lambda a: len(a.k), reverse=True) # larger sequence first add
      if debug: pbar.update()
      if debug: pbar.set_postfix_str(s = B[0])
      A = B
      B = []
      first_print = False
      if prefix:
        for j in range(len(A)-1):
          for i in range(j+1, len(A)):
            if not is_prefix(A[i].k, A[j].k): continue
            if debug:
              if first_print: print()
              print("  Prefix: {} -> {}".format(A[i].string(), A[j].string()))
            # A[i] -> A[j]
            pred, _ = forward_step(A[i].k[-1], A[i].h)
            idx = len(A[i].k)
            logp = self.joint(x, pred)[0, :]
            curlogp = A[i].logp + float(logp[A[j].k[idx]])
            for k in range(idx, len(A[j].k)-1):
              logp = self.joint(x, A[j].g[k])[0, :]
              curlogp += float(logp[A[j].k[k+1]])
            A[j].logp = log_aplusb(A[j].logp, curlogp)
            
      for i in A:
        i.resetOrig()
      
      while len(A) > 0:
        # y* = most probable in A
        y_hat = max(A, key=lambda a: a.logp)
        # remove y* from A
        A.remove(y_hat)
        # calculate P(k|y_hat, t)
        # get last label and hidden state
        pred, hidden = forward_step(y_hat.k[-1], y_hat.h)
        logp = self.joint(x, pred)[0, :]
        # for k \in vocab
        for k in range(label_size):
          yk = Sequence(y_hat)
          yk.logp += float(logp[k])
          if k == 0:
              B.append(yk) # next move
              if len(yk.g) - len(yk.orig) > 60:
                print(yk.g, yk.orig)
                break
              continue
          # store prediction distribution and last hidden state
          yk.h = hidden; yk.k.append(k); 
          if prefix: yk.g.append(pred)
          A.append(yk)
        y_hat = max(A, key=lambda a: a.logp)
        yb = max(B, key=lambda b: b.logp)
        if len(B) >= width and yb.logp >= y_hat.logp: break

      # beam width
      sorted(B, key=lambda b: b.logp, reverse=True)
      B = B[:width]

    if debug: print()
    # return highest probability sequence
    return B[0].string(), -B[0].logp
  

class SimpleLSTM(tf.keras.Model):
  def __init__(self, output_units):
    super(SimpleLSTM, self).__init__()
    self.model = [
        tf.keras.layers.Dense(256)
    ] + [
        tf.keras.layers.LSTM(1024,
                                 return_sequences=True, recurrent_activation='hard_sigmoid') for _ in range(7)
    ] + [
        tf.keras.layers.Dense(output_units)
    ]
    self.model = tf.keras.Sequential(self.model)
    
  @tf.function
  def call(self, x):
    return tf.nn.softmax(self.model(x))

def model_test(): #TODO: Implement model tests
  pass

In [0]:
def pad_and_stack(data):
  max_data_length = max(map(len, data))
  padded_data = [np.pad(datum, (0, max_data_length - len(datum)), mode='constant', constant_values=0) for datum in data]
  data = np.stack(padded_data, axis=0)
  
  return data

In [0]:
def test():
  iterator = iter(dataset)
  for data, targets in iterator:
    data, targets = next(iterator)
    pred = model(data)
    print(tf.reduce_mean(loss(targets, pred)))
  return loss(targets, pred)

# test()

In [0]:
@tf.function
def y_processing(y):
  y = tf.cast(y, tf.int32)
  audio_length, text_length, text = y[:, 0, 0], y[:, 1, 0], y[:, 2:, 0]
  return audio_length, text_length, text

@tf.function
def ctc_loss(y, x):
  audio_length, text_length, text = y_processing(y)
  loss_value = tf.nn.ctc_loss(text, x, text_length, audio_length, logits_time_major = False)
  loss_value = tf.reduce_mean(loss_value)
  return loss_value

@tf.function
def loss(y, x):
  audio_length, text_length, text = y_processing(y)
  loss_value = warprnnt_tensorflow.rnnt_loss(x, text, tf.cast(tf.math.ceil(audio_length / 2**3), tf.int32), text_length)
  loss_value = tf.reduce_mean(loss_value)
  return loss_value
# #   label_data = pad_and_stack(targets)
# #   label_data = tf.keras.backend.ctc_label_dense_to_sparse(targets, [len(i) for i in targets])
# #   loss_value = tf.reduce_mean(inputs)
#   loss_value = tf.nn.ctc_loss(label_data, pred, label_length=label_length, logit_length=logits_length, logits_time_major=False)
# #   loss_value = tf.keras.backend.ctc_batch_cost(label_data, pred, tf.constant([pred.shape[2] for _ in range(pred.shape[0])], shape=(pred.shape[0], 1)), tf.constant([len(i) for i in targets], shape=(pred.shape[0], 1)))
#   loss_value = tf.reduce_mean(loss_value)
#   return loss_value

class TestCallback(tf.keras.callbacks.Callback):
  def __init__(self, model, dataset, steps=100):
    self.model = model
    self.dataset = iter(dataset.shuffle(10))
    self.steps = steps

  def on_train_batch_begin(self, batch, logs=None):
    if not batch % self.steps and batch:
      datum = next(self.dataset)
      print()
      print(" Predicted: {}\n Real: {}".format(self.model.beam_search(datum[0][0], width=10, debug=True), ''.join([labels[i] for i in datum[1][0, 2:, 0, 0]])))

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[33m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    BLINK = '\033[5m'
  
import shutil

audio_batches = 8
text_batches = 64
rnnt_batches= 4

def main():
  
  gpu_config(1024*12, print_gpus=True, log_device_placement=False)  
  
  files = download_data()
  audio_dataset, text_dataset, dataset, test_dataset = get_dataset(files, audio_batches=audio_batches, text_batches=text_batches, rnnt_batches=rnnt_batches, caching=True)
  
  EPOCHS = 30
  
  # Initialising models
  text_model = TextModel()
  audio_model = AudioModel()
  joint_model = JointModel(label_size)
  
  checkpoint_dir = "checkpoints"
  current_dir = checkpoint_dir + "/current"
  text_checkpoint_path = current_dir + "/cp-text-{batch:04d}.ckpt"
  audio_checkpoint_path = current_dir + "/cp-audio-{batch:04d}.ckpt"
  joint_checkpoint_path = current_dir + "/cp-joint-{batch:04d}.ckpt"
  RNNT_checkpoint_path = checkpoint_dir + "/cp-{epoch}-rnnt.tf"
   
  os.makedirs(current_dir, exist_ok=True)
  
  try:
    text_model.load_weights(tf.train.latest_checkpoint(text_checkpoint_path))
    audio_model.load_weights(tf.train.latest_checkpoint(audio_checkpoint_path))
    joint_model.load_weights(tf.train.latest_checkpoint(joint_checkpoint_path))
    
    print(bcolors.OKGREEN + "Resumed from checkpoints" + bcolors.ENDC)
  except:
    try:
      tmp_model = RNNT(audio_model, text_model, joint_model)
      tmp_model.load_weights(RNNT_checkpoint_path.format(epoch="bak"))
      
      print(bcolors.OKBLUE + "No checkpoints found. Resuming from backup joint model" + bcolors.ENDC)
    except:
      print(bcolors.WARNING + "No checkpoints or backups found. Generating fresh weights" + bcolors.ENDC)
  
  if True: # TODO: Add proper handling for debug messages
    model = RNNT(audio_model, text_model, joint_model) # TODO: Add input_specs so that it doesn't require fit to summarise
    model.compile(loss = loss,
              optimizer = 'nadam')
    model.fit(dataset.take(1))
    model.summary()

  for epoch in range(EPOCHS):
    os.makedirs(current_dir, exist_ok=True)
    
    text_model.save_weights(text_checkpoint_path.format(batch=0), overwrite=True)
    audio_model.save_weights(audio_checkpoint_path.format(batch=0), overwrite=True)
    joint_model.save_weights(joint_checkpoint_path.format(batch=0), overwrite=True)
    
    # Text Training
    encap_model = EncapsulationModel(text_model, label_size)
    
    cp_callback= tf.keras.callbacks.ModelCheckpoint(
        filepath=text_checkpoint_path,
        verbose=0,
        save_weights_only=True,
        save_freq=100)
    
    encap_model.compile(loss='categorical_crossentropy', 
                        optimizer='nadam', 
                        metrics=['categorical_accuracy'])
    
    encap_model.fit(text_dataset,
                    callbacks=[cp_callback, ],
                    epochs=1)
  
    # Audio_Training
    encap_model = EncapsulationModel(audio_model, label_size, False)

    cp_callback= tf.keras.callbacks.ModelCheckpoint(
        filepath=audio_checkpoint_path,
        verbose=0,
        save_weights_only=True,
        save_freq=100)
    
    encap_model.compile(loss=ctc_loss,
                        optimizer='nadam')
    encap_model.fit(audio_dataset,
                    callbacks=[cp_callback,], 
                    epochs=1)

    encap_model = None

    # Regular RNNT Training
    model = RNNT(audio_model, text_model, joint_model)

    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=joint_checkpoint_path, 
        verbose=0,
        save_weights_only=True,
        save_freq=100)

    test_callback = TestCallback(
        model = model,
        dataset = test_dataset
    )

    model.compile(loss = loss,
                  optimizer = 'nadam')

    model.fit(dataset,
              callbacks=[cp_callback, test_callback], 
              epochs=1)

    model.save_weights(RNNT_checkpoint_path.format(epoch="bak"), overwrite=True)
    model.save_weights(RNNT_checkpoint_path.format(epoch=epoch+1))
    shutil.rmtree(current_dir, ignore_errors=True)
    
if __name__ == "__main__":
  main()

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[33mNo checkpoints or backups found. Generating fresh weights[0m
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "rnnt_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_1 (Sequential)    multiple                  63946752  
_________________________________________________________________
sequential (Sequential)      multiple                  38066176  
_________________________________________________________________
joint_model (JointModel)     multiple                  4765254   
Total params: 106,778,182
Trainable params: 106,778,182
Non-trainable params: 0
_________________________________________________________________
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
Instructions

  0%|          | 1/1024 [00:00<00:01, 802.28it/s, Prediction:  (0.00)]


    100/Unknown - 520s 5s/step - loss: 337.2027