[utils] Improve cache utils to support layer-based caches#545
[utils] Improve cache utils to support layer-based caches#545mhs4670go merged 3 commits intoSamsung:mainfrom
Conversation
|
I write down some explanation of Why this pytree registration existsThis PR registers several HuggingFace KV cache classes as PyTorch pytrees so that they can be handled correctly by 1. What a pytree isA pytree is a nested Python structure that can be flattened into a list of tensors and later reconstructed. Example: x = {
"a": torch.tensor([1, 2]),
"b": [torch.tensor([3]), torch.tensor([4])]
}PyTorch internally converts this into: Later the original structure can be reconstructed from the flattened tensors + spec. This mechanism is used heavily by:
These systems operate on flat tensor lists, so nested Python objects must be flattened first. 2. Why this is needed for Transformers KV cacheTransformer models often pass a KV cache object through the forward pass. Example: However,
It does not know how to flatten custom classes like: Without registration, exporting a model will fail with errors such as: or 3. What
|
Let's improve its coverage.
TICO-DCO-1.0-Signed-off-by: Dayoung Lee dayoung.lee@samsung.com
/cc
@parjong Some Qwen-3 VLM export issue will now be covered.
@llFreetimell Llama export issues too.
NOTE: But this don't resolve vmap issue lazy_load_decomposition issue.
Related to #417
NOTE