In [1]:
import tensorflow as tf
import numpy as np
import better_exchook
better_exchook.install()

In [2]:
tf.enable_eager_execution()

In [3]:
NEG_INF = -float("inf")

def logsumexp(*args):  # summation in linear space -> LSE in log-space
    """
    Stable log sum exp.
    """
    if all(a == NEG_INF for a in args):
        return NEG_INF
    a_max = max(args)
    lsp = np.log(sum(np.exp(a - a_max)
                   for a in args))
    return a_max + lsp


def log_softmax(acts, axis=None):
    """computes log(softmax(x, axis)) in a numerical stable way."""
    assert axis is not None
    a = acts - np.max(acts, axis=axis, keepdims=True)  # normalize for stability
    probs = np.sum(np.exp(a), axis=axis, keepdims=True)
    log_probs = a - np.log(probs)
    return log_probs

In [6]:
n_batch = 3
n_time = 5
n_labels = 7
n_vocab = 4
np.random.seed(42)
acts = np.random.random((n_batch, n_time, n_labels, n_vocab))
labels = np.random.randint(1, n_vocab, (n_batch, n_labels-1))
input_lengths = np.random.randint(1, n_time, (n_batch,), dtype=np.int32)
label_lengths = np.random.randint(1, n_labels-1, (n_batch,), dtype=np.int32)
log_probs = log_softmax(acts, axis=3)  # along vocabulary for (B, T, U, V)

We want to achieve:
$$
\alpha(t,u) = \bar{\alpha}(u, t-u)\\
\beta(t,u) = \bar{\beta}(u, t-u)\\
y(t, u, k) = \bar{y}(u, t-u, k)
$$

In [8]:
def triangular_shift_matrix(mat, axis, axis_to_expand, batch_axis=0):
    """
    Shifts the matrix in one dimension such that diagonal elements will be in one dimension.
    
    :param mat: matrix (B, ..., dim, ...)
    :param axis: axis to perform the shifting
    :param axis_from: axis from 
    :param int batch_axis:
    :return Tensor of shape (B, ..., dim+dim_expand, ...)
    """
    assert batch_axis == 0
    # mat: (B, T, U, V)
    # axis_to_expand: usually U
    # axis: usually T
    # batch-axis has to be first
    dim_axis = tf.shape(mat)[axis]
    n_batch = tf.shape(mat)[batch_axis]
    #n_vocab = tf.shape(mat)[-1]
    rem_axes = list(tf.shape(mat))
    print("rem_axes", rem_axes)
    rem_axes.pop(axis)
    rem_axes.pop(batch_axis)
    print("rem_axes", rem_axes)
    #rem_axes.remove(axis)
    #rem_axes.remove(batch_axis)
    rem_axes_list = [i+2 for i in range(len(rem_axes))]
    print(rem_axes_list)
    mat = tf.transpose(mat, [0, axis, ] + rem_axes_list)  # (B, axis, ...)
    rem_axes_prod = np.prod(rem_axes)
    mat = tf.reshape(mat, (n_batch, dim_axis, rem_axes_prod))   # (B, axis, *)
    print("reshaped", mat.shape)
    shifts = tf.cast(tf.range(dim_axis), tf.float32)  # (T,)
    #shifts = shifts[tf.newaxis, :, :, tf.newaxis]
    #shifts = tf.tile(shifts, [n_batch, 1, 1, n_vocab])
    shifts = tf.tile(shifts[tf.newaxis, n, tf.newaxis], [n_batch, 1, rem_axes_prod])  # (B, axis, *)
    pads = tf.zeros((n_batch, dim_axis,), dtype=tf.float32)
    # (B, axis, *) ; (B, T, U, V) ; (B, T, T, V)
    # -> (B, T, U+T+1, V)
    a_ranged = tf.concat([shifts, tf.cast(mat, tf.float32), pads], axis=1)
    #U = tf.shape(mat)[axis_to_expand]
    #T = dim_axis
    def fn(x): # x: (B, U+T+1, *)
        shift = tf.cast(x[0][0][0], tf.int32) # (B,)
        # 1:U+1 is the original data, in front: shift as wanted, back: padding for shape
        n = tf.pad(x[:, 1:U+1, :], [[0,0],  # B
                                      [shift, T+1-shift],  # U+T+1
                                      [0,0] # V
                                     ])
        return n
    t = tf.map_fn(fn, elems=tf.transpose(a_ranged, [1,0,2,3]))
    t = tf.transpose(t, [1, 0, 2, 3])
    return t
