In [1]:
from inspect import getframeinfo, stack
import tensorflow as tf
import timeit
import numpy as np
import itertools
import t3f

In [2]:
sess = tf.Session()

In [14]:

def my_timeit(tens, sess):
    timings = []
    for rep in range(20):
        best_of_three = np.inf
        for i in range(3):
            start = timeit.default_timer()
            sess.run(tens)
            end = timeit.default_timer()
            best_of_three = min(best_of_three, end - start)
        timings.append(best_of_three)
    return np.mean(timings)
    

def freeze_args(argument_list, sess):
    cheap_args = []
    for argument in argument_list:
        shape = sess.run(tf.shape(argument))
        cheap_args.append(tf.constant(np.random.rand(*shape)))
    return cheap_args


def optimize_einsum(struct, sess):
    subscripts = struct['subscripts']
    pos = subscripts.find('->')
    if pos != -1:
        output_str = subscripts[pos:]
        subscripts = subscripts[:pos]
    else:
        output_str = ''
    argument_strings = np.array(subscripts.split(','))
    
    cheap_args = np.array(struct['cheap_args'])
    num_args = len(cheap_args[0])
    perm_idx = 0
    orders = np.array(list(itertools.permutations(range(num_args))))
    num_orders = orders.shape[0]
    timings_table = np.zeros((num_orders, len(cheap_args)))
    for order_idx in range(num_orders):
        curr_order = orders[order_idx, :]
        curr_einsum_string = ','.join(argument_strings[curr_order])
        curr_einsum_string += output_str
        for i in range(len(cheap_args)):
            curr_tens = tf.einsum(curr_einsum_string, *cheap_args[i][curr_order])
            timings_table[order_idx, i] = my_timeit(curr_tens, sess)
    
    return timings_table, orders
            
def optimizer(f, sess, *args):
    cache = {}
    original_einsum = tf.einsum
    def my_einsum(subscripts, *args):
        caller = getframeinfo(stack()[1][0])
        caller_str = "%s:%d" % (caller.filename, caller.lineno)
        if caller_str in cache:
            if cache[caller_str]['subscripts'] != subscripts:
                raise ValueError('Calling different types of einsum from the same line of code '
                                 'is not supported, %s sometimes calls einsum with argumens "%s"'
                                 'and sometimes with "%s"' % (caller_str, cache[caller_str]['subscripts'],
                                                              subscripts))
            cache[caller_str]['arguments'].append(args)
        else:
            cache[caller_str] = {'subscripts': subscripts, 'arguments': [args]}
        return original_einsum(subscripts, *args)
    tf.einsum = my_einsum
    f_out = f(*args)
    tf.einsum = original_einsum
    print('Found %d einsums.' % len(cache))
    vanilla_whole_runtime = my_timeit(f_out, sess)
    print('The running time of the whole function is %f s' % vanilla_whole_runtime)
    for caller_str in cache:
        subscripts = cache[caller_str]['subscripts']
        arguments = cache[caller_str]['arguments']
        cache[caller_str]['cheap_args'] = []
        cur_timings = np.zeros(len(arguments))
        for i in range(len(arguments)):
            cheap_args = freeze_args(arguments[i], sess)
            cache[caller_str]['cheap_args'].append(cheap_args)
            curr_tens = original_einsum(subscripts, *cheap_args)
            cur_timings[i] = my_timeit(curr_tens, sess)
        cache[caller_str]['timings'] = cur_timings
    vanilla_einsum_runtime = [np.sum(cache[s]['timings']) for s in cache]
    print('Einsums constitue %0.1f %% of the running time of the whole function (%f s).' %
          (100 * np.sum(vanilla_einsum_runtime) / vanilla_whole_runtime, np.sum(vanilla_einsum_runtime)))
    
    worst_einsum_idx = np.argmax([np.max(cache[s]['timings']) for s in cache])
    worst_einsum = list(cache)[worst_einsum_idx]
    vanilla_wors_timings = cache[worst_einsum]['timings']
    print('The slowest einsum (on which we gonna focus) is located in %s and it '
          'constitues %0.1f %% of the running time of the whole function (%f s).' %
          (worst_einsum, 100 * np.sum(vanilla_wors_timings) / vanilla_whole_runtime, np.sum(vanilla_wors_timings)))
    
    print(cache)
    timings_table, orders = optimize_einsum(cache[worst_einsum], sess)
    print(vanilla_wors_timings, timings_table, np.sum(vanilla_wors_timings - timings_table, axis=1))
    absolute_savings = np.sum(vanilla_wors_timings - timings_table, axis=1)
    global_rel_savings = (absolute_savings) / float(vanilla_whole_runtime)
    best_order_idx = np.argmax(global_rel_savings, axis=1)
    best_order = orders[best_order_idx]
    best_improovement = 100 * global_rel_savings[best_order_idx]
    if best_improovement >= 20:
        print('By changing the order of einsum in "%s" to %s you program will run %0.1f %% faster.' % 
              (worst_einsum, best_order, best_improovement))
    else:
        print('Einsum improovements haven\'t found, good work!')


