## Imports

In [1]:
import tensorflow as tf
import jax
import jax.numpy as jnp
from flax import linen as nn
from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_and, hard_not, primitives, symbolic_primitives
from tests import test_mnist
tf.config.experimental.set_visible_devices([], "GPU")
import numpy

2023-01-05 11:24:33.805733: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-05 11:24:34.719410: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-05 11:24:34.719533: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
import ml_collections

In [3]:
from neurallogic import sym_gen

In [4]:
from jax.interpreters import xla

In [14]:
# clear the GPU memory
from numba import cuda
cuda.select_device(0)
cuda.close()

# Sandpit

In [5]:
def examine_jaxpr(closed_jaxpr):
   jaxpr = closed_jaxpr.jaxpr
   print("invars:", jaxpr.invars)
   print("outvars:", jaxpr.outvars)
   print("constvars:", jaxpr.constvars)
   for eqn in jaxpr.eqns:
     print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
   print()
   print("jaxpr:", jaxpr)

In [6]:
train_ds, test_ds = test_mnist.get_datasets()
train_ds["image"] = jnp.reshape(train_ds["image"], (train_ds["image"].shape[0], -1))
test_ds["image"] = jnp.reshape(test_ds["image"], (test_ds["image"].shape[0], -1))

In [7]:
def nln(type, x, width):
    x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jnp.float32)(x) 
    #x = hard_not.not_layer(type)(10, dtype=jnp.float32)(x)
    #x = primitives.nl_ravel(type)(x) 
    #x = harden_layer.harden_layer(type)(x) 
    #x = primitives.nl_reshape(type)((10, width))(x) 
    #x = primitives.nl_sum(type)(-1)(x) 
    return x

def batch_nln(type, x, width):
    return jax.vmap(lambda x: nln(type, x, width))(x)

In [8]:
width = 10
soft, hard, _ = neural_logic_net.net(lambda type, x: nln(type, x, width))

In [9]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
mock_input = harden.harden(jnp.ones([28 * 28]))
hard_weights = harden.hard_weights(soft.init(rng, mock_input))
hard_weights

FrozenDict({
    params: {
        HardOrLayer_0: {
            weights: DeviceArray([[ True, False, False, ...,  True,  True,  True],
                         [False, False, False, ...,  True, False, False],
                         [False,  True,  True, ...,  True,  True, False],
                         ...,
                         [False,  True, False, ...,  True, False, False],
                         [False,  True, False, ...,  True,  True, False],
                         [ True, False,  True, ..., False, False, False]], dtype=bool),
        },
    },
})

In [10]:

jaxpr = jax.make_jaxpr(lambda x: hard.apply(hard_weights, x))(harden.harden(test_ds['image'][0]))

In [11]:
examine_jaxpr(jaxpr)

invars: [b]
outvars: [d]
constvars: [a]
equation: [a, b] xla_call [c] {'device': None, 'backend': None, 'name': 'hard_or_include', 'donated_invars': (False, False), 'inline': False, 'keep_unused': False, 'call_jaxpr': { lambda ; a:bool[10,784] b:bool[784]. let
    c:bool[1,784] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 784)] b
    d:bool[10,784] = and c a
  in (d,) }}
equation: [c] reduce_or [d] {'axes': (1,)}

jaxpr: { lambda a:bool[10,784]; b:bool[784]. let
    c:bool[10,784] = xla_call[
      call_jaxpr={ lambda ; d:bool[10,784] e:bool[784]. let
          f:bool[1,784] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 784)
          ] e
          g:bool[10,784] = and f d
        in (g,) }
      name=hard_or_include
    ] a b
    h:bool[10] = reduce_or[axes=(1,)] c
  in (h,) }


In [12]:
hard_mock_input = harden.harden(test_ds['image'][0])
hard_output = hard.apply(hard_weights, hard_mock_input)
#print("hard_output shape:", hard_output.shape)
#print("hard_output:", hard_output)
eval_hard_output, eval_symbolic_output = sym_gen.eval_jaxpr(False, jaxpr.jaxpr, jaxpr.literals, hard_mock_input)
#print("eval_hard_output:", eval_hard_output)
#print("eval_symbolic_output:", eval_symbolic_output)
assert numpy.array_equal(numpy.array(eval_hard_output), eval_symbolic_output)
print("SUCCESS: jax primitives and symbolic primitives are identical.")
standard_jax_output = hard.apply(hard_weights, hard_mock_input)
#print("standard_jax_output", standard_jax_output)
#print("eval_hard_output", numpy.array(eval_hard_output))
assert jax.numpy.array_equal(numpy.array(eval_hard_output), standard_jax_output)
print("SUCCESS: non-standard evaluation is identical to standard evaluation of jaxpr.")
symbolic_mock_input = symbolic_primitives.to_boolean_string(hard_mock_input)
#print("symbolic_mock_input:", symbolic_mock_input)
#print("type of symbolic_mock_input = ", type(symbolic_mock_input))
#print("type of element = ", symbolic_mock_input.dtype)
symbolic_jaxpr_literals = symbolic_primitives.to_boolean_string(jaxpr.literals)
#print("jaxpr.literals = ", symbolic_jaxpr_literals)
#print("type of jaxpr.literals = ", type(symbolic_jaxpr_literals))
#print("type of element = ", symbolic_jaxpr_literals.dtype)
eval_symbolic_output = sym_gen.eval_jaxpr(True, jaxpr.jaxpr, symbolic_jaxpr_literals, symbolic_mock_input)
# assert the dimensions of eval_hard_output and eval_symbolic_output are the same
eval_hard_output = numpy.array(eval_hard_output)
print("eval_hard_output", eval_hard_output)
#print("type of eval_hard_output = ", type(eval_hard_output))
#print("eval_symbolic_output:", eval_symbolic_output)
#print("type of eval_symbolic_output = ", type(eval_symbolic_output))
#print("shape of eval_hard_output = ", eval_hard_output.shape)
#print("shape of eval_symbolic_output = ", eval_symbolic_output.shape)
assert numpy.array_equal(eval_hard_output.shape, eval_symbolic_output.shape)
print("SUCCESS: dimensions of non-standard evaluation and standard evaluation of jaxpr are identical.")
# assert the values of eval_hard_output and eval_symbolic_output are the same
reduced_eval_symbolic_output = symbolic_primitives.symbolic_eval(eval_symbolic_output)
print("reduced_eval_symbolic_output:", reduced_eval_symbolic_output)
assert numpy.array_equal(eval_hard_output, reduced_eval_symbolic_output)
print("SUCCESS: values of symbolic evaluation and standard evaluation of jaxpr are identical.")

SUCCESS: jax primitives and symbolic primitives are identical.
SUCCESS: non-standard evaluation is identical to standard evaluation of jaxpr.
eval_hard_output [ True  True  True  True  True  True  True  True  True  True]
SUCCESS: dimensions of non-standard evaluation and standard evaluation of jaxpr are identical.
reduced_eval_symbolic_output: [ True  True  True  True  True  True  True  True  True  True]
SUCCESS: values of symbolic evaluation and standard evaluation of jaxpr are identical.
