In [6]:
import jax
import jax.numpy as jnp

In [7]:

#import & utils

import jax
import jax.numpy as jp
import optax

input_n, output_n = 8, 8
case_n = 1<<input_n
arity, layer_width, layer_n = 4, 64, 5
layer_sizes = [(input_n, 1)] + [(layer_width, arity)]*(layer_n-1) + [
    (layer_width//2, arity//2), (output_n, 1)]

def gen_wires(key, in_n, out_n, arity, group_size):
    edge_n = out_n*arity//group_size
    n = max(in_n, edge_n)
    return  jax.random.permutation(key, n)[:edge_n].reshape(arity,-1) % in_n

def make_nops(gate_n, arity, group_size, nop_scale=3.0):
    I = jp.arange(1<<arity)
    bits = (I>>I[:arity,None])&1
    luts = bits[jp.arange(gate_n)%arity]
    logits = (2.0*luts-1.0) * nop_scale
    return logits.reshape(gate_n//group_size, group_size, -1)

@jax.jit
def run_layer(lut, inputs):
    # lut:[group_n, group_size, 1<<arity], [arity, ... , group_n]
    for x in inputs:
        x = x[...,None,None]
        lut = (1.0-x)*lut[...,::2] + x*lut[...,1::2]
    # [..., group_n, group_size, 1]
    return lut.reshape(*lut.shape[:-3]+(-1,))

def run_circuit(logits, wires, x, hard=False):
    acts = [x]
    for ws, lgt in zip(wires, logits):
        luts = jax.nn.sigmoid(lgt)
        if hard:
            luts = jp.round(luts)
        x = run_layer(luts, [x[...,w] for w in ws])
        acts.append(x)
    return acts

def gen_circuit(key, layer_sizes, arity=4):
  in_n = layer_sizes[0][0]
  all_wires, all_logits = [], []
  for out_n, group_size in layer_sizes:
      wires = gen_wires(key, in_n, out_n, arity, group_size)
      logits = make_nops(out_n, arity, group_size)
      _, key = jax.random.split(key)
      in_n = out_n
      all_wires.append(wires)
      all_logits.append(logits)
  return all_wires, all_logits

def unpack(x, bit_n=8):
    return jp.float32((x[...,None] >> np.r_[:bit_n])&1)

def res2loss(res):
    return jp.square(jp.square(res)).sum()

def loss_f(logits, wires, x, y0):
    act = run_circuit(logits, wires, x)
    y = act[-1]
    res = y-y0
    return res2loss(res), dict(act=act)
grad_loss_f = jax.jit(jax.value_and_grad(loss_f, has_aux=True))


key = jax.random.PRNGKey(42)
wires, logits0 = gen_circuit(key, layer_sizes)

opt = optax.adamw(1.0, 0.8, 0.8, weight_decay=0.1)

TrainState = namedtuple('TrainState', 'params opt_state')
state = TrainState(params=logits0, opt_state=opt.init(logits0))


x = jp.arange(case_n)
y0 = (x&0xf) * (x>>4)
x, y0 = unpack(x), unpack(y0)

def train_step(state):
    logits, opt_state = state
    (loss, aux), grad = grad_loss_f(logits, wires, x, y0)
    upd, opt_state = opt.update(grad, opt_state, logits)
    logits = optax.apply_updates(logits, upd)
    return loss, TrainState(logits, opt_state)

loss_log = []

imshow(zoom(x.T, 4))
imshow(zoom(y0.T, 4))

for i in range(100):
  loss, state = train_step(state)
  loss_log.append(loss)
  if i%10 == 0:
    print(i, loss.item())

0 6.036703586578369
10 5.957761287689209
20 5.989912033081055
30 5.984184265136719
40 5.9712748527526855
50 5.9743523597717285
60 6.052950382232666
70 5.951821327209473
80 5.833841800689697
90 5.826268196105957

Colab paid products - Cancel contracts here


4