<a href="https://colab.research.google.com/github/5aurabhpathak/neural-net-training-algorithms/blob/main/direct_feedback_alignment_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -U tensorboard_plugin_profile

In [None]:
import numpy as np, os, tensorflow as tf, tensorflow.keras as keras, pandas as pd
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Dense, Input, ReLU, Softmax
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras import Model
from collections import namedtuple

In [None]:
mnist_train = pd.read_csv('sample_data/mnist_train_small.csv')

In [None]:
X, Y = mnist_train.iloc[:,1:].to_numpy(), mnist_train.iloc[:,:1].to_numpy()

In [None]:
x = X / 255.
y = keras.utils.to_categorical(Y)

In [None]:
class DFAModel(Model):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    for i, layer in enumerate(self.layers[1:-1]):
      fanout = self.layers[i+1].output_shape[1]
      layer.feedback = tf.random.uniform(minval=-1./fanout**.5,
                                         maxval=1./fanout**.5,
                                         shape=(self.output.shape.as_list()[1],
                                                layer.output_shape[1]))
    self._name = 'model_dfa'

  def train_step(self, data):
    x, y = data

    outs = []
    with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
      tape.watch(self.trainable_variables)
      for layer in self.layers[1:-1]:
        x = layer(tf.stop_gradient(x), training=True)
        yl = tf.matmul(y, layer.feedback)
        outs.append(x)
      yp = self.layers[-1](x)
      tape.watch(yp)
      loss = self.compiled_loss(y, yp, regularization_losses=self.losses)

    dl_dy = tf.reduce_sum(tape.gradient(loss, yp), axis=0, keepdims=True)
    # print('dl_dy:', dl_dy.shape)
 
    gradients = []
    for i, layer in enumerate(self.layers[1:-1]):
      with tf.name_scope(f'{layer.name}/local_grads'):
        da_dw, da_db = tape.gradient(outs[i], [layer.kernel, layer.bias])
        # print(layer.name, ':\nda_dw:', da_dw.shape, 'da_db:', da_db.shape)

      with tf.name_scope(f'{layer.name}/global_feedback'):
        dl_da = tf.reduce_sum(tape.gradient(loss, outs[i]), axis=0, keepdims=True)
        dl_da = tf.matmul(dl_dy, layer.feedback)
        # print('dl_da:', dl_da.shape)

      with tf.name_scope(f'{layer.name}/updates'):
        dl_dw = tf.multiply(dl_da, da_dw)
        dl_db = tf.squeeze(tf.multiply(dl_da, da_db))
        # print('dl_dw:', dl_dw.shape, 'dl_db:', dl_db.shape)
      gradients.extend([dl_dw, dl_db])

      # dl_dw1, dl_db1 = tape.gradient(loss, [layer.kernel, layer.bias])
      # tf.debugging.assert_near(dl_dw, dl_dw1)
      # tf.debugging.assert_near(dl_db, dl_db1)

    dl_dw, dl_db = tape.gradient(loss, [self.layers[-1].kernel,
                                        self.layers[-1].bias])
    # print(self.layers[-1].name, ':\ndl_dw:', dl_dw.shape, 'dl_db:', dl_db.shape)
    gradients.extend([dl_dw, dl_db])

    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    self.compiled_metrics.update_state(y, yp)

    return {m.name: m.result() for m in self.metrics}

In [None]:
LayerConfig = namedtuple('LayerConfig', 'cls args kwargs')

def get_twin_models(layers):

  def get_model(tag):
    layer_name = layers[0].kwargs['name']
    kwargs = {x:y for x, y in layers[0].kwargs.items() if x != 'name'}
    inp = layers[0].cls(*layers[0].args, name=f'{layer_name}_{tag}', **kwargs)
    x = inp
    for layer in layers[1:]:
      layer_name = layer.kwargs['name']
      if tag == 'bp':
        kwargs = {x:y for x, y in layer.kwargs.items()
                  if x not in {'name', 'kernel_initializer'}}
      else:
        kwargs = {x:y for x, y in layer.kwargs.items() if x != 'name'}
      x = layer.cls(*layer.args, name=f'{layer_name}_{tag}', **kwargs)(x)
    
    if tag == 'dfa':
      model = DFAModel(inputs=inp, outputs=x)
    else:
      model = Model(inputs=inp, outputs=x, name='model_bp')
    model.summary()
    weights = model.get_weights()
    return model, weights

  return get_model


