In [2]:
import tensorflow as tf
tf.enable_eager_execution()

import functools
import numpy as np

In [89]:
n_hidden = 4
width = 32
n_inputs = 28

x = tf.random_normal([1, n_inputs])

In [90]:
class Expert():
    def __init__(self, n_outputs):
        self.fn = tf.keras.Sequential([
            tf.keras.layers.Dense(width, activation=tf.nn.relu),
            tf.keras.layers.Dense(n_outputs)
        ])
        self.variables = self.fn.variables
    
    def __call__(self, x, z):
        # will probably work better if the experts are residual fns!? <<< why!
        # ^^ equivalently we could make a resnet and then randomly shuffle?? 
        # or is noise in a resnet enough!? 
        
        # key is that identity is easy to learn
        inputs = tf.concat([x, z], axis=-1)
        return self.fn(inputs) + x

In [91]:
class ICM():
    def __init__(self, n_inputs, n_experts):
        self.n_inputs = n_inputs
        self.experts = [Expert(n_inputs) for _ in range(n_experts)]
        self.variables = [var for expert in self.experts for var in expert.variables]

    def __call__(self, z):
        experts = self.condition_experts(z)
        np.random.shuffle(experts)
        x_init = tf.zeros([tf.shape(z)[0], self.n_inputs])
        return functools.reduce(lambda x, f: f(x), [x_init] + experts)
    
    def condition_experts(self, z):
      # maybe pick a subset as well!?
      return [lambda x: expert(x,z[:, i:i+1]) for i, expert in enumerate(self.experts)]

In [92]:
encoder = tf.keras.Sequential([
    tf.keras.layers.Dense(width, activation=tf.nn.relu),
    tf.keras.layers.Dense(width, activation=tf.nn.relu),
    tf.keras.layers.Dense(n_hidden)
])
decoder = ICM(n_inputs, n_hidden)


In [93]:
with tf.GradientTape() as tape:
    y = decoder(encoder(x))
    loss = tf.losses.mean_squared_error(x,  y)
grads = tape.gradient(loss, encoder.variables)

In [94]:
grads

[<tf.Tensor: id=3456, shape=(28, 32), dtype=float32, numpy=
 array([[ 1.35224164e-02,  0.00000000e+00,  3.33243877e-01,
         -2.00659201e-01, -3.71341348e-01, -1.95218727e-01,
         -2.54331172e-01,  0.00000000e+00,  2.30213404e-01,
          6.35093898e-02,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  2.12030128e-01,  2.01497480e-01,
          0.00000000e+00,  0.00000000e+00, -1.66284353e-01,
         -4.11939621e-01,  0.00000000e+00,  1.59954990e-03,
          2.12467819e-01, -2.75752217e-01,  0.00000000e+00,
          0.00000000e+00, -8.06286260e-02, -2.73477316e-01,
          0.00000000e+00, -4.06928927e-01,  0.00000000e+00,
          0.00000000e+00, -5.38353801e-01],
        [ 5.67620620e-04,  0.00000000e+00,  1.39883356e-02,
         -8.42292514e-03, -1.55875254e-02, -8.19455460e-03,
         -1.06758745e-02,  0.00000000e+00,  9.66350082e-03,
          2.66588735e-03,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  8.90023448e-03,  8.45811330e-