Skip to content

[utils] Improve cache utils to support layer-based caches#545

Merged
mhs4670go merged 3 commits intoSamsung:mainfrom
dayo09:0310-improve
Mar 13, 2026
Merged

[utils] Improve cache utils to support layer-based caches#545
mhs4670go merged 3 commits intoSamsung:mainfrom
dayo09:0310-improve

Conversation

@dayo09
Copy link
Copy Markdown
Contributor

@dayo09 dayo09 commented Mar 10, 2026

Let's improve its coverage.

TICO-DCO-1.0-Signed-off-by: Dayoung Lee dayoung.lee@samsung.com

/cc
@parjong Some Qwen-3 VLM export issue will now be covered.
@llFreetimell Llama export issues too.
NOTE: But this don't resolve vmap issue lazy_load_decomposition issue.

Related to #417


NOTE

  • StaticLayer's _init_ function's implementation differs based on torch version.
  • Static/DynamicCache implementation diverges into 'list-based' and 'layer-based'.
    • See the comment on the top of pytree_utils.py for details.

dayo09 added 3 commits March 10, 2026 15:14
Let's improve its coverage.

TICO-DCO-1.0-Signed-off-by: Dayoung Lee <dayoung.lee@samsung.com>
@dayo09 dayo09 marked this pull request as ready for review March 10, 2026 09:35
@dayo09 dayo09 requested a review from mhs4670go March 10, 2026 09:35
@llFreetimell llFreetimell self-requested a review March 11, 2026 04:10
@mhs4670go
Copy link
Copy Markdown
Contributor

I write down some explanation of pytree for the reviewers.

Why this pytree registration exists

This PR registers several HuggingFace KV cache classes as PyTorch pytrees so that they can be handled correctly by torch.export / torch.fx.

1. What a pytree is

A pytree is a nested Python structure that can be flattened into a list of tensors and later reconstructed.

Example:

x = {
    "a": torch.tensor([1, 2]),
    "b": [torch.tensor([3]), torch.tensor([4])]
}

PyTorch internally converts this into:

Flattened tensors: [tensor([1,2]), tensor([3]), tensor([4])]
Structure spec: {"a": *, "b": [*, *]}

Later the original structure can be reconstructed from the flattened tensors + spec.

This mechanism is used heavily by:

  • torch.export
  • torch.fx
  • torch.compile
  • torch.func

These systems operate on flat tensor lists, so nested Python objects must be flattened first.

2. Why this is needed for Transformers KV cache

Transformer models often pass a KV cache object through the forward pass.

Example:

DynamicCache
 └ layers
     ├ DynamicLayer
     │   ├ keys
     │   └ values
     ├ DynamicLayer
     │   ├ keys
     │   └ values

However, torch.export only understands standard container types such as:

  • list
  • tuple
  • dict

It does not know how to flatten custom classes like:

DynamicCache
StaticCache
DynamicLayer
StaticLayer
EncoderDecoderCache

Without registration, exporting a model will fail with errors such as:

TypeError: Cannot flatten object of type DynamicCache

or

torch.export.Unsupported: Python object not supported

3. What register_pytree_node does

We explicitly tell PyTorch how to flatten and reconstruct these objects.

Example:

pytree.register_pytree_node(
    DynamicLayer,
    flatten_fn,
    unflatten_fn
)

Meaning:

  • flatten_fn → convert the object into tensor children
  • unflatten_fn → reconstruct the object from those tensors

For example:

DynamicLayer
 ├ keys
 └ values

is flattened into

[keys, values]

and later reconstructed back into a DynamicLayer instance.

4. Why flatten_with_keys exists

flatten_with_keys preserves attribute names during flattening.

Instead of returning:

(keys, values)

it returns:

("keys", tensor)
("values", tensor)

This helps torch.export maintain stable attribute mappings.

5. Why fx_pytree.register_pytree_flatten_spec is also needed

torch.export internally relies on torch.fx. Therefore FX must also know how to flatten the same objects. fx_pytree.register_pytree_flatten_spec registers the equivalent flatten rule for FX.

6. Supporting multiple Transformers cache layouts

The HuggingFace cache structure changed across versions.

Older versions (≤ ~4.52):

DynamicCache
 ├ key_cache
 └ value_cache

Newer versions:

DynamicCache
 └ layers
     ├ DynamicLayer
     │   ├ keys
     │   └ values

This PR detects the layout by feature detection rather than version strings:

try:
    from transformers.cache_utils import DynamicLayer

and registers the correct pytree behavior accordingly.

7. Summary

This PR enables the following workflow:

Transformers KV cache objects
        ↓
Registered as PyTorch pytrees
        ↓
torch.export / torch.fx can flatten them
        ↓
Model export succeeds

In short:

This code teaches PyTorch how to treat HuggingFace KV cache classes as pytrees so that torch.export can correctly flatten and reconstruct them during graph export.

Copy link
Copy Markdown
Contributor

@llFreetimell llFreetimell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with

  • transformers 4.57.6
  • torch 2.10.0

and checked it works well :)
LGTM!

@mhs4670go mhs4670go merged commit b610f6e into Samsung:main Mar 13, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants