Permalink
Fetching contributors…
Cannot retrieve contributors at this time
79 lines (61 sloc) 2.67 KB
from __future__ import absolute_import
from __future__ import print_function
import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.norm as norm
from autograd import grad
from autograd.misc import flatten
from autograd.misc.optimizers import adam
def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
"""Build a list of (weights, biases) tuples, one for each layer."""
return [(rs.randn(insize, outsize) * scale, # weight matrix
rs.randn(outsize) * scale) # bias vector
for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]
def nn_predict(params, inputs, nonlinearity=np.tanh):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = nonlinearity(outputs)
return outputs
def log_gaussian(params, scale):
flat_params, _ = flatten(params)
return np.sum(norm.logpdf(flat_params, 0, scale))
def logprob(weights, inputs, targets, noise_scale=0.1):
predictions = nn_predict(weights, inputs)
return np.sum(norm.logpdf(predictions, targets, noise_scale))
def build_toy_dataset(n_data=80, noise_std=0.1):
rs = npr.RandomState(0)
inputs = np.concatenate([np.linspace(0, 3, num=n_data/2),
np.linspace(6, 8, num=n_data/2)])
targets = np.cos(inputs) + rs.randn(n_data) * noise_std
inputs = (inputs - 4.0) / 2.0
inputs = inputs[:, np.newaxis]
targets = targets[:, np.newaxis] / 2.0
return inputs, targets
if __name__ == '__main__':
init_scale = 0.1
weight_prior_variance = 10.0
init_params = init_random_params(init_scale, layer_sizes=[1, 4, 4, 1])
inputs, targets = build_toy_dataset()
def objective(weights, t):
return -logprob(weights, inputs, targets)\
-log_gaussian(weights, weight_prior_variance)
print(grad(objective)(init_params, 0))
# Set up figure.
fig = plt.figure(figsize=(12,8), facecolor='white')
ax = fig.add_subplot(111, frameon=False)
plt.show(block=False)
def callback(params, t, g):
print("Iteration {} log likelihood {}".format(t, -objective(params, t)))
# Plot data and functions.
plt.cla()
ax.plot(inputs.ravel(), targets.ravel(), 'bx', ms=12)
plot_inputs = np.reshape(np.linspace(-7, 7, num=300), (300,1))
outputs = nn_predict(params, plot_inputs)
ax.plot(plot_inputs, outputs, 'r', lw=3)
ax.set_ylim([-1, 1])
plt.draw()
plt.pause(1.0/60.0)
print("Optimizing network parameters...")
optimized_params = adam(grad(objective), init_params,
step_size=0.01, num_iters=1000, callback=callback)