# Differentiable Stacks

There exist many implementations of differentiable stacks in the literature related to neural turing machines and similar. e.g. [Learning to Transduce with Unbounded Memory](http://papers.nips.cc/paper/5648-learning-to-transduce-with-unbounded-memory.pdf), [Inferring Algorithmic Patterns with Stack-Augmented Recurrent Nets](https://papers.nips.cc/paper/5857-inferring-algorithmic-patterns-with-stack-augmented-recurrent-nets.pdf) etc.

A differentiable stack maintains the LIFO (Last In, First Out) behavior of classical stacks while being fully differentiable for gradient-based optimization. This enables neural networks to learn algorithms that manipulate stack-like memory structures.

**Key principles:**
1. **Deterministic and lossless forward pass** - Operations preserve information exactly
2. **Well-defined gradients** - All operations support backpropagation through PyTorch's automatic differentiation

PyTorch's autograd seamlessly handles gradient flow through our stack operations, making the differentiable implementation nearly identical to classical stack operations.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Differentiable Stack Module

The stack uses two components:
- **Buffer**: A fixed-size tensor that stores all stack elements 
- **Index**: A one-hot vector indicating the current top position + 1

**Soft operations**: Instead of discrete indexing (which breaks differentiability), we use:
- Soft assignment using weighted combinations
- Superposition lookup using attention-like mechanisms  
- One-hot vectors shifted with `torch.roll` to track position

This PyTorch module encapsulates all stack operations while maintaining full differentiability.

In [None]:
class DifferentiableStack(nn.Module):
    """A differentiable stack data structure implemented as a PyTorch module.

    Note: This implementation uses in-place operations for demonstration purposes.
    For gradient-based learning, use the functional operations in ListReverser.
    """

    def __init__(self, stack_shape, device=None):
        """Initialize the differentiable stack.

        Args:
            stack_shape: Shape of the stack buffer (max_size, element_dim, ...)
            device: Device to place tensors on (defaults to CPU)
        """
        super().__init__()

        if device is None:
            device = torch.device("cpu")

        self.stack_shape = stack_shape
        self.device = device

        # Initialize buffer and index as learnable parameters
        buffer = torch.zeros(stack_shape, dtype=torch.float32, device=device)
        index = F.one_hot(torch.tensor(0, device=device), stack_shape[0]).float()

        self.register_parameter("buffer", nn.Parameter(buffer))
        self.register_parameter("index", nn.Parameter(index))

    def _soft_assign(self, buffer, index, element):
        """Soft assignment operation for differentiable indexing."""
        if buffer.dim() == 1:
            return buffer + index * (element - buffer)
        else:
            # Expand index to match buffer dimensions
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return buffer + index * (element.unsqueeze(0) - buffer)

    def _soft_lookup(self, buffer, index):
        """Soft lookup operation for differentiable indexing."""
        if buffer.dim() == 1:
            return torch.sum(index * buffer)
        else:
            # Expand index to match buffer dimensions
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return torch.sum(index * buffer, dim=0)

    def push(self, element):
        """Push an element onto the stack.

        Note: Uses .data assignment which breaks gradients for optimization.
        For learnable operations, use functional approach in ListReverser.
        """
        # Update buffer at current index position
        self.buffer.data = self._soft_assign(self.buffer, self.index, element)
        # Shift index pointer forward
        self.index.data = torch.roll(self.index, shifts=1, dims=0)

    def pop(self):
        """Pop an element from the stack."""
        # Shift index pointer back
        self.index.data = torch.roll(self.index, shifts=-1, dims=0)
        # Get element at current position
        element = self._soft_lookup(self.buffer, self.index)
        return element

    def peek(self):
        """Peek at the top element without removing it."""
        # Get index of top element
        peek_index = torch.roll(self.index, shifts=-1, dims=0)
        element = self._soft_lookup(self.buffer, peek_index)
        return element


# Create and test the stack
stack = DifferentiableStack((3, 3))
print("Initial stack buffer:")
print(stack.buffer)
print("Initial stack index:", stack.index)

Initial stack buffer:
Parameter containing:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], requires_grad=True)
Initial stack index: Parameter containing:
tensor([1., 0., 0.], requires_grad=True)


