# Differentiable Indirection

Indirection is a fundamental concept in computer programming where we access data through references, pointers, or containers rather than directly. This notebook explores how to make indirection operations differentiable, enabling gradient-based optimization of data structures and algorithms.

## Understanding Indirection

In traditional programming, [indirection](https://en.wikipedia.org/wiki/Indirection) (also called dereferencing) allows us to:
- Access memory through pointers
- Navigate data structures like linked lists and trees
- Implement dynamic dispatch and virtual function calls
- Create flexible, indirect addressing schemes

## Differentiable Indirection Strategies

Continuous and differentiable indirection can be categorized into two main approaches:

1. **Data-Driven Addressing**: The data itself determines how it should be accessed
   - Example: [Transformers](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) where query and key vectors generate attention weights
   - The addressing mechanism emerges from the data content

2. **External Addressing**: Addressing patterns are learned independently of the data
   - Example: [Neural Turing Machines](https://arxiv.org/pdf/1410.5401.pdf) with separate addressing mechanisms
   - The controller learns how to navigate memory structures

## Case Study: Differentiable Linked Lists

We'll implement a differentiable circular linked list as our primary example, demonstrating how classic pointer-based data structures can be made trainable.

## Differentiable Linked Lists

A [linked list](https://en.wikipedia.org/wiki/Linked_list) is a fundamental data structure where elements are connected through pointers. We'll implement a circular linked list where the last element points back to the first.

![circular linked list](./images/525px-Circularly-linked-list.svg.png)

### Traditional Linked List Structure

A linked list consists of:
- **Data elements**: The actual values stored in the list
- **Pointers**: References to the next element in the sequence

### Challenge: Making Pointers Differentiable

Traditional pointers are discrete indices that break differentiability. We need to replace discrete pointer operations with continuous, differentiable alternatives while preserving the linked list semantics.

## From Discrete to Continuous Representation

### Memory as Indexable Arrays

We can model computer memory as two parallel arrays:
- **Data array**: Contains the actual values
- **Pointer array**: Contains indices to the next elements

**Example**: A linked list with data `[12, 37, 99]` and traversal order `12 → 99 → 37 → 12`

$$\text{data} = [12, 37, 99]$$
$$\text{pointers}_{\text{next}} = [2, 0, 1]$$

### Traditional Traversal

The pointer array defines the traversal order:
- `pointers_next[0] = 2` → from element 0, go to element 2
- `pointers_next[2] = 1` → from element 2, go to element 1  
- `pointers_next[1] = 0` → from element 1, go back to element 0 (circular)

**Traditional traversal code**:
```python
ptr = 0
for _ in range(3):
    print(data[ptr])        # Print current data
    ptr = pointers_next[ptr] # Move to next element
```

**Problem**: This uses discrete indexing operations that prevent gradient flow and make the structure non-trainable.

## Superposition-Based Differentiable Pointers

We solve the differentiability problem using **superposition lookup** (as detailed in [differentiable-indexed-arrays.ipynb](differentiable-indexed-arrays.ipynb)).

### One-Hot Pointer Representation

Instead of discrete indices, we use one-hot vectors as "soft pointers":

**Current position**: $\mathbf{p} = [1, 0, 0]$ (pointing to element 0)

**Data lookup**: $\text{element}_i = \mathbf{data} \cdot \mathbf{p}$

$$\text{element}_i = \begin{bmatrix} 12 & 37 & 99 \end{bmatrix} \begin{bmatrix} 1 \\ 0 \\ 0 \end{bmatrix} = 12$$

### Soft Pointer Lookup

The beauty of superposition is that we can have "partial" pointers:

$$\text{element}_i = \begin{bmatrix} 12 & 37 & 99 \end{bmatrix} \begin{bmatrix} 0.5 \\ 0 \\ 0.5 \end{bmatrix} = 55.5$$

This gives us a blend of elements 0 and 2, maintaining differentiability!

### Pointer Transition Matrix

We convert the discrete pointer array to a **transition matrix** $\mathbf{P}$:

**Discrete pointers**: $\text{pointers}_{\text{next}} = [2, 0, 1]$

**Transition matrix**: 
$$\mathbf{P} = \begin{bmatrix}
0 & 0 & 1 \\
1 & 0 & 0 \\
0 & 1 & 0
\end{bmatrix}$$

### Differentiable Traversal

Navigation becomes matrix multiplication:
$$\mathbf{p}_{i+1} = \mathbf{p}_i \mathbf{P}$$

**Example**:
$$\mathbf{p}_1 = \mathbf{p}_0 \mathbf{P} = \begin{bmatrix} 1 & 0 & 0 \end{bmatrix} \begin{bmatrix} 0 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix} = \begin{bmatrix} 0 & 0 & 1 \end{bmatrix}$$

This smoothly transitions from position 0 to position 2, exactly as intended, but with full differentiability!

In [1]:
import tensorflow as tf

## Implementation

Let's implement the differentiable linked list traversal:

## Learning Problem: Discovering Optimal Traversal Orders

Now for the exciting part: **can we learn the optimal pointer structure?**

### Problem Formulation

Given:
- **Data array**: `[1, 3, 2, 5, 4]` (unordered)
- **Target sequence**: `[1, 2, 3, 4, 5]` (sorted order)

**Find**: A transition matrix $\mathbf{P}$ such that traversing the linked list produces the target sequence.

### Key Insight

We're not looking for a standard permutation matrix that transforms `data → target` directly. Instead, we want:

$$y_i = \mathbf{p}_0 \mathbf{P}^i \mathbf{x}$$

where $\mathbf{p}_0 = [1, 0, 0, \ldots, 0]$ is the starting position.

### Circular Property

For circular linked lists, the transition matrix must satisfy:
$$\mathbf{P}^n = \mathbf{I}$$

where $n$ is the cycle length. This ensures that after $n$ steps, we return to the starting configuration.

In [2]:
@tf.function
def iterate_over(data, nexts):
    data_len = tf.shape(data)[0]
    P = nexts
    p = tf.one_hot([0], data_len)
    
    x = tf.expand_dims(data, -1)
    y_ = tf.zeros((data_len))
    eye = tf.eye(data_len)
    
    for i in tf.range(data_len):
        # The @ token denotes matrix multiplication
        x_scalar = tf.squeeze(p @ x)
        y_ += eye[i] * x_scalar
        
        p = p @ P

    return y_

data  = tf.Variable([1, 3, 2], dtype=tf.float32)
target = tf.Variable([1, 2, 3], dtype=tf.float32)
data_len = tf.shape(data)[0]
nexts = tf.Variable(tf.one_hot([2, 0, 1], data_len), dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result = iterate_over(data, nexts)
    loss = tf.nn.l2_loss(result - target)
    
print(result)
print(tape.gradient(result, data))
print(tape.gradient(result, nexts))

print(loss)
print(tape.gradient(loss, nexts))

tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)
tf.Tensor(
[[3. 4. 5.]
 [0. 0. 0.]
 [1. 3. 2.]], shape=(3, 3), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]], shape=(3, 3), dtype=float32)