layers = [
          LayerConfig(Input, (784,), dict(name='Input1')),
          LayerConfig(Dense, (800,), dict(activation='tanh',
                                          kernel_initializer='zeros',
                                          name='Dense1')),
          LayerConfig(Dense, (800,), dict(activation='tanh',
                                          kernel_initializer='zeros',
                                          name='Dense2')),
          LayerConfig(Dense, (10,), dict(kernel_initializer='zeros',
                                         name='Dense3'))
        ]

get_model = get_twin_models(layers)
model_dfa, weights_dfa = get_model('dfa')
model_bp, weights_bp = get_model('bp')

In [None]:
# log dir
save_path = os.path.join('logs')
os.makedirs(save_path, exist_ok=True)
! rm -rf $save_path/*

In [None]:
! sudo pkill -9 tensorboard

In [None]:
%reload_ext tensorboard
%tensorboard --logdir $save_path

In [None]:
! cp -r logs /content/drive/MyDrive/experiments/logs1

In [None]:
def train_twins(x, y, model_bp, model_dfa, weights_bp, weights_dfa):
  def train(tag):
    log_dir = os.path.join(save_path, tag)
    # filepath = os.path.join(log_dir, 'model.{epoch:02d}-{val_loss:.4f}.h5')
    os.makedirs(log_dir, exist_ok=True)
    callbacks = [
        # EarlyStopping(patience=10, verbose=1),
        # ReduceLROnPlateau(factor=0.5, patience=5, min_lr=0.00001, verbose=1),
        # ModelCheckpoint(filepath=filepath, verbose=1,
        #                 save_best_only=True, save_weights_only=True),
        TensorBoard(log_dir=log_dir,
                    histogram_freq=1,
                    write_graph=True,
                    write_images=False,
                    write_steps_per_second=True,
                    update_freq='batch',
                    profile_batch=0,
                    embeddings_freq=0,
                    embeddings_metadata=None)]

    if tag == 'dfa':
      model, weights = model_dfa, weights_dfa
    else:
      model, weights = model_bp, weights_bp

    model.set_weights(weights)
    model.compile(
        # run_eagerly=True,
        optimizer=RMSprop(learning_rate=.002), loss=CategoricalCrossentropy(from_logits=True))
    return model.fit(x, y,
              epochs=50,
              batch_size=32,
              validation_split=.1,
              callbacks=callbacks)
  
  return train

train_fn = train_twins(x, y, model_bp, model_dfa, weights_bp, weights_dfa)

In [None]:
! rm -rf $save_path/dfa/*
res_dfa = train_fn('dfa')

In [None]:
! rm -rf $save_path/bp/*
res_bp = train_fn('bp')

In [None]:
def plot_metric(res, metric, mark_epoch, *, logy=False):
  val_metric = f'val_{metric}'
  plt.title(metric)
  func = plt.plot if not logy else plt.semilogy
  ylabel = metric if not logy else f'log_{metric}'
  func(res.history[metric], label=metric)
  func(res.history[val_metric], label=val_metric)
  plt.scatter(mark_epoch,
              res.history[val_metric][mark_epoch],
              marker='^', color='r', s=50, label='best model')
  plt.xlabel('epochs')
  plt.ylabel(ylabel)

best_epoch = np.argmin(res_dfa.history['val_loss'])
n_metrics = len(model_dfa.metrics_names)
nr, nc = n_metrics//3 + (n_metrics%3 != 0), 3
fig, axes = plt.subplots(nr, nc, sharex=True, figsize=(15, 8))
for i, metric in enumerate(model_dfa.metrics_names):
  plt.subplot(nr, nc, i+1)
  plot_metric(res_dfa, metric, best_epoch, logy=True)

rem = (nr * 3) - n_metrics
[fig.delaxes(ax) for ax in axes.ravel()[-rem:]]
plt.tight_layout()
plt.legend()

In [None]:
best_weights = f'model.{best_epoch+1:02d}-{res_dfa.history["val_loss"][best_epoch]:.4f}.h5'
model_dfa.load_weights(os.path.join(save_path, 'dfa', best_weights))

In [None]:
# yp_bp = model_bp.predict(x)
yp_dfa = model_dfa.predict(x)

In [None]:
# plt.scatter(x,yp_bp)
plt.scatter(x,yp_dfa)
plt.scatter(x,y, alpha=.01)

In [None]:
x = np.random.uniform(size=(5000,1)).astype('float32')
y = np.sin(2.*np.pi*x) + np.random.uniform(-.1, .1, size=x.shape)
plt.scatter(x,y)
x.shape, y.shape

In [None]:
def fit(model, x, y, epochs=200, lr=.001, mode='bp'):
  xs, weights, grads, ys, cost = [], [], [], [], []

  if mode == 'dfa':
      for i, layer in enumerate(model.layers[:-1]):
        if 'Input' in layer.name:
          continue
        layer.feedback = tf.random.uniform(minval=-1., maxval=1., shape=(layer.output_shape[1], model.output[-1].shape.as_list()[1]))
      model.layers[-1].feedback = tf.eye(model.output[-1].shape.as_list()[1])

  for j in range(epochs):
    with tf.GradientTape(persistent=True) as tape:
      all_out = model(x)
      yp = all_out[-1]
      # print(all_out[0], all_out[1])
      loss = .5 * tf.reduce_mean((yp - y) ** 2.)

    if mode == 'dfa':
      dl_dy = tape.gradient(loss, yp)
      # print('global error:', dl_dy)
 
    # print(loss)
    for i, layer in enumerate(model.layers):
      if 'Input' in layer.name:
        continue
      # print(layer.weights)
      if mode == 'bp':
        dl_da = tape.gradient(loss, all_out[i-1])
        dl_dw1, dl_db1 = tape.gradient(loss, [layer.kernel, layer.bias])
        da_dw, da_db = tape.jacobian(all_out[i-1], [layer.kernel, layer.bias])
        # print('layer_error:', dl_da, da_dw, da_db)
        dl_dw = tf.reduce_sum(tf.concat([x * y for x, y in zip(tf.unstack(dl_da, axis=0), tf.unstack(tf.reduce_sum(da_dw, axis=1), axis=0))], axis=0), axis=0, keepdims=True)
        dl_db = tf.reduce_sum(tf.concat([x * y for x, y in zip(tf.unstack(dl_da, axis=0), tf.unstack(tf.reduce_sum(da_db, axis=1), axis=0))], axis=0), axis=0, keepdims=True)
        dl_dw1, dl_db1 = tape.gradient(loss, [layer.kernel, layer.bias])
        tf.debugging.assert_near(dl_dw, dl_dw1)
        tf.debugging.assert_near(dl_db, dl_db1)
      else:
        layer_out = all_out[i-1]
        da_dw, da_db = tape.jacobian(layer_out, [layer.kernel, layer.bias])
        # print('feedback', layer.feedback)
        dl_da = dl_dy * layer.feedback
        # print('layer_error:', dl_da)
        dl_dw = tf.reduce_sum(tf.concat([x * y for x, y in zip(tf.unstack(dl_da, axis=0), tf.unstack(tf.reduce_sum(da_dw, axis=1), axis=0))], axis=0), axis=0, keepdims=True)
        dl_db = tf.reduce_sum(tf.concat([x * y for x, y in zip(tf.unstack(dl_da, axis=0), tf.unstack(tf.reduce_sum(da_db, axis=1), axis=0))], axis=0), axis=0, keepdims=True)
        # print(dl_dw, dl_db)
        # print(layer_out.shape, da_dw.shape, da_db.shape, dl_da.shape, dl_dw.shape, dl_db.shape)
 
      layer.kernel.assign(tf.subtract(layer.kernel, tf.multiply(tf.constant(lr), dl_dw)))
      layer.bias.assign(tf.subtract(layer.bias, tf.multiply(tf.constant(lr), dl_db)))
    ys.append(yp)
    cost.append(loss)
    print('\r', j+1, end='')
  return ys, cost

In [None]:
model.set_weights(weights)
ys, cost = fit(model, x, y, 500, 1., 'dfa')

In [None]:
numpy.abs(model.predict(x)[-1] - y).mean()

In [None]:
ys = numpy.asarray([yy.numpy() for yy in ys])

In [None]:
cost = numpy.asarray([yy.numpy() for yy in cost])

In [None]:
# plt.subplot(221)
# plt.plot(grads)
 
# plt.subplot(222)
# plt.plot(weights)
 
plt.subplot(223)
plt.plot(cost)
 
plt.subplot(224)
# plt.plot(ys)
plt.plot(ys[:,-1])
# plt.scatter(xs[100:], ys[100:,0,:])
# plt.scatter(xs[100:], ys[100:,1,:])
ys[-5:,-1], y[-1]#, weights[-5:], grads[-5:]

In [None]:
xx = numpy.argsort(x.ravel())
plt.plot(x[xx], model.predict(x)[-1][xx])
plt.plot(x[xx], y[xx])