-
Notifications
You must be signed in to change notification settings - Fork 0
USE 136 - implement create embeddings for OSNeuralSparseDocV3GTE #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| from pathlib import Path | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| from huggingface_hub import snapshot_download | ||
| from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
|
|
||
|
|
@@ -26,6 +27,9 @@ | |
| class OSNeuralSparseDocV3GTE(BaseEmbeddingModel): | ||
| """OpenSearch Neural Sparse Encoding Doc v3 GTE model. | ||
|
|
||
| This model generates sparse embeddings for documents by using a masked language | ||
| model's logits to identify the most relevant tokens. | ||
|
|
||
| HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte | ||
| """ | ||
|
|
||
|
|
@@ -40,8 +44,8 @@ def __init__(self, model_path: str | Path) -> None: | |
| super().__init__(model_path) | ||
| self._model: PreTrainedModel | None = None | ||
| self._tokenizer: DistilBertTokenizerFast | None = None | ||
| self._special_token_ids: list | None = None | ||
| self._id_to_token: list | None = None | ||
| self._special_token_ids: list[int] | None = None | ||
| self._device: torch.device = torch.device("cpu") | ||
|
|
||
| def download(self) -> Path: | ||
| """Download and prepare model, saving to self.model_path. | ||
|
|
@@ -139,29 +143,205 @@ def load(self) -> None: | |
| if not self.model_path.exists(): | ||
| raise FileNotFoundError(f"Model not found at path: {self.model_path}") | ||
|
|
||
| # load local model and tokenizer | ||
| self._model = AutoModelForMaskedLM.from_pretrained( | ||
| # setup device (use CUDA if available, otherwise CPU) | ||
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| # load tokenizer | ||
| self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call] | ||
| self.model_path, | ||
| trust_remote_code=True, | ||
| local_files_only=True, | ||
| ) | ||
| self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call] | ||
|
|
||
| # load model as AutoModelForMaskedLM (required for sparse embeddings) | ||
| self._model = AutoModelForMaskedLM.from_pretrained( | ||
| self.model_path, | ||
| trust_remote_code=True, | ||
| local_files_only=True, | ||
| ) | ||
| self._model.to(self._device) # type: ignore[arg-type] | ||
| self._model.eval() | ||
|
|
||
| # setup special tokens | ||
| # set special token IDs (following model card pattern) | ||
| # these will be zeroed out in the sparse vectors | ||
| self._special_token_ids = [ | ||
| self._tokenizer.vocab[str(token)] | ||
| self._tokenizer.vocab[token] # type: ignore[index] | ||
| for token in self._tokenizer.special_tokens_map.values() | ||
| ] | ||
|
|
||
| # setup id_to_token mapping | ||
| self._id_to_token = ["" for _ in range(self._tokenizer.vocab_size)] | ||
| for token, token_id in self._tokenizer.vocab.items(): | ||
| self._id_to_token[token_id] = token | ||
| logger.info( | ||
| f"Model loaded successfully on {self._device}, " | ||
| f"{time.perf_counter() - start_time:.2f}s" | ||
| ) | ||
|
|
||
| def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding: | ||
| """Create sparse vector and decoded token weight embeddings for an input text. | ||
|
|
||
| Args: | ||
| embedding_input: EmbeddingInput object with a .text attribute | ||
| """ | ||
| # generate the sparse embeddings | ||
| sparse_vector, decoded_tokens = self._encode_documents([embedding_input.text])[0] | ||
|
|
||
| # coerce sparse vector tensor into list[float] | ||
| sparse_vector_list = sparse_vector.cpu().numpy().tolist() | ||
|
|
||
| return Embedding( | ||
| timdex_record_id=embedding_input.timdex_record_id, | ||
| run_id=embedding_input.run_id, | ||
| run_record_offset=embedding_input.run_record_offset, | ||
| model_uri=self.model_uri, | ||
| embedding_strategy=embedding_input.embedding_strategy, | ||
| embedding_vector=sparse_vector_list, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just adding this here from our earlier discussion: even if the vector compresses to nothing, I think we should consider why we're passing it through if we don't have a use case for it. Obviously it's needed for generating token weights but if OpenSearch doesn't use it, I'm not sure why we're storing it on the object. I won't press beyond this comment but it feels like we're keeping an unnecessary precursor in addition to the useful output. Happy to be corrected if that's not the case!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd wager to say it's kind of an endless topic. Your point and hesitance are well founded and registered. To me, it's a bit of a gamble. It takes time and money to produce embeddings, even as we tune for performance, and the sparse vectors arguably have more information, given it's a representation of the embedding across the entire vocabulary. Theoretically, we could perform some mathematical operations on the sparse vectors we can't do on the decoded token weights. Other folks, myself included, have some interest in this. Storing them keeps that option on the table. Additionally, there might be a good argument for only storing those sparse vectors in the future and decoding the data on the way out. This could be much cheaper to store in the long term. Lastly, hopefully, we'll use a model in the future that produces true dense vectors, which will require storage in that form. If any of these pan out, I think it'd be nice to have some repititions and tested schemas for storing this data. Certainly not opposed to removing it a few months down the road as we tune the pipeline, but I'd lobby for keeping it in these early days as we develop our understanding. But, as I lead with, I don't think there is a right or wrong answer here. We may very well decide it's not useful at some point and cease to store it for this model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is all fine, appreciate the discussion! 🙂 |
||
| embedding_token_weights=decoded_tokens, | ||
| ) | ||
|
|
||
| def _encode_documents( | ||
| self, | ||
| texts: list[str], | ||
| ) -> list[tuple[torch.Tensor, dict[str, float]]]: | ||
| """Encode documents into sparse vectors and decoded token weights. | ||
|
|
||
| This follows the pattern outlined on the HuggingFace model card for document | ||
| encoding. | ||
|
|
||
| This method will accommodate MULTIPLE text inputs, and return a list of | ||
| embeddings, but the calling context of create_embedding() is a SINGULAR input + | ||
| output. This method keeps the ability to handle multiple inputs + outputs, in the | ||
| event we want something like a create_multiple_embeddings() method in the future, | ||
| but only returns a single result. | ||
|
|
||
| At a very high level, the following is performed: | ||
|
|
||
| 1. We tokenize the input text into "features" using the model's tokenizer. | ||
|
|
||
| 2. The features are fed to the model returning model output logits. These logits | ||
| are "dense" in the sense there are few zeros, but they are not "dense vectors" | ||
| (embeddings) in the sense that they meaningfully represent the input document in | ||
| geometric space; two logit tensors cannot be compared with something like cosine | ||
| similarity. | ||
|
|
||
| 3. The logits are then converted into a sparse vector, which is a numeric | ||
| array of floats with the same number of values as the model's vocabulary. Each | ||
| value's position in the sparse array corresponds to the token id in the | ||
| vocabulary, and the value itself is the "weight" of this token in the input text. | ||
|
|
||
| 4. Lastly, we convert this sparse vector into a {token:weight} dictionary of the | ||
| actual token strings and their numerical weight. This dictionary may contain | ||
| tokens not present in the original text, but will be considerably shorter than | ||
| the model vocabulary length given all zero and low scoring tokens are dropped. | ||
| This is the final form that we will ultimately index into OpenSearch. | ||
ghukill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Args: | ||
| texts: list of strings to create embeddings for | ||
| """ | ||
| if self._model is None or self._tokenizer is None: | ||
| raise RuntimeError("Model not loaded. Call load() before create_embedding.") | ||
|
|
||
| # tokenize the input texts | ||
| features = self._tokenizer( | ||
| texts, | ||
| padding=True, | ||
| truncation=True, | ||
| return_tensors="pt", # returns PyTorch tensors instead of Python lists | ||
| return_token_type_ids=False, | ||
| ) | ||
|
|
||
| # move to CPU or GPU device, depending on what's available | ||
| features = {k: v.to(self._device) for k, v in features.items()} | ||
|
|
||
| # pass features to the model and receive model output logits as a tensor | ||
| with torch.no_grad(): | ||
| output = self._model(**features)[0] | ||
|
|
||
| # generate sparse vectors from model logits tensor | ||
| sparse_vectors = self._get_sparse_vectors(features, output) | ||
|
|
||
| # decode sparse vectors to token-weight dictionaries | ||
| decoded = self._decode_sparse_vectors(sparse_vectors) | ||
|
|
||
| # return list of tuple(vector, decoded token weights) embedding results | ||
| return [(sparse_vectors[i], decoded[i]) for i in range(len(texts))] | ||
|
|
||
| def _get_sparse_vectors( | ||
| self, features: dict[str, torch.Tensor], output: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """Convert model logits output to sparse vectors. | ||
|
|
||
| This follows the HuggingFace model card exactly: https://huggingface.co/ | ||
| opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte#usage-huggingface | ||
|
|
||
| This implements the get_sparse_vector function from the model card: | ||
| 1. Max pooling with attention mask | ||
| 2. log(1 + log(1 + relu())) transformation | ||
| 3. Zero out special tokens | ||
|
|
||
| The end result is a sparse vector with a length of the model vocabulary, with each | ||
| position representing a token in the model vocabulary and each value representing | ||
| that token's weight relative to the input text. | ||
|
|
||
| Args: | ||
| features: Tokenizer output with attention_mask | ||
| output: Model logits of shape (batch_size, seq_len, vocab_size) | ||
|
|
||
| Returns: | ||
| Sparse vectors of shape (batch_size, vocab_size) | ||
| """ | ||
| # collapse sequence positions: take max logit for each vocab token across all | ||
| # positions (also masks out padding tokens) | ||
| values, _ = torch.max(output * features["attention_mask"].unsqueeze(-1), dim=1) | ||
|
|
||
| # compress values to create sparsity: ReLU removes negatives, | ||
| # double-log shrinks large values | ||
| values = torch.log(1 + torch.log(1 + torch.relu(values))) | ||
|
|
||
| # remove special tokens like [CLS], [SEP], [PAD] | ||
| values[:, self._special_token_ids] = 0 | ||
|
|
||
| return values | ||
|
|
||
| def _decode_sparse_vectors( | ||
| self, sparse_vectors: torch.Tensor | ||
| ) -> list[dict[str, float]]: | ||
| """Convert sparse vectors to token-weight dictionaries. | ||
|
|
||
| Handles both single vectors and batches, returning a list of dictionaries mapping | ||
| token strings to their weights. | ||
|
|
||
| Args: | ||
| sparse_vectors: Tensor of shape (batch_size, vocab_size) or (vocab_size,) | ||
|
|
||
| Returns: | ||
| List of dictionaries with token-weight pairs | ||
| """ | ||
| if sparse_vectors.dim() == 1: | ||
| sparse_vectors = sparse_vectors.unsqueeze(0) | ||
|
|
||
| # move to CPU for processing | ||
| sparse_vectors_cpu = sparse_vectors.cpu() | ||
|
|
||
| results: list[dict] = [] | ||
| for vector in sparse_vectors_cpu: | ||
|
|
||
| # find non-zero indices and values | ||
| nonzero_indices = torch.nonzero(vector, as_tuple=False).squeeze(-1) | ||
|
|
||
| if nonzero_indices.numel() == 0: | ||
| results.append({}) | ||
| continue | ||
|
|
||
| # get weights | ||
| weights = vector[nonzero_indices].tolist() | ||
|
|
||
| # convert indices to token strings | ||
| token_ids = nonzero_indices.tolist() | ||
| tokens = self._tokenizer.convert_ids_to_tokens(token_ids) # type: ignore[union-attr] | ||
|
|
||
| logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s") | ||
| # create token:weight dictionary | ||
| token_dict = { | ||
| token: weight | ||
| for token, weight in zip(tokens, weights, strict=True) | ||
| if token is not None | ||
| } | ||
| results.append(token_dict) | ||
|
|
||
| def create_embedding(self, input_record: EmbeddingInput) -> Embedding: | ||
| raise NotImplementedError | ||
| return results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,7 @@ ignore = [ | |
| "D102", | ||
| "D103", | ||
| "D104", | ||
| "EM101", | ||
| "EM102", | ||
| "G004", | ||
| "PLR0912", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, why text length? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping it brief here, but in short, just a handy little indicator of how long the text used for the embedding was! This might help when we don't see the original text, maybe uneareth instances where it's zero? or huge?
But purely a guess at helpful data for the interactive python/shell environment. Won't have any bearing otherwise.