# Day 17: Neural Turing Machines (NTM)

**Paper:** "Neural Turing Machines" â€” Graves, Wayne, Danihelka (2014)

We explore the internal mechanics of the NTM, specifically how it uses differentiable addressing to read and write to an external memory bank.

---

## What You'll Learn

1. How content-based addressing uses cosine similarity to find relevant memory locations (Eq 5)
2. How circular convolution implements location-based shifting (Eq 8)
3. How sharpening prevents blurry focus after the shift operation (Eq 9)
4. How the write head performs erase-then-add updates to memory (Eq 3, 4)
5. Why decoupling memory from computation allows NTMs to learn algorithms like copying and sorting

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from implementation import NTM, NTMMemory

# Setup a small NTM for demonstration
N, M = 16, 8
controller_size = 32
ntm = NTM(input_size=5, output_size=5, controller_size=controller_size, N=N, M=M)
print(f"NTM initialized with {N} memory slots and {M} dimensions per slot.")

## 1. Differentiable Addressing Visualization

Standard Turing Machines use discrete pointers. NTMs use "blurry" weightings that sum to 1. This allows the model to be trained with gradient descent.

In [None]:
# Example of a weighting vector w
torch.manual_seed(42)
w = F.softmax(torch.randn(1, N), dim=1)

plt.figure(figsize=(10, 2))
plt.bar(range(N), w[0].detach().numpy())
plt.title("Example Weighting (w)")
plt.xlabel("Memory Slot")
plt.ylabel("Weight Intensity")
plt.show()

## 2. Sharpening

Over time, the weightings tend to drift and become too uniform. Sharpening forces the head to focus on fewer locations.

In [None]:
def sharpen(w, gamma):
    w_pow = w ** gamma
    return w_pow / w_pow.sum()

gammas = [1, 2, 5, 20]
fig, axes = plt.subplots(1, len(gammas), figsize=(15, 3))

for i, g in enumerate(gammas):
    w_sh = sharpen(w[0], g)
    axes[i].bar(range(N), w_sh.detach().numpy())
    axes[i].set_title(f"Gamma = {g}")

plt.suptitle("Effect of Sharpening on Focus", fontsize=14)
plt.tight_layout()
plt.show()

## 3. Circular Shift

The Convolutional Shift allows the NTM to implement relative movements (e.g., "move to the next slot").

In [None]:
def shift_demo(w, s_val):
    # Simple shift simulation
    return torch.roll(w, shifts=s_val, dims=0)

w_focused = torch.zeros(N)
w_focused[4] = 1.0

plt.figure(figsize=(10, 4))
plt.subplot(211)
plt.bar(range(N), w_focused)
plt.title("Original Focus at Slot 4")

plt.subplot(212)
plt.bar(range(N), shift_demo(w_focused, 2))
plt.title("Shifted Focus by +2 Slots")
plt.tight_layout()
plt.show()

## 4. Key Takeaways

### What We Learned

1. **Content addressing finds relevant memory rows** - Cosine similarity between a search key and each memory row produces a soft weighting over locations

2. **Circular convolution enables sequential access** - The shift mechanism lets the controller increment or decrement its focus position, which is necessary for tasks like copying

3. **Sharpening counteracts convolution blur** - Raising weights to a power gamma and renormalizing keeps the focus peaked on specific addresses

4. **Erase-then-add is order-dependent** - The write head first removes old content (Eq 3) then writes new content (Eq 4); reversing these operations produces different results

### What's Next?

- **Differentiable Neural Computers** (Graves et al., 2016) - Improved memory addressing with temporal links and allocation mechanisms
- **Memory Networks** (Weston et al., 2015) - An alternative approach to external memory for question answering
- **Attention mechanisms** generalized many of these ideas into what became the Transformer architecture

---

*The NTM demonstrated that neural networks can learn simple algorithms by separating storage from computation. The four-stage addressing pipeline (content lookup, interpolation, shift, sharpening) laid groundwork for the attention mechanisms used in modern architectures.*