Skip to content

Commit

Permalink
Enable Guided KeyBERT for both local and global seed keywords (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
shengbo-ma committed Jan 18, 2023
1 parent 7b763ae commit f91c502
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.8]
python-version: [3.7, 3.8]

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -76,3 +76,6 @@ venv.bak/
.idea
.idea/
.vscode

# MacOS
.DS_Store
25 changes: 17 additions & 8 deletions keybert/_model.py
Expand Up @@ -68,7 +68,7 @@ def extract_keywords(
nr_candidates: int = 20,
vectorizer: CountVectorizer = None,
highlight: bool = False,
seed_keywords: List[str] = None,
seed_keywords: Union[List[str], List[List[str]]] = None,
doc_embeddings: np.array = None,
word_embeddings: np.array = None,
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
Expand Down Expand Up @@ -103,6 +103,10 @@ def extract_keywords(
NOTE: This does not work if multiple documents are passed.
seed_keywords: Seed keywords that may guide the extraction of keywords by
steering the similarities towards the seeded keywords.
NOTE: when multiple documents are passed,
`seed_keywords`funtions in either of the two ways below:
- globally: when a flat list of str is passed, keywords are shared by all documents,
- locally: when a nested list of str is passed, keywords differs among documents.
doc_embeddings: The embeddings of each document.
word_embeddings: The embeddings of each potential keyword/keyphrase across
across the vocabulary of the set of input documents.
Expand Down Expand Up @@ -176,8 +180,19 @@ def extract_keywords(
doc_embeddings = self.model.embed(docs)
if word_embeddings is None:
word_embeddings = self.model.embed(words)

# Guided KeyBERT either local (keywords shared among documents) or global (keywords per document)
if seed_keywords is not None:
seed_embeddings = self.model.embed([" ".join(seed_keywords)])
if isinstance(seed_keywords[0], str):
seed_embeddings = self.model.embed(seed_keywords).mean(axis=0, keepdims=True)
elif len(docs) != len(seed_keywords):
raise ValueError("The length of docs must match the length of seed_keywords")
else:
seed_embeddings = np.vstack([
self.model.embed(keywords).mean(axis=0, keepdims=True)
for keywords in seed_keywords
])
doc_embeddings = ((doc_embeddings * 3 + seed_embeddings) / 4)

# Find keywords
all_keywords = []
Expand All @@ -190,12 +205,6 @@ def extract_keywords(
candidate_embeddings = word_embeddings[candidate_indices]
doc_embedding = doc_embeddings[index].reshape(1, -1)

# Guided KeyBERT with seed keywords
if seed_keywords is not None:
doc_embedding = np.average(
[doc_embedding, seed_embeddings], axis=0, weights=[3, 1]
)

# Maximal Marginal Relevance (MMR)
if use_mmr:
keywords = mmr(
Expand Down
44 changes: 42 additions & 2 deletions tests/test_model.py
Expand Up @@ -103,21 +103,61 @@ def test_extract_keywords_multiple_docs(keyphrase_length, candidates):
assert keywords_list[0][0][0] == candidates[0]
assert len(keywords_list[1]) == 0


def test_guided():
"""Test whether the keywords are correctly extracted"""

# single doc + a keywords list
top_n = 5
seed_keywords = ["time", "night", "day", "moment"]
keywords = model.extract_keywords(
doc_one, min_df=1, top_n=top_n, seed_keywords=seed_keywords
)

assert isinstance(keywords, list)
assert isinstance(keywords[0], tuple)
assert isinstance(keywords[0][0], str)
assert isinstance(keywords[0][1], float)
assert len(keywords) == top_n

# a bacth of docs sharing one single list of seed keywords
top_n = 5
list_of_docs = [doc_one, doc_two]
list_of_seed_keywords = ["time", "night", "day", "moment"]
keywords = model.extract_keywords(
list_of_docs,
min_df=1,
top_n=top_n,
seed_keywords=list_of_seed_keywords
)
print(keywords)

assert isinstance(keywords, list)
assert isinstance(keywords[0], list)
assert isinstance(keywords[0][0], tuple)
assert isinstance(keywords[0][0][0], str)
assert isinstance(keywords[0][0][1], float)
assert len(keywords[0]) == top_n

# a bacth of docs, each of which has its own seed keywords
top_n = 5
list_of_docs = [doc_one, doc_two]
list_of_seed_keywords = [
["time", "night", "day", "moment"],
["hockey", "games", "afternoon", "tv"]
]
keywords = model.extract_keywords(
list_of_docs,
min_df=1,
top_n=top_n,
seed_keywords=list_of_seed_keywords
)
print(keywords)

assert isinstance(keywords, list)
assert isinstance(keywords[0], list)
assert isinstance(keywords[0][0], tuple)
assert isinstance(keywords[0][0][0], str)
assert isinstance(keywords[0][0][1], float)
assert len(keywords[0]) == top_n

def test_empty_doc():
"""Test empty document"""
Expand Down

0 comments on commit f91c502

Please sign in to comment.