# MNIST Experiments: Positional Encoding Variants

This notebook runs and compares small ViT models on MNIST using different
relative positional encoding (RPE) mechanisms:

- RoPE baseline
- Cayley-STRING with dense S
- Reflection-based STRING
- Sparse-S Cayley-STRING (varying sparsity f)



In [1]:
import json

import torch

from data_utils import set_seed

from train_eval import ExperimentConfig, run_experiment

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

set_seed(42)


Using device: cpu


In [2]:
# Sanity check: RoPE-only ViT on MNIST

config_rope_mnist = ExperimentConfig(
    dataset="mnist",
    pos_variant="rope",
    img_size=28,
    patch_size=7,
    in_chans=1,
    num_classes=10,
    emb_dim=128,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=2,
    lr=3e-4,
)

results_rope_mnist = run_experiment(config_rope_mnist, device=DEVICE)
print(json.dumps(results_rope_mnist, indent=2))


100.0%
100.0%
100.0%
100.0%


[Epoch 1/2] train_loss=0.5610, train_acc=0.8262, val_loss=0.1995, val_acc=0.9428, time=50.02s
[Epoch 2/2] train_loss=0.1448, train_acc=0.9581, val_loss=0.1536, val_acc=0.9535, time=44.63s
{
  "config": {
    "dataset": "mnist",
    "pos_variant": "rope",
    "img_size": 28,
    "patch_size": 7,
    "in_chans": 1,
    "num_classes": 10,
    "emb_dim": 128,
    "depth": 4,
    "n_heads": 4,
    "batch_size": 128,
    "epochs": 2,
    "lr": 0.0003,
    "weight_decay": 0.01,
    "f_sparse": null
  },
  "final_train_loss": 0.1447838130871455,
  "final_train_acc": 0.9581166666666666,
  "final_val_loss": 0.1535803920030594,
  "final_val_acc": 0.9535,
  "avg_epoch_time_sec": 47.322014927864075,
  "inference_time_ms_per_batch": 28.18448543548584
}


In [3]:
# Baseline Cayley-STRING (dense S) on MNIST

config_cayley_dense_mnist = ExperimentConfig(
    dataset="mnist",
    pos_variant="cayley_dense",
    img_size=28,
    patch_size=7,
    in_chans=1,
    num_classes=10,
    emb_dim=128,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=2,
    lr=3e-4,
)

results_cayley_dense_mnist = run_experiment(config_cayley_dense_mnist, device=DEVICE)
print(json.dumps(results_cayley_dense_mnist, indent=2))


[Epoch 1/2] train_loss=0.5377, train_acc=0.8342, val_loss=0.2083, val_acc=0.9396, time=262.99s
[Epoch 2/2] train_loss=0.1486, train_acc=0.9563, val_loss=0.1118, val_acc=0.9691, time=260.01s
{
  "config": {
    "dataset": "mnist",
    "pos_variant": "cayley_dense",
    "img_size": 28,
    "patch_size": 7,
    "in_chans": 1,
    "num_classes": 10,
    "emb_dim": 128,
    "depth": 4,
    "n_heads": 4,
    "batch_size": 128,
    "epochs": 2,
    "lr": 0.0003,
    "weight_decay": 0.01,
    "f_sparse": null
  },
  "final_train_loss": 0.14857388066848118,
  "final_train_acc": 0.9563333333333334,
  "final_val_loss": 0.11180496111512184,
  "final_val_acc": 0.9691,
  "avg_epoch_time_sec": 261.49744296073914,
  "inference_time_ms_per_batch": 117.74463653564453
}


In [4]:
# Reflection-based STRING on MNIST

config_reflection_mnist = ExperimentConfig(
    dataset="mnist",
    pos_variant="reflection",
    img_size=28,
    patch_size=7,
    in_chans=1,
    num_classes=10,
    emb_dim=128,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=2,
    lr=3e-4,
)

results_reflection_mnist = run_experiment(config_reflection_mnist, device=DEVICE)
print(json.dumps(results_reflection_mnist, indent=2))


[Epoch 1/2] train_loss=0.6917, train_acc=0.7828, val_loss=0.2794, val_acc=0.9147, time=48.00s
[Epoch 2/2] train_loss=0.2188, train_acc=0.9345, val_loss=0.1761, val_acc=0.9458, time=69.27s
{
  "config": {
    "dataset": "mnist",
    "pos_variant": "reflection",
    "img_size": 28,
    "patch_size": 7,
    "in_chans": 1,
    "num_classes": 10,
    "emb_dim": 128,
    "depth": 4,
    "n_heads": 4,
    "batch_size": 128,
    "epochs": 2,
    "lr": 0.0003,
    "weight_decay": 0.01,
    "f_sparse": null
  },
  "final_train_loss": 0.21880235421657562,
  "final_train_acc": 0.9345333333333333,
  "final_val_loss": 0.176052467751503,
  "final_val_acc": 0.9458,
  "avg_epoch_time_sec": 58.63502240180969,
  "inference_time_ms_per_batch": 57.71787166595459
}


