In [8]:
# TF CODE
import tensorflow as tf
import numpy as np

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but Tensorflow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.AdamOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.initialize_all_variables()

# Launch the graph.
sess = tf.Session()
sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))

0 [0.60629356] [0.49999976]
20 [0.01879729] [0.27725616]
40 [0.07610643] [0.27667654]
60 [0.11405155] [0.30311984]
80 [0.10251226] [0.30413422]
100 [0.10005254] [0.30157742]
120 [0.10037202] [0.30064023]
140 [0.0999805] [0.29993704]
160 [0.10001576] [0.29996943]
180 [0.10000985] [0.29998872]
200 [0.10001373] [0.30000582]


In [0]:
from jax import jit, grad, vmap, random
from functools import partial
import jax
import jax.numpy as np
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax, LeakyRelu, Dropout # neural network layers
import matplotlib.pyplot as plt # visualization
import numpy as onp
from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

In [0]:
#Generate data
rng = random.PRNGKey(1)
x_data = random.normal(rng, (100,1))
y_data = x_data * 0.1 + 0.3

In [0]:
#Define model
net_init, net_apply = stax.serial(Dense(1))
in_shape = (-1, 1,)
out_shape, net_params = net_init(rng, in_shape)

In [0]:
#Define losses and optimizers
def loss(params, inputs, targets):
    # Computes average loss for the batch
    predictions = net_apply(params, inputs)
    return np.mean((targets - predictions)**2)

opt_init, opt_update, get_params = optimizers.adam(step_size=0.5)  # this LR seems to be better than 1e-2 and 1e-4
out_shape, net_params = net_init(rng, in_shape)
opt_state = opt_init(net_params)

@jit
def step(i, opt_state, x, y):
    p = get_params(opt_state)
    g = grad(loss)(p, x, y)
    l = loss(p, x, y)
    return opt_update(i, g, opt_state), l

In [28]:
#Training
losses=[]
for i in range(201):
    opt_state, l = step(i, opt_state, x_data, y_data)
    losses.append(l)
    if i % 20 == 0:
        print(get_params(opt_state))
net_params=get_params(opt_state)

[(DeviceArray([[0.3528281]], dtype=float32), DeviceArray([0.49513587], dtype=float32))]
[(DeviceArray([[0.09138642]], dtype=float32), DeviceArray([0.39634892], dtype=float32))]
[(DeviceArray([[0.09057112]], dtype=float32), DeviceArray([0.3366886], dtype=float32))]
[(DeviceArray([[0.12353643]], dtype=float32), DeviceArray([0.30404165], dtype=float32))]
[(DeviceArray([[0.10185698]], dtype=float32), DeviceArray([0.29825112], dtype=float32))]
[(DeviceArray([[0.09979293]], dtype=float32), DeviceArray([0.29922235], dtype=float32))]
[(DeviceArray([[0.10034811]], dtype=float32), DeviceArray([0.30057368], dtype=float32))]
[(DeviceArray([[0.10039034]], dtype=float32), DeviceArray([0.30027398], dtype=float32))]
[(DeviceArray([[0.10002711]], dtype=float32), DeviceArray([0.3000829], dtype=float32))]
[(DeviceArray([[0.09994931]], dtype=float32), DeviceArray([0.30000454], dtype=float32))]
[(DeviceArray([[0.10001832]], dtype=float32), DeviceArray([0.29999167], dtype=float32))]