In [3]:
class StackFromBuffer(nn.Module):
    """Create a differentiable stack from an existing buffer."""

    def __init__(self, buffer):
        """Initialize stack with pre-existing buffer.

        Args:
            buffer: Pre-existing tensor to use as stack buffer
        """
        super().__init__()

        # Register buffer as parameter
        self.register_parameter("buffer", nn.Parameter(buffer.clone()))

        # Initialize index pointing to first position
        device = buffer.device
        stack_shape = buffer.shape
        index = F.one_hot(torch.tensor(0, device=device), stack_shape[0]).float()
        self.register_parameter("index", nn.Parameter(index))

    def _soft_assign(self, buffer, index, element):
        """Soft assignment operation for differentiable indexing."""
        if buffer.dim() == 1:
            return buffer + index * (element - buffer)
        else:
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return buffer + index * (element.unsqueeze(0) - buffer)

    def _soft_lookup(self, buffer, index):
        """Soft lookup operation for differentiable indexing."""
        if buffer.dim() == 1:
            return torch.sum(index * buffer)
        else:
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return torch.sum(index * buffer, dim=0)

    def push(self, element):
        """Push an element onto the stack."""
        self.buffer.data = self._soft_assign(self.buffer, self.index, element)
        self.index.data = torch.roll(self.index, shifts=1, dims=0)

    def pop(self):
        """Pop an element from the stack."""
        self.index.data = torch.roll(self.index, shifts=-1, dims=0)
        element = self._soft_lookup(self.buffer, self.index)
        return element

    def peek(self):
        """Peek at the top element without removing it."""
        peek_index = torch.roll(self.index, shifts=-1, dims=0)
        element = self._soft_lookup(self.buffer, peek_index)
        return element


# Example: Create stack from existing buffer
buffer = torch.ones((3, 3), dtype=torch.float32)
stack_from_buffer = StackFromBuffer(buffer)
print("Stack from buffer:")
print("Buffer:")
print(stack_from_buffer.buffer)
print("Index:", stack_from_buffer.index)

Stack from buffer:
Buffer:
Parameter containing:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], requires_grad=True)
Index: Parameter containing:
tensor([1., 0., 0.], requires_grad=True)


## Understanding Soft Operations

The key to differentiable stacks lies in **soft operations** that approximate discrete indexing while maintaining gradients.

In [4]:
class SoftOperations(nn.Module):
    """Demonstrates the soft assignment and lookup operations used in differentiable stacks."""

    def __init__(self):
        super().__init__()

    def soft_assign(self, buffer, index, element):
        """Soft assignment: buffer[i] = element (differentiable version)

        Instead of discrete assignment, we use:
        buffer = buffer + index * (element - buffer)

        When index is one-hot [0,1,0], this updates only the selected position.
        """
        if buffer.dim() == 1:
            return buffer + index * (element - buffer)
        else:
            # Expand index to match buffer dimensions
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return buffer + index * (element.unsqueeze(0) - buffer)

    def soft_lookup(self, buffer, index):
        """Soft lookup: element = buffer[i] (differentiable version)

        Instead of discrete indexing, we use weighted sum:
        element = sum(index * buffer)

        When index is one-hot [0,1,0], this selects only one element.
        """
        if buffer.dim() == 1:
            return torch.sum(index * buffer)
        else:
            # Expand index to match buffer dimensions
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return torch.sum(index * buffer, dim=0)


# Demonstrate soft operations
soft_ops = SoftOperations()

# Test buffer and index
test_buffer = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
test_index = torch.tensor([0.0, 1.0, 0.0])  # Points to second row
test_element = torch.tensor([10.0, 11.0, 12.0])

print("Original buffer:")
print(test_buffer)
print("Index (one-hot, points to row 1):", test_index)
print("Element to assign:", test_element)

# Test soft assignment
new_buffer = soft_ops.soft_assign(test_buffer, test_index, test_element)
print("\nAfter soft assignment:")
print(new_buffer)

