Skip to content

TerryPei/FutureMask

Repository files navigation

Rethinking Causal Mask Attention for Vision–Language Inference

A family of future-aware causal masks that let visual tokens preview future context — improving multimodal reasoning while preserving autoregressive decoding.

Venue Paper Transformers Python License: MIT

Xiaohuan Pei¹, Tao Huang², Yanxiang Ma¹, Chang Xu¹†
¹ The University of Sydney · ² Shanghai Jiao Tong University · † Corresponding author

Overview of future-aware causal attention for vision-language inference

Overview

Causal attention is the default mechanism in autoregressive Vision–Language Models (VLMs): visual and textual tokens are concatenated into one sequence and a strict left-to-right mask blocks every token from attending to its future. This mask is inherited unchanged from text-only LLMs, where it is well justified — predicting the next word must not peek ahead.

But visual tokens are not sequential. An image is processed holistically; its regions carry spatial and temporal relationships that do not respect a left-to-right order. Strictly masking the future for visual queries is therefore overly rigid and discards semantic cues that often live in later tokens (a later frame, a label on the right side of a diagram, a future navigation goal).

This repository studies that mismatch and provides a drop-in, fine-tuning-free answer:

  • 🔍 An empirical study of how different causal-masking strategies affect vision–language inference across 15 multimodal task types (MileBench).
  • 🎭 A family of future-aware causal masks that selectively relax future masking for visual queries only, while preserving causal structure everywhere else.
  • A lightweight "merge" mechanism that pools future visual context into past prefix positions during prefill — keeping the gains of future access while restoring standard causal-decoding speed (up to ≈3× faster than naive future attention).

Everything is inference-only: the masks are applied to released LLaVA checkpoints with no additional training.

📄 Paper: Rethinking Causal Mask Attention for Vision-Language Inference, ICLR 2026 — OpenReview


The Future-Aware Mask Family

We focus on the common VLM layout where m visual tokens precede n text tokens: X = [x^v_1 … x^v_m ; x^t_1 … x^t_n]. Let V denote visual positions and T text positions. For a query at position i, the standard causal mask Mᶜ allows attention only to j ≤ i. The future-aware variants relax this only when the query is a visual token (i ∈ V):

Mask Flag (--attention_mask_mode) What a visual query may additionally see Best for
Mᶜ — Causal (baseline) original nothing beyond the past
Mᶠ — Future-Aware Full future all future tokens (visual and text) Temporal multi-image reasoning
Mᵛ²ᵛ — Visual→Visual future_v2v future visual tokens only Visual-relation tasks (change captioning, relation)
Mᵛ²ᵗ — Visual→Textual future_v2t future text tokens only Text-rich image QA (OCR-VQA, TextVQA)

Text queries (i ∈ T) always keep the standard causal mask, so the autoregressive contract for generation is never violated.

Attention mask designs: causal, future-full, visual-to-visual, visual-to-textual, and light merge
Attention designs. (a) Causal mask. (b) Future-Aware Full Mᶠ. (c) Visual→Visual Mᵛ²ᵛ. (d) Visual→Textual Mᵛ²ᵗ. (e) Light Future-Aware (future scores pooled into a past prefix).

⚡ Light Future-Aware Attention (the +merge variant)

Granting visual tokens full future visibility helps accuracy but adds latency in the decoding phase. Because modern VLMs separate prefill and decode, we push that cost entirely into prefill: a 1-D kernel pooling aggregates each query's valid future attention into a compact summary, which is merged back into the initial (prefix / attention-sink) past positions. Decoding then runs with the original causal mask — no extra future attention, no KV-cache blow-up.

Enable it with --compress_future kernel. Remarkably, merging future context into even a single well-placed prefix token recovers most of the benefit.


Repository Structure

