In [174]:
import tensorflow as tf
import numpy as np
import better_exchook
better_exchook.install()

In [2]:
tf.enable_eager_execution()

In [30]:
NEG_INF = -float("inf")

def logsumexp(*args):  # summation in linear space -> LSE in log-space
    """
    Stable log sum exp.
    """
    if all(a == NEG_INF for a in args):
        return NEG_INF
    a_max = max(args)
    lsp = np.log(sum(np.exp(a - a_max)
                   for a in args))
    return a_max + lsp


def log_softmax(acts, axis=None):
    """computes log(softmax(x, axis)) in a numerical stable way."""
    assert axis is not None
    a = acts - np.max(acts, axis=axis, keepdims=True)  # normalize for stability
    probs = np.sum(np.exp(a), axis=axis, keepdims=True)
    log_probs = a - np.log(probs)
    return log_probs

In [33]:
n_batch = 3
n_time = 5
n_labels = 7
n_vocab = 4
np.random.seed(42)
acts = np.random.random((n_batch, n_time, n_labels, n_vocab))
labels = np.random.randint(1, n_vocab, (n_batch, n_labels-1))
input_lengths = np.random.randint(1, n_time, (n_batch,), dtype=np.int32)
label_lengths = np.random.randint(1, n_labels-1, (n_batch,), dtype=np.int32)
log_probs = log_softmax(acts, axis=3)  # along vocabulary for (B, T, U, V)

We want to achieve:
$$
\alpha(t,u) = \bar{\alpha}(u, t-u)\\
\beta(t,u) = \bar{\beta}(u, t-u)\\
y(t, u, k) = \bar{y}(u, t-u, k)
$$

In [142]:
# log_probs: (B, T, U, V)
new_y = tf.transpose(new_y, [0,2,1,3])  # (B, U, T, K)
new_y = tf.roll(new_y, shift=[-n_labels], axis=[2])
# this will break when we have uneven seq-lengths

print("new_y", new_y.shape)

new_y (3, 5, 7, 4)


In [69]:
u = 2
t = 3
i = 1  # batch idx
k = 1
print(new_y[i, u, t-u, k])
print(log_probs[i, t, u, k])

tf.Tensor(-1.007450722532022, shape=(), dtype=float64)
-1.4083998261235249


In [180]:
def triangular_shift_matrix(mat, axis, axis_to_expand, batch_axis=0):
    """
    Shifts the matrix in one dimension such that diagonal elements will be in one dimension.
    
    :param mat: matrix (B, ..., dim, ...)
    :param axis: axis to perform the shifting
    :param axis_from: axis from 
    :param int batch_axis:
    :return Tensor of shape (B, ..., dim+dim_expand, ...)
    """
    assert batch_axis == 0
    # mat: (B, T, U, V)
    # axis_to_expand: usually U
    # axis: usually T
    # batch-axis has to be first
    dim_axis = tf.shape(mat)[axis]
    n_batch = tf.shape(mat)[batch_axis]
    #n_vocab = tf.shape(mat)[-1]
    rem_axes = list(tf.shape(mat))
    print("rem_axes", rem_axes)
    rem_axes.pop(axis)
    rem_axes.pop(batch_axis)
    print("rem_axes", rem_axes)
    #rem_axes.remove(axis)
    #rem_axes.remove(batch_axis)
    rem_axes_list = [i+2 for i in range(len(rem_axes))]
    print(rem_axes_list)
    mat = tf.transpose(mat, [0, axis, ] + rem_axes_list)  # (B, axis, ...)
    rem_axes_prod = np.prod(rem_axes)
    mat = tf.reshape(mat, (n_batch, dim_axis, rem_axes_prod))   # (B, axis, *)
    print("reshaped", mat.shape)
    shifts = tf.cast(tf.range(dim_axis), tf.float32)  # (T,)
    #shifts = shifts[tf.newaxis, :, :, tf.newaxis]
    #shifts = tf.tile(shifts, [n_batch, 1, 1, n_vocab])
    shifts = tf.tile(shifts[tf.newaxis, n, tf.newaxis], [n_batch, 1, rem_axes_prod])  # (B, axis, *)
    pads = tf.zeros((n_batch, dim_axis,), dtype=tf.float32)
    # (B, axis, *) ; (B, T, U, V) ; (B, T, T, V)
    # -> (B, T, U+T+1, V)
    a_ranged = tf.concat([shifts, tf.cast(mat, tf.float32), pads], axis=1)
    #U = tf.shape(mat)[axis_to_expand]
    #T = dim_axis
    def fn(x): # x: (B, U+T+1, *)
        shift = tf.cast(x[0][0][0], tf.int32) # (B,)
        # 1:U+1 is the original data, in front: shift as wanted, back: padding for shape
        n = tf.pad(x[:, 1:U+1, :], [[0,0],  # B
                                      [shift, T+1-shift],  # U+T+1
                                      [0,0] # V
                                     ])
        return n
    t = tf.map_fn(fn, elems=tf.transpose(a_ranged, [1,0,2,3]))
    t = tf.transpose(t, [1, 0, 2, 3])
    return t
