Skip to content

NeuroCompLab-psu/RMAAT

Repository files navigation

RMAAT: Astrocyte-Inspired Memory Compression and Replay for Efficient Long-Context Transformers

Python 3.10 PyTorch 1.13 License: MIT ICLR 2026

Official implementation of RMAAT (Recurrent Memory Augmented Astromorphic Transformer), accepted at ICLR 2026.

Paper links: arXiv | OpenReview

The quadratic complexity of self-attention limits Transformers on long sequences. RMAAT integrates computational principles derived from astrocytes — glial cells critical for biological memory and synaptic modulation — into a recurrent Transformer framework. It uses segment-based processing with persistent memory tokens, an adaptive compression mechanism governed by a retention factor derived from simulated astrocyte long-term plasticity (LTP), linear-complexity attention inspired by astrocyte short-term plasticity (STP), and the memory-efficient Astrocytic Memory Replay Backpropagation (AMRB) training algorithm.

Architecture Overview

RMAAT processes long sequences by dividing them into fixed-length segments processed recurrently. Key components:

  • Segmented Processing with Memory Tokens — Persistent memory tokens propagate compressed context across segments, enabling long-range dependency modeling without attending over the full sequence.
  • Astromorphic Attention — A linear-complexity O(N) attention mechanism inspired by astrocyte short-term plasticity (STP), replacing standard O(N^2) self-attention within each segment.
  • Memory Retention Factor — Derived from a macro model of neuron-astrocyte LTP dynamics, this factor adaptively compresses memory tokens across segments, implementing biologically motivated context decay.
  • AMRB Training — Astrocytic Memory Replay Backpropagation replays segments from stored compressed memory states during backpropagation, reducing peak GPU memory by up to 4.41x versus standard BPTT (Retrieval: 3.4 -> 15.0 GB, Text: 5.1 -> 22.0 GB, i.e., 4.31x) while maintaining equivalent accuracy.

Results

All table values below are copied directly from the camera-ready paper tables.

LRA Benchmark Accuracy (%)

Model ListOps (2K) Text (4K) Retrieval (8K) Image (1K) Pathfinder (1K) Average
Transformer 36.4 64.3 57.5 42.4 71.4 54.4
Nystromformer 37.2 65.5 79.6 41.6 70.9 59.0
Luna-256 37.3 64.6 79.3 47.4 77.7 61.3
RMT 37.4 65.0 79.3 54.6 81.5 63.6
RMAAT (Ours) 38.9 65.9 83.2 64.8 87.1 68.0

Peak GPU Memory (GB)

Model ListOps Text Retrieval Image Pathfinder
Transformer 4.7 6.7 5.2 7.8 5.4
RMT 20.4 24 18.3 22.7 12.7
RMAAT (Ours) 5.2 5.1 3.4 5.3 4.7

Training Speed (relative to RMT)

Model ListOps Text Retrieval Image Pathfinder
RMT 1.0x 1.0x 1.0x 1.0x 1.0x
RMAAT (Ours) 1.5x 1.5x 1.73x 1.3x 0.95x

Repository Structure

RMAAT/
├── configs/
│   └── config_v10.yaml       # Main configuration file
├── model.py                  # RMAAT architecture (AstroAttention, memory tokens, encoder)
├── train.py                  # Training and evaluation loops with AMRB
├── run_train.py              # Entry point — loads config and launches training
├── dataloader.py             # Data loading and preprocessing
├── lra_datasets.py           # Dataset classes for LRA benchmarks
├── lra_config.py             # Tokenizers and configs for LRA tasks
├── utils.py                  # Utility functions (positional encoding, custom embeddings)
├── get_lra_data.sh           # Script to download LRA datasets
├── environment.yml           # Conda environment specification
├── CITATION.bib              # BibTeX citation
├── LICENSE                   # MIT License
└── README.md

Setup

1. Clone the repository

git clone https://github.com/NeuroCompLab-psu/RMAAT.git
cd RMAAT

2. Create the Conda environment

conda env create -f environment.yml
conda activate rmaat

3. Download datasets

The script downloads CIFAR-10, Long Range Arena (LRA), and IMDb datasets into a datasets/ directory:

bash get_lra_data.sh

Note: The original LRA dataset hosted on Google Cloud Storage may return a 403 error (the upstream repo was archived in Feb 2025). If the download fails, the script will print instructions for obtaining the data from alternative sources such as the e-lra fork. Place the extracted lra_release/ directory inside datasets/.

Usage

Training

Run training from the repository root:

CUDA_VISIBLE_DEVICES=0 python run_train.py --config configs/config_v10.yaml

Selecting a task

Edit configs/config_v10.yaml to change the dataset and relevant hyperparameters:

Parameter Description Options
dataset Dataset/task to run imdb, imdb_long, imdb_lra, listops, cifar10, pathfinder32, aan
max_seq_len Maximum sequence length Typical: 4096 (text), 2048 (listops), 1024 (cifar10/pathfinder32), 8192 (aan/retrieval-style setup)
num_segments Number of recurrent segments 2, 4, 8, 16
num_memory_tokens Memory tokens per segment 1 (default)
attention_type Attention mechanism astro (astromorphic), softmax (standard)
memory_replay_backprop Enable AMRB training True / False
astro_mem Enable astrocytic memory retention True / False

Logging

To enable Weights & Biases logging, set wandb: True and configure wandb_run_name in the config file.

Citation

If you find this work useful, please cite:

@article{mia2026rmaat,
  title         = {{RMAAT}: Astrocyte-Inspired Memory Compression and Replay for Efficient Long-Context Transformers},
  author        = {Mia, Md Zesun Ahmed and Bal, Malyaban and Sengupta, Abhronil},
  journal       = {arXiv preprint arXiv:2601.00426},
  year          = {2026},
  archivePrefix = {arXiv},
  eprint        = {2601.00426},
  primaryClass  = {cs.NE},
  doi           = {10.48550/arXiv.2601.00426},
  url           = {https://arxiv.org/abs/2601.00426}
}

License

This project is licensed under the MIT License. See LICENSE for details.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors