In [1]:
import tensorflow as tf
import numpy as np

In [2]:
tf.config.experimental_run_functions_eagerly(True)

In [3]:
@tf.function
def is_phi(element):
    tf.debugging.assert_rank(element, 1)
    
    elem_dim = tf.shape(element)[0]
    phi = tf.one_hot(0, elem_dim)
    
    element = tf.math.l2_normalize(element)
    t = tf.tensordot(element, phi, axes=1)

    return t

test1 = tf.Variable([1,0,0], dtype=tf.float32)
test2 = tf.Variable([0,1,0], dtype=tf.float32)
test3 = tf.Variable([.5,.5,0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result1 = is_phi(test1)
    result2 = is_phi(test2)
    result3 = is_phi(test3)

tf.print(result1, tape.gradient(result1, test1))
tf.print(result2, tape.gradient(result2, test2))
tf.print(result3, tape.gradient(result3, test3))

1 [0 0 0]
0 [1 0 0]
0.707106769 [0.707106829 -0.707106709 0]


In [4]:
from library.stacks import stack_push, stack_pop

In [5]:
@tf.function
def safe_push(stack, element, is_phi_fn):
    tf.debugging.assert_rank_at_least(stack[0], 2)
    tf.debugging.assert_rank(stack[1], 1)
    tf.debugging.assert_equal(tf.shape(stack[0])[1:], tf.shape(element))
    tf.debugging.assert_equal(tf.rank(stack[0]) - 1, tf.rank(element) )
    
    t = is_phi_fn(element)
    
    old_buffer, old_index = stack
    new_buffer, new_index = stack_push(stack, element)

    buffer = t * old_buffer + (1 - t) * new_buffer
    index = t * old_index + (1 - t) * new_index

    new_stack = (buffer, index)

    return new_stack

buffer = tf.Variable(tf.zeros((3,3), dtype=tf.float32))
index = tf.Variable(tf.one_hot(0, 3, dtype=tf.float32))
stack = (buffer, index)
original_stack = stack

element1 = tf.Variable([0,1,0], dtype=tf.float32)
element2 = tf.Variable([0.5,0.5,0], dtype=tf.float32)
element3 = tf.Variable([0,0,1], dtype=tf.float32)
element4 = tf.Variable([0,1,0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    stack = safe_push(stack, element1, is_phi)
    stack = safe_push(stack, element2, is_phi)
    stack = safe_push(stack, element3, is_phi)
    stack = safe_push(stack, element4, is_phi)
    
tf.print(stack[0])
tf.print(tf.round(stack[0]))
tf.print(stack[1])
tf.print(tf.round(stack[1]))
tf.print(tape.gradient(stack[0], element3))
tf.print(tape.gradient(stack, original_stack))

[[0 1 0]
 [0.0428932235 0.0428932235 0.707106769]
 [0 0.707106769 0.0857864469]]
[[0 1 0]
 [0 0 1]
 [0 1 0]]
[0.707106769 0.292893231 0]
[1 0 0]
[0.0606601834 0.792893231 0.792893231]
([[0 0 0]
 [0.207106784 0.207106784 0.207106784]
 [0.207106799 0.207106799 0.207106799]], [2.87867975 1.53553391 1.76776707])


In [6]:
buffer = tf.Variable(tf.zeros((3,3), dtype=tf.float32))
index = tf.Variable(tf.one_hot(0, 3, dtype=tf.float32))
stack = (buffer, index)

element1 = tf.Variable([0,1,0], dtype=tf.float32)
element2 = tf.Variable([1,0,0], dtype=tf.float32)
element3 = tf.Variable([0,0,1], dtype=tf.float32)
element4 = tf.Variable([0,1,0], dtype=tf.float32)

original_stack = stack

with tf.GradientTape(persistent=True) as tape:
    stack = safe_push(stack, element1, is_phi)
    stack = safe_push(stack, element2, is_phi)
    stack = safe_push(stack, element3, is_phi)
    stack = safe_push(stack, element4, is_phi)
    
tf.print(stack[0])
tf.print(stack[1])
tf.print(tape.gradient(stack[0], element3))
tf.print(tape.gradient(stack, original_stack))

[[0 1 0]
 [0 0 1]
 [0 1 0]]
[1 0 0]
[-1 1 1]
([[0 0 0]
 [0 0 0]
 [0 0 0]], [4 1 1])


In [7]:
from library.array_ops import tensor_lookup_2d

In [8]:
@tf.function
def production_step(grammar, productions, stack, output, is_phi_fn):
    tf.debugging.assert_rank(grammar[0], 4)
    tf.debugging.assert_rank(grammar[1], 4)
    tf.debugging.assert_rank(productions[0], 2)
    tf.debugging.assert_rank(stack[0], 2)
    tf.debugging.assert_rank(output[0], 2)
    
    G_s, G_o = grammar
    
    # Get next token from stack
    stack, stack_top_token = stack_pop(stack)
    
    # Get next production
    productions, next_prod = stack_pop(productions)

    # Push tokens back onto the stack
    tokens_to_push = tensor_lookup_2d(G_s, next_prod, stack_top_token)
    for token in tokens_to_push:
        stack = safe_push(stack, token, is_phi_fn)
    
    # Push tokens to output
    tokens_to_push = tensor_lookup_2d(G_o, next_prod, stack_top_token)
    for token in tokens_to_push:
        output = safe_push(output, token, is_phi_fn)
    
    return productions, stack, output

TOKEN_DIM = 6
PRODUCTION_DIM = 4
STACK_SIZE = 10
PHI = np.eye(TOKEN_DIM)[0]
S = np.eye(TOKEN_DIM)[1]
O = np.eye(TOKEN_DIM)[2]
T = np.eye(TOKEN_DIM)[3]
X = np.eye(TOKEN_DIM)[4]
PLUS = np.eye(TOKEN_DIM)[5]

E = [PHI, PHI, PHI]

G_s = tf.constant([
    [E, E, E, E, E, E],
    [E, E, E, E, E, E],
    [E, [S, O, T], E, E, E, E],
    [E, [T, PHI, PHI], E, E, E, E],
], dtype=tf.float32)
G_o = tf.constant([
    [E, E, [PLUS, PHI, PHI], [X, PHI, PHI], E, E],
    [E, E, E, E, E, E],
    [E, E, E, E, E, E],
    [E, E, E, E, E, E],
], dtype=tf.float32)
grammar = (G_s, G_o)

buffer1 = tf.zeros((STACK_SIZE, TOKEN_DIM), dtype=tf.float32)
index1 = tf.one_hot(0, STACK_SIZE, dtype=tf.float32)
stack = (buffer1, index1)
stack = safe_push(stack, tf.constant(S, dtype=tf.float32), is_phi)

buffer2 = tf.zeros((STACK_SIZE, TOKEN_DIM), dtype=tf.float32)
index2 = tf.one_hot(0, STACK_SIZE, dtype=tf.float32)
output = (buffer2, index2)

buffer3 = tf.zeros((STACK_SIZE, PRODUCTION_DIM), dtype=tf.float32)
index3 = tf.one_hot(0, STACK_SIZE, dtype=tf.float32)
productions = (buffer3, index3)
productions = safe_push(productions, tf.one_hot(2, PRODUCTION_DIM), is_phi)

with tf.GradientTape(persistent = True) as tape:
    tape.watch(grammar)
    tape.watch(productions)
    tape.watch(stack)
    tape.watch(output)
    
    new_p, new_s, new_o = production_step(grammar, productions, stack, output, is_phi)

# tf.print(new_p[1])
# tf.print(new_s[0][0])
# tf.print(new_s[0][1])
# tf.print(new_s[0][2])
# tf.print(tape.gradient(new_p, productions))
# tf.print(tape.gradient(new_o, output))
# tf.print(tape.gradient(new_s, stack))
# tf.print(tape.gradient(new_s[0], grammar[0]).shape)
# tf.print(tape.gradient(new_s[1], grammar[0]).shape)
# tf.print(tape.gradient(new_o[0], grammar[1]).shape)
# tf.print(tape.gradient(new_o[1], grammar[1]).shape)
# tf.print(tape.gradient(new_s, productions))

In [9]:
# @tf.function
def generate(grammar, productions, stack_shape, S, is_phi):
    top = 0
    stack = tf.constant(tf.zeros(stack_shape), dtype=tf.float32)
    state = (stack, top)
    stack = safe_push(state, S, is_phi)
    
    output_top = 0
    output_buffer = tf.constant(tf.zeros(stack_shape), dtype=tf.float32)
    output = (output_buffer, output_top)

    while top > 0:
        step = production_step(grammar, productions, state, output, is_phi)
        productions, state, output = step
        _, top = state
        
    return output