In [14]:
import tensorflow as tf

class DenseLayer(tf.Module):
    def __init__(self, outputs):
        super().__init__()
        self.outputs = outputs
        self.fl_init = False
    def __call__(self, x):
        if not self.fl_init:
            self.w = tf.random.truncated_normal((x.shape[-1], self.outputs), stddev = 1, name = 'w')
            self.b = tf.zeros([self.outputs], dtype = tf.float32, name = 'b')

            self.w = tf.Variable(self.w)
            self.b = tf.Variable(self.b)

            self.fl_init = True
        y = tf.matmul(x, self.w) + self.b
        
        return y

In [15]:
class DenseNetwork(tf.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = []
        for i in range(len(layer_sizes) - 1):
            self.layers.append(DenseLayer(layer_sizes[i]))

    def __call__(self, inputs):
        for layer in self.layers:
            inputs = tf.nn.relu(layer(inputs))
        return inputs

In [16]:
model = DenseNetwork([32,1])
x_train = tf.random.uniform(minval = 0, maxval = 10, shape = (100, 2))
y_train = [a+b for a,b in x_train]

loss = lambda x,y: tf.reduce_mean(tf.square(x-y))
opt = tf.optimizers.SGD(learning_rate=0.01)
epochs =  100

for n in range(epochs):
    for x, y in zip(x_train, y_train):
        x = tf.expand_dims(x, axis = 0)
        y = tf.constant(y, shape = (1, 1))

        with tf.GradientTape() as tape:
            f_loss = loss(y, model(x))

        grads = tape.gradient(f_loss, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

    print(f_loss.numpy())


59.952904
51.192356
47.685574
43.69117
43.67299
43.662746
43.65944
43.658455
43.65821
43.658176
43.65819
43.658207
43.658215
43.65822
43.65821
43.658203
43.65819
43.65818
43.658165
43.658157
43.658142
43.65813
43.658115
43.658104
43.658092
43.65808
43.658066
43.65806
43.658047
43.658035
43.658024
43.658012
43.658
43.65799
43.65798
43.657967
43.65796
43.657948
43.65794
43.657932
43.65792
43.657913
43.6579
43.657894
43.657883
43.65788
43.657867
43.65786
43.657852
43.657845
43.657837
43.65783
43.65782
43.657814
43.657806
43.6578
43.65779
43.657784
43.657776
43.65777
43.657764
43.657757
43.657753
43.65774
43.657738
43.65773
43.657722
43.657722
43.657715
43.657707
43.6577
43.657696
43.657692
43.657684
43.65768
43.657673
43.65767
43.657665
43.657658
43.657654
43.657646
43.657646
43.65764
43.65764
43.65763
43.657623
43.657623
43.657616
43.65761
43.657608
43.657604
43.6576
43.6576
43.657593
43.65759
43.657585
43.657585
43.657578
43.657574
43.65757
