In [51]:
import tensorflow as tf
from tensorflow.keras.models import Model

In [52]:
class SyntheticData:
    def __init__(self, w, b, num_train, num_val, noise=0.3):
        num = num_train + num_val
        self.X = tf.random.normal((num, len(w)))
        self.y = self.X @ tf.reshape(w, [len(w), 1]) + tf.random.normal((num, 1)) * noise + b
        self.X_train, self.X_val = self.X[:num_train], self.X[num_train:]
        self.y_train, self.y_val = self.y[:num_train], self.y[num_train:]

    def train_dataloader(self, batch_size):
        return tf.data.Dataset.from_tensor_slices((self.X_train, self.y_train)).batch(batch_size=batch_size)

    def val_dataloader(self, batch_size):
        return tf.data.Dataset.from_tensor_slices((self.X_val, self.y_val)).batch(batch_size=batch_size)

In [53]:
class MyModel(tf.keras.Model):
    def __init__(self, lr, decay):
        super().__init__()
        self.decay = decay
        self.net = tf.keras.layers.Dense(1, kernel_initializer=tf.initializers.RandomNormal(),
                                         kernel_regularizer=tf.keras.regularizers.L2(1.0))
        self.loss = tf.keras.losses.MeanSquaredError()
        self.optimizer = tf.optimizers.SGD(learning_rate=lr)

    def my_loss(self, y, y_hat):
        return tf.keras.losses.MeanSquaredError()(y, y_hat) + self.net.losses

    def forward(self, X):
        return self.net(X)


In [54]:
class Trainer:
    def __init__(self):
        self.train_errors = []
        self.val_errors = []

    def my_fit(self, data: SyntheticData, model: MyModel, epochs, batch_size):
        for epoch in range(epochs):
            self.train_errors.append(tf.constant(0.0))
            self.val_errors.append(tf.constant(0.0))
            for batch_X, batch_y in data.train_dataloader(batch_size):
                with tf.GradientTape() as g:
                    loss = model.my_loss(batch_y, model.forward(batch_X))
                    self.train_errors[-1] += loss
                grad = g.gradient(loss, model.trainable_variables)
                model.optimizer.apply_gradients(zip(grad, model.trainable_variables))
            for batch_X, batch_y in data.val_dataloader(batch_size):
                loss = model.my_loss(batch_y, model.forward(batch_X))
                self.val_errors[-1] += loss


In [55]:
w = tf.constant([0.0, 1.0, 1.0, 1.0])
b = tf.constant(3.0)
sdata = SyntheticData(w, b, 50, 50)

In [56]:
model = MyModel(0.05, 2.0)
trainer = Trainer()
trainer.my_fit(sdata, model, 7, 20)

In [57]:
trainer.train_errors, trainer.val_errors

([<tf.Tensor: shape=(1,), dtype=float32, numpy=array([19.888708], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([12.894963], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([9.019101], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.779762], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.4633875], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.683413], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.2190185], dtype=float32)>],
 [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([30.917377], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([23.254818], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([18.258463], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([14.961313], dtype=float32)>,
  <tf.Tensor: shape=(1,), dtype=float32, numpy=array([12.76186], dtype=float32)>,
  <tf.T