## Imports

In [1]:
import tensorflow as tf
import jax
import jax.numpy as jnp
from flax import linen as nn
from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_and, hard_not, primitives
from tests import test_mnist
tf.config.experimental.set_visible_devices([], "GPU")

2022-12-07 08:56:49.730218: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-07 08:57:02.487205: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-07 08:57:02.487378: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
import ml_collections

In [3]:
from jax._src.util import safe_map

In [4]:
from jax import core

In [5]:
from jax.interpreters import xla

# Sandpit

In [6]:
def examine_jaxpr(closed_jaxpr):
   jaxpr = closed_jaxpr.jaxpr
   print("invars:", jaxpr.invars)
   print("outvars:", jaxpr.outvars)
   print("constvars:", jaxpr.constvars)
   for eqn in jaxpr.eqns:
     print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
   print()
   print("jaxpr:", jaxpr)

In [7]:
train_ds, test_ds = test_mnist.get_datasets()
train_ds["image"] = jnp.reshape(train_ds["image"], (train_ds["image"].shape[0], -1))
test_ds["image"] = jnp.reshape(test_ds["image"], (test_ds["image"].shape[0], -1))

In [19]:
def nln(type, x, width):
    x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jnp.float32)(x) 
    #x = hard_not.not_layer(type)(10, dtype=jnp.float32)(x)
    #x = primitives.nl_ravel(type)(x) 
    #x = harden_layer.harden_layer(type)(x) 
    #x = primitives.nl_reshape(type)((10, width))(x) 
    #x = primitives.nl_sum(type)(-1)(x) 
    return x

def batch_nln(type, x, width):
    return jax.vmap(lambda x: nln(type, x, width))(x)

In [20]:
width = 10
soft, hard, _ = neural_logic_net.net(lambda type, x: nln(type, x, width))

In [21]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
mock_input = harden.harden(jnp.ones([28 * 28]))
hard_weights = harden.hard_weights(soft.init(rng, mock_input))
hard_weights

FrozenDict({
    params: {
        HardOrLayer_0: {
            weights: DeviceArray([[ True, False, False, ...,  True,  True,  True],
                         [False, False, False, ...,  True, False, False],
                         [False,  True,  True, ...,  True,  True, False],
                         ...,
                         [False,  True, False, ...,  True, False, False],
                         [False,  True, False, ...,  True,  True, False],
                         [ True, False,  True, ..., False, False, False]], dtype=bool),
        },
    },
})

In [22]:

jaxpr = jax.make_jaxpr(lambda x: hard.apply(hard_weights, x))(harden.harden(test_ds['image'][0]))

In [23]:
examine_jaxpr(jaxpr)

invars: [b]
outvars: [d]
constvars: [a]
equation: [a, b] xla_call [c] {'device': None, 'backend': None, 'name': 'hard_or_include', 'donated_invars': (False, False), 'inline': False, 'keep_unused': False, 'call_jaxpr': { lambda ; a:bool[10,784] b:bool[784]. let
    c:bool[1,784] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 784)] b
    d:bool[10,784] = and c a
  in (d,) }}
equation: [c] reduce_or [d] {'axes': (1,)}

jaxpr: { lambda a:bool[10,784]; b:bool[784]. let
    c:bool[10,784] = xla_call[
      call_jaxpr={ lambda ; d:bool[10,784] e:bool[784]. let
          f:bool[1,784] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 784)
          ] e
          g:bool[10,784] = and f d
        in (g,) }
      name=hard_or_include
    ] a b
    h:bool[10] = reduce_or[axes=(1,)] c
  in (h,) }


In [36]:
import jax._src.lax_reference as lax_reference
import numpy

In [129]:
def symbolic_and(*args, **kwargs):
  if args[0].dtype == bool:
    return numpy.logical_and(*args, **kwargs)
  else:
    return {} # TODO

def symbolic_broadcast_in_dim(*args, **kwargs):
  return lax_reference.broadcast_in_dim(*args, **kwargs)

def symbolic_xor(x, y):
  return f"{x} ^ {y}"

def symbolic_not(x):
  return f"~{x}"

def symbolic_reduce_or(*args, **kwargs):
  return lax_reference.reduce(*args, init_value=False, dimensions=kwargs['axes'], computation=numpy.logical_or)

def symbolic_bind(prim, *args, **params):
  print("primitive: ", prim.name)
  symbolic_outvals = {
    'and': symbolic_and,
    'broadcast_in_dim': symbolic_broadcast_in_dim,
    'xor': symbolic_xor,
    'not': symbolic_not,
    'reshape': lax_reference.reshape,
    'reduce_or': symbolic_reduce_or,
  }[prim.name](*args, **params)
  return symbolic_outvals

