In [2]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

In [3]:
class Net(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = jnp.reshape(x, [-1, 784])
    x = nn.Dense(8)(x)
    self.sow('intermediates', 'x1', x)
    x = nn.relu(x)
    x = nn.Dense(8)(x)
    self.sow('intermediates', 'x2', x)
    x = nn.relu(x)
    x = nn.Dense(10)(x)
    return x

In [4]:
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [5]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [6]:
def create_train_state(rng, learning_rate):
    """Creates initial `TrainState`."""
    net = Net()
    params = net.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.adamw(learning_rate)
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

In [7]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    def loss_fn(params):
        logits = Net().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    return state, metrics

In [8]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

def forward_batch(state, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  perms = jax.random.permutation(rng, train_ds_size)
  perm = perms[:batch_size]  # skip incomplete batch
  batch = {k: v[perm, ...] for k, v in train_ds.items()}
  _, state = Net().apply({'params': state.params}, batch['image'], mutable=['intermediates'])
  return state['intermediates']

In [9]:
@jax.jit
def eval_step(params, batch):
  logits = Net().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

In [10]:
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(1)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.001
state = create_train_state(init_rng, learning_rate)
del init_rng  # Must not be used anymore.
num_epochs = 10
batch_size = 32
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))



train epoch: 1, loss: 0.5761, accuracy: 83.79
 test epoch: 1, loss: 0.33, accuracy: 90.57
train epoch: 2, loss: 0.3170, accuracy: 90.97
 test epoch: 2, loss: 0.29, accuracy: 91.65
train epoch: 3, loss: 0.2825, accuracy: 91.98
 test epoch: 3, loss: 0.28, accuracy: 92.14
train epoch: 4, loss: 0.2642, accuracy: 92.43
 test epoch: 4, loss: 0.26, accuracy: 92.64
train epoch: 5, loss: 0.2529, accuracy: 92.76
 test epoch: 5, loss: 0.26, accuracy: 92.80
train epoch: 6, loss: 0.2438, accuracy: 92.99
 test epoch: 6, loss: 0.25, accuracy: 92.90
train epoch: 7, loss: 0.2381, accuracy: 93.10
 test epoch: 7, loss: 0.25, accuracy: 93.09
train epoch: 8, loss: 0.2318, accuracy: 93.35
 test epoch: 8, loss: 0.24, accuracy: 93.11
train epoch: 9, loss: 0.2270, accuracy: 93.48
 test epoch: 9, loss: 0.25, accuracy: 93.09
train epoch: 10, loss: 0.2235, accuracy: 93.48
 test epoch: 10, loss: 0.24, accuracy: 92.99


In [11]:
rng, input_rng = jax.random.split(rng)
inters = forward_batch(state, 1000, rng)
x1 = np.sign(inters['x1'][0])
x2 = np.sign(inters['x2'][0])
x = np.concatenate([x1, x2], axis=1)
x = (x + 1) / 2

In [12]:
print(np.mean(x, axis=0))
print(x.shape[0])
print(np.mean(x[x[:, 1] == 0], axis=0))
print(x[x[:, 1] == 0].shape[0])
print(np.mean(x[np.logical_and(x[:, 1] == 0, x[:, 10] == 1)], axis=0))
print(x[np.logical_and(x[:, 1] == 0, x[:, 10] == 1)].shape[0])
print(np.mean(x[((x[:, 1] == 0) & (x[:, 10] == 1) & (x[:, 3] == 1))], axis=0))
print(x[((x[:, 1] == 0) & (x[:, 10] == 1) & (x[:, 3] == 1))].shape[0])
print(np.mean(x[((x[:, 1] == 0) & (x[:, 10] == 1) & (x[:, 3] == 1) & (x[:, 7] == 1))], axis=0))
print(x[((x[:, 1] == 0) & (x[:, 10] == 1) & (x[:, 3] == 1)) & (x[:, 7] == 1)].shape[0])

[0.454 0.971 0.988 0.857 0.677 0.645 0.812 0.764 0.512 0.808 0.953 0.715
 0.785 0.808 0.802 0.992]
1000
[0.13793103 0.         1.         1.         0.31034482 0.44827586
 1.         1.         0.06896552 0.6551724  1.         0.10344828
 1.         1.         1.         0.82758623]
29
[0.13793103 0.         1.         1.         0.31034482 0.44827586
 1.         1.         0.06896552 0.6551724  1.         0.10344828
 1.         1.         1.         0.82758623]
