A family of future-aware causal masks that let visual tokens preview future context — improving multimodal reasoning while preserving autoregressive decoding.
Xiaohuan Pei¹, Tao Huang², Yanxiang Ma¹, Chang Xu¹†
¹ The University of Sydney · ² Shanghai Jiao Tong University · † Corresponding author
Causal attention is the default mechanism in autoregressive Vision–Language Models (VLMs): visual and textual tokens are concatenated into one sequence and a strict left-to-right mask blocks every token from attending to its future. This mask is inherited unchanged from text-only LLMs, where it is well justified — predicting the next word must not peek ahead.
But visual tokens are not sequential. An image is processed holistically; its regions carry spatial and temporal relationships that do not respect a left-to-right order. Strictly masking the future for visual queries is therefore overly rigid and discards semantic cues that often live in later tokens (a later frame, a label on the right side of a diagram, a future navigation goal).
This repository studies that mismatch and provides a drop-in, fine-tuning-free answer:
- 🔍 An empirical study of how different causal-masking strategies affect vision–language inference across 15 multimodal task types (MileBench).
- 🎭 A family of future-aware causal masks that selectively relax future masking for visual queries only, while preserving causal structure everywhere else.
- ⚡ A lightweight "merge" mechanism that pools future visual context into past prefix positions during prefill — keeping the gains of future access while restoring standard causal-decoding speed (up to ≈3× faster than naive future attention).
Everything is inference-only: the masks are applied to released LLaVA checkpoints with no additional training.
📄 Paper: Rethinking Causal Mask Attention for Vision-Language Inference, ICLR 2026 — OpenReview
We focus on the common VLM layout where m visual tokens precede n text tokens: X = [x^v_1 … x^v_m ; x^t_1 … x^t_n]. Let V denote visual positions and T text positions. For a query at position i, the standard causal mask Mᶜ allows attention only to j ≤ i. The future-aware variants relax this only when the query is a visual token (i ∈ V):
| Mask | Flag (--attention_mask_mode) |
What a visual query may additionally see | Best for |
|---|---|---|---|
Mᶜ — Causal (baseline) |
original |
nothing beyond the past | — |
Mᶠ — Future-Aware Full |
future |
all future tokens (visual and text) | Temporal multi-image reasoning |
Mᵛ²ᵛ — Visual→Visual |
future_v2v |
future visual tokens only | Visual-relation tasks (change captioning, relation) |
Mᵛ²ᵗ — Visual→Textual |
future_v2t |
future text tokens only | Text-rich image QA (OCR-VQA, TextVQA) |
Text queries (i ∈ T) always keep the standard causal mask, so the autoregressive contract for generation is never violated.
Attention designs. (a) Causal mask. (b) Future-Aware Full
Mᶠ. (c) Visual→Visual Mᵛ²ᵛ. (d) Visual→Textual Mᵛ²ᵗ. (e) Light Future-Aware (future scores pooled into a past prefix).
Granting visual tokens full future visibility helps accuracy but adds latency in the decoding phase. Because modern VLMs separate prefill and decode, we push that cost entirely into prefill: a 1-D kernel pooling aggregates each query's valid future attention into a compact summary, which is merged back into the initial (prefix / attention-sink) past positions. Decoding then runs with the original causal mask — no extra future attention, no KV-cache blow-up.
Enable it with --compress_future kernel. Remarkably, merging future context into even a single well-placed prefix token recovers most of the benefit.
FutureMask/
├── generate.py # Inference entry point (accelerate + DeepSpeed)
├── evaluate.py # Per-task scoring of predictions
├── score.py / eval_score.py # Aggregate metrics across tasks → tables
├── get_results.py # Parse logs/ into a grouped results CSV
├── data.py / utils.py # MileBench dataset + helpers
├── configs/ # Model + accelerate configs (7B / 13B / v1.6 / InternVL / Yi …)
├── scripts/
│ ├── llava-v1.5-7b/ # One script per mask variant (7B)
│ └── llava-v1.5-13b/ # One script per mask variant (13B)
├── workers/ # Model workers (LLaVA / InternVL / Yi)
├── latency/ # Decoding-latency benchmark (speed.py)
├── logs/ # Per-task evaluation logs (committed)
├── results/ # Aggregated metric tables (CSV, committed)
├── figs/ # Figures used in this README
└── LLaVA-mix_merge_v1/llava/ # Patched LLaVA with the future-aware attention
└── model/
├── language_model/llava_llama.py # mask/kv-mode dispatch
└── kv_token_merge/modify_llama.py # future-aware masks + kernel merge
Tested with Python 3.10–3.12, PyTorch 2.4, CUDA 12.x and 🤗 Transformers 4.37, on NVIDIA GPUs with FlashAttention.
conda create -n futuremask python=3.10 -y
conda activate futuremask
pip install -r requirements.txtconda env create -f csp_env.yaml # creates an env named `csp`
conda activate cspNote on
flash-attn.flash_attnandflash_attention_softmax_nmay need to be installed last (aftertorch), e.g.pip install flash-attn==2.3.4 --no-build-isolation.
We evaluate on MileBench, a long-context multimodal benchmark that consolidates 28 tasks (VQAv2, GQA, TextVQA, OCRBench, Spot-the-Diff, and more) under a unified protocol.
- Download the MileBench data (images + annotations) from the official release.
- Place it so each task lives at
<DATA_DIR>/<Task>/, where<DATA_DIR>matches theDATA_DIRvariable in the run scripts (defaultMLBench):
MLBench/
├── ActionPrediction/
│ ├── ActionPrediction.json
│ └── images/
├── OCR-VQA/
│ ├── OCR-VQA.json
│ └── images/
└── ...
Model checkpoints (liuhaotian/llava-v1.5-7b, liuhaotian/llava-v1.5-13b) are pulled from the 🤗 Hub automatically; the paths are set in configs/model_configs.yaml (and the 13b_* configs).
Each mask variant has a ready-to-run script. The general flow is generate.py → evaluate.py → score.py, all wired together inside the scripts.
conda activate csp
# Baseline: standard causal mask (Mᶜ)
bash scripts/llava-v1.5-7b/new_eval_original_7b.sh
# Future-Aware Full mask (Mᶠ)
bash scripts/llava-v1.5-7b/new_eval_future_7b.sh
# Light variant: Mᶠ + prefix merge (fast decoding)
bash scripts/llava-v1.5-7b/new_eval_future_kernel_7b.shScript (scripts/llava-v1.5-7b/…) |
Mask | --attention_mask_mode |
--compress_future |
|---|---|---|---|
new_eval_original_7b.sh |
Mᶜ |
original |
original |
new_eval_future_7b.sh |
Mᶠ |
future |
original |
new_eval_future_v2v_7b.sh |
Mᵛ²ᵛ |
future_v2v |
original |
new_eval_future_v2t_7b.sh |
Mᵛ²ᵗ |
future_v2t |
original |
new_eval_future_kernel_7b.sh |
Mᶠ + merge |
future |
kernel |
new_eval_future_v2v_kernel_7b.sh |
Mᵛ²ᵛ + merge |
future_v2v |
kernel |
new_eval_future_v2t_kernel_7b.sh |
Mᵛ²ᵗ + merge |
future_v2t |
kernel |
The scripts/llava-v1.5-13b/ directory mirrors these for the 13B model. Before running, adjust CUDA_VISIBLE_DEVICES, the conda activation line, and DATA_DIR at the top of each script to match your machine.
accelerate launch --config_file ./configs/accelerate_configs.yaml generate.py \
--data_dir MLBench \
--dataset_name OCR-VQA \
--model_name my_run_future \
--model_configs configs/model_configs.yaml \
--kv_mode new \
--attention_mask_mode future \
--compress_future original \
--overwrite
python evaluate.py --data-dir MLBench --dataset OCR-VQA --result-dir outputs/my_run_futurePredictions are written to outputs/<model_name>/<task>/pred.json and scored into eval.json.
The tables below are Table 4 of the paper — performance across vision–language tasks for the baseline causal mask (Mᶜ), the three future-aware masks (Mᵛ²ᵗ, Mᵛ²ᵛ, Mᶠ), and their lightweight merge variants (prefix size = 1). Metric is Accuracy, except CLEVR† and SpotDiff† which are Rouge-L. Bold = best in column. The repo's own re-runs are also committed under logs/ / results/ and can be regenerated with get_results.py.
Columns: ActionL ActionLocalization · ActionP ActionPrediction · ActionS ActionSequence · CLEVR CLEVR-Change · Order CharacterOrder · DocVQA · Nav EgocentricNavigation · Moving MovingAttribute · OCRVQA OCR-VQA · Object ObjectExistence · SpotDiff Spot-the-Diff · State StateChange · TQA.
| Mask | ActionL | ActionP | ActionS | CLEVR† | Order | DocVQA | Nav | Moving | OCRVQA | Object | SpotDiff† | State | TQA |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Mᶜ |
0.230 | 0.515 | 0.445 | 0.166 | 0.245 | 0.450 | 0.310 | 0.490 | 0.225 | 0.485 | 0.162 | 0.300 | 0.320 |
Mᵛ²ᵗ |
0.250 | 0.495 | 0.435 | 0.181 | 0.250 | 0.445 | 0.320 | 0.490 | 0.230 | 0.495 | 0.165 | 0.305 | 0.385 |
Mᵛ²ᵛ |
0.255 | 0.515 | 0.440 | 0.177 | 0.250 | 0.430 | 0.325 | 0.515 | 0.220 | 0.500 | 0.167 | 0.325 | 0.385 |
Mᶠ |
0.250 | 0.500 | 0.450 | 0.187 | 0.255 | 0.430 | 0.320 | 0.505 | 0.225 | 0.505 | 0.171 | 0.315 | 0.400 |
Mᵛ²ᵛ+merge |
0.225 | 0.510 | 0.435 | 0.175 | 0.270 | 0.445 | 0.320 | 0.490 | 0.205 | 0.510 | 0.167 | 0.305 | 0.385 |
Mᵛ²ᵗ+merge |
0.245 | 0.495 | 0.435 | 0.180 | 0.250 | 0.445 | 0.320 | 0.490 | 0.230 | 0.495 | 0.164 | 0.305 | 0.375 |
Mᶠ+merge |
0.245 | 0.500 | 0.450 | 0.188 | 0.265 | 0.420 | 0.320 | 0.505 | 0.225 | 0.490 | 0.173 | 0.320 | 0.375 |
| Mask | ActionL | ActionP | ActionS | CLEVR† | Order | DocVQA | Nav | Moving | OCRVQA | Object | SpotDiff† | State | TQA |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Mᶜ |
0.230 | 0.450 | 0.450 | 0.157 | 0.435 | 0.455 | 0.260 | 0.500 | 0.455 | 0.470 | 0.158 | 0.360 | 0.495 |
Mᵛ²ᵗ |
0.245 | 0.455 | 0.495 | 0.156 | 0.445 | 0.460 | 0.270 | 0.500 | 0.415 | 0.475 | 0.120 | 0.360 | 0.515 |
Mᵛ²ᵛ |
0.225 | 0.455 | 0.500 | 0.157 | 0.435 | 0.465 | 0.265 | 0.500 | 0.415 | 0.475 | 0.143 | 0.360 | 0.525 |
Mᶠ |
0.245 | 0.460 | 0.495 | 0.156 | 0.440 | 0.460 | 0.260 | 0.510 | 0.415 | 0.475 | 0.155 | 0.370 | 0.510 |
Mᵛ²ᵛ+merge |
0.245 | 0.455 | 0.495 | 0.155 | 0.445 | 0.460 | 0.270 | 0.500 | 0.415 | 0.475 | 0.141 | 0.360 | 0.515 |
Mᵛ²ᵗ+merge |
0.245 | 0.455 | 0.495 | 0.155 | 0.445 | 0.460 | 0.270 | 0.500 | 0.415 | 0.475 | 0.119 | 0.360 | 0.510 |
Mᶠ+merge |
0.255 | 0.450 | 0.495 | 0.158 | 0.445 | 0.465 | 0.260 | 0.505 | 0.415 | 0.480 | 0.115 | 0.355 | 0.525 |
Takeaways. The best mask is task-dependent — the paper's central message. On 7B, Mᵛ²ᵛ (visual→visual) tops several temporal / spatial tasks (Moving 0.490 → 0.515, Nav 0.310 → 0.325, State 0.300 → 0.325), while Mᶠ (full future) gives the strongest TQA (0.320 → 0.400) and the best CLEVR-Change / Spot-the-Diff. Crucially, the Light +merge variants match or beat their full counterparts on many columns while decoding 1.6–3.1× faster (see below) — e.g. 7B Mᶠ+merge attains the table-best CLEVR-Change (0.188) and Spot-the-Diff (0.173). Gains are smaller and mixed on text-dominant columns at 13B (the causal baseline stays best on OCRVQA and Spot-the-Diff), consistent with the finding that strict causal masking still matters for text alignment.
Naively granting future access slows decoding; the Light (+merge) variant pools future context into the prefix during prefill and decodes with the standard causal mask, restoring near-baseline speed (up to ≈3×, Table 6 in the paper).
| Attention | No merge (ms/token) | +merge (ms/token) |
Speed-up |
|---|---|---|---|
Mᶠ |
83.18 | 26.54 | 3.1× |
Mᵛ²ᵛ |
64.13 | 26.40 | 2.4× |
Mᵛ²ᵗ |
43.04 | 26.10 | 1.6× |
Reproduce with bash latency/speed.sh (see latency/speed.py).
Every run appends a one-line-per-evaluation record to logs/<Task>.log. Each line is prefixed by the run/model name and contains the parsed metric dict:
new_eval_future_kernel_7b_v1.5: ActionPrediction: {'Accuracy': 0.50, 'image_quantity_level-Accuracy': {...}}
The committed logs cover all mask variants for both LLaVA-1.5-7B and 13B across the 28 MileBench tasks. To turn them into a tidy table:
# Aggregate logs/ → results/grouped_log_metrics_table.csv (grouped by model size & mask)
python get_results.py
# Or aggregate the per-run eval.json files under outputs/ into an HTML/markdown table
python eval_score.py
outputs/(raw per-sample predictions) is git-ignored because it is large and fully regenerable;logs/andresults/(the scored metrics) are committed so the numbers above are reproducible without re-running inference.
If you want to read or extend the core implementation:
LLaVA-mix_merge_v1/llava/model/language_model/llava_llama.py— dispatches onkv_mode/attention_mask_mode/compress_futureand swaps in the patched decoder layer.LLaVA-mix_merge_v1/llava/model/kv_token_merge/modify_llama.py— builds the future-aware causal masks (Mᶠ,Mᵛ²ᵛ,Mᵛ²ᵗ) and the 1-D kernel-pooling merge into prefix positions.generate.py— exposes the--attention_mask_modeand--compress_futureflags that select the mask at inference time.
If you find this work useful, please cite:
@inproceedings{pei2026rethinking,
title = {Rethinking Causal Mask Attention for Vision-Language Inference},
author = {Pei, Xiaohuan and Huang, Tao and Ma, Yanxiang and Xu, Chang},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2026},
url = {https://openreview.net/forum?id=DuDFytFC5Z}
}This codebase builds on the excellent open-source work of MileBench (long-context multimodal benchmark), LLaVA (the base VLM and model code), and LOOK-M (KV-cache evaluation harness, MIT). We thank the authors for releasing their code and data.
Released under the MIT License. Portions derive from MIT-licensed upstream projects (LLaVA, LOOK-M); their original notices are retained.
