Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix caching #57

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft

Prefix caching #57

wants to merge 40 commits into from

Conversation

IanMagnusson
Copy link
Contributor

@IanMagnusson IanMagnusson commented Jul 15, 2022

What is this?

Prefix caching for DecoderOnlyRCModel that reuses overlapping prefixes between instances rather than recomputing them.

This is useful in at least two settings:

  1. When using the same set of ICL examples over and over all test examples in a dataset. When ICL prompts are long and model size is large speed up can be nearly 9x.
  2. When using multiple choice examples where the context for each example is the same and only the choice continuation changes. This seems to yield around 2x speed up for the classification tasks in MetaICL.

How do I use it?

Check out experiments/prefix_cache_demo.py for example usage.

Limitations

Issues with max sequence length

One limitation is that caching with batching does not work well when close to the max input length. This is because both the cached past_keys_values and the continuation input tensor are padded to their largest length. The sum of the sequence length of these two tensors must be less than the max model. So longest input length is determined by the largest prefix and largest continuation together in a batch, even if they belong to different examples.

Over coming this issue would involve breaking batches into sub-batches or recording the examples to avoid problematic pairings of long prefixes and continuations. Right now if this issue is encountered an assertion is hit that states, "Presently batches with wide range of prefix and input lengths are not supported due overrun of max model size."

Cached Prefix Memory Usage

When batches contain more than one cached prefix, there is no memory efficient way to build the past_keys_values input tensor. That is the tensor must be a concatenation of different prefixes and thus cannot use .expand() to save memory. Instead the caching has a larger memory footprint than non-caching because prefixes are computed and then copied.

A first level of improvement would be to do the copying in place some how so the memory footprint is at least the same size as non-caching. Improving beyond this by making use of .expand() would require splitting the continuation computation into sub-batches that all share the same prefix.

Initial Result (outdated)

These results were made with the first iteration of this design on July 15. See below for more up to date results.

Speed up

The screenshot below gives initial results on the speed up from using caching in a few settings. Key take aways are:

  • Achieving almost 9x speed up on longest inputs and largest models
  • Left truncation is not an issue as expected because the resulting number of different cached prefixes is low relative to total data size (224 prefixes vs 6540 instances in metaicl::boolq)

Screen Shot 2022-07-15 at 11 54 49 AM

Reproducibility

F1 and accuracy metrics are exactly reproduced and logits are all within torch.allclose() with and without caching, and with using the code from main. The following are results running python -m catwalk --model metaicl::gpt2 --task metaicl::boolq --num_shots 1 --fewshot_seed 100:

with caching

metaicl::boolq acc 0.4051987826824188
metaicl::boolq f1 tensor([0.5296, 0.1913])
metaicl::boolq precision tensor([0.3778, 0.6183])
metaicl::boolq recall tensor([0.8852, 0.1131])

without caching

metaicl::boolq acc 0.4051987826824188
metaicl::boolq f1 tensor([0.5296, 0.1913])
metaicl::boolq precision tensor([0.3778, 0.6183])
metaicl::boolq recall tensor([0.8852, 0.1131])

code from main

metaicl::boolq acc 0.4051987826824188
metaicl::boolq f1 tensor([0.5296, 0.1913])
metaicl::boolq precision tensor([0.3778, 0.6183])
metaicl::boolq recall tensor([0.8852, 0.1131])

@IanMagnusson IanMagnusson marked this pull request as ready for review July 15, 2022 20:57
@IanMagnusson IanMagnusson requested a review from dirkgr July 15, 2022 20:57
@IanMagnusson IanMagnusson marked this pull request as draft July 23, 2022 00:15
@IanMagnusson
Copy link
Contributor Author

Benchmark Speed and Reproduction Results

These results confirm initial results above but on the full set of classification evaluation tasks in MetaICL. Some formatting changes were required in the MetaICLTask (accounting for the one large Reproduction Error). But with those fixes we are now reproducing metrics closely to those from the MetaICL repo itself. Also of note is that caching actually produces a slight slowdown when the only overlap in inputs comes from multiple choice examples sharing the same context sequence over several continuation sequences. This is expected to be slower than the repeated ICL examples setting, but ideally should have a speed up proportional to the number of choices per example. Future work can attempt to address bottlenecks that are dominating the caching speed when there is relatively low overlap.

Screen Shot 2022-07-25 at 6 03 58 PM

truncation and catchable prefixes

The following table further explains the better speedup on repeated ICL examples (in the 16 shot setting) vs the multiple-choice only overlap (in the 0 shot setting). Of note the ratio of cacheable prefixes to the total number of inputs is much higher in the 0 shot setting. There is less overlap there and thus more different prefixes must actually be computed at some point, giving less of a speedup.

Screen Shot 2022-07-25 at 6 12 20 PM

@IanMagnusson IanMagnusson marked this pull request as ready for review July 26, 2022 01:22
@IanMagnusson IanMagnusson marked this pull request as draft July 26, 2022 18:16
@IanMagnusson IanMagnusson marked this pull request as ready for review July 26, 2022 23:47
@IanMagnusson
Copy link
Contributor Author

Improved performance with batched caching

Commit 4e98c01 improves the caching code with batched computation of prefixes when there are multiple non shared prefixes within a single batch. The following results show that this gives caching a speed up even in the zeroshot setting, when there are multiple choice questions.

The results also compare against the original MetaICL repo code. Those results are for raw gpt2-large while our results add IA3 adapters, so if we reran our code base without the IA3 adapters we would expect speed to improve or stay the same.

Screen Shot 2022-08-03 at 3 06 01 PM

@dirkgr
Copy link
Member

dirkgr commented Aug 15, 2022

You were saying you're still hacking on this and I should not review it?

@IanMagnusson
Copy link
Contributor Author

IanMagnusson commented Aug 15, 2022

You were saying you're still hacking on this and I should not review it?

I may change the way that the prefixes are cached at the batch level and how those are used for the final inference to reduce memory footprint. But the part where I use the trie to organize examples by common prefix won't change from that. So to save you time, best not to review anything other than the trie and sorting stuff yet. But if you want to push out the functionality, this is a working implementation and I'm happy to add my optimizations in a new PR.

@IanMagnusson
Copy link
Contributor Author

You were saying you're still hacking on this and I should not review it?

I may change the way that the prefixes are cached at the batch level and how those are used for the final inference to reduce memory footprint. But the part where I use the trie to organize examples by common prefix won't change from that. So to save you time, best not to review anything other than the trie and sorting stuff yet. But if you want to push out the functionality, this is a working implementation and I'm happy to add my optimizations in a new PR.

@dirkgr, Qinyuan's work will not be needing any more caching improvements soon, so this branch won't be moving forward from that any more. I've changed the top post in the PR to reflect the current state of affairs. In particular note the updated limitations section which mentions the issues we talked about in person. I've also added a demo script experiments/prefix_cache_demo.py that I hope can be helpful for getting a debug session running easily. Please let me know what else I can do to help move this PR forward.

@IanMagnusson IanMagnusson removed the request for review from epwalsh August 30, 2022 02:10
@IanMagnusson IanMagnusson marked this pull request as draft September 17, 2022 01:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants