Skip to content

Commit

Permalink
Possible fix for conflict between Automated Prefix Caching (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#2762) and multi-LoRA support (vllm-project#1804) (vllm-project#3263)
  • Loading branch information
jacobthebanana authored and AdrianAbeyta committed Mar 7, 2024
1 parent ca1b39c commit fd6e57e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
46 changes: 31 additions & 15 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
Run `pytest tests/test_cache_block_hashing.py`.
"""
from typing import List, Optional

import pytest

from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Sequence

Expand Down Expand Up @@ -36,7 +39,10 @@ def flatten_2d(li):
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("max_num_seqs", [256])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
@pytest.mark.parametrize("concurrent_lora_int_ids",
[[None], [1], [None, 1], [None, 1, 2], [1, 2]])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
concurrent_lora_int_ids: List[Optional[int]]):

tokenizer = TokenizerGroup(
tokenizer_id="facebook/opt-125m",
Expand All @@ -48,20 +54,30 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
hashes = []

for prefix in prefixes:
hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
tokenizer.tokenizer.eos_token_id)

num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash_of_block(idx))

seq_id += 1
for lora_int_id in concurrent_lora_int_ids:
lora_request = None

if lora_int_id is not None:
lora_request = LoRARequest(
f"example_lora_{lora_int_id}",
lora_int_id,
f"example/path/to/lora_{lora_int_id}",
)

hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
tokenizer.tokenizer.eos_token_id, lora_request)

num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash_of_block(idx))

seq_id += 1

# Check that hashes made with two prefixes with different first blocks are
# different everywhere.
Expand Down
3 changes: 2 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def hash_of_block(self, logical_idx: int) -> int:
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
return hash(
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))

def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
Expand Down

0 comments on commit fd6e57e

Please sign in to comment.