In [2]:
import jax
import jax.numpy as jnp
from jax import random
import hard_not
from flax import linen as nn

In [8]:
def hard_not(w, x):
    return 1.0 - w + x * (2.0 * w - 1.0)

[hard_not(0.0, 1.0), hard_not(1.0, 0.0), hard_not(0.0, 0.0), hard_not(1.0, 1.0)]

[0.0, 0.0, 1.0, 1.0]

In [11]:
hard_not_neuron = jax.vmap(hard_not, 0, 0)

hard_not_neuron(jnp.array([0.0, 1.0, 0.0, 1.0]), jnp.array([1.0, 0.0, 0.0, 1.0]))

DeviceArray([0., 0., 1., 1.], dtype=float32)

In [25]:
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)

hard_not_layer(jnp.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]]), jnp.array([0.0, 1.0, 0.0, 1.0]))

DeviceArray([[0., 0., 1., 1.],
             [0., 0., 1., 1.]], dtype=float32)

In [7]:
from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple,
                    Union)

In [54]:
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Array = Any

class HardNeuralNOT(nn.Module):
    """A Not layer than transforms its inputs along the last dimension.

    Attributes:
        kernel_init: initializer function for the weight matrix.
        dtype: the dtype of the computation (default: infer from input and params).
        param_dtype: the dtype passed to parameter initializers (default: float32).
    """
    layer_size: int
    dtype: Optional[Dtype] = None
    param_dtype: Dtype = jnp.float32
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.lecun_normal()

    @nn.compact
    def __call__(self, x: Array) -> Array:
        kernel = self.param('kernel',
                        self.kernel_init,
                        (self.layer_size, jnp.shape(x)[-1]),
                        self.param_dtype)
        threaded_differentiable_hard_not = jax.vmap(hard_not.differentiable_hard_not, in_axes=1, out_axes=0)
        return threaded_differentiable_hard_not(x, kernel)
    

In [55]:
hnn = HardNeuralNOT(layer_size=4)

In [56]:
x = jax.numpy.ones((1, 2))
print(x)

[[1. 1.]]


In [59]:
weights = hnn.init(random.PRNGKey(0), x)
print(weights)

FrozenDict({
    params: {
        kernel: DeviceArray([[-0.34059498,  0.59856474],
                     [-0.07577372, -0.91646206],
                     [ 0.15280248,  0.2605774 ],
                     [ 0.71123385, -0.7111998 ]], dtype=float32),
    },
})


In [60]:
y = hnn.apply(weights, x)
print(y)

[DeviceArray([-0.340595  ,  0.59856474], dtype=float32)]
