Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/embeddedllm/backend/onnxruntime_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, model_path: str, vision: bool, device: str = "cpu"):
allow_patterns=None,
repo_type="model",
)
model_path = snapshot_path
self.model_path = snapshot_path

self.model_config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
self.device = device
Expand Down
304 changes: 264 additions & 40 deletions src/embeddedllm/backend/openvino_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import contextlib
from io import BytesIO
import time
import os
from PIL import Image
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import AsyncIterator, List, Optional
from huggingface_hub import snapshot_download

from loguru import logger
from PIL import Image
from transformers import (
AutoConfig,
AutoProcessor,
TextStreamer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TextIteratorStreamer,
Expand All @@ -16,7 +22,7 @@
from threading import Thread

from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig

from embeddedllm.backend.ov_phi3_vision import OvPhi3Vision
from embeddedllm.inputs import PromptInputs
from embeddedllm.protocol import CompletionOutput, RequestOutput
from embeddedllm.sampling_params import SamplingParams
Expand All @@ -27,11 +33,14 @@

class OpenVinoEngine(BaseLLMEngine):
def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
self.vision = vision
self.model_path = model_path
self.device = device

self.model_config: AutoConfig = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
self.model_path,
trust_remote_code=True
)
self.device = device

# model_config is to find out the max length of the model
self.max_model_len = _get_and_verify_max_len(
Expand All @@ -40,51 +49,88 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
disable_sliding_window=False,
sliding_window_len=self.get_hf_config_sliding_window(),
)

logger.info("Model Context Length: " + str(self.max_model_len))

try:
logger.info("Attempt to load fast tokenizer")
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path)
except Exception:
logger.info("Attempt to load slower tokenizer")
self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path)

try:
self.model = OVModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, export=False, device=self.device
logger.info("Tokenizer created")

# non vision
if not vision:
self.tokenizer_stream = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
except Exception as e:
model = OVModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
export=True,
quantization_config=OVWeightQuantizationConfig(
try:
self.model = OVModelForCausalLM.from_pretrained(
self.model_path,
trust_remote_code=True,
export=False,
device=self.device
)
except Exception as e:
model = OVModelForCausalLM.from_pretrained(
self.model_path,
trust_remote_code=True,
export=True,
quantization_config=OVWeightQuantizationConfig(
**{
"bits": 4,
"ratio": 1.0,
"sym": True,
"group_size": 128,
"all_layers": None,
}
),
)
self.model = model.to(self.device)

logger.info("Model loaded")

# vision
elif self.vision:
logger.info("Your model is a vision model")

# snapshot_download vision model if model path provided
if not os.path.exists(model_path):
snapshot_path = snapshot_download(
repo_id=model_path,
allow_patterns=None,
repo_type="model",
)
self.model_path = snapshot_path

try:
# it is case sensitive, only receive all char captilized only
self.model = OvPhi3Vision(
self.model_path,
self.device.upper()
)
logger.info("Model loaded")

self.processor = AutoProcessor.from_pretrained(
self.model_path,
trust_remote_code=True
)
logger.info("Processor loaded")
print("processor directory: ",dir(self.processor))
self.tokenizer_stream = TextIteratorStreamer(
self.processor,
**{
"bits": 4,
"ratio": 1.0,
"sym": True,
"group_size": 128,
"all_layers": None,
}
),
)
self.model = model.to(self.device)

logger.info("Model loaded")
self.tokenizer_stream = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
logger.info("Tokenizer created")

self.vision = vision

# if self.vision:
# self.onnx_processor = self.model.create_multimodal_processor()
# self.processor = AutoImageProcessor.from_pretrained(
# self.model_path, trust_remote_code=True
# )
# print(dir(self.processor))
"skip_special_tokens": True,
"skip_prompt": True,
"clean_up_tokenization_spaces": False,
},
)

except Exception as e:
logger.error("EmbeddedLLM Engine only support Phi 3 Vision Model.")
exit()

async def generate_vision(
self,
Expand All @@ -93,7 +139,185 @@ async def generate_vision(
request_id: str,
stream: bool = True,
) -> AsyncIterator[RequestOutput]:
raise NotImplementedError(f"`generate_vision` yet to be implemented.")
# only work if vision is set to True
if not self.vision:
raise ValueError("Your model is not a vision model. Please set vision=True when initializing the engine.")

prompt_text = inputs['prompt']
input_tokens = self.tokenizer.encode(prompt_text)
file_data = inputs["multi_modal_data"][0]["image_pixel_data"]
mime_type = inputs["multi_modal_data"][0]["mime_type"]
print(f"Detected MIME type: {mime_type}")

assert "image" in mime_type

image = Image.open(BytesIO(file_data))
image_token_length = self.processor.calc_num_image_tokens(image)[0]
prompt_token_length = len(self.tokenizer.encode(prompt_text, return_tensors="pt")[0])

input_token_length = image_token_length + prompt_token_length

# logger.debug(f"Prompt token length: {prompt_token_length}")
# logger.debug(f"Image token length: {image_token_length}")

max_tokens = sampling_params.max_tokens

assert input_token_length is not None

if input_token_length + max_tokens > self.max_model_len:
raise ValueError("Exceed Context Length")

messages = [
{'role': 'user', 'content': f'<|image_1|>\n{prompt_text}'}
]
prompt = self.processor.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# print("Prompt: ", prompt)

inputs = self.processor(prompt, [image], return_tensors="pt")

generation_options = {
'max_new_tokens': max_tokens,
'do_sample': False,
}

token_list: List[int] = []
output_text: str = ""
if stream:
generation_options["streamer"] = self.tokenizer_stream
# Include the inputs in the generation_options
generation_kwargs = {**inputs, **generation_options}

if RECORD_TIMING:
started_timestamp = time.time()
first_token_timestamp = 0
first = True
new_tokens = []

try:
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
output_text = ""
first = True
for new_text in self.tokenizer_stream:
if new_text == "":
continue
if RECORD_TIMING:
if first:
first_token_timestamp = time.time()
first = False
output_text += new_text
token_list = self.processor.tokenizer.encode(output_text, return_tensors="pt")

yield RequestOutput(
request_id=request_id,
prompt=inputs,
prompt_token_ids=input_tokens,
finished=False,
outputs=[
CompletionOutput(
index=0,
text=output_text,
token_ids=token_list[0],
cumulative_logprob=-1.0,
)
],
)

if RECORD_TIMING:
new_tokens = token_list[0]

yield RequestOutput(
request_id=request_id,
prompt=inputs,
prompt_token_ids=input_tokens,
finished=True,
outputs=[
CompletionOutput(
index=0,
text=output_text,
token_ids=token_list[0],
cumulative_logprob=-1.0,
finish_reason="stop",
)
],
)

if RECORD_TIMING:
prompt_time = first_token_timestamp - started_timestamp
run_time = time.time() - first_token_timestamp
logger.info(
f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps"
)

except Exception as e:
logger.error(str(e))

error_output = RequestOutput(
prompt=inputs,
prompt_token_ids=input_tokens,
finished=True,
request_id=request_id,
outputs=[
CompletionOutput(
index=0,
text=output_text,
token_ids=token_list,
cumulative_logprob=-1.0,
finish_reason="error",
stop_reason=str(e),
)
],
)
yield error_output

else:
try:
token_list = self.model.generate(**inputs, **generation_options)[0]
output_text = self.processor.tokenizer.decode(
token_list, skip_special_tokens=True
)

yield RequestOutput(
request_id=request_id,
prompt=inputs,
prompt_token_ids=input_tokens,
finished=True,
outputs=[
CompletionOutput(
index=0,
text=output_text,
token_ids=token_list,
cumulative_logprob=-1.0,
finish_reason="stop",
)
],
)

except Exception as e:
logger.error(str(e))

error_output = RequestOutput(
prompt=inputs,
prompt_token_ids=input_tokens,
finished=True,
request_id=request_id,
outputs=[
CompletionOutput(
index=0,
text=output_text,
token_ids=token_list,
cumulative_logprob=-1.0,
finish_reason="error",
stop_reason=str(e),
)
],
)
yield error_output


async def generate(
self,
Expand Down
Loading