Add return_cache option to TransformerBridge.generate#1337
Conversation
…ge.generate generate(return_cache=True) now also returns an ActivationCache for the full prompt + generated sequence, identical to run_with_cache(output), via one clean recompute forward over the output. Adds names_filter and device passthroughs to scope and offload the cache. Supported for single-sequence decoder-only text generation; encoder-decoder, SSM, multimodal, batched, and inputs_embeds inputs raise a clear NotImplementedError pointing to run_with_cache. Device offload moves cache_dict directly to avoid ActivationCache.to's spurious move_model DeprecationWarning.
|
Heads-up on a small follow-up, not a blocker for this PR. The Once #1336 is fixed, this can be simplified to a I'd recommend keeping this as a self-contained version and switching to the passthrough in a small follow-up PR once #1336 lands. But if you prefer to hold this PR until #1336 is in and use the passthrough here directly, let me know @jlarson4 |
|
I agree with your assessment here, and think it is fine to merge this as-is with the temporary solution. Typically, I'd ask you to add a note to #1336 that it should update this as a side-effect of the solution, but since you're the one handling that issue, I'll trust that you take care of it. Thank you for the thorough investigation of both this issue and the new one you discovered! Great work |
Description
Adds an opt-in
return_cacheflag toTransformerBridge.generate(). Whenreturn_cache=True,generatereturns(output, cache)wherecacheis a standardActivationCacheover the full prompt+generated sequence, identical torun_with_cache(output). This resolves the gap in #697, whererun_with_cacheonly covers the prompt andgeneratereturns no activations. Anames_filterargument lets callers scope the cache, and adeviceargument offloads the returned cache to another device (e.g. CPU); the cache over prompt+max_new_tokens can be large, so the docstring notes the memory cost.Semantics are "recompute one clean forward over the generated sequence," so the cache is consistent with the rest of TransformerLens, includes attention patterns and all hook points, and avoids the cached-eager-attention path behind #1322. For a causal LM this is numerically identical to capturing during generation (verified), without the ragged per-step shapes. This PR covers single-sequence decoder-only text; encoder-decoder / SSM / multimodal / batched / inputs_embeds raise a clear error pointing to the
run_with_cache-on-output workaround. Capturing during generation (for active-hook/steering scenarios) remains available viawith model.hooks(...)aroundgenerateand can be added later as an explicit opt-in.Fixes #697
Type of change
Checklist: