# Day 20: Relational Recurrent Neural Networks (RMC)

## Interactive Walkthrough

In this notebook, we will:
1.  **Build an RMC** from scratch (concept filters).
2.  ** Visualize** how the memory slots interact using Self-Attention.
3.  **Compare** RMC vs. LSTM on the "N-th Farthest" task.

### Why this matters
Standard RNNs (like LSTMs) have a "bottleneck" problem: they try to cram everything into a single vector. RMC splits memory into **slots** that can talk to each other.

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

# Import our implementation
from implementation import RelationalMemory, StandardLSTM
from visualization import plot_attention_heatmap

## 1. The Relational Memory Core

Let's initialize a small RMC with 4 memory slots.

In [None]:
INPUT_SIZE = 16
MEM_SLOTS = 4
MEM_SIZE = 32
HEADS = 4

rmc = RelationalMemory(mem_slots=MEM_SLOTS, 
                       mem_size=MEM_SIZE, 
                       input_size=INPUT_SIZE, 
                       num_heads=HEADS)

print(f"RMC Created with {MEM_SLOTS} slots of size {MEM_SIZE}.")
print(f"Total Parameters: {sum(p.numel() for p in rmc.parameters())}")

## 2. Visualizing Self-Attention

RMC is special because it uses **attention** internally. Let's feed it some random data and see the attention weights.

In [None]:
# Create a random input sequence
batch_size = 1
seq_len = 10
inputs = torch.randn(batch_size, seq_len, INPUT_SIZE)

# Run manually step-by-step to capture attention
memory = None
attention_history = []

for t in range(seq_len):
    input_step = inputs[:, t, :]
    memory, probs = rmc(input_step, memory)
    attention_history.append(probs)

# Stack: (seq, batch, heads, slots, slots)
attn_stack = torch.stack(attention_history, dim=0).squeeze(1)

# Plot
plot_attention_heatmap(attn_stack, "notebook_attention.png")

# Display (using markdown for image)
from IPython.display import Image
display(Image("notebook_attention.png"))

## 3. RMC vs LSTM Comparison

Now let's verify if RMC actually learns differently. We'll check the parameters of a comparable LSTM.

In [None]:
lstm = StandardLSTM(input_size=INPUT_SIZE, hidden_size=64, output_size=INPUT_SIZE)

print(f"LSTM Parameters: {sum(p.numel() for p in lstm.parameters())}")
print("Note: We try to keep parameter counts roughly similar for fair comparison.")

## 4. Run Training

To train the models, run the script in the terminal:

```bash
python train_minimal.py --model both --visualize
```