In [None]:
import tensorflow as tf
import numpy as np

from tensorflow.keras import Sequential, Model
from tensorflow.keras.layers import Dense, BatchNormalization, Dropout, Activation, Input, Flatten, Conv2D, AveragePooling2D
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.losses import SparseCategoricalCrossentropy, CategoricalCrossentropy

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train/255, x_test/255
y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)

In [None]:
import tensorflow as tf
import functools

from tensorflow.keras.layers import Input, Flatten, Dense, Activation, InputLayer
from tensorflow.keras.losses import CategoricalCrossentropy

class Fow(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    @classmethod
    def build(self, func):
        @functools.wraps(func)
        def wrapper(self, input_shapes):
            
            input_shape = input_shapes[0]
            
            func(self, input_shape)

            self.weights_trainable = []
            self.weights_update = []
            for weight in self.weights:
                if weight.trainable:
                    self.weights_trainable.append(weight)
                    self.weights_update.append(tf.Variable(weight, trainable=False))

            self.prev_inp = tf.Variable(tf.zeros(input_shape), trainable=False)

        return wrapper

    @classmethod
    def call(self, func):
        @functools.wraps(func)
        def wrapper(self, inp, training=False):

            x, grad = inp

            weights = []

            if training: 

                for weight, weight_update, dw in zip(self.weights_trainable, self.weights_update, grad):
                    weight.assign(weight_update)
                    w = weight + dw
                    weights.append(w)
                    weight_update.assign(w)

                self.prev_inp.assign(x)
            
            else:
                weights = self.weights_trainable

            return func(self, x, weights, training)
        
        return wrapper

class Back(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.grad = []

class FowBack(tf.keras.layers.Layer):
    def __init__(self, foward, backward):
        super().__init__()
        self.foward = foward
        self.backward = backward

    def call(self, inp, backward=False, training=False):
        if backward:
            return self.backward([inp, self.foward.weights_trainable, self.foward.prev_inp])
        else:
            return self.foward([inp, self.backward.grad], training=training) 

class FowBackModel(tf.keras.Model):
    def __init__(self, layers):
        super().__init__()
        self._layers = layers
        self.builds()

        batch_size = layers[0].batch_size
        out_dim = self.trainable_weights[-1].shape[-1]
        self.last_y = tf.Variable(tf.zeros([batch_size, out_dim]), trainable=False)
        self.last_pred = tf.Variable(tf.zeros([batch_size, out_dim]), trainable=False)

    def builds(self):
        self(np.zeros(self.layers[0].output_shape[0]))

    @property
    def layers(self):
        return self._layers
    
    def call(self, inputs, training=False):
        
        x = self._layers[0](inputs)

        if training:
            back = self.last_y - self.last_pred
            for layer in reversed(self._layers[1:]):
                if 'fow_back' in layer.name:
                    back = layer(back, backward=True)
        
        for layer in self._layers[1:]:
            x = layer(x, training=training)

        return x
    
    @tf.function
    def train_step(self, data):
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses)
        
        self.last_y.assign(y)
        self.last_pred.assign(y_pred)

        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))

        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)

        return {m.name: m.result() for m in self.metrics}

class FowDense(Fow):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.initializer = tf.keras.initializers.GlorotUniform()
    
    @Fow.build
    def build(self, input_shape):
        self.w = self.add_weight(shape=[input_shape[-1], self.n], initializer=self.initializer, trainable=True)
        self.b = self.add_weight(shape=[self.n], initializer='zeros', trainable=True)
    
    @Fow.call
    def call(self, x, weights, training=False):
        w, b = weights
        return x @ w + b

class BackDense(Back):
    def __init__(self, n, learning_rate=0):
        super().__init__()
        self.n = n
        self.learning_rate = learning_rate
        self.initializer = tf.keras.initializers.GlorotUniform()
    
    def build(self, input_shape):
        self.w = tf.Variable(self.initializer([self.n, input_shape[0][-1]]), trainable=False)
        # self.grad = tf.Variable(self.dense.output_shape[1:], trainable=False)
        pass

    def call(self, inp):
        chain_grad, weights, prev_inp = inp
        w, b = weights
        self.w.assign(w)
        
        # chain_grad = self.dense(chain_grad)
        
        prev_inp_e = tf.expand_dims(prev_inp, axis=-1)
        chain_grad_e = tf.expand_dims(chain_grad, axis=1)

        dw = tf.reduce_mean(prev_inp_e @ chain_grad_e, axis=0)
        dw = dw * self.learning_rate
        
        db = tf.reduce_mean(chain_grad, axis=0)
        db = db * self.learning_rate

        self.grad = [dw, db]
        
        chain_grad = chain_grad @ tf.transpose(self.w)

        return chain_grad

class FowRelu(Fow):
    def __init__(self):
        super().__init__()
    
    @Fow.build
    def build(self, input_shape):
        pass
    
    @Fow.call
    def call(self, x, weights, training=False):
        return tf.maximum(x, 0)

class BackRelu(Back):
    def __init__(self):
        super().__init__()
    
    def build(self, input_shape):
        pass

    def call(self, inp):
        chain_grad, weights, prev_inp = inp
        
        return chain_grad * tf.maximum(prev_inp, 0)

In [None]:
model = FowBackModel((
    InputLayer((28, 28), batch_size=32),
    Flatten(),
    FowBack(FowDense(128), BackDense(784)),
    FowBack(FowRelu(), BackRelu()),
    FowBack(FowDense(10), BackDense(128)),
    Activation('softmax')
))

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
activation (Activation)      (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [None]:
model.trainable_variables

[<tf.Variable 'fow_back_model/fow_back/fow_dense/Variable:0' shape=(784, 128) dtype=float32, numpy=
 array([[-0.00945815, -0.02329049, -0.01490351, ...,  0.06388699,
          0.05249404, -0.02342992],
        [ 0.0017712 ,  0.00118714,  0.05955144, ..., -0.01236433,
          0.03673445,  0.04385818],
        [-0.06888115, -0.01985686,  0.0357481 , ..., -0.05697545,
          0.05334226, -0.05025539],
        ...,
        [-0.07005154, -0.05987211,  0.059344  , ...,  0.02526644,
          0.06492225,  0.02652292],
        [ 0.00087293, -0.07608298, -0.01250216, ..., -0.06386201,
          0.03524765, -0.06769297],
        [-0.05166741,  0.06980281,  0.07097396, ...,  0.02213468,
         -0.01319141,  0.01672918]], dtype=float32)>,
 <tf.Variable 'fow_back_model/fow_back/fow_dense/Variable:0' shape=(128,) dtype=float32, numpy=
 array([-1.3663010e-04,  2.2185789e-03,  1.8914131e-03,  1.4470483e-03,
        -3.0852724e-03, -9.1923086e-04,  1.3057167e-03, -1.2181591e-03,
         1.779886

In [None]:
model.compile(optimizer=SGD(1e-1), loss=CategoricalCrossentropy(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fae44d4af10>