In [15]:
def func(a, b, c):
    res = tf.einsum('ijk,ja,kb->iab', a, b, c) + 1
    res = tf.einsum('iab,kb->iak', res, c)
    return res
a = tf.random_normal((10, 11, 12))
b = tf.random_normal((11, 13))
c = tf.random_normal((12, 14))
# res = func(a, b, c)
optimizer(func, sess, a, b, c)

Found 2 einsums.
The running time of the whole function is 0.000381 s
Einsums constitue 68.7 % of the running time of the whole function (0.000262 s).
The slowest einsum (on which we gonna focus) is located in <ipython-input-15-85e52e0f41da>:2 and it constitues 35.4 % of the running time of the whole function (0.000135 s).
{'<ipython-input-15-85e52e0f41da>:3': {'subscripts': 'iab,kb->iak', 'cheap_args': [[<tf.Tensor 'Const_118:0' shape=(10, 13, 14) dtype=float64>, <tf.Tensor 'Const_119:0' shape=(12, 14) dtype=float64>]], 'arguments': [(<tf.Tensor 'add_5:0' shape=(10, 13, 14) dtype=float32>, <tf.Tensor 'random_normal_38:0' shape=(12, 14) dtype=float32>)], 'timings': array([ 0.00012703])}, '<ipython-input-15-85e52e0f41da>:2': {'subscripts': 'ijk,ja,kb->iab', 'cheap_args': [[<tf.Tensor 'Const_120:0' shape=(10, 11, 12) dtype=float64>, <tf.Tensor 'Const_121:0' shape=(11, 13) dtype=float64>, <tf.Tensor 'Const_122:0' shape=(12, 14) dtype=float64>]], 'arguments': [(<tf.Tensor 'random_normal_36

In [16]:
shape = 7 * np.ones(7)
mat_rank = 10
vec_rank = 10
B = 10
mat = t3f.random_matrix((shape, shape), mat_rank)
what = t3f.random_matrix_batch((shape, None), vec_rank, batch_size=B)
where = t3f.random_matrix((shape, None), vec_rank)

func = lambda what, where, mat: t3f.project_matmul(what, where, mat).op
optimizer(func, sess, what, where, mat)

Found 6 einsums.
The running time of the whole function is 0.148120 s
Einsums constitue 3.2 % of the running time of the whole function (0.004702 s).
The slowest einsum (on which we gonna focus) is located in /Users/alex/projects/t3f/t3f/riemannian.py:435 and it constitues 0.7 % of the running time of the whole function (0.001085 s).
{'/Users/alex/projects/t3f/t3f/riemannian.py:438': {'subscripts': 'saikcb,sbcd->saikd', 'cheap_args': [[<tf.Tensor 'Const_123:0' shape=(10, 1, 7, 1, 10, 10) dtype=float64>, <tf.Tensor 'Const_124:0' shape=(10, 10, 10, 7) dtype=float64>], [<tf.Tensor 'Const_125:0' shape=(10, 7, 7, 1, 10, 10) dtype=float64>, <tf.Tensor 'Const_126:0' shape=(10, 10, 10, 10) dtype=float64>], [<tf.Tensor 'Const_127:0' shape=(10, 10, 7, 1, 10, 10) dtype=float64>, <tf.Tensor 'Const_128:0' shape=(10, 10, 10, 10) dtype=float64>], [<tf.Tensor 'Const_129:0' shape=(10, 10, 7, 1, 10, 10) dtype=float64>, <tf.Tensor 'Const_130:0' shape=(10, 10, 10, 10) dtype=float64>], [<tf.Tensor 'Const_1