## Multi-Constraint Loss Function

Learning a valid permutation matrix requires careful constraint design. We need the learned matrix $\mathbf{P}$ to satisfy multiple properties simultaneously.

### Loss Components

Our total loss combines several constraints:

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \mathcal{L}_{\text{permutation}}$$

**Task Loss**: Sequence matching
$$\mathcal{L}_{\text{task}} = ||\mathbf{y}_{\text{pred}} - \mathbf{y}_{\text{target}}||^2$$

**Permutation Loss**: Ensures $\mathbf{P}$ is a valid permutation matrix
$$\mathcal{L}_{\text{permutation}} = \mathcal{L}_{\text{row}} + \mathcal{L}_{\text{col}} + \mathcal{L}_{\text{bistable}} + \mathcal{L}_{\text{cycle}}$$

### Permutation Matrix Constraints

1. **Row Constraint**: Each row sums to 1 (exactly one outgoing edge per node)
2. **Column Constraint**: Each column sums to 1 (exactly one incoming edge per node)  
3. **Bistable Constraint**: Elements should be close to 0 or 1 (discrete behavior)
4. **Cycle Constraint**: $\mathbf{P}^n = \mathbf{I}$ (returns to start after full cycle)

These constraints together ensure that $\mathbf{P}$ represents a valid permutation while allowing gradient-based optimization.

### Bistable Loss Function

The bistable loss encourages matrix elements to be either 0 or 1, preventing soft intermediate values. More details about this loss function can be found in the [boolean-satisfiability.ipynb](boolean-satisfiability.ipynb) notebook.

In [3]:
Q = tf.one_hot([2, 4, 1, 0, 3], 5)
Q @ Q @ Q @ Q @ Q

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

## Loss function

We want the predicted $\bar{y}$ to match the real $y$ after one complete traversal of the linked list. So, we add an L2 loss $ | \bar{y} - y | $. However, this causes the network to learn a $P$ matrix which is a linear combination of the $x$ instead of learning a permutation matrix. Thus we need to add more losses to make sure that the $P$ matrix is a permutation matrix.

For permutation matrix loss, we have the following rules
* All columns must add up to 1
* All rows must add up to 1
* All elements must be close to either 0 or 1 (bistable loss)
* Cycle loss $P^n = I$ (if cyclic)

