# MNIST Experiments: Positional Encoding Variants

This notebook runs and compares small ViT models on MNIST using different
relative positional encoding (RPE) mechanisms:

- RoPE baseline
- Cayley-STRING with dense S
- Reflection-based STRING
- Sparse-S Cayley-STRING (varying sparsity f)



In [None]:
import json

import torch

from data_utils import set_seed
from train_eval import ExperimentConfig, run_experiment

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

set_seed(42)


In [None]:
# Sanity check: RoPE-only ViT on MNIST

config_rope_mnist = ExperimentConfig(
    dataset="mnist",
    pos_variant="rope",
    img_size=28,
    patch_size=7,
    in_chans=1,
    num_classes=10,
    emb_dim=128,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=2,
    lr=3e-4,
)

results_rope_mnist = run_experiment(config_rope_mnist, device=DEVICE)
print(json.dumps(results_rope_mnist, indent=2))


In [None]:
# Baseline Cayley-STRING (dense S) on MNIST

config_cayley_dense_mnist = ExperimentConfig(
    dataset="mnist",
    pos_variant="cayley_dense",
    img_size=28,
    patch_size=7,
    in_chans=1,
    num_classes=10,
    emb_dim=128,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=2,
    lr=3e-4,
)

results_cayley_dense_mnist = run_experiment(config_cayley_dense_mnist, device=DEVICE)
print(json.dumps(results_cayley_dense_mnist, indent=2))


In [None]:
# Reflection-based STRING on MNIST

config_reflection_mnist = ExperimentConfig(
    dataset="mnist",
    pos_variant="reflection",
    img_size=28,
    patch_size=7,
    in_chans=1,
    num_classes=10,
    emb_dim=128,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=2,
    lr=3e-4,
)

results_reflection_mnist = run_experiment(config_reflection_mnist, device=DEVICE)
print(json.dumps(results_reflection_mnist, indent=2))


In [None]:
# Sparse-S Cayley-STRING variants on MNIST

sparse_results = []
for f in [1.0, 0.5, 0.2, 0.1]:
    print(f"\nRunning sparse Cayley-STRING with f={f}...")
    config_sparse = ExperimentConfig(
        dataset="mnist",
        pos_variant="cayley_sparse",
        img_size=28,
        patch_size=7,
        in_chans=1,
        num_classes=10,
        emb_dim=128,
        depth=4,
        n_heads=4,
        batch_size=128,
        epochs=2,
        lr=3e-4,
        f_sparse=f,
    )
    res = run_experiment(config_sparse, device=DEVICE)
    sparse_results.append(res)

print(json.dumps(sparse_results, indent=2))
