In [1]:
import t3f
import numpy as np

In [2]:
import tensorflow as tf

In [3]:
# tf.enable_eager_execution()

In [4]:
w = t3f.random_matrix(([10] * 3, None))
w = t3f.get_variable('w', initializer=w)
A = t3f.random_matrix(([10] * 3, [10] * 3))
A = t3f.get_variable('A', initializer=A)
z = t3f.random_matrix(([10] * 3, None))
z = t3f.get_variable('z', initializer=z)
x = t3f.random_matrix(([10] * 3, None), tt_rank=100)
x = t3f.get_variable('x', initializer=x)

In [5]:
w.is_tt_matrix()

True

In [6]:

def tangent_space_to_deltas(tt):
    if tt.projection_on is None:
        raise ValueError('tt argument is supposed to be a projection, but it lacks projection_on field')
    num_dims = tt.ndims()
    deltas = [None] * num_dims
    if tt.is_tt_matrix():
        for i in range(1, num_dims - 1):
            r1, _, _, r2 = tt.tt_cores[i].get_shape().as_list()
            if int(r1 / 2) != r1 / 2:
                raise ValueError('tt argument is supposed to be a projection, but its ranks are not even.')
            deltas[i] = tt.tt_cores[i][int(r1 / 2):, :, :, :int(r2 / 2)]
        _, _, _, r = tt.tt_cores[0].get_shape().as_list()
        deltas[0] = tt.tt_cores[0][:, :, :, :int(r / 2)]
        r, _, _, _ = tt.tt_cores[num_dims - 1].get_shape().as_list()
        deltas[num_dims - 1] = tt.tt_cores[num_dims - 1][int(r / 2):, :, :, :]
    else:
        for i in range(1, num_dims - 1):
            r1, _, r2 = tt.tt_cores[i].get_shape().as_list()
            if int(r1 / 2) != r1 / 2:
                raise ValueError('tt argument is supposed to be a projection, but its ranks are not even.')
            deltas[i] = tt.tt_cores[i][int(r1 / 2):, :, :int(r2 / 2)]
        _, _, r = tt.tt_cores[0].get_shape().as_list()
        deltas[0] = tt.tt_cores[0][:, :, :int(r / 2)]
        r, _, _ = tt.tt_cores[num_dims - 1].get_shape().as_list()
        deltas[num_dims - 1] = tt.tt_cores[num_dims - 1][int(r / 2):, :, :]
    return deltas

def left_q(X, i):
    """Compute the orthogonal matrix Q_{\leq i} as defined in [1]."""
    if i < 0:
        return np.ones([1, 1], dtype=np.float32)
    answ = np.ones([1, 1])
    for dim in range(i + 1):
        answ = np.tensordot(answ, sess.run(X.tt_cores[dim]), 1)
    answ = np.reshape(answ, (-1, answ.shape[-1]))
    return answ.astype(np.float32)

def right_q(X, i):
    """Compute the orthogonal matrix Q_{\geq i} as defined in [1]."""
    if i > X.ndims() - 1:
        return np.ones([1, 1], dtype=np.float32)
    answ = np.ones([1, 1])
    for dim in range(X.ndims() - 1, i - 1, -1):
        answ = np.tensordot(sess.run(X.tt_cores[dim]), answ, 1)
    answ = np.reshape(answ, (answ.shape[0], -1))
    return answ.T.astype(np.float32)

def deltas_to_tangent_space(deltas, tt, left, right):
    cores = []
    dtype = deltas[0].dtype
    num_dims = left.ndims()
    left_tangent_tt_ranks = t3f.shapes.lazy_tt_ranks(left)
    right_tangent_tt_ranks = t3f.shapes.lazy_tt_ranks(left)
    raw_shape = t3f.shapes.lazy_raw_shape(left)
    right_rank_dim = left.right_tt_rank_dim
    left_rank_dim = left.left_tt_rank_dim
    for i in range(num_dims):
        left_tt_core = left.tt_cores[i]
        right_tt_core = right.tt_cores[i]

        if i == 0:
            tangent_core = tf.concat((deltas[i], left_tt_core),
                                     axis=right_rank_dim)
        elif i == num_dims - 1:
            tangent_core = tf.concat((right_tt_core, deltas[i]),
                                     axis=left_rank_dim)
        else:
            rank_1 = right_tangent_tt_ranks[i]
            rank_2 = left_tangent_tt_ranks[i + 1]
            if tt.is_tt_matrix():
                mode_size_n = raw_shape[0][i]
                mode_size_m = raw_shape[1][i]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size_n = raw_shape[0][i]
                shape = [rank_1, mode_size_n, rank_2]
            zeros = tf.zeros(shape, dtype)
            upper = tf.concat((right_tt_core, zeros), axis=right_rank_dim)
            lower = tf.concat((deltas[i], left_tt_core), axis=right_rank_dim)
            tangent_core = tf.concat((upper, lower), axis=left_rank_dim)
        cores.append(tangent_core)
    tangent = t3f.TensorTrain(cores)
    tangent.projection_on = tt
    return tangent