In [5]:
# Sparse-S Cayley-STRING variants on MNIST

sparse_results = []
for f in [1.0, 0.5, 0.2, 0.1]:
    print(f"\nRunning sparse Cayley-STRING with f={f}...")
    config_sparse = ExperimentConfig(
        dataset="mnist",
        pos_variant="cayley_sparse",
        img_size=28,
        patch_size=7,
        in_chans=1,
        num_classes=10,
        emb_dim=128,
        depth=4,
        n_heads=4,
        batch_size=128,
        epochs=2,
        lr=3e-4,
        f_sparse=f,
    )
    res = run_experiment(config_sparse, device=DEVICE)
    sparse_results.append(res)

print(json.dumps(sparse_results, indent=2))



Running sparse Cayley-STRING with f=1.0...
[Epoch 1/2] train_loss=0.5314, train_acc=0.8384, val_loss=0.1728, val_acc=0.9518, time=318.82s
[Epoch 2/2] train_loss=0.1388, train_acc=0.9590, val_loss=0.1082, val_acc=0.9674, time=300.90s

Running sparse Cayley-STRING with f=0.5...
[Epoch 1/2] train_loss=0.5574, train_acc=0.8282, val_loss=0.1715, val_acc=0.9503, time=285.48s
[Epoch 2/2] train_loss=0.1368, train_acc=0.9598, val_loss=0.1022, val_acc=0.9690, time=280.93s

Running sparse Cayley-STRING with f=0.2...
[Epoch 1/2] train_loss=0.5124, train_acc=0.8441, val_loss=0.1993, val_acc=0.9387, time=290.49s
[Epoch 2/2] train_loss=0.1427, train_acc=0.9584, val_loss=0.1266, val_acc=0.9618, time=283.54s

