In [1]:
import tensorflow as tf

In [2]:
@tf.function
def is_phi(element):
    elem_dim = tf.shape(element)[0]
    phi = tf.one_hot(0, elem_dim)
    
    element = tf.math.l2_normalize(element)
    t = tf.tensordot(element, phi, axes=1)

    return t

test1 = tf.Variable([1,0,0], dtype=tf.float32)
test2 = tf.Variable([0,1,0], dtype=tf.float32)
test3 = tf.Variable([.5,.5,0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result1 = is_phi(test1)
    result2 = is_phi(test2)
    result3 = is_phi(test3)

tf.print(result1, tape.gradient(result1, test1))
tf.print(result2, tape.gradient(result2, test2))
tf.print(result3, tape.gradient(result3, test3))

1 [0 0 0]
0 [1 0 0]
0.707106769 [0.707106829 -0.707106709 0]


In [3]:
from library.stacks import stack_push, stack_pop

In [4]:
@tf.function
def safe_push(state, element, is_phi_fn):
    t = is_phi_fn(element)

    old_buffer, old_index = state
    new_buffer, new_index = stack_push(state, element)

    buffer = t * old_buffer + (1 - t) * new_buffer
    index = t * old_index + (1 - t) * new_index

    new_state = (buffer, index)

    return new_state

buffer = tf.zeros((3,3), dtype=tf.float32)
index = tf.one_hot(0, 3, dtype=tf.float32)
state = (buffer, index)

element1 = tf.Variable([0,1,0], dtype=tf.float32)
element2 = tf.Variable([1,0,0], dtype=tf.float32)
element3 = tf.Variable([0,0,1], dtype=tf.float32)
element4 = tf.Variable([0,1,0], dtype=tf.float32)

with tf.GradientTape() as tape:
    state = safe_push(state, element1, is_phi)
    state = safe_push(state, element2, is_phi)
    state = safe_push(state, element3, is_phi)
    state = safe_push(state, element4, is_phi)
    
tf.print(state[0])
tf.print(state[1])
tf.print(tape.gradient(state[0], element3))

[[0 1 0]
 [0 0 1]
 [0 1 0]]
[1 0 0]
[-1 1 1]


In [5]:
buffer = tf.zeros((3,3), dtype=tf.float32)
index = tf.one_hot(0, 3, dtype=tf.float32)
state = (buffer, index)

element1 = tf.Variable([0,1,0], dtype=tf.float32)
element2 = tf.Variable([0.5,0.5,0], dtype=tf.float32)
element3 = tf.Variable([0,0,1], dtype=tf.float32)
element4 = tf.Variable([0,1,0], dtype=tf.float32)

with tf.GradientTape() as tape:
    state = safe_push(state, element1, is_phi)
    state = safe_push(state, element2, is_phi)
    state = safe_push(state, element3, is_phi)
    state = safe_push(state, element4, is_phi)
    
tf.print(state[0])
tf.print(tf.round(state[0]))
tf.print(state[1])
tf.print(tf.round(state[1]))
tf.print(tape.gradient(state[0], element3))

[[0 1 0]
 [0.0428932235 0.0428932235 0.707106769]
 [0 0.707106769 0.0857864469]]
[[0 1 0]
 [0 0 1]
 [0 1 0]]
[0.707106769 0.292893231 0]
[1 0 0]
[0.0606601834 0.792893231 0.792893231]


In [6]:
from library.array_ops import tensor_lookup_2d

In [7]:
# @tf.function
def production_step(grammar, productions, state, output, is_phi):
    G_s, G_o = grammar
    
    next_nt = stack_pop(state)
    next_p = stack_pop(productions)

    next_state = tensor_lookup_2d(G_s, next_nt, next_p)
    next_output = tensor_lookup_2d(G_o, next_nt, next_p)

    stack = safe_push(state, next_state, is_phi)
    output = safe_push(output, next_output, is_phi)
    
    return productions, state, output

In [8]:
# @tf.function
def generate(grammar, productions, stack_shape, S, is_phi):
    top = 0
    stack = tf.constant(tf.zeros(stack_shape), dtype=tf.float32)
    state = (stack, top)
    stack = safe_push(state, S, is_phi)
    
    output_top = 0
    output_buffer = tf.constant(tf.zeros(stack_shape), dtype=tf.float32)
    output = (output_buffer, output_top)

    while top > 0:
        step = production_step(grammar, productions, state, output, is_phi)
        productions, state, output = step
        _, top = state
        
    return output