## 完全手写优化过程

In [1]:
import tensorflow as tf
from tensorflow import keras as K
tf.enable_eager_execution()

tf.Variable可以保存模型的参数

In [2]:
class Model(K.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.W = tf.Variable(5., name = 'weight')
        self.B = tf.Variable(10., name = 'bias')
    def call(self, inputs):
        return inputs * self.W + self.B

定义计算loss和计算梯度函数

In [3]:
def loss(model, inputs, targets):
    error = model(inputs) - targets
    return tf.reduce_mean(tf.square(error))

In [4]:
def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return tape.gradient(loss_value, [model.W, model.B])

训练

In [5]:
model = Model()

In [6]:
optimizer = tf.train.GradientDescentOptimizer(learning_rate = 0.01)

In [7]:
NUM_EXAMPLES = 2000
training_inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
training_outputs = training_inputs * 3 + 2 + noise

In [8]:
print('init loss: {:.3f}'.format(loss(model, training_inputs, training_outputs)))

Instructions for updating:
Colocations handled automatically by placer.
init loss: 69.426


In [12]:
for i in range(300):
    grads = grad(model, training_inputs, training_outputs)
    optimizer.apply_gradients(zip(grads, [model.W, model.B]),
                             global_step = tf.train.get_or_create_global_step())
    if i % 20 == 0:
        print("Loss at step {:03d}: {:.3f}".format(i, loss(model, training_inputs, training_outputs)))

Loss at step 000: 1.052
Loss at step 020: 1.051
Loss at step 040: 1.051
Loss at step 060: 1.051
Loss at step 080: 1.051
Loss at step 100: 1.051
Loss at step 120: 1.051
Loss at step 140: 1.051
Loss at step 160: 1.051
Loss at step 180: 1.051
Loss at step 200: 1.051
Loss at step 220: 1.051
Loss at step 240: 1.051
Loss at step 260: 1.051
Loss at step 280: 1.051


In [10]:
print('Final loss: {:.3f}'.format(loss(model, training_inputs, training_outputs)))

Final loss: 1.052


In [11]:
print('W = {}, B = {}'.format(model.W.numpy(), model.B.numpy()))

W = 2.99423885345459, B = 1.9894063472747803
