In [1]:
import tensorflow as tf

In [2]:
from library.stacks import stack_push, stack_pop

In [3]:
# @tf.function
def is_phi(element):
    return tf.argmax(element) == 0

In [4]:
# @tf.function
def safe_push(state, element, is_phi):
    if is_phi(element):
        return state
    else:
        return stack_push(state, element)

In [5]:
from library.array_ops import tensor_lookup_2d

In [6]:
# @tf.function
def production_step(grammar, productions, state, output, is_phi):
    G_s, G_o = grammar
    
    next_nt = stack_pop(state)
    next_p = stack_pop(productions)

    next_state = tensor_lookup_2d(G_s, next_nt, next_p)
    next_output = tensor_lookup_2d(G_o, next_nt, next_p)

    stack = safe_push(state, next_state, is_phi)
    output = safe_push(output, next_output, is_phi)
    
    return productions, state, output

In [7]:
# @tf.function
def generate(grammar, productions, stack_shape, S, is_phi):
    top = 0
    stack = tf.constant(tf.zeros(stack_shape), dtype=tf.float32)
    state = (stack, top)
    stack = safe_push(state, S, is_phi)
    
    output_top = 0
    output_buffer = tf.constant(tf.zeros(stack_shape), dtype=tf.float32)
    output = (output_buffer, output_top)

    while top > 0:
        step = production_step(grammar, productions, state, output, is_phi)
        productions, state, output = step
        _, top = state
        
    return output