In [1]:
import tensorflow as tf
tf.enable_eager_execution()

In [30]:
import sonnet as snt
import tensorflow_probability as tfp
tfd = tfp.distributions

In [115]:
class Transition():
    def __init__(self):
        self.energy_fn = tf.keras.Sequential([
            # could add some memory in here. lSTM or DNC
            tf.keras.layers.Dense(64, activation=tf.nn.selu),
            tf.keras.layers.Dense(1)
        ])
        
        self.value_fn = tf.keras.Sequential([
            # could add some memory in here. lSTM or DNC
            tf.keras.layers.Dense(64, activation=tf.nn.selu),
            tf.keras.layers.Dense(1)
        ])
        
        self.gamma = 0.99
        
        self.opt = tf.train.AdamOptimizer()
        self.step = tf.Variable(0, name='step')
        
        self.old_x = None
        
    def __call__(self, x, r):
        """
        Handles training and prediction. (how can these be unified!?)
        """
        # or use a worker to collect data and train offline
        x_hat_tp1 = self.forward(x)
        
        if self.old_x is not None:
            # recompute the x_hat_tp1. could do better.
            loss = self.train_step(self.old_x, x, r)
        
        self.old_x = x
        
        return x_hat_tp1
        
    def forward(self, x, step_size=0.1):
        with tf.GradientTape() as tape:
            tape.watch(x)
            e = self.energy_fn(x) 
            v = self.value_fn(x)
            
            cost = v - e
            
        grad = tape.gradient(cost, x)
        return x + step_size*grad[0]  # ascend value and descend energy
    
    def train_step(self, x_t, x_tp1, r_t):
        with tf.GradientTape() as tape:
            # observations should have low energy
            # not sure how to optimise for that!?

            # for now. optimise E, V for accuracy
            x_hat_tp1 = self.forward(x_t)
            loss_acc = tf.losses.mean_squared_error(x_tp1, x_hat_tp1)

            # value should predict future rewards
            v_t = self.value_fn(x_t)
            v_tp1 = self.value_fn(x_tp1)
            loss_value = tf.losses.mean_squared_error(v_t, r_t+self.gamma*v_tp1)  
            # could split out the value fn as another class. as will need for policy as well.
            
            loss = loss_value+loss_acc
            
        variables = self.energy_fn.variables + self.value_fn.variables

        grads = tape.gradient(loss, variables)
        self.opt.apply_gradients(zip(grads, variables), global_step=self.step)
        return loss

In [116]:
x = tf.random_normal([1, 6])
r = tf.random_normal([1,1])
t = Transition()
t(x, r)

<tf.Tensor: id=8200, shape=(1, 6), dtype=float32, numpy=
array([[ 0.6959551 , -0.38642088,  1.6699668 , -1.3171674 , -0.15613846,
         1.3045323 ]], dtype=float32)>

In [94]:
class Policy():
    """
    This policy has little to do with achieving 'extrinsic value'.
    Its main task is reachability. I want to go to X. This policy should make it happen. 
    """
    
    # does this need memory?
    # the ability to integrate the deltas?
    # the ability to remember the past?
    # will be a pain for training...
    def __init__(self, n_actions):
        self.fn = tf.keras.Sequential([  # VQ!?
            tf.keras.layers.Dense(64, activation=tf.nn.selu),
            tf.keras.layers.Dense(n_actions)
        ])
         
    def __call__(self, x):  # possibly recieves many deltas for different layers.
        z = self.fn(x)
        return tfd.RelaxedOneHotCategorical(1.0, logits=z).sample()
    
    def train_step(self, goal, truth):
        loss = tf.losses.mean_squared_error(goal, truth)
        reward = tf.stop_gradient(loss)  
        # will have zero grads anyway. 
        # unless we can use the transition fn somehow?
        
        # x, a, r = ?,?,reward
        # A2C - No we can do vanilla RL

In [95]:
class Encoder():
    def __init__(self, n_hidden):
        self.fn = tf.keras.Sequential([  # VQ!? + RNN
            tf.keras.layers.Dense(64, activation=tf.nn.selu),
            tf.keras.layers.Dense(n_hidden)
        ])
        
    def __call__(self, x): 
        return self.fn(x)
    
    def train_step(self):
        # trained for high entropy?
        # independence?
        # sparisity?
        # fully unsupervised!?
        pass

In [91]:
class Layer():
    # __call__ and choose action will be moved to a parent class. Network, which takes many layers.
    def __init__(self):
        self.encode = Encoder(64)
        self.transition = Transition()
        
        self.policy = Policy(4)
        
    def forward(self, x_t):
        s_t = self.encode(x_t)
        return self.transition(s_t)  # s_hat_tp1 should be stored as the state?!?
        
    def choose_action(self, s_hat_tp1, x_tp1):
        s_tp1 = self.encode(x_tp1)
        diff = s_tp1 - s_hat_tp1
        return self.policy(diff)
    
    def __call__(self, x_t, x_tp1, x_tp2, r):
        s_t = self.encode(x_t)
        s_tp1 = self.encode(x_tp1)
        s_tp2 = self.encode(x_tp2)
        
        a = self.choose_action(self.forward(x_t), x_tp1)
        
        # the policy is rewarded/trained on its ability to make predictions come true
        self.policy.train_step(goal=self.forward(x_tp1), truth=s_tp2)
        
        # could use call backs to do training!?
        # return lambda r: self.train_step(s, a)
        # therefore s, a is still differentiable and we can keep them until we get the reward!?
        
        return a

In [92]:
l = Layer()

In [93]:
l.__call__(l.forward(x),x)

<tf.Tensor: id=5782, shape=(1, 4), dtype=float32, numpy=array([[0.02318801, 0.87725997, 0.08311243, 0.01643958]], dtype=float32)>