In [11]:
def _riemannian_grad(func, w, w_projection, left, right):
    h = func(w_projection)
    cores_grad = tf.gradients(h, w_projection.tt_cores)
    deltas = []
    for i in range(w.ndims()):
        if w.is_tt_matrix():
            r1, n, m, r2 = left.tt_cores[i].shape.as_list()
        else:
            r1, n, r2 = left.tt_cores[i].shape.as_list()
        q = tf.reshape(left.tt_cores[i], (-1, r2))
        if w.is_tt_matrix():
            if i == 0:
                curr_grad = cores_grad[i][:, :, :, :r2]
            elif i == w.ndims() - 1:
                curr_grad = cores_grad[i][r1:, :, :, :]
            else:
                curr_grad = cores_grad[i][r1:, :, :, :r2]
        else:
            if i == 0:
                curr_grad = cores_grad[i][:, :, :r2]
            elif i == w.ndims() - 1:
                curr_grad = cores_grad[i][r1:, :, :]
            else:
                curr_grad = cores_grad[i][r1:, :, :r2]
        if i < w.ndims() - 1:
            proj = (tf.eye(r1 * n) - q @ tf.transpose(q))
            delta = proj @ tf.reshape(curr_grad, (-1, r2))
            delta = tf.reshape(delta, left.tt_cores[i].shape)
        else:
            delta = curr_grad
        deltas.append(delta)
    return deltas_to_tangent_space(deltas, w, left, right)
def riemannian_grad(func, w):
    left = t3f.orthogonalize_tt_cores(w)
    right = t3f.orthogonalize_tt_cores(left, left_to_right=False)
    deltas = [right.tt_cores[0]] + [tf.zeros_like(cc) for cc in right.tt_cores[1:]]
    w_projection = deltas_to_tangent_space(deltas, w, left, right)
    return _riemannian_grad(func, w, w_projection, left, right)

def hessian_by_vector(f, w, vector):
    left = t3f.orthogonalize_tt_cores(w)
    right = t3f.orthogonalize_tt_cores(left, left_to_right=False)
    vector_projected = t3f.project(vector, w)
    vector_projected = t3f.expand_batch_dim(vector_projected)
    vector_projected.projection_on = w
    def new_f(new_w):
        grad = _riemannian_grad(f, w, new_w, left, right)
        grad = t3f.expand_batch_dim(grad)
        # TODO: durty hack.
        grad.projection_on = w
        return t3f.pairwise_flat_inner_projected(grad, vector_projected)[0, 0]
    return riemannian_grad(new_f, w)

In [12]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [18]:
func = lambda w: 0.5 * t3f.flat_inner(x, w)**2
desired, actual = sess.run([t3f.full(t3f.flat_inner(x, w) * t3f.project(x, w)), t3f.full(riemannian_grad(func, w))])
np.testing.assert_allclose(desired, actual, rtol=1e-3)

func = lambda w: t3f.quadratic_form(A, w, w)
desired = t3f.project(t3f.matmul(t3f.transpose(A) + A, t3f.project(z, w)), w)
actual = hessian_by_vector(func, w, z)
desired, actual = sess.run([t3f.full(desired), t3f.full(actual)])
np.testing.assert_allclose(desired, actual, rtol=1e-2)

In [26]:
op1 = riemannian_grad(func, w).op
op2 = (t3f.project_matmul(t3f.expand_batch_dim(w), w, A)).op

In [29]:
%timeit sess.run(op1)

115 µs ± 7.24 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [30]:
%timeit sess.run(op2)

106 µs ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
