Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/engine/optimum/optimum_rr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@


import asyncio
import gc
import logging
from typing import Any, AsyncIterator, Dict, List, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from transformers import AutoTokenizer
from optimum.intel import OVModelForCausalLM

from src.server.models.optimum import RerankerConfig

from typing import Any, AsyncIterator, Dict

from src.server.model_registry import ModelLoadConfig, ModelRegistry

class Optimum_RR:

def __init__(self, load_config: ModelLoadConfig):
self.model_path = None
self.encoder_tokenizer = None
self.load_config = load_config

def compute_logits(self, inputs, **kwargs):
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

def format_instruction(self, instruction, query, doc):
if instruction is None:
instruction = "Given a search query, retrieve relevant passages that answer the query"
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction, query=query, doc=doc)
return output

async def generate_rerankings(self, rr_config: RerankerConfig) -> AsyncIterator[Union[Dict[str, Any], str]]:
prefix_tokens = self.tokenizer.encode(rr_config.prefix, add_special_tokens=False)
suffix_tokens = self.tokenizer.encode(rr_config.suffix, add_special_tokens=False)
self.max_length = rr_config.max_length
pairs = [self.format_instruction(rr_config.task, rr_config.query, doc) for doc in rr_config.documents]
print(pairs)
# Currently hard coding tokenizer args. If these are model independent than it is fine. Otherwise
# implement the rr_config PreTrainedTokenizerConfig args.
max_length = 8192
inputs = self.tokenizer(
pairs, padding=False, truncation="longest_first", return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
)

for i, ele in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens

# Currently hard coding tokenizer args. If these are model independent than it is fine. Otherwise
# implement the rr_config PreTrainedTokenizerConfig args.
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)

scores = self.compute_logits(inputs)

ranked_documents = [{"doc":doc, "score":score} for score, doc in sorted(zip(scores, rr_config.documents), reverse=True)]

yield ranked_documents

#not implemented
def collect_metrics(self, rr_config: RerankerConfig, perf_metrics) -> Dict[str, Any]:
pass

def load_model(self, loader: ModelLoadConfig):
"""Load model using a ModelLoadConfig configuration and cache the tokenizer.

Args:
loader: ModelLoadConfig containing model_path, device, engine, and runtime_config.
"""

self.model = OVModelForCausalLM.from_pretrained(loader.model_path,
device=loader.device,
export=False,
use_cache=False)

self.tokenizer = AutoTokenizer.from_pretrained(loader.model_path)
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
logging.info(f"Model loaded successfully: {loader.model_name}")

async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool:
"""Unregister model from registry and free memory resources.

Args:
registry: ModelRegistry to unregister from
model_id: Private model identifier returned by register_load

Returns:
True if the model was found and unregistered, else False.
"""
removed = await registry.register_unload(model_name)

if self.model is not None:
del self.model
self.model = None

if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None

gc.collect()
logging.info(f"[{self.load_config.model_name}] weights and tokenizer unloaded and memory cleaned up")
return removed
1 change: 1 addition & 0 deletions src/server/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def start_server(host: str = "0.0.0.0", openarc_port: int = 8001, reload: bool =
logger.info(" - POST /v1/audio/transcriptions: Whisper only")
logger.info(" - POST /v1/audio/speech: Kokoro only")
logger.info(" - POST /v1/embeddings")
logger.info(" - POST /v1/rerank")


uvicorn.run(
Expand Down
70 changes: 68 additions & 2 deletions src/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from src.server.model_registry import ModelLoadConfig, ModelRegistry, ModelUnloadConfig
from src.server.models.openvino import OV_KokoroGenConfig
from src.server.models.optimum import PreTrainedTokenizerConfig
from src.server.models.ov_genai import OVGenAI_GenConfig, OVGenAI_WhisperGenConfig
from src.server.models.optimum import PreTrainedTokenizerConfig, RerankerConfig
from src.server.worker_registry import WorkerRegistry

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -225,6 +225,16 @@ class EmbeddingsRequest(BaseModel):
#end of openai api
config: Optional[PreTrainedTokenizerConfig] = None

# No openai api to reference
class RerankRequest(BaseModel):
model: str
query: str
documents: List[str]
prefix:Optional[str] = None
suffix:Optional[str] = None
task:Optional[str] = None
config: Optional[PreTrainedTokenizerConfig] = None #not implemented

@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def openai_list_models():
"""OpenAI-compatible endpoint that lists available models."""
Expand Down Expand Up @@ -639,4 +649,60 @@ async def embeddings(request: EmbeddingsRequest):
raise HTTPException(status_code=400, detail=str(exc))
except Exception as exc:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Embedding failed: {str(exc)}")
raise HTTPException(status_code=500, detail=f"Embedding failed: {str(exc)}")

@app.post("/v1/rerank", dependencies=[Depends(verify_api_key)])
async def rerank(request: RerankRequest):

try:
if request.config:
tok_config = PreTrainedTokenizerConfig.model_validate(request.config)
base_data = tok_config.model_dump()
rr_config = RerankerConfig.model_validate(base_data | {"query":request.query,"documents":request.documents})
else:
rr_config = RerankerConfig.model_validate({"query":request.query,"documents":request.documents})

if request.prefix:
rr_config.prefix = request.prefix
if request.suffix:
rr_config.suffix = request.suffix
if request.task:
rr_config.task = request.task

model_name = request.model
created_ts = int(time.time())
request_id = f"ov-{uuid.uuid4().hex[:24]}"

result = await _workers.rerank(model_name, rr_config)
data = result.get("data", None)
metrics = result.get("metrics", {}) or {}

prompt_tokens = metrics.get("input_token", 0)
total_tokens = metrics.get("total_token", prompt_tokens)

docs = []
for i in range(len(data)):
docs.append({
"index":i,
"object":"ranked_documents",
"ranked_documents":data[i]
})

response = {
"id": request_id,
"object": "list",
"created": created_ts,
"model": model_name,
"data": docs,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens,
},
}

return response
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except Exception as exc:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Reranking failed: {str(exc)}")
5 changes: 4 additions & 1 deletion src/server/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ class ModelType(str, Enum):
- vlm: Image-to-text VLM models
- whisper: Whisper ASR models
- kokoro: Kokoro TTS models
- emb: Text-to-vector models"""
- emb: Text-to-vector models
- rerank: Reranker models"""

LLM = "llm"
VLM = "vlm"
WHISPER = "whisper"
KOKORO = "kokoro"
EMB = "emb"
RERANK = "rerank"

class EngineType(str, Enum):
"""Engine used to load the model.
Expand Down Expand Up @@ -308,6 +310,7 @@ async def status(self) -> dict:
(EngineType.OV_GENAI, ModelType.WHISPER): "src.engine.ov_genai.whisper.OVGenAI_Whisper",
(EngineType.OPENVINO, ModelType.KOKORO): "src.engine.openvino.kokoro.OV_Kokoro",
(EngineType.OV_OPTIMUM, ModelType.EMB): "src.engine.optimum.optimum_emb.Optimum_EMB",
(EngineType.OV_OPTIMUM, ModelType.RERANK): "src.engine.optimum.optimum_rr.Optimum_RR",
}

async def create_model_instance(load_config: ModelLoadConfig) -> Any:
Expand Down
55 changes: 46 additions & 9 deletions src/server/models/optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class PreTrainedTokenizerConfig(BaseModel):
Configuration for tokenizer.
"""

text: Union[str, List[str], List[List[str]]] = Field(
text: Union[str, List[str], List[List[str]]] | None = Field(
default=None,
description=(
"The sequence or batch of sequences to be encoded. Each sequence can be a string "
Expand All @@ -17,7 +17,7 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

text_pair: Union[str, List[str], List[List[str]]] = Field(
text_pair: Union[str, List[str], List[List[str]]] | None = Field(
default=None,
description=(
"The sequence or batch of sequences to be encoded. Each sequence can be a string "
Expand All @@ -27,7 +27,7 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

text_target: Union[str, List[str], List[List[str]]] = Field(
text_target: Union[str, List[str], List[List[str]]] | None = Field(
default=None,
description=(
"The sequence or batch of sequences to be encoded as target texts. Each sequence can be "
Expand All @@ -37,7 +37,7 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

text_pair_target: Union[str, List[str], List[List[str]]] = Field(
text_pair_target: Union[str, List[str], List[List[str]]] | None = Field(
default=None,
description=(
"The sequence or batch of sequences to be encoded as target texts. Each sequence can be "
Expand Down Expand Up @@ -85,7 +85,7 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

max_length: int = Field(
max_length: int | None = Field(
default=None,
description=(
"Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to "
Expand Down Expand Up @@ -113,7 +113,7 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

pad_to_multiple_of: int = Field(
pad_to_multiple_of: int | None = Field(
default=None,
description=(
"If set will pad the sequence to a multiple of the provided value. Requires padding to be activated. "
Expand All @@ -122,7 +122,7 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

padding_side: str = Field(
padding_side: str | None = Field(
default=None,
description=(
"The side on which the model should have padding applied. Should be selected between ['right', 'left']. "
Expand All @@ -138,15 +138,15 @@ class PreTrainedTokenizerConfig(BaseModel):
)
)

return_token_type_ids: bool = Field(
return_token_type_ids: bool | None = Field(
default=None,
description=(
"Whether to return token type IDs. If left to the default, will return the token type IDs according to the specific "
"tokenizer’s default, defined by the return_outputs attribute. What are token type IDs?"
)
)

return_attention_mask: bool = Field(
return_attention_mask: bool | None = Field(
default=None,
description=(
"Whether to return the attention mask. If left to the default, will return the attention mask according to the "
Expand Down Expand Up @@ -190,3 +190,40 @@ class PreTrainedTokenizerConfig(BaseModel):
"Whether or not to print more information and warnings."
)
)

class RerankerConfig(PreTrainedTokenizerConfig):

query: str = Field(
default=None,
description=(
"Phrase to compare documents to."
)
)

documents: List[str] = Field(
default=None,
description=(
"Documents to rank."
)
)

prefix: str = Field(
default='<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n',
description=(
"Text to append to start of query. This is model specific."
)
)

suffix: str = Field(
default="<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
description=(
"Text to append to end of query. This is model specific."
)
)

task: str = Field(
default="Given a search query, retrieve relevant passages that answer the query",
description=(
"Prompt command delivered to the model."
)
)
Loading