shifted = triangular_shift_matrix(log_probs, axis=1, axis_to_expand=2)
assert shifted.shape == (n_batch, n_time, n_time+n_labels+1, n_vocab)
print("shifted", shifted.shape)
print_diagonal(log_probs)

rem_axes [<tf.Tensor: id=17, shape=(), dtype=int32, numpy=3>, <tf.Tensor: id=21, shape=(), dtype=int32, numpy=5>, <tf.Tensor: id=25, shape=(), dtype=int32, numpy=7>, <tf.Tensor: id=29, shape=(), dtype=int32, numpy=4>]
rem_axes [<tf.Tensor: id=25, shape=(), dtype=int32, numpy=7>, <tf.Tensor: id=29, shape=(), dtype=int32, numpy=4>]
[2, 3]
reshaped (3, 5, 28)


NameError: name 'n' is not defined

In [12]:
def print_diagonal(lp, n):
    b = 0  # batch-idx
    v = 0  # vocab-idx
    # [0,2], [1,1], [2,0]
    # [0,3], [1,2], [2,1], [3,0]
    for i in range(0, n+1):
        j = n - i
        print("n=%d" % n, i, j, lp[b, i, j, v])
        #np.testing.assert_almost_equal(lp[b, i, j, v], shifted[b, i, n, v])
n = 4
print_diagonal(log_probs, n)
print("shifted", shifted[0, :, 4, 0])

n=4 0 4 -1.4747639375195227
n=4 1 3 -1.7166377703348124
n=4 2 2 -1.5795390703990098
n=4 3 1 -1.0868037065231573
n=4 4 0 -1.273270601638503


NameError: name 'shifted' is not defined

In [13]:
def shift_matrix_2d(mat, n_time, batch_dim_axis=0, axis=1, axis_to_shift=2):
    assert batch_dim_axis == 0
    mat = tf.convert_to_tensor(mat)
    shape = tf.shape(mat)
    mat = tf.expand_dims(mat, axis=-1)  # (B, U, 1)
    mat = tf.tile(mat, [1,1, n_time])  # (B, U, T)
    # batch, rows
    B, R, C = tf.meshgrid(
        tf.range(shape[0]),  # (B,)
        tf.range(shape[1]),  # (U,)
        tf.range(n_time)     # (T,)
        ,indexing='ij')
    shifts = tf.range(n_labels-1)  # (T,)
    # (B, U, T) + (1, U, 1)
    C = C + shifts[tf.newaxis, :, tf.newaxis]
    idxs = tf.stack([B,R, C], axis=-1)
    print("idxs", idxs.shape)

    # (B, U, T+U)
    new_shape = [shape[0]]  # (B,)
    new_shape.append(shape[1])
    new_shape.append(shape[1] + n_time)
    print("new shape", [v.numpy() for v in new_shape])
    # idxs: (B, U, U+T, 3)
    scat_mat = tf.scatter_nd(indices=idxs, updates=mat,
                            shape=new_shape)
    return scat_mat

print(labels[0])
labels_shifted = shift_matrix_2d(labels, n_time=n_time)
print(labels_shifted[0])