# Test soft lookup
looked_up = soft_ops.soft_lookup(new_buffer, test_index)
print("Soft lookup result:", looked_up)

Original buffer:
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
Index (one-hot, points to row 1): tensor([0., 1., 0.])
Element to assign: tensor([10., 11., 12.])

After soft assignment:
tensor([[ 1.,  2.,  3.],
        [10., 11., 12.],
        [ 7.,  8.,  9.]])
Soft lookup result: tensor([10., 11., 12.])


## Stack Operations with Gradient Tracking

Let's demonstrate how stack operations maintain gradients through the computation graph.

In [5]:
class StackPushModule(nn.Module):
    """Demonstrates stack push operation with gradient tracking."""

    def __init__(self, stack_shape):
        super().__init__()
        # Create a fresh stack for each forward pass
        self.stack_shape = stack_shape

    def forward(self, elements):
        """Push multiple elements and return final buffer state."""
        # Create fresh stack for this forward pass
        stack = DifferentiableStack(self.stack_shape)

        for element in elements:
            stack.push(element)
        return stack.buffer


# Create elements to push
elements = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
    dtype=torch.float32,
    requires_grad=True,
)

# Create stack module
stack_push_module = StackPushModule((3, 3))

# Forward pass with gradient tracking
final_buffer = stack_push_module(elements)

print("Final buffer after pushing 3 elements:")
print(final_buffer)

# Test gradients
loss = final_buffer.sum()
loss.backward()
print("\nGradients with respect to input elements:")
print(elements.grad)

Final buffer after pushing 3 elements:
Parameter containing:
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]], requires_grad=True)

Gradients with respect to input elements:
None


## Stack Pop Operation

The pop operation retrieves the top element while updating the stack pointer. It demonstrates how we can perform lookups while maintaining differentiability.

In [6]:
class StackPopModule(nn.Module):
    """Demonstrates stack pop operation with gradient tracking."""

    def __init__(self, initial_buffer):
        super().__init__()
        self.initial_buffer = initial_buffer

    def forward(self):
        """Pop two elements and return them."""
        # Create fresh stack from initial buffer
        stack = StackFromBuffer(self.initial_buffer)

        element1 = stack.pop()
        element2 = stack.pop()
        return element1, element2, stack.buffer, stack.index


# Create a buffer with initial values
buffer = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
    dtype=torch.float32,
    requires_grad=True,
)

# Create stack module with pre-filled buffer
stack_pop_module = StackPopModule(buffer)

# Perform pop operations with gradient tracking
element1, element2, final_buffer, final_index = stack_pop_module()

print("First popped element (should be [3,3,3]):", element1)
print("Second popped element (should be [2,2,2]):", element2)
print("Stack buffer after pops:")
print(final_buffer)
print("Stack index after pops:", final_index)

# Test gradients
loss = element1.sum() + element2.sum()
loss.backward()
print("\nGradients with respect to buffer:")
print(buffer.grad)

First popped element (should be [3,3,3]): tensor([3., 3., 3.], grad_fn=<SumBackward1>)
Second popped element (should be [2,2,2]): tensor([2., 2., 2.], grad_fn=<SumBackward1>)
Stack buffer after pops:
Parameter containing:
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]], requires_grad=True)
Stack index after pops: Parameter containing:
tensor([0., 1., 0.], requires_grad=True)

Gradients with respect to buffer:
None


## Stack Peek Operation

The peek operation allows us to examine the top element without modifying the stack state.

In [7]:
class StackPeekModule(nn.Module):
    """Demonstrates stack peek operation with gradient tracking."""

    def __init__(self, initial_buffer):
        super().__init__()
        self.initial_buffer = initial_buffer

    def forward(self):
        """Peek at the top element without modifying stack."""
        stack = StackFromBuffer(self.initial_buffer)
        return stack.peek(), stack.buffer, stack.index


# Create a buffer with initial values
buffer = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
    dtype=torch.float32,
    requires_grad=True,
)

# Create stack module
stack_peek_module = StackPeekModule(buffer)

# Peek at top element
peeked_element, final_buffer, final_index = stack_peek_module()