FutureMask/
├── generate.py                 # Inference entry point (accelerate + DeepSpeed)
├── evaluate.py                 # Per-task scoring of predictions
├── score.py / eval_score.py    # Aggregate metrics across tasks → tables
├── get_results.py              # Parse logs/ into a grouped results CSV
├── data.py / utils.py          # MileBench dataset + helpers
├── configs/                    # Model + accelerate configs (7B / 13B / v1.6 / InternVL / Yi …)
├── scripts/
│   ├── llava-v1.5-7b/          # One script per mask variant (7B)
│   └── llava-v1.5-13b/         # One script per mask variant (13B)
├── workers/                    # Model workers (LLaVA / InternVL / Yi)
├── latency/                    # Decoding-latency benchmark (speed.py)
├── logs/                       # Per-task evaluation logs (committed)
├── results/                    # Aggregated metric tables (CSV, committed)
├── figs/                       # Figures used in this README
└── LLaVA-mix_merge_v1/llava/   # Patched LLaVA with the future-aware attention
    └── model/
        ├── language_model/llava_llama.py          # mask/kv-mode dispatch
        └── kv_token_merge/modify_llama.py         # future-aware masks + kernel merge

Environment Setup

Tested with Python 3.10–3.12, PyTorch 2.4, CUDA 12.x and 🤗 Transformers 4.37, on NVIDIA GPUs with FlashAttention.

Option A — pip

conda create -n futuremask python=3.10 -y
conda activate futuremask
pip install -r requirements.txt

Option B — conda (exact pinned environment)

conda env create -f csp_env.yaml   # creates an env named `csp`
conda activate csp

Note on flash-attn. flash_attn and flash_attention_softmax_n may need to be installed last (after torch), e.g. pip install flash-attn==2.3.4 --no-build-isolation.


Data Preparation

We evaluate on MileBench, a long-context multimodal benchmark that consolidates 28 tasks (VQAv2, GQA, TextVQA, OCRBench, Spot-the-Diff, and more) under a unified protocol.

  1. Download the MileBench data (images + annotations) from the official release.
  2. Place it so each task lives at <DATA_DIR>/<Task>/, where <DATA_DIR> matches the DATA_DIR variable in the run scripts (default MLBench):
MLBench/
├── ActionPrediction/
│   ├── ActionPrediction.json
│   └── images/
├── OCR-VQA/
│   ├── OCR-VQA.json
│   └── images/
└── ...

Model checkpoints (liuhaotian/llava-v1.5-7b, liuhaotian/llava-v1.5-13b) are pulled from the 🤗 Hub automatically; the paths are set in configs/model_configs.yaml (and the 13b_* configs).


Usage

Each mask variant has a ready-to-run script. The general flow is generate.pyevaluate.pyscore.py, all wired together inside the scripts.

Quick start (LLaVA-1.5-7B)

conda activate csp

# Baseline: standard causal mask  (Mᶜ)
bash scripts/llava-v1.5-7b/new_eval_original_7b.sh

# Future-Aware Full mask          (Mᶠ)
bash scripts/llava-v1.5-7b/new_eval_future_7b.sh

# Light variant: Mᶠ + prefix merge (fast decoding)
bash scripts/llava-v1.5-7b/new_eval_future_kernel_7b.sh

Script ↔ method map

Script (scripts/llava-v1.5-7b/…) Mask --attention_mask_mode --compress_future
new_eval_original_7b.sh Mᶜ original original
new_eval_future_7b.sh Mᶠ future original
new_eval_future_v2v_7b.sh Mᵛ²ᵛ future_v2v original
new_eval_future_v2t_7b.sh Mᵛ²ᵗ future_v2t original
new_eval_future_kernel_7b.sh Mᶠ + merge future kernel
new_eval_future_v2v_kernel_7b.sh Mᵛ²ᵛ + merge future_v2v kernel
new_eval_future_v2t_kernel_7b.sh Mᵛ²ᵗ + merge future_v2t kernel

