In [1]:
import numpy as np
import tensorflow as tf
import sys
sys.path.append('..')
import TFUtil
%load_ext autoreload
%autoreload 2

In [2]:
tf.enable_eager_execution()

In [3]:
NEG_INF = -float("inf")

def logsumexp(*args):  # summation in linear space -> LSE in log-space
    """
    Stable log sum exp.
    """
    if all(a == NEG_INF for a in args):
        return NEG_INF
    a_max = max(args)
    lsp = np.log(sum(np.exp(a - a_max)
                   for a in args))
    return a_max + lsp


def log_softmax(acts, axis=None):
    """computes log(softmax(x, axis)) in a numerical stable way."""
    assert axis is not None
    a = acts - np.max(acts, axis=axis, keepdims=True)  # normalize for stability
    probs = np.sum(np.exp(a), axis=axis, keepdims=True)
    log_probs = a - np.log(probs)
    return log_probs

def py_print_iteration_info(msg, var, n, debug=True):
    """adds a tf.print op to the graph while ensuring it will run (when the output is used)."""
    if not debug:
        return var
    var_print = tf.print("n=", n, "\t", msg, tf.shape(var), var, output_stream=sys.stdout)
    with tf.control_dependencies([var_print]):
        var = tf.identity(var)
    return var

In [38]:
acts = np.array(
    [
      [[[0.06535690384862791, 0.7875301411923206, 0.08159176605666074],
        [0.5297155426466327, 0.7506749639230854, 0.7541348379087998],
        [0.6097641124736383, 0.8681404965673826, 0.6225318186056529]],

       [[0.6685222872103057, 0.8580392805336061, 0.16453892311765583],
        [0.989779515236694, 0.944298460961015, 0.6031678586829663],
        [0.9467833543605416, 0.666202507295747, 0.28688179752461884]],

       [[0.09418426230195986, 0.3666735970751962, 0.736168049462793],
        [0.1666804425271342, 0.7141542198635192, 0.3993997272216727],
        [0.5359823524146038, 0.29182076440286386, 0.6126422611507932]],

       [[0.3242405528768486, 0.8007644367291621, 0.5241057606558068],
        [0.779194617063042, 0.18331417220174862, 0.113745182072432],
        [0.24022162381327106, 0.3394695622533106, 0.1341595066017014]]],

      [[[0.5055615569388828, 0.051597282072282646, 0.6402903936686337],
        [0.43073311517251, 0.8294731834714112, 0.1774668847323424],
        [0.3207001991262245, 0.04288308912457006, 0.30280282975568984]],

       [[0.6751777088333762, 0.569537369330242, 0.5584738347504452],
        [0.08313242153985256, 0.06016544344162322, 0.10795752845152584],
        [0.7486153608562472, 0.943918041459349, 0.4863558118797222]],

       [[0.4181986264486809, 0.6524078485043804, 0.024242983423721887],
        [0.13458171554507403, 0.3663418070512402, 0.2958297395361563],
        [0.9236695822497084, 0.6899291482654177, 0.7418981733448822]],

       [[0.25000547599982104, 0.6034295486281007, 0.9872887878887768],
        [0.5926057265215715, 0.8846724004467684, 0.5434495396894328],
        [0.6607698886038497, 0.3771277082495921, 0.3580209022231813]]]])

labels = np.array([[1, 2],
                 [1, 1]])
input_lengths = np.array([4, 3], dtype=np.int32)
label_lengths = np.array([2, 2], dtype=np.int32)
log_probs = log_softmax(acts, axis=3)  # along vocabulary for (B, T, U, V)
n_batch = 2