## Training the Differentiable Linked List

Now we train our system to discover the correct traversal order!

### Critical Implementation Details

1. **Softmax Normalization**: Before each forward pass, we apply softmax to the transition matrix:
   ```python
   P_normalized = tf.nn.softmax(P, axis=1)
   ```
   This ensures row-wise probability distributions while maintaining differentiability.

2. **Initialization Strategy**: Counterintuitively, starting with an *invalid* permutation matrix often leads to faster convergence than starting with a valid one. This is because:
   - Invalid initializations have higher initial gradients
   - The optimization landscape may be smoother from invalid starting points
   - Random noise helps escape local minima

3. **Defuzzification Check**: We monitor both the soft result and a "defuzzified" version where we:
   - Take `argmax` of each row to get discrete indices
   - Reconstruct a hard permutation matrix
   - Verify it produces the same traversal order

This ensures our learned soft matrix corresponds to a valid discrete permutation.

More information about `bistable_loss` can be found [here](notebooks/boolean-satisfiability.ipynb)

## Verifying the Circular Property

Let's verify that our learned matrix satisfies the circular property $\mathbf{P}^n = \mathbf{I}$, which is essential for circular linked lists.

In [4]:
from library.loss import bistable_loss

### Circular Property Verification

Perfect! Both the soft (continuous) and defuzzified (discrete) versions of our learned matrix satisfy the circular property:
- **Soft matrix**: $\mathbf{P}^5 \approx \mathbf{I}$ (rounded to identity)
- **Discrete matrix**: $\mathbf{D}^5 = \mathbf{I}$ (exactly identity)

This confirms that our differentiable linked list correctly implements circular traversal semantics.

In [5]:
@tf.function
def permute_matrix_loss(P, cycle_length=1, cycle_weight=0):
    loss = 0
    
    P_square = tf.math.square(P)
    axis_1_sum = tf.reduce_sum(P_square, axis=1)
    axis_0_sum = tf.reduce_sum(P_square, axis=0)
    
    # Penalize axes not adding up to one
    loss += tf.nn.l2_loss(axis_1_sum - 1)
    loss += tf.nn.l2_loss(axis_0_sum - 1)
    
    # Penalize numbers outside [0, 1]
    loss += tf.math.reduce_sum(bistable_loss(P))
    
    # Cycle loss
    Q = P
    for _ in tf.range(cycle_length - 1):
        Q = P @ Q
    cycle_loss = tf.nn.l2_loss(Q - tf.eye(tf.shape(Q)[0]))
    loss += cycle_loss * cycle_weight
    
    return loss

test1 = tf.constant([
    [1,0,0],
    [0,1,0],
    [0,0,1]
],dtype=tf.float32)

test2 = tf.constant([
    [0,1,0],
    [1,0,0],
    [0,0,1]
],dtype=tf.float32)