Running sparse Cayley-STRING with f=0.1...
[Epoch 1/2] train_loss=0.5072, train_acc=0.8479, val_loss=0.1800, val_acc=0.9471, time=286.46s
[Epoch 2/2] train_loss=0.1351, train_acc=0.9605, val_loss=0.1021, val_acc=0.9712, time=267.13s
[
  {
    "config": {
      "dataset": "mnist",
      "pos_varia

## CIFAR-10 Experiments

Quick runs to compare RoPE, dense Cayley-STRING, Reflection-STRING, and Sparse-S Cayley-STRING on CIFAR-10. Adjust `epochs`/`emb_dim`/`depth` as needed; current settings are small to keep runtime manageable.


In [6]:
# CIFAR-10: RoPE and dense Cayley-STRING (quick baseline runs)

config_rope_cifar = ExperimentConfig(
    dataset="cifar10",
    pos_variant="rope",
    img_size=32,
    patch_size=4,
    in_chans=3,
    num_classes=10,
    emb_dim=192,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=1,
    lr=3e-4,
)

config_cayley_dense_cifar = ExperimentConfig(
    dataset="cifar10",
    pos_variant="cayley_dense",
    img_size=32,
    patch_size=4,
    in_chans=3,
    num_classes=10,
    emb_dim=192,
    depth=4,
    n_heads=4,
    
    batch_size=128,
    epochs=1,
    lr=3e-4,
)

print("Running CIFAR-10 RoPE...")
results_rope_cifar = run_experiment(config_rope_cifar, device=DEVICE)
print(json.dumps(results_rope_cifar, indent=2))

print("\nRunning CIFAR-10 Cayley (dense)...")
results_cayley_dense_cifar = run_experiment(config_cayley_dense_cifar, device=DEVICE)
print(json.dumps(results_cayley_dense_cifar, indent=2))


Running CIFAR-10 RoPE...


100.0%


[Epoch 1/1] train_loss=1.5843, train_acc=0.4229, val_loss=1.3494, val_acc=0.5160, time=196.69s
{
  "config": {
    "dataset": "cifar10",
    "pos_variant": "rope",
    "img_size": 32,
    "patch_size": 4,
    "in_chans": 3,
    "num_classes": 10,
    "emb_dim": 192,
    "depth": 4,
    "n_heads": 4,
    "batch_size": 128,
    "epochs": 1,
    "lr": 0.0003,
    "weight_decay": 0.01,
    "f_sparse": null
  },
  "final_train_loss": 1.5842964737319947,
  "final_train_acc": 0.42292,
  "final_val_loss": 1.349433870124817,
  "final_val_acc": 0.516,
  "avg_epoch_time_sec": 196.68830108642578,
  "inference_time_ms_per_batch": 187.40193843841553
}

Running CIFAR-10 Cayley (dense)...


python(20673) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(20690) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


[Epoch 1/1] train_loss=1.5750, train_acc=0.4279, val_loss=1.3917, val_acc=0.4933, time=3682.29s


python(21373) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(21387) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


{
  "config": {
    "dataset": "cifar10",
    "pos_variant": "cayley_dense",
    "img_size": 32,
    "patch_size": 4,
    "in_chans": 3,
    "num_classes": 10,
    "emb_dim": 192,
    "depth": 4,
    "n_heads": 4,
    "batch_size": 128,
    "epochs": 1,
    "lr": 0.0003,
    "weight_decay": 0.01,
    "f_sparse": null
  },
  "final_train_loss": 1.5749987294387817,
  "final_train_acc": 0.42792,
  "final_val_loss": 1.3916774576187134,
  "final_val_acc": 0.4933,
  "avg_epoch_time_sec": 3682.2861709594727,
  "inference_time_ms_per_batch": 34389.75887298584
}


In [7]:
# CIFAR-10: Reflection-STRING and Sparse-S Cayley-STRING

config_reflection_cifar = ExperimentConfig(
    dataset="cifar10",
    pos_variant="reflection",
    img_size=32,
    patch_size=4,
    in_chans=3,
    num_classes=10,
    emb_dim=192,
    depth=4,
    n_heads=4,
    batch_size=128,
    epochs=1,
    lr=3e-4,
)

sparse_results_cifar = []
for f in [1.0, 0.3, 0.1]:
    print(f"\nRunning CIFAR-10 Sparse Cayley-STRING with f={f}...")
    config_sparse_cifar = ExperimentConfig(
        dataset="cifar10",
        pos_variant="cayley_sparse",
        img_size=32,
        patch_size=4,
        in_chans=3,
        num_classes=10,
        emb_dim=192,
        depth=4,
        n_heads=4,
        batch_size=128,
        epochs=1,
        lr=3e-4,
        f_sparse=f,
    )
    res = run_experiment(config_sparse_cifar, device=DEVICE)
    sparse_results_cifar.append(res)

print("\nRunning CIFAR-10 Reflection-STRING...")
results_reflection_cifar = run_experiment(config_reflection_cifar, device=DEVICE)
print(json.dumps(results_reflection_cifar, indent=2))

print("\nSparse CIFAR-10 results:")
print(json.dumps(sparse_results_cifar, indent=2))



Running CIFAR-10 Sparse Cayley-STRING with f=1.0...


python(21577) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(21593) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(42495) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(42520) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


[Epoch 1/1] train_loss=1.5589, train_acc=0.4312, val_loss=1.2992, val_acc=0.5271, time=43580.31s


python(43894) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(43908) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.



Running CIFAR-10 Sparse Cayley-STRING with f=0.3...


python(44214) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44230) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(64821) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(64840) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


[Epoch 1/1] train_loss=1.5592, train_acc=0.4306, val_loss=1.2896, val_acc=0.5385, time=3492.91s


python(65652) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(65668) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.



Running CIFAR-10 Sparse Cayley-STRING with f=0.1...


python(65857) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(65889) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(82303) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(82319) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


[Epoch 1/1] train_loss=1.5895, train_acc=0.4212, val_loss=1.3522, val_acc=0.5153, time=32885.43s


python(82861) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(82866) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.



Running CIFAR-10 Reflection-STRING...


python(83044) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(83059) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(84797) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(84811) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


[Epoch 1/1] train_loss=1.6229, train_acc=0.4043, val_loss=1.4004, val_acc=0.4925, time=200.32s


python(85127) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(85143) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


{
  "config": {
    "dataset": "cifar10",
    "pos_variant": "reflection",
    "img_size": 32,
    "patch_size": 4,
    "in_chans": 3,
    "num_classes": 10,
    "emb_dim": 192,
    "depth": 4,
    "n_heads": 4,
    "batch_size": 128,
    "epochs": 1,
    "lr": 0.0003,
    "weight_decay": 0.01,
    "f_sparse": null
  },
  "final_train_loss": 1.6229468309020996,
  "final_train_acc": 0.40426,
  "final_val_loss": 1.4004474094390869,
  "final_val_acc": 0.4925,
  "avg_epoch_time_sec": 200.32258820533752,
  "inference_time_ms_per_batch": 165.4984712600708
}

Sparse CIFAR-10 results:
[
  {
    "config": {
      "dataset": "cifar10",
      "pos_variant": "cayley_sparse",
      "img_size": 32,
      "patch_size": 4,
      "in_chans": 3,
      "num_classes": 10,
      "emb_dim": 192,
      "depth": 4,
      "n_heads": 4,
      "batch_size": 128,
      "epochs": 1,
      "lr": 0.0003,
      "weight_decay": 0.01,
      "f_sparse": 1.0
    },
    "final_train_loss": 1.5588574738311767,
    "final_t