In [1]:
import numpy as np
import tensorflow as tf
from bokeh.io import output_notebook, show
from bokeh.plotting import figure

In [2]:
output_notebook()

In [3]:
class Model:
    def __init__(self):
        self.W = tf.Variable(5.)
        self.b = tf.Variable(0.)
        
    def __call__(self, x):
        return self.W * x + self.b

In [4]:
model = Model()

In [5]:
model.summary()

AttributeError: 'Model' object has no attribute 'summary'

In [4]:
model(3.)

<tf.Tensor: id=19, shape=(), dtype=float32, numpy=15.0>

In [5]:
def loss(y_hat, y):
    return tf.reduce_mean(tf.square(y_hat - y))

In [6]:
TRUE_W = 3.
TRUE_b = 2.
NUM_EXAMPLES = 1000

In [7]:
inputs = tf.random.normal(shape=[NUM_EXAMPLES])
noise = tf.random.normal(shape=[NUM_EXAMPLES])
outputs = inputs * TRUE_W + TRUE_b

In [13]:
p = figure(plot_height=300, plot_width=300)
p.scatter(inputs.numpy(), outputs.numpy())
show(p)

In [15]:
loss_val = loss(model(inputs), outputs).numpy()
print(f"Current loss: {loss_val:.3f}")

Current loss: 7.204


In [19]:
def train(model, inputs, outputs, lr):
    with tf.GradientTape() as t:
        loss_val = loss(model(inputs), outputs)
    dW, db = t.gradient(loss_val, [model.W, model.b])
    model.W.assign_sub(lr * dW)
    model.b.assign_sub(lr * db)

In [20]:
model = Model()
Ws, bs = [], []
epochs = range(10)
for epoch in epochs:
    Ws.append(model.W.numpy())
    bs.append(model.b.numpy())
    epoch_loss = loss(model(inputs), outputs)
    train(model, inputs, outputs, lr=0.1)
    print(f"Epoch {epoch}: W={Ws[-1]:.3f} b={bs[-1]:3f}, loss={epoch_loss:.3f}")

Epoch 0: W=5.000 b=0.000000, loss=7.204
Epoch 1: W=4.649 b=0.368940, loss=4.842
Epoch 2: W=4.359 b=0.669549, loss=3.254
Epoch 3: W=4.119 b=0.914541, loss=2.188
Epoch 4: W=3.922 b=1.114248, loss=1.471
Epoch 5: W=3.760 b=1.277077, loss=0.989
Epoch 6: W=3.625 b=1.409866, loss=0.665
Epoch 7: W=3.515 b=1.518179, loss=0.447
Epoch 8: W=3.424 b=1.606545, loss=0.301
Epoch 9: W=3.349 b=1.678651, loss=0.202


In [30]:
p = figure(plot_height=400, plot_width=400, x_axis_label="Epochs")
p.line(range(len(Ws)), Ws, color="darkorange", legend="W")
p.line(range(len(bs)), bs, color="forestgreen", legend="b")
p.line(range(len(Ws)), TRUE_W, color="darkorange", line_dash="dashed", legend="True W")
p.line(range(len(bs)), TRUE_b, color="forestgreen", line_dash="dashed", legend="True b")
show(p)