  https://openreview.net/pdf?id=SkZxCk-0Z


In [1]:
import parser

import tensorflow as tf
tf.enable_eager_execution()

In [3]:
def read_data(fname):
  """
  Reads the data files.
  """
  with open(fname, 'r') as f:
    data = f.read()
  data = data.split('\n')
  new_data = []
  for d in data[:-1]:
#     print('\r {}'.format(d), end='', flush=True)
    a, b, e, _, _, _ = tuple(d.split(','))
    new_data.append([a, b, int(e)])
  return new_data

def batch_data(data, batch_size):
  n = len(data)
  data = list(zip(*data[0:-1]))  # transpose the data
  for i in range(n//batch_size-1):
    A = data[0][i*batch_size:(i+1)*batch_size]
    B = data[1][i*batch_size:(i+1)*batch_size]
    E = data[2][i*batch_size:(i+1)*batch_size]
    yield A, B, E

In [5]:
data = read_data('../logical_entailment_dataset/data/train.txt')
A, B, E = next(batch_data(data, 10))
A, B, E

(('((m>m)&(((m>m)>(m>m))&((m>m)>(m|m))))',
  '~((((m&m)|(m&m))&(m&m)))',
  '((m>m)&(((m>m)>(m>m))&((m>m)>(m|m))))',
  '~((((m&m)|(m&m))&(m&m)))',
  '~((~(m)&~((m&m))))',
  '~((m|m))',
  '~((~(m)&~((m&m))))',
  '~((m|m))',
  '(m|(m|m))',
  '~((m|(m&(m|m))))'),
 ('~(~((((m|m)&~(m))|m)))',
  '~(((((m>m)&m)&m)|(m|m)))',
  '~(((((m>m)&m)&m)|(m|m)))',
  '~(~((((m|m)&~(m))|m)))',
  '((((m|m)&m)|m)|(~(~((((m|m)>m)&m)))|m))',
  '((m|((m>m)>m))>((m&m)&~((m&(m&(m>m))))))',
  '((m|((m>m)>m))>((m&m)&~((m&(m&(m>m))))))',
  '((((m|m)&m)|m)|(~(~((((m|m)>m)&m)))|m))',
  '((m>m)&m)',
  '~((m&m))'),
 (1, 1, 0, 0, 1, 1, 0, 0, 1, 1))

In [5]:
class Parser():
  def __init__(self, language):
    self.language = language
    self.parser = parser.Parser(language)
    self.vocabulary = {op: i for i, op in enumerate(language.symbols)}
    
  def __call__(self, s):
    parse_result = self.parser.parse(s)
    ops = [self.vocabulary[op.decode("utf-8")] for op in parse_result.ops]
    return ops, parse_result.inputs

In [6]:
prop_parser = Parser(parser.propositional_language())

tree = prop_parser('((m>m)&(((m>m)>(m>m))&((m>m)>(m|m))))')
tree

([18, 18, 5, 18, 18, 5, 18, 18, 5, 5, 18, 18, 5, 18, 18, 3, 5, 2, 2],
 [[],
  [],
  [-2, -1],
  [],
  [],
  [-2, -1],
  [],
  [],
  [-2, -1],
  [-4, -1],
  [],
  [],
  [-2, -1],
  [],
  [],
  [-2, -1],
  [-4, -1],
  [-8, -1],
  [-16, -1]])

In [7]:
d_world = 30
n_worlds=24
d_embed = 50
batch_size = 10
n_ops = len(prop_parser.language.symbols)

In [8]:
n_ops

32

In [None]:
class Sat3Cell():
  """
  Real valued evaluation of satisfiability.
  Given a real valued truth assignment, aka the world you are in,
  check if it satisfies the given equation.
  """

  def __init__(self, d_world, n_ops, d_embed):
    num_units = d_embed
    self.op_embeddings = tf.get_variable(shape=(n_ops, d_world, num_units), dtype=tf.float32, name='operation_embeddings')
    self.W4 = tf.get_variable(shape=(n_ops, 2*d_embed, num_units), dtype=tf.float32, name='W4')
    self.b4 = tf.get_variable(shape=(n_ops, num_units), dtype=tf.float32, name='b4')


  def __call__(self, w, op, l=None, r=None, scope=None):
    """    
    Args:
      w (tf.tensor): [1, d_world]
      TODO op (list): 
    """
    
    # TODO change so __call__ can recieve a batch.
    # then bundle all embed/matmul calls
    # but op will be varing length. need to stack them!?

    
    with tf.variable_scope(scope or type(self).__name__):
      # nullary ops      
      if l is None and r is None:
        # look up their embeddings
        h = tf.matmul(w, self.op_embeddings[op])
        
      else:
        # unary and binary ops
        if l is not None and r is None:
          r = tf.zeros_like(l)  # just fake it
        
        x = tf.concat([l, r], axis=1)
        h = tf.matmul(x, self.W4[op]) + self.b4[op]
      
      return tf.nn.l2_normalize(h, axis=1)

In [11]:
sat3 = Sat3Cell(d_world, n_ops, d_embed)
w = tf.random_normal([1, d_world])
l = tf.random_normal((1, d_embed))
r = tf.random_normal((1, d_embed))

h = sat3(w, 0)
print(h.shape)

h = sat3(w, 12, l, r)
print(h.shape)

(1, 50)
(1, 50)


In [None]:
class TreeNN():
  def __init__(self, cell, parser):
    self.cell = cell
    self.parser = parser 
    # !? what about learning to parse the inputs into a tree!?
    
  def __call__(self, w, s):
    """
    Because each parse will be different!?
    
    Args:
      w: a world
      s: a string
    
    Returns: (1, n)
    """
    # NOTE Can only handle a single element of a batch at a time
    
    tree = self.parser(s)
    return self.apply(tree, [])
    
    
  def apply(self, tree, results, i=0):
    """
    Applies self.cell in a recursive manner.
    
    Args:
      tree (tuple): (ops, args)
        ops (list): nodes in depth first order
        args (list): the children of ops in depth first order
    """
    ops, args = tree
    
    # if the current node has children, fetch them from results
    l = None
    r = None
    if len(args[0]) == 1:
      l = results[i+args[0][0]]
    elif len(args[0]) == 2:
      l = results[i+args[0][0]]
      r = results[i+args[0][1]]
    
    if len(tree[1]) == 1:
      return self.cell(w, ops[i], l, r)
    else:
      results.append(self.cell(w, ops[i], l, r))
      
      tree = (ops, args[1:])
      return self.apply(tree, results, i+1)

In [13]:
treenn = TreeNN(Sat3Cell(d_world, n_ops, d_embed), prop_parser)
treenn(w, '((m>m)&(((m>m)>(m>m))&((m>m)>(m|m))))')

<tf.Tensor: id=395, shape=(1, 50), dtype=float32, numpy=
array([[-0.03000259,  0.12397584,  0.07279226, -0.2090192 ,  0.12104005,
        -0.18710051, -0.10114643,  0.08847106,  0.23247388,  0.01880257,
         0.16483982, -0.03280551, -0.14048365, -0.00199388,  0.07384284,
        -0.06964146,  0.0094379 , -0.1465935 ,  0.21219827, -0.13372083,
        -0.12857585,  0.06898802,  0.21416542,  0.08968278, -0.18789943,
        -0.10467124, -0.02284116, -0.1230051 ,  0.0354168 , -0.05308964,
         0.16737577, -0.01140991,  0.19947988, -0.04662628, -0.22023317,
        -0.15620331, -0.20942825,  0.24156609, -0.17697154,  0.04958358,
        -0.24008581, -0.09461536, -0.11807208, -0.199788  , -0.04693555,
         0.01570038, -0.21335697, -0.17780055, -0.20999888, -0.07432975]],
      dtype=float32)>

In [None]:
class PossibleWorlds():
  """
  A NN designed specifically for predicting entailment.
  """
  def __init__(self, encoder, num_units, n_worlds, d_world):
    self.encoder = encoder
    self.n_worlds = n_worlds
    self.worlds = tf.get_variable(shape=(n_worlds, d_world), dtype=tf.float32, name='worlds')
    
    self.dense = tf.keras.layers.Dense(num_units)
  
  
  def inner(self, a, b):
    """
    Convolve over possible worlds.
    For each random direction, do !??!
    
    """
    p = tf.constant(1.0, dtype=tf.float32) 
    for i in range(self.n_worlds):
      
      x = tf.concat([self.encoder(self.worlds[i:i+1], a), 
                     self.encoder(self.worlds[i:i+1], b)], axis=1)
      p *= self.dense(x)  # in the paper this isnt actually a dense layer....
      
    return p
  
  def __call__(self, A, B):
    """
    For each element of a batch.
    """
    return tf.concat([self.inner(a, b) for a, b in zip(A, B)], axis=0)   

In [None]:
possibleworldsnet = PossibleWorlds(
    encoder=TreeNN(Sat3Cell(d_world, n_ops, d_embed), prop_parser),
    num_units=1,
    n_worlds=n_worlds, 
    d_world=d_world
)

variables = (possibleworldsnet.dense.variables + 
             [possibleworldsnet.encoder.cell.b4,
             possibleworldsnet.encoder.cell.op_embeddings, 
             possibleworldsnet.encoder.cell.W4])

opt = tf.train.AdamOptimizer()

In [None]:
def gradients(A, B, E):

  with tf.GradientTape() as tape:
    y = possibleworldsnet(A, B)
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.constant(E, dtype=tf.float32, shape=(batch_size, 1)),
                                          logits=y)
    
    step = tf.train.get_or_create_global_step().numpy()
    print('\rstep: {} loss {}'.format(step, tf.reduce_mean(loss)), end='', flush=True)
    
  return tape.gradient(loss, variables)

In [None]:
for A, B, E in batch_data(data, batch_size):
  gnvs = zip(gradients(A, B, E), variables)
  opt.apply_gradients(gnvs, global_step=tf.train.get_or_create_global_step())

step: 67 loss 0.6931471824645996

Argh. It is soo slow... A problem for another day.