# Day 19: Relational Reasoning (Santoro et al., 2017)

This notebook provides a technical walkthrough of the **Relation Network (RN)** architecture. We verify its core inductive biases—permutation invariance and cardinality awareness—and demonstrate how explicit relational bottlenecks enable reasoning on set-structured data.

In [None]:
import torch
from implementation import RelationNetwork, add_coordinates

model = RelationNetwork(object_dim=4, output_dim=2)
model.eval()

objects = torch.randn(1, 5, 4)
out_orig = model(objects)

indices = torch.randperm(5)
out_shuffled = model(objects[:, indices, :])

print(f"Max Difference: {torch.abs(out_orig - out_shuffled).max().item():.2e}")

## 2. Cardinality Bias: Sum vs. Mean

In Section 2.1, Santoro et al. choose `sum` as their aggregator because it preserves information about the number of objects. Let's see how `mean` washes this information away.

In [None]:
rn_sum = RelationNetwork(object_dim=4, aggregator='sum')
rn_mean = RelationNetwork(object_dim=4, aggregator='mean')

objs_small = torch.ones(1, 2, 4)
objs_large = torch.ones(1, 10, 4)

with torch.no_grad():
    sum_diff = torch.abs(rn_sum(objs_small) - rn_sum(objs_large)).mean().item()
    mean_diff = torch.abs(rn_mean(objs_small) - rn_mean(objs_large)).mean().item()

print(f"Sum Aggregator Sensitivity:  {sum_diff:.2f}")
print(f"Mean Aggregator Sensitivity: {mean_diff:.2f}")
print("\nObservation: Mean aggregator makes sets of 2 and 10 look nearly identical to the network.")

## 3. Coordinate Injection

For spatial tasks (like CLEVR), Section 3.1 describes appending (x, y) coordinates to objects. This allows the model to know *where* objects are in the scene.

In [None]:
objs = torch.randn(1, 4, 8)
objs_with_coords = add_coordinates(objs)
print(f"Original shape:    {objs.shape}")
print(f"With coordinates: {objs_with_coords.shape}")
print(f"Coordinate values for object 0: {objs_with_coords[0, 0, -2:].tolist()}")
print(f"Coordinate values for object n: {objs_with_coords[0, -1, -2:].tolist()}")

## 4. Training with Generalization

Finally, we run a training script that evaluates the generalization gap: training on sets of 5 but testing on sets of 15.

In [None]:
!python train_minimal.py --mode count --train-n 5 --test-n 10 --epochs 10