def eval_jaxpr(symbolic, jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}
  symbolic_env = {}
  
  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def symbolic_read(var):
    return symbolic_env[var]

  def write(var, val):
    env[var] = val

  def symbolic_write(var, val):
    symbolic_env[var] = val

  # Bind args and consts to environment
  if not symbolic:
    safe_map(write, jaxpr.invars, args)
    safe_map(write, jaxpr.constvars, consts)
  safe_map(symbolic_write, jaxpr.invars, args)
  safe_map(symbolic_write, jaxpr.constvars, consts)

  def eval_jaxpr_impl(jaxpr):
    # Loop through equations and evaluate primitives using `bind`
    for eqn in jaxpr.eqns:
      # Read inputs to equation from environment
      if not symbolic:
        invals = safe_map(read, eqn.invars)  
      symbolic_invals = safe_map(symbolic_read, eqn.invars)
      # `bind` is how a primitive is called
      prim = eqn.primitive
      if type(prim) is jax.core.CallPrimitive:
        call_jaxpr = eqn.params['call_jaxpr']
        if not symbolic:
          safe_map(write, call_jaxpr.invars, map(read, eqn.invars))
        safe_map(symbolic_write, call_jaxpr.invars, map(symbolic_read, eqn.invars))
        eval_jaxpr_impl(call_jaxpr)
        if not symbolic:
          safe_map(write, eqn.outvars, map(read, call_jaxpr.outvars))
        safe_map(symbolic_write, eqn.outvars, map(symbolic_read, call_jaxpr.outvars))
      else:
        if not symbolic:
          outvals = prim.bind(*invals, **eqn.params)
        symbolic_outvals = symbolic_bind(prim, *symbolic_invals, **eqn.params)
        if not symbolic:
          print(f"outvals: {type(outvals)}: {outvals.shape}: {outvals}")
        print(f"symbolic_outvals: {type(symbolic_outvals)}: {symbolic_outvals.shape}: {symbolic_outvals}")
        # Primitives may return multiple outputs or not
        if not prim.multiple_results: 
          if not symbolic:
            outvals = [outvals]
          symbolic_outvals = [symbolic_outvals]
        if not symbolic:
          assert numpy.array_equal(numpy.array(outvals), symbolic_outvals)
        # Write the results of the primitive into the environment
        if not symbolic:
          safe_map(write, eqn.outvars, outvals)
        safe_map(symbolic_write, eqn.outvars, symbolic_outvals)

  # Read the final result of the Jaxpr from the environment
  eval_jaxpr_impl(jaxpr)
  if not symbolic:
    val, symbolic_val = safe_map(read, jaxpr.outvars), safe_map(symbolic_read, jaxpr.outvars)
  else:
    val, symbolic_val = safe_map(symbolic_read, jaxpr.outvars)
  return val, symbolic_val

In [130]:
hard_mock_input = harden.harden(test_ds['image'][0])
hard_output = hard.apply(hard_weights, hard_mock_input)
print("hard_output shape:", hard_output.shape)
print("hard_output:", hard_output)
eval_hard_output, eval_symbolic_output = eval_jaxpr(False, jaxpr.jaxpr, jaxpr.literals, hard_mock_input)
print("eval_hard_output:", eval_hard_output)
print("eval_symbolic_output:", eval_symbolic_output)
assert numpy.array_equal(numpy.array(eval_hard_output), eval_symbolic_output)
print("SUCCESS: jax primitives and symbolic primitives are identical.")
# TODO: why do we need to add extra brackets?
standard_jax_output = [hard.apply(hard_weights, hard_mock_input)]
print("standard_jax_output", standard_jax_output)
print("eval_hard_output", numpy.array(eval_hard_output))
assert jax.numpy.array_equal(numpy.array(eval_hard_output), standard_jax_output)
print("SUCCESS: non-standard evaluation is identical to standard evaluation of jaxpr.")
symbolic_mock_input = numpy.array([f"{x}" for x in hard_mock_input])
_, eval_symbolic_output = eval_jaxpr(True, jaxpr.jaxpr, jaxpr.literals, symbolic_mock_input)
print("eval_symbolic_output (with symobls):", eval_symbolic_output)


hard_output shape: (10,)
hard_output: [ True  True  True  True  True  True  True  True  True  True]
primitive:  broadcast_in_dim
outvals: <class 'jaxlib.xla_extension.DeviceArray'>: (1, 784): [[False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False False False 

UFuncTypeError: ufunc 'logical_and' did not contain a loop with signature matching types (<class 'numpy.dtype[str_]'>, <class 'numpy.dtype[bool_]'>) -> None