The scripts/llava-v1.5-13b/ directory mirrors these for the 13B model. Before running, adjust CUDA_VISIBLE_DEVICES, the conda activation line, and DATA_DIR at the top of each script to match your machine.

Run a single task manually

accelerate launch --config_file ./configs/accelerate_configs.yaml generate.py \
    --data_dir MLBench \
    --dataset_name OCR-VQA \
    --model_name my_run_future \
    --model_configs configs/model_configs.yaml \
    --kv_mode new \
    --attention_mask_mode future \
    --compress_future original \
    --overwrite

python evaluate.py --data-dir MLBench --dataset OCR-VQA --result-dir outputs/my_run_future

Predictions are written to outputs/<model_name>/<task>/pred.json and scored into eval.json.


Results

The tables below are Table 4 of the paper — performance across vision–language tasks for the baseline causal mask (Mᶜ), the three future-aware masks (Mᵛ²ᵗ, Mᵛ²ᵛ, Mᶠ), and their lightweight merge variants (prefix size = 1). Metric is Accuracy, except CLEVR† and SpotDiff† which are Rouge-L. Bold = best in column. The repo's own re-runs are also committed under logs/ / results/ and can be regenerated with get_results.py.

Columns: ActionL ActionLocalization · ActionP ActionPrediction · ActionS ActionSequence · CLEVR CLEVR-Change · Order CharacterOrder · DocVQA · Nav EgocentricNavigation · Moving MovingAttribute · OCRVQA OCR-VQA · Object ObjectExistence · SpotDiff Spot-the-Diff · State StateChange · TQA.

LLaVA-1.5-7B

Mask ActionL ActionP ActionS CLEVR† Order DocVQA Nav Moving OCRVQA Object SpotDiff† State TQA
Mᶜ 0.230 0.515 0.445 0.166 0.245 0.450 0.310 0.490 0.225 0.485 0.162 0.300 0.320
Mᵛ²ᵗ 0.250 0.495 0.435 0.181 0.250 0.445 0.320 0.490 0.230 0.495 0.165 0.305 0.385
Mᵛ²ᵛ 0.255 0.515 0.440 0.177 0.250 0.430 0.325 0.515 0.220 0.500 0.167 0.325 0.385
Mᶠ 0.250 0.500 0.450 0.187 0.255 0.430 0.320 0.505 0.225 0.505 0.171 0.315 0.400
Mᵛ²ᵛ+merge 0.225 0.510 0.435 0.175 0.270 0.445 0.320 0.490 0.205 0.510 0.167 0.305 0.385
Mᵛ²ᵗ+merge 0.245 0.495 0.435 0.180 0.250 0.445 0.320 0.490 0.230 0.495 0.164 0.305 0.375
Mᶠ+merge 0.245 0.500 0.450 0.188 0.265 0.420 0.320 0.505 0.225 0.490 0.173 0.320 0.375

LLaVA-1.5-13B

Mask ActionL ActionP ActionS CLEVR† Order DocVQA Nav Moving OCRVQA Object SpotDiff† State TQA
Mᶜ 0.230 0.450 0.450 0.157 0.435 0.455 0.260 0.500 0.455 0.470 0.158 0.360 0.495
Mᵛ²ᵗ 0.245 0.455 0.495 0.156 0.445 0.460 0.270 0.500 0.415 0.475 0.120 0.360 0.515
Mᵛ²ᵛ 0.225 0.455 0.500 0.157 0.435 0.465 0.265 0.500 0.415 0.475 0.143 0.360 0.525
Mᶠ 0.245 0.460 0.495 0.156 0.440 0.460 0.260 0.510 0.415 0.475 0.155 0.370 0.510
Mᵛ²ᵛ+merge 0.245 0.455 0.495 0.155 0.445 0.460 0.270 0.500 0.415 0.475 0.141 0.360 0.515
Mᵛ²ᵗ+merge 0.245 0.455 0.495 0.155 0.445 0.460 0.270 0.500 0.415 0.475 0.119 0.360 0.510
Mᶠ+merge 0.255 0.450 0.495 0.158 0.445 0.465 0.260 0.505 0.415 0.480 0.115 0.355 0.525