print("Peeked element (should be [3,3,3]):", peeked_element)
print("Stack buffer (unchanged):")
print(final_buffer)
print("Stack index (unchanged):", final_index)

# Test gradients
loss = peeked_element.sum()
loss.backward()
print("\nGradients with respect to buffer:")
print(buffer.grad)

Peeked element (should be [3,3,3]): tensor([3., 3., 3.], grad_fn=<SumBackward1>)
Stack buffer (unchanged):
Parameter containing:
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]], requires_grad=True)
Stack index (unchanged): Parameter containing:
tensor([1., 0., 0.], requires_grad=True)

Gradients with respect to buffer:
None


## Application: List Reversal with Dual Stacks

A classic application demonstrating differentiable stacks: reversing a list using two stacks. This algorithm shows how complex operations can be learned end-to-end through backpropagation.

**Algorithm:**
1. Push all elements from the input list into Stack 1
2. Pop all elements from Stack 1 and push them into Stack 2  
3. Stack 2's buffer contains the reversed list

In [8]:
class ListReverser(nn.Module):
    """Module that reverses a list using two differentiable stacks."""

    def __init__(self, max_length, element_dim):
        super().__init__()
        self.max_length = max_length
        self.element_dim = element_dim

    def _soft_assign(self, buffer, index, element):
        """Soft assignment operation for differentiable indexing."""
        if buffer.dim() == 1:
            return buffer + index * (element - buffer)
        else:
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return buffer + index * (element.unsqueeze(0) - buffer)

    def _soft_lookup(self, buffer, index):
        """Soft lookup operation for differentiable indexing."""
        if buffer.dim() == 1:
            return torch.sum(index * buffer)
        else:
            for _ in range(buffer.dim() - 1):
                index = index.unsqueeze(-1)
            return torch.sum(index * buffer, dim=0)

    def _stack_push(self, buffer, index, element):
        """Functional stack push - returns new state."""
        new_buffer = self._soft_assign(buffer, index, element)
        new_index = torch.roll(index, shifts=1, dims=0)
        return new_buffer, new_index

    def _stack_pop(self, buffer, index):
        """Functional stack pop - returns element and new state."""
        new_index = torch.roll(index, shifts=-1, dims=0)
        element = self._soft_lookup(buffer, new_index)
        return element, buffer, new_index

    def forward(self, input_list):
        """Reverse a list using two stacks.

        Args:
            input_list: Tensor of shape (sequence_length, element_dim)

        Returns:
            Reversed list tensor
        """
        seq_length = input_list.shape[0]
        device = input_list.device

        # Initialize two stacks
        buffer1 = torch.zeros((seq_length, self.element_dim), device=device)
        index1 = F.one_hot(torch.tensor(0, device=device), seq_length).float()

        buffer2 = torch.zeros((seq_length, self.element_dim), device=device)
        index2 = F.one_hot(torch.tensor(0, device=device), seq_length).float()

        # Step 1: Push all elements into stack1
        for i in range(seq_length):
            buffer1, index1 = self._stack_push(buffer1, index1, input_list[i])

        # Step 2: Transfer all elements from stack1 to stack2
        for _ in range(seq_length):
            element, buffer1, index1 = self._stack_pop(buffer1, index1)
            buffer2, index2 = self._stack_push(buffer2, index2, element)

        # Return the buffer of stack2 (contains reversed list)
        return buffer2


# Test the list reverser
reverser = ListReverser(max_length=4, element_dim=4)

# Create input array
input_arr = torch.tensor(
    [
        [1.0, 1.0, 1.0, 1.0],
        [2.0, 2.0, 2.0, 2.0],
        [3.0, 3.0, 3.0, 3.0],
        [4.0, 4.0, 4.0, 4.0],
    ],
    dtype=torch.float32,
    requires_grad=True,
)

# Reverse the list
reversed_arr = reverser(input_arr)

print("Original array:")
print(input_arr)
print("\nReversed array:")
print(reversed_arr)

# Test gradient flow
loss = reversed_arr.sum()
loss.backward()
print("\nGradients with respect to input:")
print(input_arr.grad)

Original array:
tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.],
        [4., 4., 4., 4.]], requires_grad=True)

