[View in Colaboratory](https://colab.research.google.com/github/act65/dynamics-waves-myelin/blob/master/equilibrium_propagation.ipynb)

In [0]:
import tensorflow as tf
import numpy as np

In [0]:
def energy_fn(x, W, b):
  """
  Somehow related to Hopfield nets?
  Define what we mean by 'energy'
  
  - neighbor energy
  i get the part that measures something like the label propagation.
  distance between two strongly connected nodes should be small
  
  What alternatives are there? Want to explore these!
  
  """
  h = tf.nn.relu(x)
  
  node_energy = 0.5*tf.reduce_sum(x**2, axis=1) 
  neighbor_energy = 0.5*tf.reduce_sum(tf.expand_dims(W, 0) * (tf.expand_dims(h, 1) * tf.expand_dims(h, 2)), axis=[1,2])
  energy = tf.reduce_sum(h*b, axis=1)
  
#   elastic_energy = 0.5**tf.reduce_sum(tf.expand_dims(W, 0) * (tf.expand_dims(h, 1) - tf.expand_dims(h, 2))**2, axis=[1,2])
  return tf.reduce_mean(
      node_energy + 
      neighbor_energy + 
      energy)

In [0]:
def get_sym_adj(n_nodes):
  """
  Why does the adjacency matrix need to be symmetric?
  Else we can prove that the back prop is equivalent?
  """
  mat = tf.random_normal(shape=[n_nodes, n_nodes], dtype=tf.float32)
  sym = (mat + tf.transpose(mat))/2
  adj = sym - tf.eye(n_nodes)*sym
  return adj

In [0]:
class Network():
  """
  https://github.com/bscellier/Towards-a-Biologically-Plausible-Backprop
  Rather than having two phases, want the nodes to have some temporal state.
  If some input values were recently 'clamped' then they should 
  correlate with output values that are 'clamped' not long afterward.

  So there exists a delay between the clamping of the inputs and the outputs.
  What happens if;
  - delay is large
  - delay is variable
  - ?

  Might not even need to do anything smart? Bc SGD will want to find the shortest path from
  old state (clamped at inputs), to new state (clamped at labels).
  Optimise the parameters the minimize the distance travelled by the state!?
  
  What about spiking nets!?
  
  Pros:
  - Can easily add more nodes or new inputs
  
  Cons:
  - must simulate for n steps rather than 1 shot prediction
  - 
  """
  def __init__(self, n_inputs, n_hidden, n_outputs, name=''):
    self.n_nodes = n_inputs + n_hidden + n_outputs
    self.beta = 10.0
    
    self.input_idx = tf.range(n_inputs)
    self.output_idx = tf.range(n_inputs+n_hidden, self.n_nodes)
        
    with tf.variable_scope('network'):
      self.weights = tf.Variable(get_sym_adj(self.n_nodes), name='weights')
      self.biases = tf.Variable(tf.random_normal(shape=[1, self.n_nodes], dtype=tf.float32), name='biases')
    self.variables = [self.weights, self.biases]
      
    self.opt = tf.train.AdamOptimizer(0.0001)
    
  def energy_loss(self, state):
    """
    WANT a single loss to optimise!!
    """
    with tf.name_scope('energy_loss'):
      return tf.reduce_mean(energy_fn(state, self.weights, self.biases))
  
  def forcing_loss(self, state, vals, idx):
    """
    How can I get grads w.r.t the parameters!?
    dLdparam = mse(state, target)
    """
    with tf.name_scope('forcing_loss'):
      return self.beta*tf.losses.mean_squared_error(tf.gather(state, idx, axis=1), vals)
  
  def step(self, state):
    with tf.name_scope('step'):
      # Always trying to find a state with lower enegy
      loss = self.energy_loss(state)
      
      if vals is not None and idx is not None:
        loss += self.forcing_loss(state, vals, idx)
        
      grad = tf.gradients(loss, state)[0]
      
      # TODO want smarter optimisation here. AMSGrad!?
      return state - 0.1*grad
  
  def forward(self, state, vals, idx, n_steps=10):
    """
    Use while loop to take advantage of smart compilation!?
    but the problem is we now have a finite window of data we can view.
    """
    def step(i, state):
      # a wrapper for self.step(...)
      return i + 1, self.step(state, vals, idx)
    
    with tf.name_scope('forward'):
      while_condition = lambda i, m : tf.less(i, n_steps)   # TODO change to state - old_state!? or low loss
      i = tf.constant(0)
      i_, new_state = tf.while_loop(while_condition, step, loop_vars=[i, state])
    
      return new_state
  
  def train_step(self, inputs, targets, init_state=None):
    if init_state is None:
      init_state = tf.zeros([tf.shape(inputs)[0], self.n_nodes])
    
    # clamp inputs
    state_f = self.forward(init_state, inputs, self.input_idx)
    self.pred = tf.gather(state_0, self.output_idx, axis=1)
    
    # clamp outputs
    state_b = self.forward(targets, self.output_idx)
    
    # minimise the distance to be travelled/the changes to be made. lazy.
    self.loss = tf.losses.mean_squared_error(state_f, state_b)
    return self.opt.minimize(self.loss, var_list=self.variables, global_step=tf.train.get_or_create_global_step())

In [0]:
def model_fn(features, labels, mode, params, config):
    x = features['x']
    net = Network(28*28, 64, 10)
    train_op = net.train_step(x, labels)

    return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=net.loss,
      train_op=train_op,
      eval_metric_ops={"accuracy": tf.metrics.accuracy(labels, tf.argmax(net.pred, axis=1))}
    )

In [8]:
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images  # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images  # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

batch_size=50

train_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": train_data},
      y=train_labels,
      batch_size=batch_size,
      num_epochs=1,
      shuffle=True)

eval_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": eval_data},
    y=eval_labels,
    batch_size=batch_size,
    num_epochs=1,
    shuffle=False)


estimator = tf.estimator.Estimator(
  model_fn,
  params=dict(),
  config=tf.estimator.RunConfig(
      model_dir='log/1',
      save_checkpoints_steps=100,
  ),
)

Instructions for updating:
Please use tf.data.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST-data/t10k-labels-idx1-u

In [0]:
for _ in range(10):
    estimator.train(train_input_fn, steps=100)
    eval_results = estimator.evaluate(eval_input_fn)
    print("Evaluation_results:\n\t%s\n" % eval_results)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into log/0/model.ckpt.
INFO:tensorflow:loss = 2.3025854, step = 1


TODO. vector map of a 2d example!?

In [8]:
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip

--2018-09-04 07:47:57--  https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
Resolving bin.equinox.io (bin.equinox.io)... 52.207.5.158, 52.5.182.176, 52.207.39.76, ...
Connecting to bin.equinox.io (bin.equinox.io)|52.207.5.158|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5363700 (5.1M) [application/octet-stream]
Saving to: ‘ngrok-stable-linux-amd64.zip’


2018-09-04 07:47:57 (20.7 MB/s) - ‘ngrok-stable-linux-amd64.zip’ saved [5363700/5363700]

Archive:  ngrok-stable-linux-amd64.zip
  inflating: ngrok                   


In [0]:
LOG_DIR = 'EP'
get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)

In [0]:
get_ipython().system_raw('./ngrok http 6006 &')

In [11]:
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

http://5b92938b.ngrok.io