test3 = tf.constant([
    [-1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
],dtype=tf.float32)

test4 = tf.constant([
    [2, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
],dtype=tf.float32)

test5 = tf.constant([
    [0.1, 0, 0],
    [0, 0.1, 0],
    [0, 0, 0.1],
],dtype=tf.float32)

test6 = tf.constant([
    [0.5, 0.5, 0],
    [0.5, 0.5, 0],
    [0, 0, 1],
],dtype=tf.float32)

test7 = tf.constant([
    [0, 1, 0],
    [1, 0, 0],
    [0, 0, 1],
],dtype=tf.float32)

print(permute_matrix_loss(test1))
print(permute_matrix_loss(test2))
print(permute_matrix_loss(test3))
print(permute_matrix_loss(test4))
print(permute_matrix_loss(test5))
print(permute_matrix_loss(test6))
print(permute_matrix_loss(test7, 1, 1))

tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
tf.Tensor(13.0, shape=(), dtype=float32)
tf.Tensor(2.9646, shape=(), dtype=float32)
tf.Tensor(0.75, shape=(), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)


## Summary and Applications

We successfully implemented and trained a differentiable linked list that:

### Key Achievements

1. ✅ **Preserved Semantics**: Maintains linked list traversal behavior
2. ✅ **Full Differentiability**: Enables gradient-based optimization
3. ✅ **Learned Structure**: Automatically discovered optimal pointer configuration
4. ✅ **Circular Properties**: Correctly implements circular traversal
5. ✅ **Discrete Convergence**: Soft matrices converge to valid permutations

### Technical Innovations

- **Superposition Pointers**: Replace discrete indices with probability distributions
- **Matrix Multiplication Traversal**: Navigation through matrix operations
- **Multi-Constraint Training**: Simultaneous task and structure learning
- **Soft-to-Hard Convergence**: Continuous optimization yielding discrete structures

### Broader Applications

This differentiable indirection framework enables:

- **Neural Data Structures**: Trainable stacks, queues, trees, graphs
- **Learnable Algorithms**: Sorting, searching, graph traversal algorithms  
- **Memory-Augmented Networks**: Neural Turing Machines, Differentiable Neural Computers
- **Attention Mechanisms**: Transformer architectures with learned addressing
- **Program Synthesis**: Learning program control flow and data access patterns

### Implications

Differentiable indirection bridges the gap between:
- **Traditional algorithms** (discrete, exact, non-trainable)
- **Neural networks** (continuous, approximate, trainable)

This opens new possibilities for **end-to-end learning** of complete algorithmic systems, where both the computation and the data access patterns can be optimized simultaneously through gradient descent.

## Training
For convergence, a [softmax](https://en.wikipedia.org/wiki/Softmax_function) operation on $P$ is critical before traversing and computing loss. `TODO: Why?`

Initializing $P$ with any invalid permutation matrix leads to a faster convergence than valid permutation matrix. `TODO: WHY?`

In order to make sure that our matrix $P$ is learning a permutation matrix and not linear combination of the input, we also print a defuzzified result `y_defuzz`. It is generated by taking the `argmax` of $P$ and iterating over it again.

In [6]:
opt = tf.keras.optimizers.Adam()

@tf.function
def train_step(data, nexts, target_data):
    data_length = tf.shape(data)[0]
    
    with tf.GradientTape() as tape:
        nextss = tf.nn.softmax(nexts, axis=1)
        actual_data = iterate_over(data, nextss)
        loss = tf.nn.l2_loss(actual_data - target_data)
        loss += permute_matrix_loss(nextss, data_length, 1)
    
    grads = tape.gradient(loss, nexts)
    opt.apply_gradients(zip([grads], [nexts]))
    
    return loss, actual_data

data  = tf.constant([1, 3, 2, 5, 4], dtype=tf.float32)
target_data  = tf.constant([1, 2, 3, 4, 5], dtype=tf.float32)
data_len = tf.shape(data)[0]
# nexts = tf.Variable(tf.one_hot([2, 4, 1, 0, 3], data_len), dtype=tf.float32)
nexts = tf.Variable(tf.one_hot([1, 1, 1, 1, 1], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([0,0,0,0,0], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([4, 4, 4, 4, 4], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([3, 3, 3, 3, 3], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([1, 2, 3, 4, 5], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.random.uniform((data_len, data_len), 0, 1))

tf.print('|   loss  |   y_pred  | y_defuzz |   P_pred  |   P_actual   |')
for i in range(10000):
    loss, actual_data = train_step(data, nexts, target_data)
    if i % 1000 == 0:
        argmax_next = tf.argmax(nexts, 1)
        defuzzified = tf.one_hot(argmax_next, data_len)
        defuzzified_data = iterate_over(data, defuzzified)
        tf.print(loss, tf.round(actual_data), defuzzified_data, argmax_next, [2, 4, 1, 0, 3])
        
tf.print(nexts)

|   loss  |   y_pred  | y_defuzz |   P_pred  |   P_actual   |
8.7374239 [1 3 3 3 3] [1 3 3 3 3] [1 1 1 1 1] [2, 4, 1, 0, 3]
5.66760302 [1 2 3 4 4] [1 2 3 5 5] [2 3 1 3 3] [2, 4, 1, 0, 3]
3.87311959 [1 2 3 4 4] [1 2 3 4 5] [2 4 1 3 3] [2, 4, 1, 0, 3]
1.50285113 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.205469772 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0740788057 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0350763313 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0185657572 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0103952968 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.00600781338 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
[[-2.18890929 -2.56413817 4.51624966 -2.87268877 -3.21818304]
 [-2.38402414 -2.36196756 -2.74159765 -1.29020119 4.7228055]
 [-2.63246131 5.06587076 -2.92895341 -2.58874464 -2.98058629]
 [3.2075038 -3.05285668 -3.33832383 -1.79198897 -2.805336]
 [-2.95142794 -2.45104289 -3.21477652 4.53860283 -2.0

### Verifying cyclic permutation

We can see that $P^n = I$ for both normal and defuzzified cases.

In [7]:
P = tf.nn.softmax(nexts, axis=1)
tf.print(tf.argmax(P,1))
tf.round(P @ P @ P @ P @ P)

[2 4 1 0 3]


<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

In [8]:
argmax_next = tf.argmax(P, 1)
DQ = tf.one_hot(argmax_next, data_len)
DQ @ DQ @ DQ @ DQ @ DQ

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>