In [1]:
import tensorflow as tf

In [2]:
def pretty_print_guess_tensor(const_guess, operand_guess, operator_guess):
    # TODO: const_guess

    s = []

    for t in operand_guess:
        s += [f'x_{tf.argmax(t)}']

    operator_lookup = ['+','-', '*','/']
    result = s[::]
    for i, op_one_hot in enumerate(operator_guess):
        operators = tf.argmax(op_one_hot,axis=-1)
        left = result[::2]
        right = (result[1:] + result[:1])[::2]
        ops = operators[:len(left)]
        result = []
        for l, op, r in zip(left, ops, right):
            result += [f'({l} {operator_lookup[op]} {r})']


    return ' '.join(result)

NUM_LEAVES = 8
NUM_OPERATORS = 4
v1 = tf.range(NUM_LEAVES)
v2 = tf.range(NUM_OPERATORS)

cgv = tf.one_hot(v1 // 2, NUM_LEAVES//2, dtype=tf.float32)
const_guess = tf.concat([cgv, cgv],axis=1)
operand_guess = tf.one_hot(v1, NUM_LEAVES, dtype=tf.float32)
ogv = tf.expand_dims(tf.one_hot(v2, NUM_OPERATORS, dtype=tf.float32), axis=0)
operator_guess = tf.concat([ogv,ogv,ogv], axis=0)

pretty_print_guess_tensor(const_guess, operand_guess, operator_guess)

'(((x_0 + x_1) + (x_2 - x_3)) + ((x_4 * x_5) - (x_6 / x_7)))'

In [3]:
@tf.function
def dot(x, y):
    r = tf.multiply(x, y)
    return tf.reduce_sum(r, -1)

x = tf.constant([
    [2,2,2],
    [3,3,3]
])

dot(x, x)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([12, 27])>

In [4]:
@tf.function
def resolve_values(const_guess, values, operand_guess):
    # TODO: const_guess

    operand_count = tf.shape(operand_guess)[0]
    values = tf.expand_dims(values, axis=0)
    values = tf.tile(values, [operand_count,1])
    
    result = dot(values, operand_guess)

    return result

v1 = tf.range(NUM_LEAVES)
cgv = tf.one_hot(v1 // 2, NUM_LEAVES//2, dtype=tf.float32)
const_guess = tf.concat([cgv, cgv],axis=1)
operand_guess = tf.one_hot(v1, NUM_LEAVES, dtype=tf.float32)
values = tf.cast(v1,dtype=tf.float32)

resolve_values(const_guess, values, operand_guess)

<tf.Tensor: shape=(8,), dtype=float32, numpy=array([0., 1., 2., 3., 4., 5., 6., 7.], dtype=float32)>

In [5]:
@tf.function
def operate(operands, operators):
    left = operands[::2]
    right = tf.roll(operands, shift=-1, axis=0)[::2]

    r_add = left + right
    r_sub = left - right
    r_mul = left * right
    r_div = tf.math.divide_no_nan(left, right)

    r = tf.stack([r_add, r_sub, r_mul, r_div], axis=1)

    return dot(r, operators)

operands = tf.range(NUM_LEAVES, dtype=tf.float32)
v2 = tf.range(NUM_OPERATORS)
operators = tf.constant([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,0],
    [0,0,0,1],
],dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(operands)
    tape.watch(operators)
    result = operate(operands, operators)

tf.print(result)
tf.print(tf.reshape(tape.gradient(result, operands),(2,4)))
tf.print(tape.gradient(result, operators))

[1 -1 20 0.857142866]
[[1 1 1 -1]
 [5 4 0.142857149 -0.122448981]]
[[1 -1 0 0]
 [5 -1 6 0.666666687]
 [9 -1 20 0.8]
 [13 -1 42 0.857142866]]


In [15]:
def eager_process_block(operands, operators_arr):
    acc = operands

    for operators in operators_arr:
        num_operands = tf.shape(acc)[0]
        operators = operators[:num_operands // 2]
        acc = operate(acc, operators)

    return acc

NUM_LEAVES = 8
NUM_OPERATORS = 4
v1 = tf.range(NUM_LEAVES)
v2 = tf.range(NUM_OPERATORS)

cgv = tf.one_hot(v1 // 2, NUM_LEAVES//2, dtype=tf.float32)
const_guess = tf.concat([cgv, cgv],axis=1)
operand_guess = tf.one_hot(v1, NUM_LEAVES, dtype=tf.float32)
ogv = tf.expand_dims(tf.one_hot(v2, NUM_OPERATORS, dtype=tf.float32), axis=0)
operator_guess = tf.concat([ogv,ogv,ogv], axis=0)
values = tf.cast(v1,dtype=tf.float32)
operands = resolve_values(const_guess, values, operand_guess)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(operands)
    tape.watch(operator_guess)
    result = eager_process_block(operands, operator_guess)

tf.print(pretty_print_guess_tensor(const_guess, operand_guess, operator_guess))
x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7 = list(range(8))
tf.print((((x_0 + x_1) + (x_2 - x_3)) + ((x_4 * x_5) - (x_6 / x_7))))
tf.print(result)
tf.print(tape.gradient(result, operands))
tf.print(tape.gradient(result, operator_guess))

(((x_0 + x_1) + (x_2 - x_3)) + ((x_4 * x_5) - (x_6 / x_7)))
19.142857142857142
[19.1428566]
[1 1 1 ... 4 -0.142857149 0.122448981]
[[[1 -1 0 0]
  [5 -1 6 0.666666687]
  [9 -1 20 0.8]
  [-13 1 -42 -0.857142866]]

 [[0 2 -1 -1]
  [20.8571434 19.1428566 17.1428566 23.333334]
  [0 0 0 0]
  [0 0 0 0]]

 [[19.1428566 -19.1428566 0 0]
  [0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]]]


In [19]:
@tf.function
def unrolled_process_block_3(operands, operators_arr):
    acc = operands

    # Level 1
    operators = operators_arr[0]
    operators = operators[:4]
    acc = operate(acc, operators)

    # Level 2
    operators = operators_arr[1]
    operators = operators[:2]
    acc = operate(acc, operators)

    # Level 3
    operators = operators_arr[2]
    operators = operators[:1]
    acc = operate(acc, operators)

    return acc

NUM_LEAVES = 8
NUM_OPERATORS = 4
v1 = tf.range(NUM_LEAVES)
v2 = tf.range(NUM_OPERATORS)

cgv = tf.one_hot(v1 // 2, NUM_LEAVES//2, dtype=tf.float32)
const_guess = tf.concat([cgv, cgv],axis=1)
operand_guess = tf.one_hot(v1, NUM_LEAVES, dtype=tf.float32)
ogv = tf.expand_dims(tf.one_hot(v2, NUM_OPERATORS, dtype=tf.float32), axis=0)
operator_guess = tf.concat([ogv,ogv,ogv], axis=0)
values = tf.cast(v1,dtype=tf.float32)
operands = resolve_values(const_guess, values, operand_guess)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(operands)
    tape.watch(operator_guess)
    result = unrolled_process_block_3(operands, operator_guess)

tf.print(pretty_print_guess_tensor(const_guess, operand_guess, operator_guess))
x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7 = list(range(8))
tf.print((((x_0 + x_1) + (x_2 - x_3)) + ((x_4 * x_5) - (x_6 / x_7))))
tf.print(result)
tf.print(tape.gradient(result, operands))
tf.print(tape.gradient(result, operator_guess))

(((x_0 + x_1) + (x_2 - x_3)) + ((x_4 * x_5) - (x_6 / x_7)))
19.142857142857142
[19.1428566]
[1 1 1 ... 4 -0.142857149 0.122448981]
[[[1 -1 0 0]
  [5 -1 6 0.666666687]
  [9 -1 20 0.8]
  [-13 1 -42 -0.857142866]]

 [[0 2 -1 -1]
  [20.8571434 19.1428566 17.1428566 23.333334]
  [0 0 0 0]
  [0 0 0 0]]

 [[19.1428566 -19.1428566 0 0]
  [0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]]]
