In [5]:
import jax.numpy as jnp
from jax import grad, jit, vmap, jacfwd
from jax import random

key = random.PRNGKey(0)

In [21]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def critic(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

def generator(W, b, noise):
    return jnp.tanh(jnp.dot(noise, W) + b)

# Build a toy dataset.
noise = jnp.array([[0.52, 0.12,  0.77],
                   [0.88, -0.08, 0.15],
                   [0.52, 0.06, -0.30],
                   [0.74, -0.49, 0.39]])
inputs = jnp.array([[0.52, 0.12,  0.77],
                   [0.88, -0.08, 0.15],
                   [0.52, 0.06, -0.30],
                   [0.74, -0.49, 0.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b, data):
    preds = critic(W, b, data)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

def genLoss(Wg, bg, Wc, bc, noise):
    fake_data = generator(Wg, bg, noise)
    return loss(Wc, bc, fake_data)

# Initialize random model coefficients
key, W_key_c, b_key_c, W_key_g, b_key_g = random.split(key, 5)
Wc = random.normal(W_key_c, (3,))
bc = random.normal(b_key_c, ())
Wg = random.normal(W_key_g, (3,3))
bg = random.normal(b_key_g, ())

In [22]:
# Differentiate `loss` with respect to the first positional argument:
W_grad, b_grad = grad(loss, (0, 1))(Wc, bc, inputs)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-1.1221914   0.24741234 -0.8037301 ]
b_grad -1.4731382


In [34]:
def i_grad(ins):
    return grad(critic,2)(Wc,bc,ins)

@jit
def vmap_batched_grad(Wg, bg, noise):
    inputs = generator(Wg, bg, noise)
    return jnp.mean(vmap(i_grad)(inputs))

fn = vmap_batched_grad(Wg, bg, noise)

In [35]:
grad(vmap_batched_grad,(0,1))(Wg, bg, noise)

(DeviceArray([[ 0.0056585 ,  0.01171166, -0.00460712],
              [-0.00060031, -0.00234188,  0.00013395],
              [ 0.00124006,  0.00499358, -0.00217985]], dtype=float32),
 DeviceArray(0.01692257, dtype=float32))

In [36]:
fn

DeviceArray(-0.08819881, dtype=float32)