In [5]:
def select_diagonal_batched(n=0, input_lens=None, label_lens=None):
    """
    Helper function to index various diagonals in a 2D matrix, which can be non-square.
    One diagonal starts from the top-right and goes down to the bottom-left.
    `n=1` indices (with start_row=0, start_col=0):
    [[0,0]]
    `n`=2:
    [[0,1], [1,0]]
    `n`=3:
    [[0,2], [1,1], [2,0]]

    :param n: specifies the diagonal to select
    :param tf.Tensor input_lens:
    :param tf.Tensor label_lens:
    :return: (B, N') tensor of indices
    :rtype: tf.Tensor
    """
    from TFUtil import expand_dims_unbroadcast, sequence_mask
    n_tiled = tf.tile([n], [n_batch])  # (B,)
    diff_t_u = tf.abs(input_lens - label_lens)  # (B,)
    min_t_u = tf.minimum(input_lens, label_lens)  # (B,)
    max_t_u = tf.maximum(input_lens, label_lens)  # (B,)
    
    # diagonal lengths
    #n_prime_a = tf.where(tf.less_equal(n_tiled, min_t_u), n_tiled, tf.zeros_like(n_tiled)-1)
    #n_prime_b = tf.where(tf.greater(n_tiled, min_t_u), min_t_u, tf.zeros_like(n_tiled)-1)
    #n_prime_c = tf.where(tf.greater(n_tiled, min_t_u + diff_t_u), min_t_u + diff_t_u - n_tiled, tf.zeros_like(n_tiled)-1)
    #diag_len = tf.reduce_max(tf.stack([n_prime_a, n_prime_b, n_prime_c]), axis=0)  # (B,)
    #diag_len = py_print_iteration_info("diag len", diag_len, n, debug=True)
    

    batch_idxs = expand_dims_unbroadcast(tf.range(n_batch), 1, n)  # (B, N)
    batch_idxs = tf.reshape(batch_idxs, (-1,))  # (B*N,)
    indices = tf.stack([
        batch_idxs,
        tf.tile(tf.range(0, n), [n_batch]),
        tf.tile(n - tf.range(n) - 1, [n_batch]),
    ], axis=-1)  # (N*B, 3)
    
    # reshape, so that we have for each batch each item in the diag
    indices = tf.reshape(indices, [n_batch, n, 3])  # (B, N, 3)
    
    # mask for phase (b)
    idxs_len_b = tf.where(tf.logical_and(
        tf.greater(n_tiled, min_t_u),
        tf.less_equal(n_tiled, min_t_u + diff_t_u)),
                           min_t_u, n_tiled)
    print("len_b", idxs_len_b)
    idxs_mask_b = tf.where(input_lens > label_lens,
                           tf.sequence_mask(idxs_len_b, maxlen=n),  # T > U
                           tf.reverse(tf.sequence_mask(idxs_len_b, maxlen=n), axis=[0]) # U > T
                          )  # (B, N)
    print("mask_b", idxs_mask_b)
    
    # mask for phase (c)
    idxs_len_c = tf.where(tf.greater(n_tiled, min_t_u + diff_t_u),
                         n_tiled - (min_t_u + diff_t_u),  # phase (c)
                          n_tiled)  # default-case
    idxs_start_c = tf.where(tf.greater(n_tiled, min_t_u + diff_t_u),
                            min_t_u + diff_t_u,
                            tf.ones_like(n_tiled))  # (B,)
    # build mask from slice limits
    range_mat = tf.expand_dims(tf.tile([0], [n_batch]), axis=1) \
    + tf.expand_dims(tf.range(n), axis=0)
    print("range_mat", range_mat)
    idxs_mask_c = tf.where(tf.logical_and(range_mat >= tf.expand_dims(idxs_start_c, axis=1), # (B, 1)
                                          range_mat < tf.expand_dims(idxs_start_c+idxs_len_c, axis=1)),
                       tf.ones_like(range_mat),
                       tf.zeros_like(range_mat)
                      )  # (B, N)
    idxs_mask_c = tf.cast(idxs_mask_c, tf.bool)
    print("start_c", idxs_start_c, "len(b):", idxs_len_c)
    print("mask_c", idxs_mask_c)
    
    

    print("indices pre-mask", indices.shape, indices)
    mask = tf.logical_and(idxs_mask_b, idxs_mask_c)
    idxs = tf.boolean_mask(indices, mask)
    print("indices post-mask", idxs.shape, idxs)
    
    idxs = tf.reshape(idxs, [n_batch, -1, 3])
    
    print("indices post-mask-reshape", idxs.shape, idxs)
    return idxs
idxs = select_diagonal_batched(n=3, input_lens=input_lengths, label_lens=label_lengths)
# (B*N, 3)
# -> (B, N, 3)
#idxs = tf.reshape(idxs, (n_batch, -1, 3))
#idxs = idxs[:, :-1, :]
#idxs = tf.reshape(idxs, (-1, 3))
#idxs = idxs[:, :-1]
 # (B=2, T=4, U=3, V=3)
print("gather from", log_probs[:, :, :, 0].shape)
print("idxs", idxs.shape, idxs)  # (B, N, 2)
# gather: data=(2, 4, 3) using idxs=(2,1,2)
lp_blank = tf.gather_nd(log_probs[:, :, :, 0], idxs)
lp_blank = tf.reshape(lp_blank, (n_batch, -1))
print("lp_blank", lp_blank.shape)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
len_b tf.Tensor([2 2], shape=(2,), dtype=int32)
mask_b tf.Tensor(
[[ True  True False]
 [ True  True False]], shape=(2, 3), dtype=bool)
range_mat tf.Tensor(
[[0 1 2]
 [0 1 2]], shape=(2, 3), dtype=int32)
start_c tf.Tensor([1 1], shape=(2,), dtype=int32) len(b): tf.Tensor([3 3], shape=(2,), dtype=int32)
mask_c tf.Tensor(
[[False  True  True]
 [False  True  True]], shape=(2, 3), dtype=bool)
indices pre-mask (2, 3, 3) tf.Tensor(
[[[0 0 2]
  [0 1 1]
  [0 2 0]]

 [[1 0 2]
  [1 1 1]
  [1 2 0]]], shape=(2, 3, 3), dtype=int32)
indices post-mask (2, 3) tf.Tensor(
[[0 1 1]
 [1 1 1]], shape=(2, 3), dtype=int32)
indices post-mask-reshape (2, 1, 3) tf.Tensor(
[[[0 1 1]]

 [[1 1 1]]], shape=(2, 1, 3), dtype=int32)
gather from (2, 4, 3)
idxs (2, 1, 3) tf.Tensor(
[[[0 1 1]]

 [[1 1 1]]], shape=(2, 1, 3), dtype=int32)
lp_blank (2, 1)


In [40]:
from ref_transduce import forward_pass
from rnnt_tf_impl import numpy_forward as forward_pass_debug
i = 1
alphas, ll_forward = forward_pass_debug(log_probs[i][:input_lengths[i], :label_lengths[i]+1], labels[i], blank_index=0, debug=True)
alphas, ll_forward = forward_pass(log_probs[i][:input_lengths[i], :label_lengths[i]+1], labels[i], blank=0)
print("alpha")
print(alphas)
print("\nll_forward")
print(ll_forward)

U=3, T=3, V=3
t= 1 u= 1: LSE(-1.476 + -1.184, -1.022 +  -1.132) = LSE(-2.660, -2.154) = -1.682
t= 1 u= 2: LSE(-2.261 + -1.008, -1.682 +  -1.122) = LSE(-3.269, -2.804) = -2.317
t= 2 u= 1: LSE(-1.682 + -1.099, -2.048 +  -0.844) = LSE(-2.781, -2.892) = -2.142
t= 2 u= 2: LSE(-2.317 + -1.094, -2.142 +  -1.002) = LSE(-3.410, -3.144) = -2.575
Alpha matrix: (3, 3)
[[ 0.         -1.47617485 -2.2610643 ]
 [-1.02221057 -1.68195323 -2.31674093]
 [-2.04810766 -2.14188269 -2.57539042]]
