## Imports

In [2]:
import tensorflow as tf
import jax
import jax.numpy as jnp
from flax import linen as nn
from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_and, hard_not, primitives
from tests import test_mnist
tf.config.experimental.set_visible_devices([], "GPU")

2022-12-05 13:53:00.557773: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-05 13:53:14.166363: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-05 13:53:14.166550: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [3]:
import ml_collections

In [4]:
from jax._src.util import safe_map

In [5]:
from jax import core

In [6]:
from jax.interpreters import xla

# Sandpit

In [7]:
def examine_jaxpr(closed_jaxpr):
   jaxpr = closed_jaxpr.jaxpr
   print("invars:", jaxpr.invars)
   print("outvars:", jaxpr.outvars)
   print("constvars:", jaxpr.constvars)
   for eqn in jaxpr.eqns:
     print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
   print()
   print("jaxpr:", jaxpr)

In [8]:
train_ds, test_ds = test_mnist.get_datasets()
train_ds["image"] = jnp.reshape(train_ds["image"], (train_ds["image"].shape[0], -1))
test_ds["image"] = jnp.reshape(test_ds["image"], (test_ds["image"].shape[0], -1))

In [9]:
def nln(type, x, width):
    x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jnp.float32)(x) 
    x = hard_not.not_layer(type)(10, dtype=jnp.float32)(x)
    x = primitives.nl_ravel(type)(x) 
    #x = harden_layer.harden_layer(type)(x) 
    #x = primitives.nl_reshape(type)((10, width))(x) 
    #x = primitives.nl_sum(type)(-1)(x) 
    return x

def batch_nln(type, x, width):
    return jax.vmap(lambda x: nln(type, x, width))(x)

In [10]:
width = 10
soft, hard, _ = neural_logic_net.net(lambda type, x: nln(type, x, width))

In [11]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
mock_input = harden.harden(jnp.ones([28 * 28]))
hard_weights = harden.hard_weights(soft.init(rng, mock_input))
hard_weights

FrozenDict({
    params: {
        HardNotLayer_0: {
            weights: DeviceArray([[False,  True,  True,  True,  True,  True, False, False,
                          False, False],
                         [False,  True,  True, False, False, False,  True, False,
                           True, False],
                         [ True,  True, False,  True,  True,  True, False,  True,
                           True, False],
                         [False, False, False, False, False, False,  True,  True,
                          False,  True],
                         [ True,  True, False,  True, False, False, False,  True,
                          False, False],
                         [False, False,  True, False, False, False,  True, False,
                          False, False],
                         [ True,  True,  True,  True, False, False,  True,  True,
                          False, False],
                         [False, False,  True,  True,  True, False, False, Fa

In [15]:

jaxpr = jax.make_jaxpr(lambda x: hard.apply(hard_weights, x))(harden.harden(test_ds['image'][0]))

In [16]:
examine_jaxpr(jaxpr)

invars: [c]
outvars: [g]
constvars: [a, b]
equation: [a, c] xla_call [d] {'device': None, 'backend': None, 'name': 'hard_or_include', 'donated_invars': (False, False), 'inline': False, 'keep_unused': False, 'call_jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:bool[10,784][39m b[35m:bool[784][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:bool[1,784][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 784)] b
    d[35m:bool[10,784][39m = and c a
  [34m[22m[1min [39m[22m[22m(d,) }}
equation: [d] reduce_or [e] {'axes': (1,)}
equation: [b, e] xla_call [f] {'device': None, 'backend': None, 'name': 'hard_not', 'donated_invars': (False, False), 'inline': False, 'keep_unused': False, 'call_jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:bool[10,10][39m b[35m:bool[10][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:bool[1,10][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 10)] b
    d[35m:bool[10,10][39m = xor c a
    e[35m:bool[10,10][39

In [25]:
def symbolic_bind(prim, *args, **params):
  outvals = prim.bind(*args, **params)
  print("prim:", prim.name)
  return outvals

def eval_jaxpr(jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}
  
  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  def eval_jaxpr_impl(jaxpr):
    # Loop through equations and evaluate primitives using `bind`
    for eqn in jaxpr.eqns:
      # Read inputs to equation from environment
      invals = safe_map(read, eqn.invars)  
      # `bind` is how a primitive is called
      prim = eqn.primitive
      #print("prim:", prim)
      #print("type:", type(prim))
      if type(prim) is jax.core.CallPrimitive:
        #print("calling prim:", prim)
        call_jaxpr = eqn.params['call_jaxpr']
        safe_map(write, call_jaxpr.invars, map(read, eqn.invars))
        eval_jaxpr_impl(call_jaxpr)
        safe_map(write, eqn.outvars, map(read, call_jaxpr.outvars))
      else:
        #print("binding prim:", prim)
        # outvals = prim.bind(*invals, **eqn.params)
        outvals = symbolic_bind(prim, *invals, **eqn.params)
        # Primitives may return multiple outputs or not
        if not prim.multiple_results: 
          outvals = [outvals]
        # Write the results of the primitive into the environment
        safe_map(write, eqn.outvars, outvals)

  # Read the final result of the Jaxpr from the environment
  eval_jaxpr_impl(jaxpr)
  return safe_map(read, jaxpr.outvars) 

In [26]:
hard_mock_input = harden.harden(test_ds['image'][0])
hard_output = hard.apply(hard_weights, hard_mock_input)
print("hard_output shape:", hard_output.shape)
print("hard_output:", hard_output)
eval_hard_output = eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, hard_mock_input)
print("eval_hard_output shape:", eval_hard_output[0].shape)
print("eval_hard_output:", eval_hard_output[0])

hard_output shape: (100,)
hard_output: [False  True  True  True  True  True False False False False False  True
  True False False False  True False  True False  True  True False  True
  True  True False  True  True False False False False False False False
  True  True False  True  True  True False  True False False False  True
 False False False False  True False False False  True False False False
  True  True  True  True False False  True  True False False False False
  True  True  True False False False False  True  True  True  True False
  True  True False  True False False False  True False False False  True
 False False  True False]
eval_hard_output shape: (100,)
eval_hard_output: [False  True  True  True  True  True False False False False False  True
  True False False False  True False  True False  True  True False  True
  True  True False  True  True False False False False False False False
  True  True False  True  True  True False  True False False False  True
 False Fal