Skip to content

Commit

Permalink
Use batch_size parameter with keybert.backend.SentenceTransformerBack…
Browse files Browse the repository at this point in the history
…end (#210)
  • Loading branch information
adhadse committed Feb 28, 2024
1 parent 1733032 commit dcf31dd
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 24 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
- repo: https://github.com/psf/black
Expand Down
54 changes: 35 additions & 19 deletions keybert/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from keybert._mmr import mmr
from keybert._maxsum import max_sum_distance
from keybert._highlight import highlight_document
from keybert.backend._base import BaseEmbedder
from keybert.backend._utils import select_backend
from keybert.llm._base import BaseLLM
from keybert import KeyLLM
Expand All @@ -38,11 +39,15 @@ class KeyBERT:
</div>
"""

def __init__(self, model="all-MiniLM-L6-v2", llm: BaseLLM = None):
def __init__(
self,
model="all-MiniLM-L6-v2",
llm: BaseLLM = None,
):
"""KeyBERT initialization
Arguments:
model: Use a custom embedding model.
model: Use a custom embedding model or a specific KeyBERT Backend.
The following backends are currently supported:
* SentenceTransformers
* 🤗 Transformers
Expand Down Expand Up @@ -78,7 +83,7 @@ def extract_keywords(
seed_keywords: Union[List[str], List[List[str]]] = None,
doc_embeddings: np.array = None,
word_embeddings: np.array = None,
threshold: float = None
threshold: float = None,
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
"""Extract keywords and/or keyphrases
Expand Down Expand Up @@ -111,9 +116,9 @@ def extract_keywords(
NOTE: This does not work if multiple documents are passed.
seed_keywords: Seed keywords that may guide the extraction of keywords by
steering the similarities towards the seeded keywords.
NOTE: when multiple documents are passed,
NOTE: when multiple documents are passed,
`seed_keywords`funtions in either of the two ways below:
- globally: when a flat list of str is passed, keywords are shared by all documents,
- globally: when a flat list of str is passed, keywords are shared by all documents,
- locally: when a nested list of str is passed, keywords differs among documents.
doc_embeddings: The embeddings of each document.
word_embeddings: The embeddings of each potential keyword/keyphrase across
Expand Down Expand Up @@ -178,10 +183,12 @@ def extract_keywords(
# Check if the right number of word embeddings are generated compared with the vectorizer
if word_embeddings is not None:
if word_embeddings.shape[0] != len(words):
raise ValueError("Make sure that the `word_embeddings` are generated from the function "
"`.extract_embeddings`. \nMoreover, the `candidates`, `keyphrase_ngram_range`,"
"`stop_words`, and `min_df` parameters need to have the same values in both "
"`.extract_embeddings` and `.extract_keywords`.")
raise ValueError(
"Make sure that the `word_embeddings` are generated from the function "
"`.extract_embeddings`. \nMoreover, the `candidates`, `keyphrase_ngram_range`,"
"`stop_words`, and `min_df` parameters need to have the same values in both "
"`.extract_embeddings` and `.extract_keywords`."
)

# Extract embeddings
if doc_embeddings is None:
Expand All @@ -192,15 +199,21 @@ def extract_keywords(
# Guided KeyBERT either local (keywords shared among documents) or global (keywords per document)
if seed_keywords is not None:
if isinstance(seed_keywords[0], str):
seed_embeddings = self.model.embed(seed_keywords).mean(axis=0, keepdims=True)
seed_embeddings = self.model.embed(seed_keywords).mean(
axis=0, keepdims=True
)
elif len(docs) != len(seed_keywords):
raise ValueError("The length of docs must match the length of seed_keywords")
raise ValueError(
"The length of docs must match the length of seed_keywords"
)
else:
seed_embeddings = np.vstack([
self.model.embed(keywords).mean(axis=0, keepdims=True)
for keywords in seed_keywords
])
doc_embeddings = ((doc_embeddings * 3 + seed_embeddings) / 4)
seed_embeddings = np.vstack(
[
self.model.embed(keywords).mean(axis=0, keepdims=True)
for keywords in seed_keywords
]
)
doc_embeddings = (doc_embeddings * 3 + seed_embeddings) / 4

# Find keywords
all_keywords = []
Expand Down Expand Up @@ -256,18 +269,21 @@ def extract_keywords(
# Fine-tune keywords using an LLM
if self.llm is not None:
import torch

doc_embeddings = torch.from_numpy(doc_embeddings).float()
if torch.cuda.is_available():
doc_embeddings = doc_embeddings.to("cuda")
if isinstance(all_keywords[0], tuple):
candidate_keywords = [[keyword[0] for keyword in all_keywords]]
else:
candidate_keywords = [[keyword[0] for keyword in keywords] for keywords in all_keywords]
candidate_keywords = [
[keyword[0] for keyword in keywords] for keywords in all_keywords
]
keywords = self.llm.extract_keywords(
docs,
embeddings=doc_embeddings,
candidate_keywords=candidate_keywords,
threshold=threshold
threshold=threshold,
)
return keywords
return all_keywords
Expand All @@ -279,7 +295,7 @@ def extract_embeddings(
keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: Union[str, List[str]] = "english",
min_df: int = 1,
vectorizer: CountVectorizer = None
vectorizer: CountVectorizer = None,
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
"""Extract document and word embeddings for the input documents and the
generated candidate keywords/keyphrases respectively.
Expand Down
3 changes: 2 additions & 1 deletion keybert/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._base import BaseEmbedder
from ._sentencetransformers import SentenceTransformerBackend

__all__ = ["BaseEmbedder"]
__all__ = ["BaseEmbedder", "SentenceTransformerBackend"]
9 changes: 7 additions & 2 deletions keybert/backend/_sentencetransformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SentenceTransformerBackend(BaseEmbedder):
Arguments:
embedding_model: A sentence-transformers embedding model
encode_kwargs: Additional parameters for the SentenceTransformers.encode() method
Usage:
Expand All @@ -33,7 +34,9 @@ class SentenceTransformerBackend(BaseEmbedder):
```
"""

def __init__(self, embedding_model: Union[str, SentenceTransformer]):
def __init__(
self, embedding_model: Union[str, SentenceTransformer], **encode_kwargs
):
super().__init__()

if isinstance(embedding_model, SentenceTransformer):
Expand All @@ -46,6 +49,7 @@ def __init__(self, embedding_model: Union[str, SentenceTransformer]):
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('all-MiniLM-L6-v2')`"
)
self.encode_kwargs = encode_kwargs

def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
Expand All @@ -59,5 +63,6 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
self.encode_kwargs.update({"show_progress_bar": verbose})
embeddings = self.embedding_model.encode(documents, **self.encode_kwargs)
return embeddings
41 changes: 41 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
from keybert import KeyBERT
from keybert.backend import SentenceTransformerBackend
import sentence_transformers

from sklearn.feature_extraction.text import CountVectorizer
from .utils import get_test_data


doc_one, doc_two = get_test_data()


@pytest.mark.parametrize("keyphrase_length", [(1, i + 1) for i in range(5)])
@pytest.mark.parametrize(
"vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]
)
def test_single_doc_sentence_transformer_backend(keyphrase_length, vectorizer):
"""Test whether the keywords are correctly extracted"""
top_n = 5

model_name = "paraphrase-MiniLM-L6-v2"
st_model = sentence_transformers.SentenceTransformer(model_name, device="cpu")

kb_model = KeyBERT(model=SentenceTransformerBackend(st_model, batch_size=128))

keywords = kb_model.extract_keywords(
doc_one,
keyphrase_ngram_range=keyphrase_length,
min_df=1,
top_n=top_n,
vectorizer=vectorizer,
)

assert model_name in kb_model.model.embedding_model.tokenizer.name_or_path
assert isinstance(keywords, list)
assert isinstance(keywords[0], tuple)
assert isinstance(keywords[0][0], str)
assert isinstance(keywords[0][1], float)
assert len(keywords) == top_n
for keyword in keywords:
assert len(keyword[0].split(" ")) <= keyphrase_length[1]

0 comments on commit dcf31dd

Please sign in to comment.