29
[0.13793103 0.         1.         1.         0.31034482 0.44827586
 1.         1.         0.06896552 0.6551724  1.         0.10344828
 1.         1.         1.         0.82758623]
29
[0.13793103 0.         1.         1.         0.31034482 0.44827586
 1.         1.         0.06896552 0.6551724  1.         0.10344828
 1.         1.         1.         0.82758623]
29


In [13]:
def id3(x):
    if x.shape[0] < 80:
        return x.shape[0]
    marginals = np.mean(x, axis=0)
    split_dim = np.argmin(np.abs(marginals - 0.5))
    p = x[x[:, split_dim] == 1].shape[0] / x.shape[0]
    return split_dim, p, id3(x[x[:, split_dim] == 0]), id3(x[x[:, split_dim] == 1])

def viz(t):
    dot = graphviz.Digraph(comment='Plinko Tree')
    def parse_tuple(t, bs='', ps=[]):
        if len(bs) > 0:
            dot.edge(bs[:-1], bs)
        dim, p, left, right = t
        if isinstance(left, tuple):
            l_number = parse_tuple(left, bs + '0', ps=ps+[1-p])
        else:
            dot.node(bs + '0', '{} {}'.format(left, bs + '0'))
            dot.edge(bs, bs + '0')
            l_number = left
        if isinstance(right, tuple):
            r_number = parse_tuple(right, bs + '1', ps=ps+[p])
        else:
            dot.node(bs + '1', '{} {}'.format(right, bs + '1'))
            dot.edge(bs, bs + '1')
            r_number = right
        est_n = (l_number + r_number)
        est_var = (p * (1 - p)) / est_n
        path_prob = np.prod(ps)
        path_var = np.prod(ps) - np.prod(np.square(ps))
        dot.node(bs, "Split at {}\n(n={}, p={:.2f}) Var {:.2e}\nPath P: {:.2e} Var: {:.2e}".format(dim, est_n, p, est_var, path_prob, path_var))
        return l_number + r_number
    parse_tuple(t)
    return dot

t = id3(x)
print(t)
dot = viz(t)
dot.render('doctest-output/round-table.gv', view=True)
dot.source

(8, 0.512, (4, 0.5676229508196722, (3, 0.6066350710900474, (7, 0.42168674698795183, 48, 35), (5, 0.5859375, 53, 75)), (5, 0.5487364620938628, (0, 0.496, 63, 62), (0, 0.3223684210526316, (12, 0.6407766990291263, 37, 66), 49))), (0, 0.548828125, (13, 0.4805194805194805, (7, 0.5083333333333333, 59, 61), (14, 0.4774774774774775, 58, 53)), (14, 0.6263345195729537, (5, 0.5619047619047619, 46, 59), (11, 0.6534090909090909, 61, (9, 0.7565217391304347, 28, (5, 0.7126436781609196, 25, 62))))))


'// Plinko Tree\ndigraph {\n\t"" -> 0\n\t0 -> 00\n\t00 -> 000\n\t0000 [label="48 0000"]\n\t000 -> 0000\n\t0001 [label="35 0001"]\n\t000 -> 0001\n\t000 [label="Split at 7\n(n=83, p=0.42) Var 2.94e-03\nPath P: 8.30e-02 Var: 7.61e-02"]\n\t00 -> 001\n\t0010 [label="53 0010"]\n\t001 -> 0010\n\t0011 [label="75 0011"]\n\t001 -> 0011\n\t001 [label="Split at 5\n(n=128, p=0.59) Var 1.90e-03\nPath P: 1.28e-01 Var: 1.12e-01"]\n\t00 [label="Split at 3\n(n=211, p=0.61) Var 1.13e-03\nPath P: 2.11e-01 Var: 1.66e-01"]\n\t0 -> 01\n\t01 -> 010\n\t0100 [label="63 0100"]\n\t010 -> 0100\n\t0101 [label="62 0101"]\n\t010 -> 0101\n\t010 [label="Split at 0\n(n=125, p=0.50) Var 2.00e-03\nPath P: 1.25e-01 Var: 1.09e-01"]\n\t01 -> 011\n\t011 -> 0110\n\t01100 [label="37 01100"]\n\t0110 -> 01100\n\t01101 [label="66 01101"]\n\t0110 -> 01101\n\t0110 [label="Split at 12\n(n=103, p=0.64) Var 2.23e-03\nPath P: 1.03e-01 Var: 9.24e-02"]\n\t0111 [label="49 0111"]\n\t011 -> 0111\n\t011 [label="Split at 0\n(n=152, p=0.32) Var