Reversed array:
tensor([[4., 4., 4., 4.],
        [3., 3., 3., 3.],
        [2., 2., 2., 2.],
        [1., 1., 1., 1.]], grad_fn=<AddBackward0>)

Gradients with respect to input:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])


## Learning Through Backpropagation

This demonstrates the power of differentiable stacks: we can train a neural network to learn the input that produces a desired output through gradient descent. The stack operations maintain full differentiability throughout the computation graph.

In [None]:
class LearnableListReverser(nn.Module):
    """Training module that learns to reverse lists through gradient descent."""

    def __init__(self, max_length, element_dim):
        super().__init__()
        self.reverser = ListReverser(max_length, element_dim)

    def forward(self, input_list):
        return self.reverser(input_list)


# Set up the learning experiment
reverser_model = LearnableListReverser(max_length=4, element_dim=4)

# Learnable input (starts as all ones)
learnable_input = nn.Parameter(torch.ones((4, 4), dtype=torch.float32))

# Target: what we want the reversed output to be
target_reversed = torch.tensor(
    [
        [4.0, 4.0, 4.0, 4.0],
        [3.0, 3.0, 3.0, 3.0],
        [2.0, 2.0, 2.0, 2.0],
        [1.0, 1.0, 1.0, 1.0],
    ],
    dtype=torch.float32,
)

# Optimizer
optimizer = torch.optim.Adam([learnable_input], lr=0.1)

# Training loop
print("Training to learn input that reverses to target...")
print("Initial input:")
print(learnable_input.data)

for epoch in range(100):
    optimizer.zero_grad()

    # Forward pass: reverse the learnable input
    reversed_output = reverser_model(learnable_input)

    # Loss: how different is our output from the target?
    loss = F.mse_loss(reversed_output, target_reversed)

    # Backward pass
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch:3d}, Loss: {loss.item():.6f}")

print("\nFinal learned input:")
print(torch.round(learnable_input))
print("\nTarget reversed output:")
print(target_reversed)
print("\nActual reversed output:")
print(torch.round(reverser_model(learnable_input)))

Training to learn input that reverses to target...
Initial input:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
Epoch   0, Loss: 3.500000
Epoch  10, Loss: 1.278137
Epoch  20, Loss: 0.343994
Epoch  30, Loss: 0.054887
Epoch  40, Loss: 0.016218
Epoch  50, Loss: 0.008572
Epoch  60, Loss: 0.006566
Epoch  70, Loss: 0.001941
Epoch  80, Loss: 0.000025
Epoch  90, Loss: 0.000215

Final learned input:
tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.],
        [4., 4., 4., 4.]], grad_fn=<RoundBackward0>)

Target reversed output:
tensor([[4., 4., 4., 4.],
        [3., 3., 3., 3.],
        [2., 2., 2., 2.],
        [1., 1., 1., 1.]])

Actual reversed output:
tensor([[4., 4., 4., 4.],
        [3., 3., 3., 3.],
        [2., 2., 2., 2.],
        [1., 1., 1., 1.]], grad_fn=<RoundBackward0>)


# Summary

This notebook demonstrated differentiable stacks implemented as PyTorch modules:

1. **Core Concepts**: Differentiable stacks maintain LIFO behavior while preserving gradients
2. **Soft Operations**: Use weighted combinations instead of discrete indexing
3. **PyTorch Integration**: Implemented as `nn.Module` classes with learnable parameters
4. **Gradient Flow**: All operations support backpropagation through autograd
5. **Applications**: Can learn complex algorithms like list reversal end-to-end

## Important Note on Gradient Flow

**For demonstration purposes**, the `DifferentiableStack` class uses `.data` assignment which breaks gradient flow. This is fine for showing stack behavior but prevents learning.

**For gradient-based optimization**, the `ListReverser` uses **functional operations** that maintain gradients:
- No `.data` assignment 
- Returns new tensors instead of modifying in-place
- Preserves computation graph for backpropagation

The key insight is that by replacing discrete operations with their "soft" differentiable equivalents **while maintaining functional purity**, we can create data structures that neural networks can learn to manipulate programmatically.