log-posterior = alpha[2, 2] + log_probs[2, 2, 0] = -2.575 + -0.965 = -3.5406
alpha
[[ 0.         -1.47617485 -2.2610643 ]
 [-1.02221057 -1.68195323 -2.31674093]
 [-2.04810766 -2.14188269 -2.57539042]]

ll_forward
-3.5406081372922227


In [7]:
# for precomputation_col
tf.cumsum(log_probs[:, :, 0, 0], exclusive=False, axis=1)[:,:-1]

<tf.Tensor: id=143, shape=(2, 3), dtype=float64, numpy=
array([[-1.40493705, -2.43911218, -3.87740021],
       [-1.02221057, -2.04810766, -3.12593642]])>

In [8]:
# for precomputation_row
n_target = tf.reduce_max(label_lengths+1)
from TFUtil import expand_dims_unbroadcast
a = expand_dims_unbroadcast(tf.range(n_batch), axis=1, dim=n_target-1)  # (B,U-1)
b = expand_dims_unbroadcast(tf.range(n_target - 1), axis=0, dim=n_batch) # (B, U-1)
c = labels # (B, U-1)
indices_w_labels = tf.stack([a, b, c], axis=-1)   # (B, U-1, 3)
print(indices_w_labels)
# log_probs[:,0,:,:]: (B, U, V)
log_probs_y = tf.gather_nd(log_probs[:,0,:,:], indices_w_labels)
# -> (B, U-1)

tf.Tensor(
[[[0 0 1]
  [0 1 2]]

 [[1 0 1]
  [1 1 1]]], shape=(2, 2, 3), dtype=int32)


In [9]:
tf.cumsum(log_probs_y, exclusive=False, axis=1)

<tf.Tensor: id=176, shape=(2, 2), dtype=float64, numpy=
array([[-0.68276381, -1.71078415],
       [-1.47617485, -2.2610643 ]])>

In [10]:
tf.concat([tf.cast(tf.tile([[0.]], [n_batch,1]), tf.double), tf.cumsum(log_probs_y, exclusive=False, axis=1)], axis=1)

<tf.Tensor: id=184, shape=(2, 3), dtype=float64, numpy=
array([[ 0.        , -0.68276381, -1.71078415],
       [ 0.        , -1.47617485, -2.2610643 ]])>

In [11]:
log_probs_y.shape

TensorShape([Dimension(2), Dimension(2)])

In [17]:
m = tf.constant([
    [0  ,   0,   1,   0, 1],
    [1,   0,   1, 1,   0],
    [0  , 1,     0,   0,   0]])

tmp_indices = tf.where(m)
tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])
#tf.argmin(, axis=0)

<tf.Tensor: id=207, shape=(3,), dtype=int64, numpy=array([2, 0, 1])>

In [41]:
acts.shape

(2, 4, 3, 3)

In [42]:
#B=2, T=4, U=2, V=3

In [45]:
label_lengths

array([2, 2], dtype=int32)

In [48]:
n_batch = 20
n_vocab = 5
max_target = 4
max_input = 8
np.random.seed(42)
labels = np.random.randint(1, n_vocab, (n_batch, max_target-1))
input_lengths = np.random.randint(1, max_input, (n_batch,), dtype=np.int32)
label_lengths = np.random.randint(1, max_target, (n_batch,), dtype=np.int32)
acts = np.random.normal(0, 1, (n_batch, max_input, max_target, n_vocab))
log_probs = log_softmax(acts, axis=3)  # along vocabulary for (B, T, U, V)

In [50]:
i = 5
alphas, ll_forward = forward_pass_debug(log_probs[i][:input_lengths[i], :label_lengths[i]+1], labels[i], blank_index=0, debug=True)

U=2, T=2, V=5
t= 1 u= 1: LSE(-1.810 + -1.891, -1.534 +  -1.054) = LSE(-3.701, -2.587) = -2.303
Alpha matrix: (2, 2)
[[ 0.         -1.80972908]
 [-1.53368601 -2.30339965]]
log-posterior = alpha[1, 1] + log_probs[1, 1, 0] = -2.303 + -1.192 = -3.4954