[3 1 3 1 1 2]
idxs (3, 6, 5, 3)
new shape [3, 6, 11]
tf.Tensor(
[[3 3 3 3 3 0 0 0 0 0 0]
 [0 1 1 1 1 1 0 0 0 0 0]
 [0 0 3 3 3 3 3 0 0 0 0]
 [0 0 0 1 1 1 1 1 0 0 0]
 [0 0 0 0 1 1 1 1 1 0 0]
 [0 0 0 0 0 2 2 2 2 2 0]], shape=(6, 11), dtype=int64)


In [14]:
def shift_rows(mat, shifts, batch_dim_axis=0, axis=1, axis_to_shift=2):
    assert batch_dim_axis == 0
    assert len(shifts.shape) == 1  # per row
    from TFUtil import move_axis
    mat = tf.convert_to_tensor(mat)
    #mat = move_axis(mat, old_axis=axis_to_shift, new_axis=1)  # (B, axis, ...)
    shape = tf.shape(mat)
    print("shape", shape)
    #idxs_range = tf.stack([
    #    tf.range(shape[0]),  # (B,)
    #    tf.range(shape[1]),  # (T,)
    #], axis=-1)  # (B, T, 2)
    # batch, rows, cols
    B, R, C = tf.meshgrid(
        tf.range(shape[0]),  # (B,)
        tf.range(shape[1]),  # (T,)
        tf.range(shape[2])  # (U,)
        ,indexing='ij')
    # [B=3,T=5,U=7] + [1,5,1]
    C = C + shifts[tf.newaxis, :, tf.newaxis]
    idxs = tf.stack([B,R,C], axis=-1)
    print(C[0])
    # idxs are slices into the `mat` matrix
    #print(idxs)
    rem_shape = shape[2:] if len(shape) > 2 else []
    print("idxs", idxs.shape)

    new_shape = [shape[0]]  # (B,)
    new_shape.append(shape[1] + shape[2])
    #new_shape.append(shape[2] + len(shifts))
    new_shape.extend(rem_shape)  # (B, axis-to-shift, axis, ...)
    print("new shape", [v.numpy() for v in new_shape])
    #mat_tr = tf.transpose(mat, (0, 2, 1, 3))  # (B, U, T, V)
    # idxs: (B, U+T, 2)
    scat_mat = tf.scatter_nd(indices=idxs, updates=mat,
                            shape=new_shape)
    
    #shifted_idxs = tf.transpose(idxs[:,:,1] + shifts)
    return scat_mat
# B=3, T=5, U=7, V=4
shifts = tf.range(n_time)  # (U,)
print("log_probs", log_probs.shape)
shifted_mat = shift_rows(log_probs, shifts, axis=1, axis_to_shift=2)
print("shifted", shifted_mat.shape)
print("expected", (n_batch, n_time+n_labels, n_labels, n_vocab))
#assert shifted_mat.shape == (n_batch, n_time+n_labels, n_labels, n_vocab)

print_diagonal(log_probs, n=4)
print("shifted", shifted_mat[0, :, 4, 0])

log_probs (3, 5, 7, 4)
shape tf.Tensor([3 5 7 4], shape=(4,), dtype=int32)
tf.Tensor(
[[ 0  1  2  3  4  5  6]
 [ 1  2  3  4  5  6  7]
 [ 2  3  4  5  6  7  8]
 [ 3  4  5  6  7  8  9]
 [ 4  5  6  7  8  9 10]], shape=(5, 7), dtype=int32)
idxs (3, 5, 7, 3)
new shape [3, 12, 7, 4]
shifted (3, 12, 7, 4)
expected (3, 12, 7, 4)
n=4 0 4 -1.4747639375195227
n=4 1 3 -1.7166377703348124
n=4 2 2 -1.5795390703990098
n=4 3 1 -1.0868037065231573
n=4 4 0 -1.273270601638503
shifted tf.Tensor(
[-1.47476394 -1.71663777 -1.57953907 -1.08680371 -1.2732706   0.
  0.          0.          0.          0.          0.          0.        ], shape=(12,), dtype=float64)
