In [1]:
import tensorflow as tf
import numpy as np

# Differentiable Abstract Syntax Trees (ASTs)

This notebook demonstrates how to implement differentiable Abstract Syntax Trees using TensorFlow. We'll explore how to make tree operations differentiable by using soft selections and gradient-friendly stack operations.

## Overview

Traditional ASTs are discrete structures that don't allow gradient flow. By replacing discrete operations with continuous approximations, we can train neural networks to generate and manipulate syntactic structures end-to-end.

## Initial Setup

First, let's import the necessary libraries:

## Phi Function - The "Empty" Detector

The `is_phi` function determines how "close" an element is to the empty symbol (φ). This is crucial for differentiable operations as it allows us to softly decide whether to perform operations based on how "empty" an element is.

- When an element exactly matches the first basis vector [1,0,0,...], it returns 1 (indicating it's the empty symbol)
- For other elements, it returns the dot product after L2 normalization
- The gradient flows through the normalization, allowing the network to learn when elements should be treated as empty

## Stack Operations

Now let's import our differentiable stack operations. These allow us to maintain stack-like data structures while preserving gradient flow:

## Safe Push Operation

The `safe_push` function implements a differentiable stack push operation. Instead of discretely deciding whether to push or not, it uses the `is_phi` function to blend between the old and new stack states:

- `t = is_phi_fn(element)`: Determines how "empty" the element is
- If `t` is close to 1 (element is empty), the stack remains unchanged
- If `t` is close to 0 (element is not empty), the element gets pushed
- The blending `t * old + (1-t) * new` allows gradients to flow through both paths

This is a key insight: instead of discrete conditional operations, we use weighted combinations that preserve differentiability.

## Pop and Purge Operation

The `pop_and_purge` function demonstrates a more complex stack operation that:
1. Pops an element from the stack
2. Pushes a phi (empty) element
3. Immediately pops the phi element

This creates a "clean" pop operation that maintains stack structure while removing unwanted elements.

In [2]:
@tf.function
def is_phi(element):
    tf.debugging.assert_rank(element, 1)
    
    elem_dim = tf.shape(element)[0]
    phi = tf.one_hot(0, elem_dim)
    
    element = tf.math.l2_normalize(element)
    t = tf.tensordot(element, phi, axes=1)

    return t

test1 = tf.Variable([1,0,0], dtype=tf.float32)
test2 = tf.Variable([0,1,0], dtype=tf.float32)
test3 = tf.Variable([.5,.5,0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result1 = is_phi(test1)
    result2 = is_phi(test2)
    result3 = is_phi(test3)

tf.print(result1, tape.gradient(result1, test1))
tf.print(result2, tape.gradient(result2, test2))
tf.print(result3, tape.gradient(result3, test3))

1 [0 0 0]
0 [1 0 0]
0.707106769 [0.707106829 -0.707106709 0]


## Tensor Lookup Operations

We'll need 2D tensor lookup functionality for our grammar operations:

## Grammar Definition

Here we define a simple context-free grammar for generating expressions. The grammar has two parts:

- **G_s**: Productions for the stack (what gets pushed back onto the stack)  
- **G_o**: Productions for the output (what gets written to the output stream)

### Token Definitions:
- **PHI** (φ): Empty symbol  
- **S**: Start symbol
- **O**: Open expression
- **T**: Terminal symbol  
- **X**: Variable (x)
- **PLUS**: Addition operator (+)

### Grammar Rules:
The grammar matrices encode production rules where `G_s[production][token]` gives the tokens to push onto the stack, and `G_o[production][token]` gives the tokens to output.

For example, when we see 'S' and apply production 2, we push `[S, O, T]` back onto the stack.

## Pretty Printing Utility

A helper function to convert one-hot encoded tokens back to readable symbols for debugging:

## Production Step

The `production_step` function implements one step of grammar-based generation:

1. **Pop**: Remove the top token from the stack
2. **Lookup**: Use the current production and popped token to look up what to do in the grammar
3. **Push to Stack**: Push the tokens from `G_s` back onto the stack (in reverse order)  
4. **Push to Output**: Push the tokens from `G_o` to the output stream

This simulates how a pushdown automaton processes a context-free grammar, but in a differentiable way.

In [3]:
from library.stacks import stack_push, stack_pop, stack_peek, new_stack, new_stack_from_buffer

## Debugging Utility

A helper function to print detailed information about each production step, showing:
- Which production rule was applied
- What tokens were looked up in the grammar
- Current state of stack and output

## Generate Function

The `generate` function orchestrates the complete generation process:

1. **Initialize**: Create empty stack and output buffers
2. **Start**: Push the start symbol 'S' onto the stack  
3. **Process**: Apply each production rule in sequence
4. **Return**: The generated output sequence and final stack state

This function effectively implements a differentiable parser/generator that can produce structured output following the grammar rules.

## Example Generation

Let's see the grammar in action! We'll apply a sequence of production rules `[2, 3, 0, 1, 0]` starting from the symbol 'S' and observe how it generates an expression step by step.

The debug output shows:
- Which production rule is applied at each step
- What gets pushed onto the stack (G_s lookups)  
- What gets added to the output (G_o lookups)
- Current stack and output states

## Token Encoding Utility

A helper function to convert string expressions into one-hot encoded token sequences. This is useful for creating training targets and comparing generated outputs.

## Soft Generation with Loss

Here we demonstrate the fully differentiable version:

1. **Soften Everything**: Apply softmax to grammar rules, production sequences, and symbols
2. **Generate**: Run the soft generation process  
3. **Compute Loss**: Compare generated output with expected target using cross-entropy loss
4. **Gradients**: Show that gradients flow back through the entire generation process

This is the key insight: by making all discrete operations continuous, we can train end-to-end using standard gradient descent!

## Training Step

Now we implement a complete training step using Adam optimizer. The training process:

1. **Forward Pass**: Generate output using current production sequence
2. **Loss Calculation**: Compare with target output and desired final stack state  
3. **Backward Pass**: Compute gradients with respect to production parameters
4. **Update**: Apply gradients using Adam optimizer

Notice we can train the production sequence itself - the network learns which grammar rules to apply!

## Training Loop

Finally, let's train the model to generate the target expression "x + x". We start with an initial production sequence and let the optimizer find the correct sequence of grammar rules.

The output shows the loss decreasing and the generated output converging to the target expression!

In [41]:
# @tf.function
def safe_push(stack, element, is_phi_fn):
    tf.debugging.assert_rank_at_least(stack[0], 2)
    tf.debugging.assert_rank(stack[1], 1)
    tf.debugging.assert_equal(tf.shape(stack[0])[1:], tf.shape(element))
    tf.debugging.assert_equal(tf.rank(stack[0]) - 1, tf.rank(element) )
    
    t = is_phi_fn(element)
    
    old_buffer, old_index = stack
    new_buffer, new_index = stack_push(stack, element)

    buffer = t * old_buffer + (1 - t) * new_buffer
    index = t * old_index + (1 - t) * new_index
    
    tf.print(tokens_pretty_print(old_buffer))
    tf.print(tokens_pretty_print(new_buffer))
    tf.print(tokens_pretty_print(buffer))
    tf.print('-'*80)

    # Hack to tell tensorflow that the shape has not changed
    # TODO: Why does this hack work?
    buffer = tf.reshape(buffer, tf.shape(old_buffer))
    index = tf.reshape(index, tf.shape(old_index))

    new_stack = (buffer, index)

    return new_stack

stack = new_stack((3,3), True)
original_stack = stack

element1 = tf.Variable([0,1,0], dtype=tf.float32)
element2 = tf.Variable([0.5,0.5,0], dtype=tf.float32)
element3 = tf.Variable([0,0,1], dtype=tf.float32)
element4 = tf.Variable([0,1,0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    stack = safe_push(stack, element1, is_phi)
    stack = safe_push(stack, element2, is_phi)
    stack = safe_push(stack, element3, is_phi)
    stack = safe_push(stack, element4, is_phi)
    
tf.print(stack[0])
tf.print(tf.round(stack[0]))
tf.print(stack[1])
tf.print(tf.round(stack[1]))
tf.print(tape.gradient(stack[0], element3))
tf.print(tape.gradient(stack, original_stack))

_ _ _ 
S _ _ 
S _ _ 
--------------------------------------------------------------------------------
S _ _ 
S _ _ 
S _ _ 
--------------------------------------------------------------------------------
S _ _ 
S O O 
S O O 
--------------------------------------------------------------------------------
S O O 
S O S 
S O S 
--------------------------------------------------------------------------------
[[0 1 0]
 [0.0428932235 0.0428932235 0.707106769]
 [0 0.707106769 0.0857864469]]
[[0 1 0]
 [0 0 1]
 [0 1 0]]
[0.707106769 0.292893231 0]
[1 0 0]
[0.0606601834 0.792893231 0.792893231]
([[0 0 0]
 [0.207106784 0.207106784 0.207106784]
 [0.207106799 0.207106799 0.207106799]], [2.87867975 1.53553391 1.76776707])


In [5]:
@tf.function
def pop_and_purge(stack, phi):
    stack_len = tf.shape(stack[0])[1]
    stack, element = stack_pop(stack)
    stack = stack_push(stack, phi)
    stack, _ = stack_pop(stack)
    
    return stack, element

stack = new_stack_from_buffer(tf.ones((3,3), dtype=tf.float32))
phi = tf.one_hot(0, 3, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    stack, element = pop_and_purge(stack, phi)
    
tf.print(stack)
tf.print(element)

([[1 1 1]
 [1 1 1]
 [1 0 0]], [0 0 1])
[1 1 1]


In [6]:
stack = new_stack((3,3), True)

element1 = tf.Variable([0,1,0], dtype=tf.float32)
element2 = tf.Variable([1,0,0], dtype=tf.float32)
element3 = tf.Variable([0,0,1], dtype=tf.float32)
element4 = tf.Variable([0,1,0], dtype=tf.float32)

original_stack = stack

with tf.GradientTape(persistent=True) as tape:
    stack = safe_push(stack, element1, is_phi)
    stack = safe_push(stack, element2, is_phi)
    stack = safe_push(stack, element3, is_phi)
    stack = safe_push(stack, element4, is_phi)
    
tf.print(stack[0])
tf.print(stack[1])
tf.print(tape.gradient(stack[0], element3))
tf.print(tape.gradient(stack, original_stack))

[[0 1 0]
 [0 0 1]
 [0 1 0]]
[1 0 0]
[-1 1 1]
([[0 0 0]
 [0 0 0]
 [0 0 0]], [4 1 1])


In [7]:
from library.array_ops import tensor_lookup_2d

In [8]:
TOKEN_DIM = 6
PRODUCTION_DIM = 4
STACK_SIZE = 10
PHI = np.eye(TOKEN_DIM)[0]
S = np.eye(TOKEN_DIM)[1]
O = np.eye(TOKEN_DIM)[2]
T = np.eye(TOKEN_DIM)[3]
X = np.eye(TOKEN_DIM)[4]
PLUS = np.eye(TOKEN_DIM)[5]

E = [PHI, PHI, PHI]

G_s = tf.constant([
    [E, E, E, E, E, E],
    [E, E, E, E, E, E],
    [E, [S, O, T], E, E, E, E],
    [E, [T, PHI, PHI], E, E, E, E],
], dtype=tf.float32)
G_o = tf.constant([
    [E, E, E, [X, PHI, PHI], E, E],
    [E, E, [PLUS, PHI, PHI], E, E, E],
    [E, E, E, E, E, E],
    [E, E, E, E, E, E],
], dtype=tf.float32)
grammar = (G_s, G_o)

In [9]:
def tokens_pretty_print(tokens):
    tokens = tf.argmax(tokens, axis=1)
    lookup = ['_', 'S', 'O', 'T', 'x', '+']
    
    result = ''
    
    for token in tokens:
        result += f'{lookup[token]} '
        
    return result

tokens = tf.transpose(tf.one_hot([0,1,2,3,4,5], TOKEN_DIM, dtype=tf.float32))
tokens_pretty_print(tokens)

'_ S O T x + '

In [30]:
@tf.function
def production_step(grammar, production, stack, output, phi, is_phi_fn):
    tf.debugging.assert_rank(grammar[0], 4)
    tf.debugging.assert_rank(grammar[1], 4)
    tf.debugging.assert_rank(production, 1)
    tf.debugging.assert_rank(stack[0], 2)
    tf.debugging.assert_rank(output[0], 2)
    
    G_s, G_o = grammar
    
    # Save the shapes
    stack_0_shape = tf.shape(stack[0])
    stack_1_shape = tf.shape(stack[1])
    output_0_shape = tf.shape(output[0])
    output_1_shape = tf.shape(output[1])
    
    # Get next token from stack
    stack, stack_top_token = pop_and_purge(stack, phi)

    # Push tokens back onto the stack
    tokens_to_push = tensor_lookup_2d(G_s, production, stack_top_token)
    for token in tf.reverse(tokens_to_push, axis=[0]):
        stack = safe_push(stack, token, is_phi_fn)
    
    # Push tokens to output
    tokens_to_push = tensor_lookup_2d(G_o, production, stack_top_token)
    for token in tokens_to_push:
        output = safe_push(output, token, is_phi_fn)
    
    return stack, output

stack = new_stack(((STACK_SIZE, TOKEN_DIM)))
output = new_stack(((STACK_SIZE, TOKEN_DIM)))

stack = safe_push(stack, tf.constant(S, dtype=tf.float32), is_phi)
production = tf.one_hot(2, PRODUCTION_DIM)
phi = tf.one_hot(0, TOKEN_DIM, dtype=tf.float32)

with tf.GradientTape(persistent = True) as tape:
    tape.watch(grammar)
    tape.watch(production)
    tape.watch(stack)
    tape.watch(output)
    
    new_s, new_o = production_step(grammar, production, stack, output, phi, is_phi)

tf.print(tokens_pretty_print(new_s[0]))
# tf.print(tape.gradient(new_o, output))
# tf.print(tape.gradient(new_s, stack))
# tf.print(tape.gradient(new_s[0], grammar[0]).shape)
# tf.print(tape.gradient(new_s[1], grammar[0]).shape)
# tf.print(tape.gradient(new_o[0], grammar[1]).shape)
# tf.print(tape.gradient(new_o[1], grammar[1]).shape)
tf.print(tape.gradient(new_s, production))

T O S _ _ _ _ _ _ _ 
[-0.0833333358 -0.0833333358 3 -0.0416666679]


In [42]:
stack_shape = (STACK_SIZE, TOKEN_DIM)

phi = tf.one_hot([0], TOKEN_DIM, dtype=tf.float32)
soft_phi = tf.nn.softmax(phi, axis=-1)
stack_buffer = tf.tile(soft_phi, (stack_shape[0], 1))
soft_stack = new_stack_from_buffer(stack_buffer)
output_buffer = tf.tile(soft_phi, (stack_shape[0], 1))
soft_output = new_stack_from_buffer(output_buffer)
soft_s = tf.nn.softmax(tf.constant(S, dtype=tf.float32))

soft_stack = safe_push(soft_stack, soft_s, is_phi)
soft_p = tf.nn.softmax(tf.one_hot(2, PRODUCTION_DIM))

tf.config.experimental_run_functions_eagerly(True)
with tf.GradientTape(persistent = True) as tape:
    tape.watch(grammar)
    tape.watch(production)
    tape.watch(stack)
    tape.watch(output)
    
    # Soften the grammar
    gs, go = grammar
    sgs, sgo = tf.nn.softmax(gs), tf.nn.softmax(go)
    soft_g = (sgs, sgo)

    new_s, new_o = production_step(soft_g, soft_p, soft_stack, soft_output, soft_phi[0], is_phi)
# tf.config.experimental_run_functions_eagerly(False)

tf.print(tokens_pretty_print(soft_stack[0]))
#here
tf.print(tokens_pretty_print(new_s[0]))
# tf.print(tape.gradient(new_o, output))
# tf.print(tape.gradient(new_s, stack))
# tf.print(tape.gradient(new_s[0], grammar[0]).shape)
# tf.print(tape.gradient(new_s[1], grammar[0]).shape)
# tf.print(tape.gradient(new_o[0], grammar[1]).shape)
# tf.print(tape.gradient(new_o[1], grammar[1]).shape)
tf.print(tape.gradient(new_s, production))

_ _ _ _ _ _ _ _ _ _ 
S _ _ _ _ _ _ _ _ _ 
S _ _ _ _ _ _ _ _ _ 
--------------------------------------------------------------------------------
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
--------------------------------------------------------------------------------
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
--------------------------------------------------------------------------------
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
--------------------------------------------------------------------------------
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
--------------------------------------------------------------------------------
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
--------------------------------------------------------------------------------
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ 
-------------------------------------------------------------------------

In [None]:
def dump_step_info(grammar, production, stack, output):
    gs, go = grammar
    top = stack_peek(stack)
    tf.print('p\t', tf.argmax(production))
    i = tf.argmax(production)
    j = tf.argmax(top)
    tf.print('G_s\t', tokens_pretty_print(gs[i][j]), (i,j))
    tf.print('G_o\t', tokens_pretty_print(go[i][j]), (i,j))
    tf.print('S_i+1\t', tokens_pretty_print(stack[0]), tf.argmax(stack[1]))
    tf.print('O_i+1\t', tokens_pretty_print(output[0]), tf.argmax(output[1]))
    tf.print('-'*80)

In [None]:
@tf.function
def generate(grammar, productions, stack_shape, S, phi, is_phi_fn, print_steps=False):
    # Reserve space for stack and output
    stack_buffer = tf.tile(phi, (stack_shape[0], 1))
    stack = new_stack_from_buffer(stack_buffer)
    output_buffer = tf.tile(phi, (stack_shape[0], 1))
    output = new_stack_from_buffer(output_buffer)
    
    # Push S to top of stack
    stack = safe_push(stack, S, is_phi)
    
    productions = tf.unstack(productions)

    for production in productions:
        stack, output = production_step(grammar, production, stack, output, phi[0], is_phi_fn)
        
        if print_steps:
            dump_step_info(grammar, production, stack, output)
        
    return tf.reverse(output[0], axis=[0]), stack[0]

In [None]:
tf.config.experimental_run_functions_eagerly(True)
productions = tf.one_hot([2, 3, 0, 1, 0], PRODUCTION_DIM)

stack_shape = (STACK_SIZE, TOKEN_DIM)
d_S = tf.constant(S, dtype=tf.float32)
d_phi = tf.constant(tf.one_hot([0], TOKEN_DIM))

with tf.GradientTape(persistent = True) as tape:
    tape.watch(productions)
    output, final_stack = generate(grammar, productions, stack_shape, d_S, d_phi, is_phi, True)

tf.config.experimental_run_functions_eagerly(False)
tf.print('Final result:')
tf.print(tokens_pretty_print(output))
tf.print('-'*80)
tf.print('Final stack:')
tf.print(tokens_pretty_print(final_stack))
tf.print('-'*80)
tf.print(tape.gradient(output, productions))

In [None]:
def encode_to_tokens(s, token_dim, total_length):
    lookup = ['_', 'S', 'O', 'T', 'x', '+']
    arr = []
    for t in s.split(' '):
        arr.append(lookup.index(t))
    
    phi = lookup.index('_')
    arr = ([phi] * (total_length - len(arr))) + arr
        
    return tf.one_hot(arr, token_dim)

encode_to_tokens('x + x +', TOKEN_DIM, 5)

In [None]:
tf.config.experimental_run_functions_eagerly(True)
productions = tf.one_hot([2, 3, 0, 1, 0], PRODUCTION_DIM)

stack_shape = (STACK_SIZE, TOKEN_DIM)
d_S = tf.constant(S, dtype=tf.float32)
d_phi = tf.constant(tf.one_hot([0], TOKEN_DIM))
expected_output = encode_to_tokens('x + x +', TOKEN_DIM, STACK_SIZE)
zero_stack = tf.one_hot([0] * stack_shape[0], stack_shape[1], dtype=tf.float32)

with tf.GradientTape(persistent = True) as tape:
    tape.watch(productions)
    # Soften the grammar
    gs, go = grammar
    sgs, sgo = tf.nn.softmax(gs), tf.nn.softmax(go)
    soft_g = (sgs, sgo)

    # Soften the productions
    soft_p = tf.nn.softmax(productions,axis=-1)

    # Soften S
    soft_s = tf.nn.softmax(d_S)

    soft_phi = tf.nn.softmax(d_phi, axis=-1)
    
    output_, stack_ = generate(soft_g, soft_p, stack_shape, soft_s, soft_phi, is_phi, True)

    loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(output, output_))
    loss += tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(zero_stack, stack_))

tf.config.experimental_run_functions_eagerly(False)
tf.print('Final result:')
tf.print(tokens_pretty_print(output_))
tf.print('-'*80)
tf.print('Final stack:')
tf.print(tokens_pretty_print(stack_))
tf.print('-'*80)
tf.print(loss)
tf.print(tape.gradient(loss, productions))

In [None]:
opt = tf.keras.optimizers.Adam(1e-1)

@tf.function
def train_step(grammar, productions, stack_shape, S, is_phi_fn, output, print_steps=False):
    zero_stack = tf.one_hot([0] * stack_shape[0], stack_shape[1], dtype=tf.float32)
    phi = tf.one_hot([0], stack_shape[1], dtype=tf.float32)
    
    with tf.GradientTape() as tape:
        tape.watch(productions)
#         # Soften the grammar
#         gs, go = grammar
#         sgs, sgo = tf.nn.softmax(gs), tf.nn.softmax(go)
#         soft_g = (sgs, sgo)
        
#         # Soften the productions
#         soft_p = tf.nn.softmax(productions,axis=-1)
        
#         # Soften S
#         soft_s = tf.nn.softmax(S)
        
#         soft_phi = tf.nn.softmax(phi, axis=-1)
        
#         output_, stack_ = generate(soft_g, soft_p, stack_shape, soft_s, soft_phi, is_phi_fn, True)
        
#         loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(output, output_))
#         loss += tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(zero_stack, stack_))
        
        output_, stack_ = generate(grammar, productions, stack_shape, S, phi, is_phi_fn, print_steps)
        loss = tf.nn.l2_loss(output - output_) + tf.nn.l2_loss(zero_stack - stack_)
        
    grads = tape.gradient(loss, productions)
    opt.apply_gradients(zip([grads], [productions]))
    
    return loss, output_, stack_

MAX_PRODUCTIONS = 5
# productions = tf.Variable(tf.one_hot([0] * MAX_PRODUCTIONS, PRODUCTION_DIM), dtype=tf.float32)
productions = tf.Variable(tf.one_hot([2, 3, 0, 1, 0], PRODUCTION_DIM, dtype=tf.float32))
# productions = tf.Variable(tf.one_hot([2, 0, 0, 0, 0], PRODUCTION_DIM), dtype=tf.float32)
stack_shape = (STACK_SIZE, TOKEN_DIM)
d_S = tf.constant(S, dtype=tf.float32)
output = encode_to_tokens('x + x', TOKEN_DIM, STACK_SIZE)

tf.config.experimental_run_functions_eagerly(True)
loss, output_, stack_ = train_step(grammar, productions, stack_shape, d_S, is_phi, output, True)
tf.config.experimental_run_functions_eagerly(False)
tf.print(loss)
tf.print(tokens_pretty_print(output_), tokens_pretty_print(output))
tf.print(tokens_pretty_print(stack_))

In [None]:
# productions = tf.Variable(tf.one_hot([2, 3, 0, 1, 0], PRODUCTION_DIM), dtype=tf.float32)
productions = tf.Variable(tf.one_hot([2, 3, 0, 0, 0], PRODUCTION_DIM), dtype=tf.float32)
output = encode_to_tokens('x + x', TOKEN_DIM, STACK_SIZE)

for var in opt.variables():
    var.assign(tf.zeros_like(var))

for i in range(100):
    loss, output_, stack_ = train_step(grammar, productions, stack_shape, d_S, is_phi, output)
    if i % 10 == 0:
        p_output = tokens_pretty_print(output_)
        p_stack = tokens_pretty_print(stack_)
        
        tf.print(loss, p_output, p_stack, tf.argmax(productions, axis=-1))