Takeaways. The best mask is task-dependent — the paper's central message. On 7B, Mᵛ²ᵛ (visual→visual) tops several temporal / spatial tasks (Moving 0.490 → 0.515, Nav 0.310 → 0.325, State 0.300 → 0.325), while Mᶠ (full future) gives the strongest TQA (0.320 → 0.400) and the best CLEVR-Change / Spot-the-Diff. Crucially, the Light +merge variants match or beat their full counterparts on many columns while decoding 1.6–3.1× faster (see below) — e.g. 7B Mᶠ+merge attains the table-best CLEVR-Change (0.188) and Spot-the-Diff (0.173). Gains are smaller and mixed on text-dominant columns at 13B (the causal baseline stays best on OCRVQA and Spot-the-Diff), consistent with the finding that strict causal masking still matters for text alignment.

Decoding latency — the +merge payoff

Naively granting future access slows decoding; the Light (+merge) variant pools future context into the prefix during prefill and decodes with the standard causal mask, restoring near-baseline speed (up to ≈3×, Table 6 in the paper).

Decoding latency: future-aware masks vs. their light +merge variants
Attention No merge (ms/token) +merge (ms/token) Speed-up
Mᶠ 83.18 26.54 3.1×
Mᵛ²ᵛ 64.13 26.40 2.4×
Mᵛ²ᵗ 43.04 26.10 1.6×

Reproduce with bash latency/speed.sh (see latency/speed.py).


Logs

Every run appends a one-line-per-evaluation record to logs/<Task>.log. Each line is prefixed by the run/model name and contains the parsed metric dict:

new_eval_future_kernel_7b_v1.5:  ActionPrediction:  {'Accuracy': 0.50, 'image_quantity_level-Accuracy': {...}}

The committed logs cover all mask variants for both LLaVA-1.5-7B and 13B across the 28 MileBench tasks. To turn them into a tidy table:

# Aggregate logs/ → results/grouped_log_metrics_table.csv (grouped by model size & mask)
python get_results.py

# Or aggregate the per-run eval.json files under outputs/ into an HTML/markdown table
python eval_score.py

outputs/ (raw per-sample predictions) is git-ignored because it is large and fully regenerable; logs/ and results/ (the scored metrics) are committed so the numbers above are reproducible without re-running inference.


Where the Method Lives

If you want to read or extend the core implementation:

  • LLaVA-mix_merge_v1/llava/model/language_model/llava_llama.py — dispatches on kv_mode / attention_mask_mode / compress_future and swaps in the patched decoder layer.
  • LLaVA-mix_merge_v1/llava/model/kv_token_merge/modify_llama.py — builds the future-aware causal masks (Mᶠ, Mᵛ²ᵛ, Mᵛ²ᵗ) and the 1-D kernel-pooling merge into prefix positions.
  • generate.py — exposes the --attention_mask_mode and --compress_future flags that select the mask at inference time.

Citation

If you find this work useful, please cite:

@inproceedings{pei2026rethinking,
  title     = {Rethinking Causal Mask Attention for Vision-Language Inference},
  author    = {Pei, Xiaohuan and Huang, Tao and Ma, Yanxiang and Xu, Chang},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {2026},
  url       = {https://openreview.net/forum?id=DuDFytFC5Z}
}

Acknowledgements

This codebase builds on the excellent open-source work of MileBench (long-context multimodal benchmark), LLaVA (the base VLM and model code), and LOOK-M (KV-cache evaluation harness, MIT). We thank the authors for releasing their code and data.

License

Released under the MIT License. Portions derive from MIT-licensed upstream projects (LLaVA, LOOK-M); their original notices are retained.

About

Code Implementation of ICLR 2026: Rethinking Causal Mask Attention for Vision-Language Inference

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages