# Differentiable Indirection

In computer programming, [indirection](https://en.wikipedia.org/wiki/Indirection) (also called dereferencing) is the ability to reference something using a name, reference, or container instead of the value itself.

Continious and differentiable implementation of indexing can be divided into two types. In the first type, the data decides how it wants to be addressed. An example of this would be [Transformers](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf), where the query and key vectors are generated from the data itself. It can also choose to take no extra information from the data itself, e.g. [Neural Turing Machines](https://arxiv.org/pdf/1410.5401.pdf)

## Linked List

In this example we implement a circular [linked list](https://en.wikipedia.org/wiki/Linked_list) with differentiable indirection.

![circular linked list](./images/525px-Circularly-linked-list.svg.png)

Linked lists are structures which have data and pointer to the next memory location for the next data and pointer is stored.

We can think of the RAM as an indexable array of data. Then each pointer address is essentially an integer index for this array.

The structs are implicitly defined using two arrays. The data array and the index to the next pointer array.

$$ data = [12, 37, 99] $$
$$ pointers_{next} = [2, 0, 1] $$

The $pointers_{next}$ can be used to traverse the linked list to get the required order of elements.

$$ pointers_{next}[0] = 2 $$
$$ pointers_{next}[pointers_{next}[0]] = 1 $$
$$ pointers_{next}[pointers_{next}[pointers_{next}[0]]] = 0 $$

In an non-differentiable setting, we can traverse the linked list like so

```python
ptr = 0
for _ in range(3):
    print(data[ptr])
    ptr = pointers_next[ptr]
```

## Superpositioned Indexes

As discussed in [differentiable-indexed-arrays.ipynb](differentiable-indexed-arrays.ipynb), there are many ways to implement a continuous indexing of arrays. In this example we would use the superposition lookup. 

Our pointers are one hot vectors corresponding to the index.

$$ element_i = data \cdot p $$
$$ element_i = 
\begin{bmatrix}
12 & 37 & 99
\end{bmatrix} 
\begin{bmatrix}
1\\
0\\
0
\end{bmatrix} 
 = 12 
$$

This allows for a continuous lookup of elements

$$ element_i = 
\begin{bmatrix}
12 & 37 & 99
\end{bmatrix} 
\begin{bmatrix}
0.5\\
0\\
0.5
\end{bmatrix} 
 = 55.5 
$$

Similarly, the index array which looked like

$$
pointers_{next} = \begin{bmatrix}
2 & 0 & 1
\end{bmatrix}
$$

After converting to one hot vectors, looks like

$$
pointers_{next} = P = \begin{bmatrix}
0 & 0 & 1\\
1 & 0 & 0\\
0 & 1 & 0
\end{bmatrix}
$$

Thus we can see

$$
p_{i+1} = P p_i = \begin{bmatrix}
0 & 0 & 1\\
1 & 0 & 0\\
0 & 1 & 0
\end{bmatrix}
\begin{bmatrix}
1\\
0\\
0
\end{bmatrix} 
=
\begin{bmatrix}
0\\
0\\
1
\end{bmatrix} 
$$

In [1]:
import tensorflow as tf

In [2]:
@tf.function
def iterate_over(data, nexts):
    data_len = tf.shape(data)[0]
    P = nexts
    p = tf.one_hot([0], data_len)
    
    x = tf.expand_dims(data, -1)
    y_ = tf.zeros((data_len))
    eye = tf.eye(data_len)
    
    for i in tf.range(data_len):
        # The @ token denotes matrix multiplication
        x_scalar = tf.squeeze(p @ x)
        y_ += eye[i] * x_scalar
        
        p = p @ P

    return y_

data  = tf.Variable([1, 3, 2], dtype=tf.float32)
target = tf.Variable([1, 2, 3], dtype=tf.float32)
data_len = tf.shape(data)[0]
nexts = tf.Variable(tf.one_hot([2, 0, 1], data_len), dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result = iterate_over(data, nexts)
    loss = tf.nn.l2_loss(result - target)
    
print(result)
print(tape.gradient(result, data))
print(tape.gradient(result, nexts))

print(loss)
print(tape.gradient(loss, nexts))

tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)
tf.Tensor(
[[3. 4. 5.]
 [0. 0. 0.]
 [1. 3. 2.]], shape=(3, 3), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]], shape=(3, 3), dtype=float32)


## Toy problem: Find a pointer array which when traversed, leads to a particular permutation

The goal of this toy problem is not to find a permutation matrix $P$ such that $y = Px$ but instead find a matrix such that $y_i = P^ipx$ where $i$ is the $i^{th}$ element in the linked list and $p=onehot(0)$

For circular linked lists, there exists some cycle length $n$ where $P^n = I$

In [3]:
Q = tf.one_hot([2, 4, 1, 0, 3], 5)
Q @ Q @ Q @ Q @ Q

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

## Loss function

We want the predicted $\bar{y}$ to match the real $y$ after one complete traversal of the linked list. So, we add an L2 loss $ | \bar{y} - y | $. However, this causes the network to learn a $P$ matrix which is a linear combination of the $x$ instead of learning a permutation matrix. Thus we need to add more losses to make sure that the $P$ matrix is a permutation matrix.

For permutation matrix loss, we have the following rules
* All columns must add up to 1
* All rows must add up to 1
* All elements must be close to either 0 or 1 (bistable loss)
* Cycle loss $P^n = I$ (if cyclic)

More information about `bistable_loss` can be found [here](notebooks/boolean-satisfiability.ipynb)

In [4]:
from library.loss import bistable_loss

In [5]:
@tf.function
def permute_matrix_loss(P, cycle_length=1, cycle_weight=0):
    loss = 0
    
    P_square = tf.math.square(P)
    axis_1_sum = tf.reduce_sum(P_square, axis=1)
    axis_0_sum = tf.reduce_sum(P_square, axis=0)
    
    # Penalize axes not adding up to one
    loss += tf.nn.l2_loss(axis_1_sum - 1)
    loss += tf.nn.l2_loss(axis_0_sum - 1)
    
    # Penalize numbers outside [0, 1]
    loss += tf.math.reduce_sum(bistable_loss(P))
    
    # Cycle loss
    Q = P
    for _ in tf.range(cycle_length - 1):
        Q = P @ Q
    cycle_loss = tf.nn.l2_loss(Q - tf.eye(tf.shape(Q)[0]))
    loss += cycle_loss * cycle_weight
    
    return loss

test1 = tf.constant([
    [1,0,0],
    [0,1,0],
    [0,0,1]
],dtype=tf.float32)

test2 = tf.constant([
    [0,1,0],
    [1,0,0],
    [0,0,1]
],dtype=tf.float32)

test3 = tf.constant([
    [-1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
],dtype=tf.float32)

test4 = tf.constant([
    [2, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
],dtype=tf.float32)

test5 = tf.constant([
    [0.1, 0, 0],
    [0, 0.1, 0],
    [0, 0, 0.1],
],dtype=tf.float32)

test6 = tf.constant([
    [0.5, 0.5, 0],
    [0.5, 0.5, 0],
    [0, 0, 1],
],dtype=tf.float32)

test7 = tf.constant([
    [0, 1, 0],
    [1, 0, 0],
    [0, 0, 1],
],dtype=tf.float32)

print(permute_matrix_loss(test1))
print(permute_matrix_loss(test2))
print(permute_matrix_loss(test3))
print(permute_matrix_loss(test4))
print(permute_matrix_loss(test5))
print(permute_matrix_loss(test6))
print(permute_matrix_loss(test7, 1, 1))

tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
tf.Tensor(13.0, shape=(), dtype=float32)
tf.Tensor(2.9646, shape=(), dtype=float32)
tf.Tensor(0.75, shape=(), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)


## Training
For convergence, a [softmax](https://en.wikipedia.org/wiki/Softmax_function) operation on $P$ is critical before traversing and computing loss. `TODO: Why?`

Initializing $P$ with any invalid permutation matrix leads to a faster convergence than valid permutation matrix. `TODO: WHY?`

In order to make sure that our matrix $P$ is learning a permutation matrix and not linear combination of the input, we also print a defuzzified result `y_defuzz`. It is generated by taking the `argmax` of $P$ and iterating over it again.

In [6]:
opt = tf.keras.optimizers.Adam()

@tf.function
def train_step(data, nexts, target_data):
    data_length = tf.shape(data)[0]
    
    with tf.GradientTape() as tape:
        nextss = tf.nn.softmax(nexts, axis=1)
        actual_data = iterate_over(data, nextss)
        loss = tf.nn.l2_loss(actual_data - target_data)
        loss += permute_matrix_loss(nextss, data_length, 1)
    
    grads = tape.gradient(loss, nexts)
    opt.apply_gradients(zip([grads], [nexts]))
    
    return loss, actual_data

data  = tf.constant([1, 3, 2, 5, 4], dtype=tf.float32)
target_data  = tf.constant([1, 2, 3, 4, 5], dtype=tf.float32)
data_len = tf.shape(data)[0]
# nexts = tf.Variable(tf.one_hot([2, 4, 1, 0, 3], data_len), dtype=tf.float32)
nexts = tf.Variable(tf.one_hot([1, 1, 1, 1, 1], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([0,0,0,0,0], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([4, 4, 4, 4, 4], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([3, 3, 3, 3, 3], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([1, 2, 3, 4, 5], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.random.uniform((data_len, data_len), 0, 1))

tf.print('|   loss  |   y_pred  | y_defuzz |   P_pred  |   P_actual   |')
for i in range(10000):
    loss, actual_data = train_step(data, nexts, target_data)
    if i % 1000 == 0:
        argmax_next = tf.argmax(nexts, 1)
        defuzzified = tf.one_hot(argmax_next, data_len)
        defuzzified_data = iterate_over(data, defuzzified)
        tf.print(loss, tf.round(actual_data), defuzzified_data, argmax_next, [2, 4, 1, 0, 3])
        
tf.print(nexts)

|   loss  |   y_pred  | y_defuzz |   P_pred  |   P_actual   |
8.7374239 [1 3 3 3 3] [1 3 3 3 3] [1 1 1 1 1] [2, 4, 1, 0, 3]
5.66760302 [1 2 3 4 4] [1 2 3 5 5] [2 3 1 3 3] [2, 4, 1, 0, 3]
3.87311959 [1 2 3 4 4] [1 2 3 4 5] [2 4 1 3 3] [2, 4, 1, 0, 3]
1.50285113 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.205469772 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0740788057 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0350763313 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0185657572 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0103952968 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.00600781338 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
[[-2.18890929 -2.56413817 4.51624966 -2.87268877 -3.21818304]
 [-2.38402414 -2.36196756 -2.74159765 -1.29020119 4.7228055]
 [-2.63246131 5.06587076 -2.92895341 -2.58874464 -2.98058629]
 [3.2075038 -3.05285668 -3.33832383 -1.79198897 -2.805336]
 [-2.95142794 -2.45104289 -3.21477652 4.53860283 -2.0

### Verifying cyclic permutation

We can see that $P^n = I$ for both normal and defuzzified cases.

In [7]:
P = tf.nn.softmax(nexts, axis=1)
tf.print(tf.argmax(P,1))
tf.round(P @ P @ P @ P @ P)

[2 4 1 0 3]


<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

In [8]:
argmax_next = tf.argmax(P, 1)
DQ = tf.one_hot(argmax_next, data_len)
DQ @ DQ @ DQ @ DQ @ DQ

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>