In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb

tf.random.set_seed(0)

In [71]:
train_X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
train_y = np.array([[0], [1], [1], [0]], dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((train_X, train_y)).batch(len(train_X))

W1 = tf.Variable(tf.random.normal([train_X.shape[1], train_y.shape[1]], dtype=tf.float32), name="weight1")
b1 = tf.Variable(tf.random.normal([train_y.shape[1]], dtype=tf.float32), name="bias1")
W2 = tf.Variable(tf.random.normal([train_X.shape[1], train_y.shape[1]], dtype=tf.float32), name="weight2")
b2 = tf.Variable(tf.random.normal([train_y.shape[1]], dtype=tf.float32), name="bias2")
W3 = tf.Variable(tf.random.normal([train_X.shape[1], train_y.shape[1]], dtype=tf.float32), name="weight3")
b3 = tf.Variable(tf.random.normal([train_y.shape[1]], dtype=tf.float32), name="bias3")

vrbls = [W1, b1, W2, b2, W3, b3]

def hyp(features):
    layer1 = tf.sigmoid(tf.matmul(features, W1) + b1)
    layer2 = tf.sigmoid(tf.matmul(features, W2) + b2)
    layer3 = tf.concat([layer1, layer2], axis=1)
    layer3 = tf.reshape(layer3, shape=[-1, 2])
    h = tf.sigmoid(tf.matmul(layer3, W3) + b3)
    return h

def loss_fn(features, labels):
    h = hyp(features)
    loss = - tf.reduce_mean(labels * tf.math.log(h) + (1 - labels) * tf.math.log(1 - h))
    return loss

def grad_fn(features, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(features, labels)
        grads = tape.gradient(loss, vrbls)
        return grads

def fit(X, Y, n_epochs=5000):
    for step in range(n_epochs):
        for features, labels in iter(dataset):
            loss = loss_fn(features, labels)
            grads = grad_fn(features, labels)
            optimizer = tf.keras.optimizers.SGD(lr=0.4)
            optimizer.apply_gradients(zip(grads, vrbls))
            if (step + 1) % 1000 == 0:
                print("{:>6d} | {:>8.6f}".format(step+1, loss))

In [72]:
fit(train_X, train_y)

  1000 | 0.218325
  2000 | 0.025839
  3000 | 0.012752
  4000 | 0.008380
  5000 | 0.006217
