In [None]:
import tensorflow as tf
import numpy as np

def rnnt_loss (logits, labels, time_lengths, label_lengths):
  pr = pr_loss(logits, labels, time_lengths, label_lengths)
  ret = tf.reduce_sum(pr)
  return ret

@tf.custom_gradient
def pr_loss (logits, labels, time_lengths, label_lengths):
  LOG_0 = float('-inf')
  batch_size = logits.shape[0]
  max_time_lengths = logits.shape[1]
  max_label_lengths = logits.shape[2]

  def get_alpha (log_pr, labels, time_lengths, label_lengths):
    alpha = LOG_0 * np.ones((batch_size, max_time_lengths, max_label_lengths + 1), dtype=np.float32)
    log_pr = tf.concat([LOG_0 * tf.ones((batch_size, max_time_lengths, 1, tf.shape(log_pr)[-1]), tf.float32), log_pr], axis=-2)
    labels = tf.concat([tf.zeros((batch_size, 1), tf.int32), labels], axis=-1)
    for b in range(batch_size):
      alpha[b][0][1] = 0
      for u in tf.range(2, label_lengths[b] + 1):
        alpha[b][0][u] = alpha[b][0][u - 1] + log_pr[b][0][u][labels[b][u]]

      for t in tf.range(1, time_lengths[b]):
        for u in tf.range(1, label_lengths[b] + 1):
          alpha[b][t][u] = tf.reduce_logsumexp(tf.stack([
              alpha[b][t - 1][u] + log_pr[b][t - 1][u][0],
              alpha[b][t][u - 1] + log_pr[b][t][u - 1][labels[b][u]],
            ]), axis=-1
          )

    return alpha[:, :, 1: ]
  
  def get_beta (log_pr, labels, time_lengths, label_lengths):
    beta = LOG_0 * np.ones((batch_size, max_time_lengths, max_label_lengths + 1), dtype=np.float32)
    log_pr = tf.concat([log_pr, LOG_0 * tf.ones((batch_size, max_time_lengths, 1, tf.shape(log_pr)[-1]), tf.float32)], axis=-2)
    labels = tf.concat([labels, tf.zeros((batch_size, 1), tf.int32)], axis=-1)

    for b in range(batch_size):
      beta[b][-1][-2] = log_pr[b][-1][-2][0]
      for u in tf.reverse(tf.range(label_lengths[b] + 1 - 2), axis=[-1]):
        beta[b][-1][u] = beta[b][-1][u + 1] + log_pr[b][-1][u + 1][labels[b][u]]
      for t in tf.reverse(tf.range(time_lengths[b] - 1), axis=[-1]):
        for u in tf.reverse(tf.range(label_lengths[b] + 1 - 1), axis=[-1]):
          beta[b][t][u] = tf.reduce_logsumexp(tf.stack([
              beta[b][t + 1][u] + log_pr[b][t + 1][u][0],
              beta[b][t][u + 1] + log_pr[b][t][u + 1][labels[b][u]]
            ]), axis=-1
          )
    
    return beta[:, :, : -1]

  log_pr = tf.math.log_softmax(logits)

  alpha = get_alpha(log_pr, labels, time_lengths, label_lengths)
  beta = get_beta(log_pr, labels, time_lengths, label_lengths)
  
  total_log_pr = beta[:, 0, 0]

  vocab_size = logits.shape[-1]
  def grad (upstream):
    ret = np.zeros((batch_size, max_time_lengths, max_label_lengths, vocab_size), dtype=np.float32)
    for b in range(batch_size):
      for t in range(time_lengths[b]):
        for u in range(label_lengths[b]):
          if u + 1 < label_lengths[b]:
            ret[b][t][u][labels[b][u + 1]] = -upstream[b] * tf.math.exp(alpha[b][t][u] + beta[b][t][u + 1] - total_log_pr[b])
          if t + 1 < time_lengths[b]:
            ret[b][t][u][0] = -upstream[b] * tf.math.exp(alpha[b][t][u] + beta[b][t + 1][u] - total_log_pr[b])
    return [tf.convert_to_tensor(ret)] + [None] * 3

  return -total_log_pr, grad