shifted = triangular_shift_matrix(log_probs, axis=1, axis_to_expand=2)
assert shifted.shape == (n_batch, n_time, n_time+n_labels+1, n_vocab)
print("shifted", shifted.shape)
print_diagonal(log_probs)

rem_axes [<tf.Tensor: id=9504, shape=(), dtype=int32, numpy=3>, <tf.Tensor: id=9508, shape=(), dtype=int32, numpy=5>, <tf.Tensor: id=9512, shape=(), dtype=int32, numpy=7>, <tf.Tensor: id=9516, shape=(), dtype=int32, numpy=4>]
rem_axes [<tf.Tensor: id=9512, shape=(), dtype=int32, numpy=7>, <tf.Tensor: id=9516, shape=(), dtype=int32, numpy=4>]
[2, 3]
reshaped (3, 5, 28)


InvalidArgumentError: Expected multiples argument to be a vector of length 2 but got length 3 [Op:Tile]

In [164]:
def print_diagonal(lp):
    b = 0  # batch-idx
    v = 0  # vocab-idx
    # [0,2], [1,1], [2,0]
    # [0,3], [1,2], [2,1], [3,0]
    for n in range(4,5):
        for i in range(0, n+1):
            j = n - i
            print("n=%d" % n, i, j, lp[b, i, j, v], "%.3f" % (lp[b, i, j, v] - shifted[b, i, n, v].numpy()))
            np.testing.assert_almost_equal(lp[b, i, j, v], shifted[b, i, n, v])
print_diagonal(log_probs)
n = 4
print("shifted", shifted[0, :, 4, 0])

n=4 0 4 -1.4747639375195227 0.000
n=4 1 3 -1.7166377703348124 -0.000
n=4 2 2 -1.5795390703990098 -0.000
n=4 3 1 -1.0868037065231573 -0.000
n=4 4 0 -1.273270601638503 0.000
shifted tf.Tensor([-1.474764  -1.7166377 -1.5795391 -1.0868037 -1.2732706], shape=(5,), dtype=float32)


In [204]:
def shift_matrix_2d(mat, n_time):
    """
    mat: (B, U)
    :param int n_time:
    :param int axis 
    """
    n_batch = tf.shape(mat)[0]
    U = tf.shape(mat)[1]
    shifts = tf.cast(tf.range(n_time), tf.float32)  # (T,)
    shifts = tf.tile(shifts[tf.newaxis, :], [n_batch, 1])  # (B, T)
    #shifts = shifts[tf.newaxis, :]  # (1, T)
    print(shifts)
    pads = tf.zeros((n_batch, n_time), dtype=tf.float32)
    # (B, T) ; (B, U)
    # -> (B, 1+U+T)
    a_ranged = tf.concat([shifts, tf.cast(mat, tf.float32)], axis=1)
    #U = tf.shape(mat)[axis_to_expand]
    #T = tf.shape(mat)[1]

    def fn(x): # x: (B,)
        print("x", x)
        print("original", x[n_time:n_time+U])
        shift = tf.cast(x[0], tf.int32) # scalar
        print("shift:", shift)
        # 1:U+1 is the original data, in front: shift as wanted, back: padding for shape
        n = tf.pad(x[n_time:n_time+U], [[shift, n_time+1-shift],  # U+T+1
                                     ])
        return n
    t = tf.map_fn(fn, elems=a_ranged)
    #t = tf.transpose(t, [1, 0])  # (B, U+T+1)
    return t

labels_shifted = shift_matrix_2d(labels, n_time=n_time)
print(labels.shape)
print(labels_shifted)

tf.Tensor(
[[0. 1. 2. 3. 4.]
 [0. 1. 2. 3. 4.]
 [0. 1. 2. 3. 4.]], shape=(3, 5), dtype=float32)
x tf.Tensor([0. 1. 2. 3. 4. 3. 1. 3. 1. 1. 2.], shape=(11,), dtype=float32)
original tf.Tensor([3. 1. 3. 1. 1. 2.], shape=(6,), dtype=float32)
shift: tf.Tensor(0, shape=(), dtype=int32)
x tf.Tensor([0. 1. 2. 3. 4. 2. 3. 2. 3. 2. 1.], shape=(11,), dtype=float32)
original tf.Tensor([2. 3. 2. 3. 2. 1.], shape=(6,), dtype=float32)
shift: tf.Tensor(0, shape=(), dtype=int32)
x tf.Tensor([0. 1. 2. 3. 4. 3. 3. 1. 3. 3. 2.], shape=(11,), dtype=float32)
original tf.Tensor([3. 3. 1. 3. 3. 2.], shape=(6,), dtype=float32)
shift: tf.Tensor(0, shape=(), dtype=int32)
(3, 6)
tf.Tensor(
[[3. 1. 3. 1. 1. 2. 0. 0. 0. 0. 0. 0.]
 [2. 3. 2. 3. 2. 1. 0. 0. 0. 0. 0. 0.]
 [3. 3. 1. 3. 3. 2. 0. 0. 0. 0. 0. 0.]], shape=(3, 12), dtype=float32)
