In [2]:
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
import tensorflow as tf

import numpy as np
import random

In [72]:
labels = np.array([
    [0, 1, 1, 0], # XOR
    [1, 0, 0, 1], # XNOR
    [0, 0, 0, 1], # AND
    [0, 1, 1, 1], # OR
    [1, 0, 0, 0], # NOR
    [1, 1, 1, 0], # NAND
    [1, 0, 1, 0], # Custom 1
    [0, 1, 0, 1]  # Custom 2
])

x_train = np.array([[-1, -1], [-1, 1], [1, -1], [1, 1]])

In [148]:
def loss_fn(y, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_pred)

In [149]:
tf.random.set_seed(0)
random.seed(0)

In [150]:
inp = Input((2,))
x = Dense(40, activation="relu")(inp)
x = Dense(1, activation="sigmoid")(x)
model = Model(inputs=inp, outputs=x)

In [151]:
x = x_train
y = labels[:1]

In [166]:
def calc_context_loss(context_layer_idx, gradients):
        delta_at_next_layer = gradients[context_layer_idx + 1]
        transpose_of_weights_at_next_layer = tf.transpose(model.layers[context_layer_idx + 1].get_weights()[0])
        context_delta = np.dot(delta_at_next_layer, transpose_of_weights_at_next_layer).astype(np.float)
        error = tf.keras.losses.mean_squared_error(np.zeros(len(context_delta)), context_delta)
        return error

In [192]:
def calc_gradients(x_train, y_train, batch_size):
    # Calculate the total number of batches that need to be processed
    context_loss = 0.0
    grads = []
    num_batches = int(np.ceil(len(x_train) / batch_size))

    # Tensorflow 2 style training -- info can be found here: https://www.tensorflow.org/guide/effective_tf2 
    # This is similar to model.fit(), however this is a custom training loop -- ie. it does things differently than model.fit()
    # look at each input and label (there are 4 for the logic gates)
    for start, end in ((s*batch_size, (s + 1)*batch_size) for s in range(num_batches)):

        # Slice into batch
        x = x_train[start:end]
        y = y_train[start:end]

        with tf.GradientTape() as tape:
            predictions = model(x, training=True) # Forward pass
            loss = loss_fn(y, predictions) # Get the loss

        # Extract the gradients for the loss of the current sample
        gradients = tape.gradient(loss, model.trainable_variables)
        
        grads.append(gradients)

        context_loss += calc_context_loss(0, gradients)
                
    return context_loss, grads
        


In [234]:
context_loss, grads = calc_gradients(x, y, 1)

In [235]:
context_loss

<tf.Tensor: shape=(), dtype=float64, numpy=5.1966925698251475e-08>

In [237]:
total_grads = np.array(grads[0])
for i in range(1, len(grads)):
    total_grads = np.add(total_grads, grads[i]) / 2
# total_grads /= len(grads)
total_grads

array([<tf.Tensor: shape=(2, 40), dtype=float32, numpy=
array([[ 5.8108467e-06, -2.9225854e-05,  3.8266862e-05,  0.0000000e+00,
         0.0000000e+00,  3.1838561e-06, -5.0916820e-05,  0.0000000e+00,
         0.0000000e+00, -1.1877867e-05,  0.0000000e+00,  1.4344028e-05,
         0.0000000e+00,  3.7515896e-05, -4.1118659e-05,  4.8243091e-05,
        -3.0265262e-05,  0.0000000e+00, -3.9743773e-05, -3.6786765e-05,
         0.0000000e+00,  3.6138696e-05,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  2.7895858e-05,
         2.6317586e-05,  0.0000000e+00, -4.2991433e-06,  1.8742547e-05,
         0.0000000e+00,  1.6616352e-05,  2.9790192e-05,  0.0000000e+00,
        -2.9765752e-05,  0.0000000e+00,  0.0000000e+00, -4.2180283e-05],
       [ 5.8108467e-06, -2.9225854e-05,  3.8266862e-05,  0.0000000e+00,
         0.0000000e+00,  3.1838561e-06, -5.0916820e-05,  0.0000000e+00,
         0.0000000e+00, -1.1877867e-05,  0.0000000e+00,  1.4344028e-05,
       