diff --git a/README.md b/README.md index 7f327ed0..c7686ea5 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,10 @@ topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=6) ## Overview +BERTopic has quite a number of functions that quickly can become overwhelming. To alleviate this issue, you will find an overview +of all methods and a short description of its purpose. + +### Common For quick access to common functions, here is an overview of BERTopic's main methods: | Method | Code | @@ -208,21 +212,40 @@ For quick access to common functions, here is an overview of BERTopic's main met | Get topic freq | `.get_topic_freq()` | | Get all topic information| `.get_topic_info()` | | Get representative docs per topic | `.get_representative_docs()` | -| Get topics per class | `.topics_per_class(docs, topics, classes)` | -| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | +| Generate topic labels | `.generate_topic_labels()` | +| Set topic labels | `.set_topic_labels(my_custom_labels)` | +| Merge topics | `.merge_topics(docs, topics, topics_to_merge)` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | | Save model | `.save("my_model")` | | Load model | `BERTopic.load("my_model")` | | Get parameters | `.get_params()` | -For an overview of BERTopic's visualization methods: +### Variations +There are many different use cases in which topic modeling can be used. As such, a number of +variations of BERTopic have been developed such that one package can be used across across many use cases: + +| Method | Code | +|-----------------------|---| +| (semi-) Supervised Topic Modeling | `.fit(docs, y=y)` | +| Topic Modeling per Class | `.topics_per_class(docs, topics, classes)` | +| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | +| Hierarchical Topic Modeling | `.hierarchical_topics(docs, topics)` | +| Guided Topic Modeling | `BERTopic(seed_topic_list=seed_topic_list)` | + +### Visualizations +Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation. +Visualizing different aspects of the topic model helps in understanding the model and makes it easier +to tweak the model to your liking. | Method | Code | |-----------------------|---| | Visualize Topics | `.visualize_topics()` | +| Visualize Documents | `.visualize_documents()` | +| Visualize Document Hierarchy | `.visualize_hierarchical_documents()` | | Visualize Topic Hierarchy | `.visualize_hierarchy()` | +| Visualize Topic Tree | `.get_topic_tree(hierarchical_topics)` | | Visualize Topic Terms | `.visualize_barchart()` | | Visualize Topic Similarity | `.visualize_heatmap()` | | Visualize Term Score Decline | `.visualize_term_rank()` | diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index bdcc9ace..7970d657 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -14,8 +14,9 @@ import numpy as np import pandas as pd from tqdm import tqdm -from scipy.sparse.csr import csr_matrix -from typing import List, Tuple, Union, Mapping, Any +from scipy.sparse import csr_matrix +from scipy.cluster import hierarchy as sch +from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable # Models import hdbscan @@ -176,12 +177,13 @@ def __init__(self, cluster_selection_method='eom', prediction_data=True) + # Attributes self.topics = None self.topic_mapper = None self.topic_sizes = None self.merged_topics = None + self.custom_labels = None self.topic_embeddings = None - self.topic_sim_matrix = None self.representative_docs = None self._outliers = 1 @@ -616,6 +618,138 @@ def topics_per_class(self, return topics_per_class + def hierarchical_topics(self, + docs: List[int], + topics: List[int], + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None) -> pd.DataFrame: + """ Create a hierarchy of topics + + To create this hierarchy, BERTopic needs to be already fitted once. + Then, a hierarchy is calculated on the distance matrix of the c-TF-IDF + representation using `scipy.cluster.hierarchy.linkage`. + + Based on that hierarchy, we calculate the topic representation at each + merged step. This is a local representation, as we only assume that the + chosen step is merged and not all others which typically improves the + topic representation. + + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + topics: The topics that were returned when calling either `fit` or `fit_transform` + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + + Returns: + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children + + Usage: + + ```python + from bertopic import BERTopic + topic_model = BERTopic() + topics, probs = topic_model.fit_transform(docs) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + ``` + + A custom linkage function can be used as follows: + + ```python + from scipy.cluster import hierarchy as sch + from bertopic import BERTopic + topic_model = BERTopic() + topics, probs = topic_model.fit_transform(docs) + + # Hierarchical topics + linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics, linkage_function=linkage_function) + ``` + """ + if distance_function is None: + distance_function = lambda x: 1 - cosine_similarity(x) + + if linkage_function is None: + linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) + + # Calculate linkage + embeddings = self.c_tf_idf[self._outliers:] + X = distance_function(embeddings) + Z = linkage_function(X) + + # Calculate basic bag-of-words to be iteratively merged later + documents = pd.DataFrame({"Document": docs, + "ID": range(len(docs)), + "Topic": topics}) + documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) + documents_per_topic = documents_per_topic.loc[documents_per_topic.Topic != -1, :] + documents = self._preprocess_text(documents_per_topic.Document.values) + words = self.vectorizer_model.get_feature_names() + bow = self.vectorizer_model.transform(documents) + + # Extract clusters + hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics", + "Child_Left_ID", "Child_Left_Name", + "Child_Right_ID", "Child_Right_Name"]) + for index in tqdm(range(len(Z))): + + # Find clustered documents + clusters = sch.fcluster(Z, t=Z[index][2], criterion='distance') - self._outliers + cluster_df = pd.DataFrame({"Topic": range(len(clusters)), "Cluster": clusters}) + cluster_df = cluster_df.groupby("Cluster").agg({'Topic': lambda x: list(x)}).reset_index() + nr_clusters = len(clusters) + + # Extract first topic we find to get the set of topics in a merged topic + topic = None + val = Z[index][0] + while topic is None: + if val - len(clusters) < 0: + topic = int(val) + else: + val = Z[int(val - len(clusters))][0] + clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]] + + # Group bow per cluster, calculate c-TF-IDF and extract words + grouped = csr_matrix(bow[clustered_topics].sum(axis=0)) + c_tf_idf = self.transformer.transform(grouped) + words_per_topic = self._extract_words_per_topic(words, c_tf_idf, labels=[0]) + + # Extract parent's name and ID + parent_id = index + len(clusters) + parent_name = "_".join([x[0] for x in words_per_topic[0]][:5]) + + # Extract child's name and ID + Z_id = Z[index][0] + child_left_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters + + if Z_id - nr_clusters < 0: + child_left_name = "_".join([x[0] for x in self.get_topic(Z_id)][:5]) + else: + child_left_name = hier_topics.iloc[int(child_left_id)].Parent_Name + + # Extract child's name and ID + Z_id = Z[index][1] + child_right_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters + + if Z_id - nr_clusters < 0: + child_right_name = "_".join([x[0] for x in self.get_topic(Z_id)][:5]) + else: + child_right_name = hier_topics.iloc[int(child_right_id)].Parent_Name + + # Save results + hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name, + clustered_topics, + int(Z[index][0]), child_left_name, + int(Z[index][1]), child_right_name] + + hier_topics["Distance"] = Z[:, 2] + hier_topics = hier_topics.sort_values("Parent_ID", ascending=False) + hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]] = hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]].astype(str) + + return hier_topics + def find_topics(self, search_term: str, top_n: int = 5) -> Tuple[List[int], List[float]]: @@ -750,7 +884,7 @@ def get_topic(self, topic: int) -> Union[Mapping[str, Tuple[str, float]], bool]: return False def get_topic_info(self, topic: int = None) -> pd.DataFrame: - """ Get information about each topic including its id, frequency, and name + """ Get information about each topic including its ID, frequency, and name. Arguments: topic: A specific topic for which you want the frequency @@ -766,13 +900,18 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: """ check_is_fitted(self) - info = pd.DataFrame(self.topic_sizes.items(), columns=['Topic', 'Count']).sort_values("Count", ascending=False) + info = pd.DataFrame(self.topic_sizes.items(), columns=["Topic", "Count"]).sort_values("Topic") info["Name"] = info.Topic.map(self.topic_names) + if self.custom_labels is not None: + if len(self.custom_labels) == len(info): + labels = {topic - self._outliers: label for topic, label in enumerate(self.custom_labels)} + info["CustomName"] = info["Topic"].map(labels) + if topic: info = info.loc[info.Topic == topic, :] - return info + return info.reset_index(drop=True) def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: """ Return the the size of topics (descending order) @@ -835,6 +974,273 @@ def get_representative_docs(self, topic: int = None) -> List[str]: else: return self.representative_docs + @staticmethod + def get_topic_tree(hier_topics: pd.DataFrame, + max_distance: float = None, + tight_layout: bool = False) -> str: + """ Extract the topic tree such that it can be printed + + Arguments: + hier_topics: A dataframe containing the structure of the topic tree. + This is the output of `topic_model.hierachical_topics()` + max_distance: The maximum distance between two topics. This value is + based on the Distance column in `hier_topics`. + tight_layout: Whether to use a tight layout (narrow width) for + easier readability if you have hundreds of topics. + + Returns: + A tree that has the following structure when printed: + . + . + └─health_medical_disease_patients_hiv + ├─patients_medical_disease_candida_health + │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 + │ └─patients_disease_cancer_medical_doctor + │ ├─■──hiv_medical_cancer_patients_doctor ── Topic: 34 + │ └─■──pain_drug_patients_disease_diet ── Topic: 26 + └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + + The blocks (■) indicate that the topic is one you can directly access + from `topic_model.get_topic`. In other words, they are the original un-grouped topics. + + Usage: + + ```python + # Train model + from bertopic import BERTopic + topic_model = BERTopic() + topics, probs = topic_model.fit_transform(docs) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Print topic tree + tree = topic_model.get_topic_tree(hierarchical_topics) + print(tree) + ``` + """ + width = 1 if tight_layout else 4 + if max_distance is None: + max_distance = hier_topics.Distance.max() + 1 + + max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 + + # Extract mapping from ID to name + topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) + topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) + topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} + + # Create tree + tree = {str(row[1].Parent_ID): [str(row[1].Child_Left_ID), str(row[1].Child_Right_ID)] + for row in hier_topics.iterrows()} + + def get_tree(start, tree): + """ Based on: https://stackoverflow.com/a/51920869/10532563 """ + + def _tree(to_print, start, parent, tree, grandpa=None, indent=""): + + # Get distance between merged topics + distance = hier_topics.loc[(hier_topics.Child_Left_ID == parent) | + (hier_topics.Child_Right_ID == parent), "Distance"] + distance = distance.values[0] if len(distance) > 0 else 10 + + if parent != start: + if grandpa is None: + to_print += topic_to_name[parent] + else: + if int(parent) <= max_original_topic: + + # Do not append topic ID if they are not merged + if distance < max_distance: + to_print += "■──" + topic_to_name[parent] + f" ── Topic: {parent}" + "\n" + else: + to_print += "O \n" + else: + to_print += topic_to_name[parent] + "\n" + + if parent not in tree: + return to_print + + for child in tree[parent][:-1]: + to_print += indent + "├" + "─" + to_print = _tree(to_print, start, child, tree, parent, indent + "│" + " " * width) + + child = tree[parent][-1] + to_print += indent + "└" + "─" + to_print = _tree(to_print, start, child, tree, parent, indent + " " * (width+1)) + + return to_print + + to_print = "." + "\n" + to_print = _tree(to_print, start, start, tree) + return to_print + + start = str(hier_topics.Parent_ID.astype(int).max()) + return get_tree(start, tree) + + def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> None: + """ Set custom topic labels in your fitted BERTopic model + + Arguments: + topic_labels: If a list of topic labels, it should contain the same number + of labels as there are topics. This must be ordered + from the topic with the lowest ID to the highest ID, + including topic -1 if it exists. + If a dictionary of `topic ID`: `topic_label`, it can have + any number of topics as it will only map the topics found + in the dictionary. + + Usage: + + First, we define our topic labels with `.get_topic_labels` in which + we can customize our topic labels: + + ```python + topic_labels = topic_model.get_topic_labels(nr_words=2, + topic_prefix=True, + word_length=10, + separator=", ") + ``` + + Then, we pass these `topic_labels` to our topic model which + can be accessed at any time with `.custom_labels`: + + ```python + topic_model.set_topic_labels(topic_labels) + topic_model.custom_labels + ``` + + You might want to change only a few topic labels instead of all of them. + To do so, you can pass a dictionary where the keys are the topic IDs and + its keys the topic labels: + + ```python + topic_model.set_topic_labels({0: "Space", 1: "Sports", 2: "Medicine"}) + topic_model.custom_labels + ``` + """ + unique_topics = sorted(set(self._map_predictions(self.hdbscan_model.labels_))) + + if isinstance(topic_labels, dict): + if self.custom_labels is not None: + original_labels = {topic: label for topic, label in zip(unique_topics, self.custom_labels)} + else: + info = self.get_topic_info() + original_labels = dict(zip(info.Topic, info.Name)) + custom_labels = [topic_labels.get(topic) if topic_labels.get(topic) else original_labels[topic] for topic in unique_topics] + + elif isinstance(topic_labels, list): + if len(topic_labels) == len(unique_topics): + custom_labels = topic_labels + else: + raise ValueError("Make sure that `topic_labels` contains the same number " + "of labels as that there are topics.") + + self.custom_labels = custom_labels + + def generate_topic_labels(self, + nr_words: int = 3, + topic_prefix: bool = True, + word_length: int = None, + separator: str = "_") -> List[str]: + """ Get labels for each topic in a user-defined format + + Arguments: + original_labels: + nr_words: Top `n` words per topic to use + topic_prefix: Whether to use the topic ID as a prefix. + If set to True, the topic ID will be separated + using the `separator` + word_length: The maximum length of each word in the topic label. + Some words might be relatively long and setting this + value helps to make sure that all labels have relatively + similar lengths. + separator: The string with which the words and topic prefix will be + separated. Underscores are the default but a nice alternative + is `", "`. + + Returns: + topic_labels: A list of topic labels sorted from the lowest topic ID to the highest. + If the topic model was trained using HDBSCAN, the lowest topic ID is -1, + otherwise it is 0. + + Usage: + + To create our custom topic labels, usage is rather straightforward: + + ```python + topic_labels = topic_model.get_topic_labels(nr_words=2, separator=", ") + ``` + """ + unique_topics = sorted(set(self._map_predictions(self.hdbscan_model.labels_))) + topic_labels = [] + for topic in unique_topics: + words, _ = zip(*self.get_topic(topic)) + + if word_length: + words = [word[:word_length] for word in words][:nr_words] + else: + words = list(words)[:nr_words] + + if topic_prefix: + topic_label = f"{topic}{separator}" + separator.join(words) + else: + topic_label = separator.join(words) + + topic_labels.append(topic_label) + + return topic_labels + + def merge_topics(self, + docs: List[str], + topics: List[int], + topics_to_merge: List[Union[Iterable[int], int]]) -> None: + """ + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + topics: The topics that were returned when calling either `fit` or `fit_transform` + topics_to_merge: Either a list of topics or a list of list of topics + to merge. For example: + [1, 2, 3] will merge topics 1, 2 and 3 + [[1, 2], [3, 4]] will merge topics 1 and 2, and + separately merge topics 3 and 4. + + Usage: + + If you want to merge topics 1, 2, and 3: + + ```python + topics_to_merge = [1, 2, 3] + topic_model.merge_topics(docs, topics, topics_to_merge) + ``` + + or if you want to merge topics 1 and 2, and separately + merge topics 3 and 4: + + ```python + topics_to_merge = [[1, 2] + [3, 4]] + topic_model.merge_topics(docs, topics, topics_to_merge) + ``` + """ + check_is_fitted(self) + documents = pd.DataFrame({"Document": docs, "Topic": topics}) + + mapping = {topic: topic for topic in set(topics)} + if isinstance(topics_to_merge[0], int): + for topic in sorted(topics_to_merge): + mapping[topic] = topics_to_merge[0] + elif isinstance(topics_to_merge[0], Iterable): + for topic_group in sorted(topics_to_merge): + for topic in topic_group: + mapping[topic] = topic_group[0] + else: + raise ValueError("Make sure that `topics_to_merge` is either" + "a list of topics or a list of list of topics.") + + documents.Topic = documents.Topic.map(mapping) + documents = self._sort_mappings_by_frequency(documents) + self._extract_topics(documents) + self._update_topic_size(documents) + def reduce_topics(self, docs: List[str], topics: List[int], @@ -930,9 +1336,208 @@ def visualize_topics(self, width=width, height=height) + def visualize_documents(self, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: float = None, + hide_annotations: bool = False, + hide_document_hover: bool = False, + custom_labels: bool = False, + width: int = 1200, + height: int = 750) -> go.Figure: + """ Visualize documents and their topics in 2D + + Arguments: + topic_model: A fitted BERTopic instance. + docs: The documents you used when calling either `fit` or `fit_transform` + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualization. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_documents(docs) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic + topic_model = BERTopic().fit(docs, embeddings) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_documents(docs, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + check_is_fitted(self) + return plotting.visualize_documents(self, + docs=docs, + topics=topics, + embeddings=embeddings, + reduced_embeddings=reduced_embeddings, + sample=sample, + hide_annotations=hide_annotations, + hide_document_hover=hide_document_hover, + custom_labels=custom_labels, + width=width, + height=height) + + def visualize_hierarchical_documents(self, + docs: List[str], + hierarchical_topics: pd.DataFrame, + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: Union[float, int] = None, + hide_annotations: bool = False, + hide_document_hover: bool = True, + nr_levels: int = 10, + custom_labels: bool = False, + width: int = 1200, + height: int = 750) -> go.Figure: + """ Visualize documents and their topics in 2D at different levels of hierarchy + + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualizations. + nr_levels: The number of levels to be visualized in the hierarchy. First, the distances + in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances with + equal length. Then, for each list of distances, the merged topics are selected that + have a distance less or equal to the maximum distance of the selected list of distances. + NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to + the length of `hierarchical_topics`. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic and extract hierarchical topics + topic_model = BERTopic().fit(docs, embeddings) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + check_is_fitted(self) + return plotting.visualize_hierarchical_documents(self, + docs=docs, + hierarchical_topics=hierarchical_topics, + topics=topics, + embeddings=embeddings, + reduced_embeddings=reduced_embeddings, + sample=sample, + hide_annotations=hide_annotations, + hide_document_hover=hide_document_hover, + nr_levels=nr_levels, + custom_labels=custom_labels, + width=width, + height=height) + def visualize_term_rank(self, topics: List[int] = None, log_scale: bool = False, + custom_labels: bool = False, width: int = 800, height: int = 500) -> go.Figure: """ Visualize the ranks of all terms across all topics @@ -946,6 +1551,8 @@ def visualize_term_rank(self, topics: A selection of topics to visualize. These will be colored red where all others will be colored black. log_scale: Whether to represent the ranking on a log scale + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -980,6 +1587,7 @@ def visualize_term_rank(self, return plotting.visualize_term_rank(self, topics=topics, log_scale=log_scale, + custom_labels=custom_labels, width=width, height=height) @@ -988,6 +1596,7 @@ def visualize_topics_over_time(self, top_n_topics: int = None, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 450) -> go.Figure: """ Visualize topics over time @@ -998,6 +1607,8 @@ def visualize_topics_over_time(self, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1026,6 +1637,7 @@ def visualize_topics_over_time(self, top_n_topics=top_n_topics, topics=topics, normalize_frequency=normalize_frequency, + custom_labels=custom_labels, width=width, height=height) @@ -1034,6 +1646,7 @@ def visualize_topics_per_class(self, top_n_topics: int = 10, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 900) -> go.Figure: """ Visualize topics per class @@ -1044,6 +1657,8 @@ def visualize_topics_per_class(self, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1072,12 +1687,14 @@ def visualize_topics_per_class(self, top_n_topics=top_n_topics, topics=topics, normalize_frequency=normalize_frequency, + custom_labels=custom_labels, width=width, height=height) def visualize_distribution(self, probabilities: np.ndarray, min_probability: float = 0.015, + custom_labels: bool = False, width: int = 800, height: int = 600) -> go.Figure: """ Visualize the distribution of topic probabilities @@ -1086,6 +1703,8 @@ def visualize_distribution(self, probabilities: An array of probability scores min_probability: The minimum probability score to visualize. All others are ignored. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1109,6 +1728,7 @@ def visualize_distribution(self, return plotting.visualize_distribution(self, probabilities=probabilities, min_probability=min_probability, + custom_labels=custom_labels, width=width, height=height) @@ -1116,9 +1736,13 @@ def visualize_hierarchy(self, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, + custom_labels: bool = False, width: int = 1000, height: int = 600, - optimal_ordering: bool = False) -> go.Figure: + hierarchical_topics: pd.DataFrame = None, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + color_threshold: int = 1) -> go.Figure: """ Visualize a hierarchical structure of the topics A ward linkage function is used to perform the @@ -1126,17 +1750,33 @@ def visualize_hierarchy(self, matrix between topic embeddings. Arguments: + topic_model: A fitted BERTopic instance. orientation: The orientation of the figure. - Either 'left' or 'bottom' + Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' - optimal_ordering: If True, the linkage matrix will be reordered so that the distance - between successive leaves is minimal. This results in a more intuitive - tree structure when the data are visualized. defaults to False, because - this algorithm can be slow, particularly on large datasets. See - also the `linkage` function fun `scipy`. + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children. + NOTE: The hierarchical topic names are only visualized + if both `topics` and `top_n_topics` are not set. + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + NOTE: Make sure to use the same `linkage_function` as used + in `topic_model.hierarchical_topics`. + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + NOTE: Make sure to use the same `distance_function` as used + in `topic_model.hierarchical_topics`. + color_threshold: Value at which the separation of clusters will be made which + will result in different colors for different clusters. + A higher value will typically lead in less colored clusters. + Returns: fig: A plotly figure @@ -1149,26 +1789,45 @@ def visualize_hierarchy(self, topic_model.visualize_hierarchy() ``` - Or if you want to save the resulting figure: + If you also want the labels visualized of hierarchical topics, + run the following: + + ```python + # Extract hierarchical topics and their representations + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Visualize these representations + topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) + ``` + + If you want to save the resulting figure: ```python fig = topic_model.visualize_hierarchy() fig.write_html("path/to/file.html") ``` + """ check_is_fitted(self) return plotting.visualize_hierarchy(self, orientation=orientation, topics=topics, top_n_topics=top_n_topics, + custom_labels=custom_labels, width=width, height=height, - optimal_ordering=optimal_ordering) + hierarchical_topics=hierarchical_topics, + linkage_function=linkage_function, + distance_function=distance_function, + color_threshold=color_threshold + ) def visualize_heatmap(self, topics: List[int] = None, top_n_topics: int = None, n_clusters: int = None, + custom_labels: bool = False, width: int = 800, height: int = 800) -> go.Figure: """ Visualize a heatmap of the topic's similarity matrix @@ -1181,6 +1840,8 @@ def visualize_heatmap(self, top_n_topics: Only select the top n most frequent topics. n_clusters: Create n clusters and order the similarity matrix by those clusters. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1208,6 +1869,7 @@ def visualize_heatmap(self, topics=topics, top_n_topics=top_n_topics, n_clusters=n_clusters, + custom_labels=custom_labels, width=width, height=height) @@ -1633,8 +2295,6 @@ def _c_tf_idf(self, documents_per_topic: pd.DataFrame, fit: bool = True) -> Tupl c_tf_idf = self.transformer.transform(X) - self.topic_sim_matrix = cosine_similarity(c_tf_idf) - return c_tf_idf, words def _update_topic_size(self, documents: pd.DataFrame): diff --git a/bertopic/backend/_hftransformers.py b/bertopic/backend/_hftransformers.py new file mode 100644 index 00000000..b44ba84b --- /dev/null +++ b/bertopic/backend/_hftransformers.py @@ -0,0 +1,96 @@ +import numpy as np + +from tqdm import tqdm +from typing import List +from torch.utils.data import Dataset +from sklearn.preprocessing import normalize +from transformers.pipelines import Pipeline + +from bertopic.backend import BaseEmbedder + + +class HFTransformerBackend(BaseEmbedder): + """ Hugging Face transformers model + + This uses the `transformers.pipelines.pipeline` to define and create + a feature generation pipeline from which embeddings can be extracted. + + Arguments: + embedding_model: A Hugging Face feature extraction pipeline + + Usage: + + To use a Hugging Face transformers model, load in a pipeline and point + to any model found on their model hub (https://huggingface.co/models): + + ```python + from bertopic.backend import HFTransformerBackend + from transformers.pipelines import pipeline + + hf_model = pipeline("feature-extraction", model="distilbert-base-cased") + embedding_model = HFTransformerBackend(hf_model) + ``` + """ + def __init__(self, embedding_model: Pipeline): + super().__init__() + + if isinstance(embedding_model, Pipeline): + self.embedding_model = embedding_model + else: + raise ValueError("Please select a correct transformers pipeline. For example: " + "pipeline('feature-extraction', model='distilbert-base-cased', device=0)") + + def embed(self, + documents: List[str], + verbose: bool = False) -> np.ndarray: + """ Embed a list of n documents/words into an n-dimensional + matrix of embeddings + + Arguments: + documents: A list of documents or words to be embedded + verbose: Controls the verbosity of the process + + Returns: + Document/words embeddings with shape (n, m) with `n` documents/words + that each have an embeddings size of `m` + """ + dataset = MyDataset(documents) + + embeddings = [] + for document, features in tqdm(zip(documents, self.embedding_model(dataset, truncation=True, padding=True)), + total=len(dataset), disable=not verbose): + embeddings.append(self._embed(document, features)) + + return np.array(embeddings) + + def _embed(self, + document: str, + features: np.ndarray) -> np.ndarray: + """ Mean pooling + + Arguments: + document: The document for which to extract the attention mask + features: The embeddings for each token + + Adopted from: + https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#usage-huggingface-transformers + """ + token_embeddings = np.array(features) + attention_mask = self.embedding_model.tokenizer(document, truncation=True, padding=True, return_tensors="np")["attention_mask"] + input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), token_embeddings.shape) + sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=input_mask_expanded.sum(1).max()) + embedding = normalize(sum_embeddings / sum_mask)[0] + return embedding + + +class MyDataset(Dataset): + """ Dataset to pass to `transformers.pipelines.pipeline` """ + def __init__(self, docs): + self.docs = docs + + def __len__(self): + return len(self.docs) + + def __getitem__(self, idx): + return self.docs[idx] diff --git a/bertopic/backend/_utils.py b/bertopic/backend/_utils.py index a2b2eef9..a4db7860 100644 --- a/bertopic/backend/_utils.py +++ b/bertopic/backend/_utils.py @@ -1,5 +1,7 @@ from ._base import BaseEmbedder from ._sentencetransformers import SentenceTransformerBackend +from ._hftransformers import HFTransformerBackend +from transformers.pipelines import Pipeline languages = ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'assamese', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bengali romanize', @@ -59,6 +61,10 @@ def select_backend(embedding_model, if isinstance(embedding_model, str): return SentenceTransformerBackend(embedding_model) + # Hugging Face embeddings + if isinstance(embedding_model, Pipeline): + return HFTransformerBackend(embedding_model) + # Select embedding model based on language if language: if language.lower() in ["English", "english", "en"]: diff --git a/bertopic/plotting/__init__.py b/bertopic/plotting/__init__.py index 8edccb73..3cc61cc5 100644 --- a/bertopic/plotting/__init__.py +++ b/bertopic/plotting/__init__.py @@ -1,20 +1,24 @@ from ._topics import visualize_topics from ._heatmap import visualize_heatmap from ._barchart import visualize_barchart +from ._documents import visualize_documents from ._term_rank import visualize_term_rank from ._hierarchy import visualize_hierarchy from ._distribution import visualize_distribution from ._topics_over_time import visualize_topics_over_time from ._topics_per_class import visualize_topics_per_class +from ._hierarchical_documents import visualize_hierarchical_documents __all__ = [ "visualize_topics", "visualize_heatmap", "visualize_barchart", + "visualize_documents", "visualize_term_rank", "visualize_hierarchy", "visualize_distribution", "visualize_topics_over_time", - "visualize_topics_per_class" + "visualize_topics_per_class", + "visualize_hierarchical_documents" ] diff --git a/bertopic/plotting/_distribution.py b/bertopic/plotting/_distribution.py index ea2ea0a3..8fe4b67f 100644 --- a/bertopic/plotting/_distribution.py +++ b/bertopic/plotting/_distribution.py @@ -5,6 +5,7 @@ def visualize_distribution(topic_model, probabilities: np.ndarray, min_probability: float = 0.015, + custom_labels: bool = False, width: int = 800, height: int = 600) -> go.Figure: """ Visualize the distribution of topic probabilities @@ -14,6 +15,8 @@ def visualize_distribution(topic_model, probabilities: An array of probability scores min_probability: The minimum probability score to visualize. All others are ignored. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -50,16 +53,19 @@ def visualize_distribution(topic_model, vals = probabilities[labels_idx].tolist() # Create labels - labels = [] - for idx in labels_idx: - words = topic_model.get_topic(idx) - if words: - label = [word[0] for word in words[:5]] - label = f"Topic {idx}: {'_'.join(label)}" - label = label[:40] + "..." if len(label) > 40 else label - labels.append(label) - else: - vals.remove(probabilities[idx]) + if topic_model.custom_labels is not None and custom_labels: + labels = [topic_model.custom_labels[idx + topic_model._outliers] for idx in labels_idx] + else: + labels = [] + for idx in labels_idx: + words = topic_model.get_topic(idx) + if words: + label = [word[0] for word in words[:5]] + label = f"Topic {idx}: {'_'.join(label)}" + label = label[:40] + "..." if len(label) > 40 else label + labels.append(label) + else: + vals.remove(probabilities[idx]) # Create Figure fig = go.Figure(go.Bar( diff --git a/bertopic/plotting/_documents.py b/bertopic/plotting/_documents.py new file mode 100644 index 00000000..3d937ff1 --- /dev/null +++ b/bertopic/plotting/_documents.py @@ -0,0 +1,220 @@ +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from umap import UMAP +from typing import List + + +def visualize_documents(topic_model, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: float = None, + hide_annotations: bool = False, + hide_document_hover: bool = False, + custom_labels: bool = False, + width: int = 1200, + height: int = 750): + """ Visualize documents and their topics in 2D + + Arguments: + topic_model: A fitted BERTopic instance. + docs: The documents you used when calling either `fit` or `fit_transform` + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualization. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_documents(docs) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic + topic_model = BERTopic().fit(docs, embeddings) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_documents(docs, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + topic_per_doc = topic_model._map_predictions(topic_model.hdbscan_model.labels_) + + # Sample the data to optimize for visualization and dimensionality reduction + if sample is None or sample > 1: + sample = 1 + + indices = [] + for topic in set(topic_per_doc): + s = np.where(np.array(topic_per_doc) == topic)[0] + size = len(s) if len(s) < 100 else int(len(s) * sample) + indices.extend(np.random.choice(s, size=size, replace=False)) + indices = np.array(indices) + + df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]}) + df["doc"] = [docs[index] for index in indices] + df["topic"] = [topic_per_doc[index] for index in indices] + + # Extract embeddings if not already done + if sample is None: + if embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + else: + embeddings_to_reduce = embeddings + else: + if embeddings is not None: + embeddings_to_reduce = embeddings[indices] + elif embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + + # Reduce input embeddings + if reduced_embeddings is None: + umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce) + embeddings_2d = umap_model.embedding_ + elif sample is not None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings[indices] + elif sample is None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings + + unique_topics = set(topic_per_doc) + if topics is None: + topics = unique_topics + + # Combine data + df["x"] = embeddings_2d[:, 0] + df["y"] = embeddings_2d[:, 1] + + # Prepare text and names + if topic_model.custom_labels is not None and custom_labels: + names = [topic_model.custom_labels[topic + topic_model._outliers] for topic in unique_topics] + else: + names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics] + + # Visualize + fig = go.Figure() + + # Outliers and non-selected topics + non_selected_topics = set(unique_topics).difference(topics) + if len(non_selected_topics) == 0: + non_selected_topics = [-1] + + selection = df.loc[df.topic.isin(non_selected_topics), :] + selection["text"] = "" + selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), "Other documents"] + + fig.add_trace( + go.Scattergl( + x=selection.x, + y=selection.y, + hovertext=selection.doc if not hide_document_hover else None, + hoverinfo="text", + mode='markers+text', + name="other", + showlegend=False, + marker=dict(color='#CFD8DC', size=5, opacity=0.5) + ) + ) + + # Selected topics + for name, topic in zip(names, unique_topics): + if topic in topics and topic != -1: + selection = df.loc[df.topic == topic, :] + selection["text"] = "" + + if not hide_annotations: + selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name] + + fig.add_trace( + go.Scattergl( + x=selection.x, + y=selection.y, + hovertext=selection.doc if not hide_document_hover else None, + hoverinfo="text", + text=selection.text, + mode='markers+text', + name=name, + textfont=dict( + size=12, + ), + marker=dict(size=5, opacity=0.5) + ) + ) + + # Add grid in a 'plus' shape + x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) + y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) + fig.add_shape(type="line", + x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], + line=dict(color="#CFD8DC", width=2)) + fig.add_shape(type="line", + x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2)) + fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) + fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + + # Stylize layout + fig.update_layout( + template="simple_white", + title={ + 'text': "Documents and Topics", + 'x': 0.5, + 'xanchor': 'center', + 'yanchor': 'top', + 'font': dict( + size=22, + color="Black") + }, + width=width, + height=height + ) + + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + return fig diff --git a/bertopic/plotting/_heatmap.py b/bertopic/plotting/_heatmap.py index 7f51c1e0..fe845550 100644 --- a/bertopic/plotting/_heatmap.py +++ b/bertopic/plotting/_heatmap.py @@ -11,6 +11,7 @@ def visualize_heatmap(topic_model, topics: List[int] = None, top_n_topics: int = None, n_clusters: int = None, + custom_labels: bool = False, width: int = 800, height: int = 800) -> go.Figure: """ Visualize a heatmap of the topic's similarity matrix @@ -24,6 +25,8 @@ def visualize_heatmap(topic_model, top_n_topics: Only select the top n most frequent topics. n_clusters: Create n clusters and order the similarity matrix by those clusters. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -90,10 +93,13 @@ def visualize_heatmap(topic_model, embeddings = embeddings[indices] distance_matrix = cosine_similarity(embeddings) - # Create nicer labels - new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + # Create labels + if topic_model.custom_labels is not None and custom_labels: + new_labels = [topic_model.custom_labels[topic + topic_model._outliers] for topic in sorted_topics] + else: + new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics] + new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] + new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] fig = px.imshow(distance_matrix, labels=dict(color="Similarity Score"), diff --git a/bertopic/plotting/_hierarchical_documents.py b/bertopic/plotting/_hierarchical_documents.py new file mode 100644 index 00000000..3fc55eb8 --- /dev/null +++ b/bertopic/plotting/_hierarchical_documents.py @@ -0,0 +1,317 @@ +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from umap import UMAP +from typing import List, Union + + +def visualize_hierarchical_documents(topic_model, + docs: List[str], + hierarchical_topics: pd.DataFrame, + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: Union[float, int] = None, + hide_annotations: bool = False, + hide_document_hover: bool = True, + nr_levels: int = 10, + custom_labels: bool = False, + width: int = 1200, + height: int = 750) -> go.Figure: + """ Visualize documents and their topics in 2D at different levels of hierarchy + + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualizations. + nr_levels: The number of levels to be visualized in the hierarchy. First, the distances + in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances with + equal length. Then, for each list of distances, the merged topics are selected that + have a distance less or equal to the maximum distance of the selected list of distances. + NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to + the length of `hierarchical_topics`. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic and extract hierarchical topics + topic_model = BERTopic().fit(docs, embeddings) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + NOTE: + This visualization was inspired by the scatter plot representation of Doc2Map: + https://github.com/louisgeisler/Doc2Map + + + """ + topic_per_doc = topic_model._map_predictions(topic_model.hdbscan_model.labels_) + + # Sample the data to optimize for visualization and dimensionality reduction + if sample is None or sample > 1: + sample = 1 + + indices = [] + for topic in set(topic_per_doc): + s = np.where(np.array(topic_per_doc) == topic)[0] + size = len(s) if len(s) < 100 else int(len(s)*sample) + indices.extend(np.random.choice(s, size=size, replace=False)) + indices = np.array(indices) + + df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]}) + df["doc"] = [docs[index] for index in indices] + df["topic"] = [topic_per_doc[index] for index in indices] + + # Extract embeddings if not already done + if sample is None: + if embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + else: + embeddings_to_reduce = embeddings + else: + if embeddings is not None: + embeddings_to_reduce = embeddings[indices] + elif embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + + # Reduce input embeddings + if reduced_embeddings is None: + umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce) + embeddings_2d = umap_model.embedding_ + elif sample is not None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings[indices] + elif sample is None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings + + # Combine data + df["x"] = embeddings_2d[:, 0] + df["y"] = embeddings_2d[:, 1] + + # Create topic list for each level, levels are created by calculating the distance + distances = hierarchical_topics.Distance.to_list() + max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1] + + for index, max_distance in enumerate(max_distances): + + # Get topics below `max_distance` + mapping = {topic: topic for topic in df.topic.unique()} + selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :] + selection.Parent_ID = selection.Parent_ID.astype(int) + selection = selection.sort_values("Parent_ID") + + for row in selection.iterrows(): + for topic in row[1].Topics: + mapping[topic] = row[1].Parent_ID + + # Make sure the mappings are mapped 1:1 + mappings = [True for _ in mapping] + while any(mappings): + for i, (key, value) in enumerate(mapping.items()): + if value in mapping.keys() and key != value: + mapping[key] = mapping[value] + else: + mappings[i] = False + + # Create new column + df[f"level_{index+1}"] = df.topic.map(mapping) + df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int) + + # Prepare topic names of original and merged topics + trace_names = [] + topic_names = {} + for topic in range(hierarchical_topics.Parent_ID.astype(int).max()): + if topic < hierarchical_topics.Parent_ID.astype(int).min(): + if topic_model.get_topic(topic): + if topic_model.custom_labels is not None and custom_labels: + trace_name = topic_model.custom_labels[topic + topic_model._outliers] + else: + trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3]) + topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]} + trace_names.append(trace_name) + else: + trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0] + plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]]) + topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]} + trace_names.append(trace_name) + + # Prepare traces + all_traces = [] + for level in range(len(max_distances)): + traces = [] + + # Outliers + if topic_model._outliers: + traces.append( + go.Scattergl( + x=df.loc[(df[f"level_{level+1}"] == -1), "x"], + y=df.loc[df[f"level_{level+1}"] == -1, "y"], + mode='markers+text', + name="other", + hoverinfo="text", + hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None, + showlegend=False, + marker=dict(color='#CFD8DC', size=5, opacity=0.5) + ) + ) + + # Selected topics + if topics: + selection = df.loc[(df.topic.isin(topics)), :] + unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()]) + else: + unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()]) + + for topic in unique_topics: + if topic != -1: + if topics: + selection = df.loc[(df[f"level_{level+1}"] == topic) & + (df.topic.isin(topics)), :] + else: + selection = df.loc[df[f"level_{level+1}"] == topic, :] + + if not hide_annotations: + selection.loc[len(selection), :] = None + selection["text"] = "" + selection.loc[len(selection) - 1, "x"] = selection.x.mean() + selection.loc[len(selection) - 1, "y"] = selection.y.mean() + selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"] + + traces.append( + go.Scattergl( + x=selection.x, + y=selection.y, + text=selection.text if not hide_annotations else None, + hovertext=selection.doc if not hide_document_hover else None, + hoverinfo="text", + name=topic_names[int(topic)]["trace_name"], + mode='markers+text', + marker=dict(size=5, opacity=0.5) + ) + ) + + all_traces.append(traces) + + # Track and count traces + nr_traces_per_set = [len(traces) for traces in all_traces] + trace_indices = [(0, nr_traces_per_set[0])] + for index, nr_traces in enumerate(nr_traces_per_set[1:]): + start = trace_indices[index][1] + end = nr_traces + start + trace_indices.append((start, end)) + + # Visualization + fig = go.Figure() + for traces in all_traces: + for trace in traces: + fig.add_trace(trace) + + for index in range(len(fig.data)): + if index >= nr_traces_per_set[0]: + fig.data[index].visible = False + + # Create and add slider + steps = [] + for index, indices in enumerate(trace_indices): + step = dict( + method="update", + label=str(index), + args=[{"visible": [False] * len(fig.data)}] + ) + for index in range(indices[1]-indices[0]): + step["args"][0]["visible"][index+indices[0]] = True + steps.append(step) + + sliders = [dict( + currentvalue={"prefix": "Level: "}, + pad={"t": 20}, + steps=steps + )] + + # Add grid in a 'plus' shape + x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) + y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) + fig.add_shape(type="line", + x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], + line=dict(color="#CFD8DC", width=2)) + fig.add_shape(type="line", + x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2)) + fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) + fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + + # Stylize layout + fig.update_layout( + sliders=sliders, + template="simple_white", + title={ + 'text': "Hierarchical Documents and Topics", + 'x': 0.5, + 'xanchor': 'center', + 'yanchor': 'top', + 'font': dict( + size=22, + color="Black") + }, + width=width, + height=height, + ) + + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + return fig diff --git a/bertopic/plotting/_hierarchy.py b/bertopic/plotting/_hierarchy.py index 0d28b150..b1fef674 100644 --- a/bertopic/plotting/_hierarchy.py +++ b/bertopic/plotting/_hierarchy.py @@ -1,6 +1,8 @@ import numpy as np -from scipy.cluster.hierarchy import linkage -from typing import List +import pandas as pd +from typing import Callable, List +from scipy.sparse import csr_matrix +from scipy.cluster import hierarchy as sch from sklearn.metrics.pairwise import cosine_similarity import plotly.graph_objects as go @@ -11,9 +13,13 @@ def visualize_hierarchy(topic_model, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, + custom_labels: bool = False, width: int = 1000, height: int = 600, - optimal_ordering: bool = False) -> go.Figure: + hierarchical_topics: pd.DataFrame = None, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + color_threshold: int = 1) -> go.Figure: """ Visualize a hierarchical structure of the topics A ward linkage function is used to perform the @@ -26,13 +32,28 @@ def visualize_hierarchy(topic_model, Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' - optimal_ordering: If True, the linkage matrix will be reordered so that the distance - between successive leaves is minimal. This results in a more intuitive - tree structure when the data are visualized. defaults to False, because - this algorithm can be slow, particularly on large datasets. See - also the `linkage` function fun `scipy`. + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children. + NOTE: The hierarchical topic names are only visualized + if both `topics` and `top_n_topics` are not set. + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + NOTE: Make sure to use the same `linkage_function` as used + in `topic_model.hierarchical_topics`. + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + NOTE: Make sure to use the same `distance_function` as used + in `topic_model.hierarchical_topics`. + color_threshold: Value at which the separation of clusters will be made which + will result in different colors for different clusters. + A higher value will typically lead in less colored clusters. + Returns: fig: A plotly figure @@ -45,7 +66,18 @@ def visualize_hierarchy(topic_model, topic_model.visualize_hierarchy() ``` - Or if you want to save the resulting figure: + If you also want the labels visualized of hierarchical topics, + run the following: + + ```python + # Extract hierarchical topics and their representations + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Visualize these representations + topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) + ``` + + If you want to save the resulting figure: ```python fig = topic_model.visualize_hierarchy() @@ -54,12 +86,11 @@ def visualize_hierarchy(topic_model, """ + if distance_function is None: + distance_function = lambda x: 1 - cosine_similarity(x) - # Select topic embeddings - if topic_model.topic_embeddings is not None: - embeddings = np.array(topic_model.topic_embeddings) - else: - embeddings = topic_model.c_tf_idf + if linkage_function is None: + linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -74,22 +105,37 @@ def visualize_hierarchy(topic_model, # Select embeddings all_topics = sorted(list(topic_model.get_topics().keys())) indices = np.array([all_topics.index(topic) for topic in topics]) - embeddings = embeddings[indices] + embeddings = topic_model.c_tf_idf[indices] + + # Annotations + if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): + annotations = _get_annotations(topic_model=topic_model, + hierarchical_topics=hierarchical_topics, + embeddings=embeddings, + distance_function=distance_function, + linkage_function=linkage_function, + orientation=orientation, + custom_labels=custom_labels) + else: + annotations = None # Create dendogram - distance_matrix = 1 - cosine_similarity(embeddings) - fig = ff.create_dendrogram(distance_matrix, + fig = ff.create_dendrogram(embeddings, orientation=orientation, - linkagefun=lambda x: linkage(x, "ward", - optimal_ordering=optimal_ordering), - color_threshold=1) + distfun=distance_function, + linkagefun=linkage_function, + hovertext=annotations, + color_threshold=color_threshold) # Create nicer labels axis = "yaxis" if orientation == "left" else "xaxis" - new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) - for x in fig.layout[axis]["ticktext"]] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + if topic_model.custom_labels is not None and custom_labels: + new_labels = [topic_model.custom_labels[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] + else: + new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) + for x in fig.layout[axis]["ticktext"]] + new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] + new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] # Stylize layout fig.update_layout( @@ -128,4 +174,101 @@ def visualize_hierarchy(topic_model, height=height, xaxis=dict(tickmode="array", ticktext=new_labels)) + + if hierarchical_topics is not None: + for index in [0, 3]: + axis = "x" if orientation == "left" else "y" + xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] + ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] + hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] + + fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', + hovertext=hovertext, hoverinfo="text", + mode='markers', showlegend=False)) return fig + + +def _get_annotations(topic_model, + hierarchical_topics: pd.DataFrame, + embeddings: csr_matrix, + linkage_function: Callable[[csr_matrix], np.ndarray], + distance_function: Callable[[csr_matrix], csr_matrix], + orientation: str, + custom_labels: bool = False) -> List[List[str]]: + + """ Get annotations by replicating linkage function calculation in scipy + + Arguments + topic_model: A fitted BERTopic instance. + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children. + NOTE: The hierarchical topic names are only visualized + if both `topics` and `top_n_topics` are not set. + embeddings: The c-TF-IDF matrix on which to model the hierarchy + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + NOTE: Make sure to use the same `linkage_function` as used + in `topic_model.hierarchical_topics`. + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + NOTE: Make sure to use the same `distance_function` as used + in `topic_model.hierarchical_topics`. + orientation: The orientation of the figure. + Either 'left' or 'bottom' + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. + + Returns: + text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` + """ + df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :] + + # Calculate linkage + X = distance_function(embeddings) + Z = linkage_function(X) + P = sch.dendrogram(Z, orientation=orientation, no_plot=True) + + # store topic no.(leaves) corresponding to the x-ticks in dendrogram + x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) + x_topic = dict(zip(P['leaves'], x_ticks)) + + topic_vals = dict() + for key, val in x_topic.items(): + topic_vals[val] = [key] + + parent_topic = dict(zip(df.Parent_ID, df.Topics)) + + # loop through every trace (scatter plot) in dendrogram + text_annotations = [] + for index, trace in enumerate(P['icoord']): + fst_topic = topic_vals[trace[0]] + scnd_topic = topic_vals[trace[2]] + + if len(fst_topic) == 1: + if topic_model.custom_labels is not None and custom_labels: + fst_name = topic_model.custom_labels[fst_topic[0] + topic_model._outliers] + else: + fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) + else: + for key, value in parent_topic.items(): + if set(value) == set(fst_topic): + fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] + + if len(scnd_topic) == 1: + if topic_model.custom_labels is not None and custom_labels: + scnd_name = topic_model.custom_labels[scnd_topic[0] + topic_model._outliers] + else: + scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) + else: + for key, value in parent_topic.items(): + if set(value) == set(scnd_topic): + scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] + + text_annotations.append([fst_name, "", "", scnd_name]) + + center = (trace[0] + trace[2]) / 2 + topic_vals[center] = fst_topic + scnd_topic + + return text_annotations diff --git a/bertopic/plotting/_term_rank.py b/bertopic/plotting/_term_rank.py index ced45ceb..a66aea39 100644 --- a/bertopic/plotting/_term_rank.py +++ b/bertopic/plotting/_term_rank.py @@ -6,6 +6,7 @@ def visualize_term_rank(topic_model, topics: List[int] = None, log_scale: bool = False, + custom_labels: bool = False, width: int = 800, height: int = 500) -> go.Figure: """ Visualize the ranks of all terms across all topics @@ -20,6 +21,8 @@ def visualize_term_rank(topic_model, topics: A selection of topics to visualize. These will be colored red where all others will be colored black. log_scale: Whether to represent the ranking on a log scale + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -69,9 +72,13 @@ def visualize_term_rank(topic_model, lines = [] for topic, x, y in zip(topic_ids, indices, values): if not any(y > 1.5): + # labels - label = f"Topic {topic}:" + "_".join([word[0] for word in topic_model.get_topic(topic)]) - label = label[:50] + if topic_model.custom_labels is not None and custom_labels: + label = topic_model.custom_labels[topic + topic_model._outliers] + else: + label = f"Topic {topic}:" + "_".join([word[0] for word in topic_model.get_topic(topic)]) + label = label[:50] # line parameters color = "red" if topic in topics else "black" diff --git a/bertopic/plotting/_topics_over_time.py b/bertopic/plotting/_topics_over_time.py index abaeeb84..a13282e4 100644 --- a/bertopic/plotting/_topics_over_time.py +++ b/bertopic/plotting/_topics_over_time.py @@ -9,6 +9,7 @@ def visualize_topics_over_time(topic_model, top_n_topics: int = None, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 450) -> go.Figure: """ Visualize topics over time @@ -20,6 +21,8 @@ def visualize_topics_over_time(topic_model, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -55,8 +58,11 @@ def visualize_topics_over_time(topic_model, selected_topics = topic_model.get_topic_freq().Topic.values # Prepare data - topic_names = {key: value[:40] + "..." if len(value) > 40 else value - for key, value in topic_model.topic_names.items()} + if topic_model.custom_labels is not None and custom_labels: + topic_names = {key: topic_model.custom_labels[key + topic_model._outliers] for key, _ in topic_model.topic_names.items()} + else: + topic_names = {key: value[:40] + "..." if len(value) > 40 else value + for key, value in topic_model.topic_names.items()} topics_over_time["Name"] = topics_over_time.Topic.map(topic_names) data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"]) diff --git a/bertopic/plotting/_topics_per_class.py b/bertopic/plotting/_topics_per_class.py index 3f064017..0894ea0c 100644 --- a/bertopic/plotting/_topics_per_class.py +++ b/bertopic/plotting/_topics_per_class.py @@ -9,6 +9,7 @@ def visualize_topics_per_class(topic_model, top_n_topics: int = 10, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 900) -> go.Figure: """ Visualize topics per class @@ -20,6 +21,8 @@ def visualize_topics_per_class(topic_model, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -55,8 +58,11 @@ def visualize_topics_per_class(topic_model, selected_topics = topic_model.get_topic_freq().Topic.values # Prepare data - topic_names = {key: value[:40] + "..." if len(value) > 40 else value - for key, value in topic_model.topic_names.items()} + if topic_model.custom_labels is not None and custom_labels: + topic_names = {key: topic_model.custom_labels[key + topic_model._outliers] for key, _ in topic_model.topic_names.items()} + else: + topic_names = {key: value[:40] + "..." if len(value) > 40 else value + for key, value in topic_model.topic_names.items()} topics_per_class["Name"] = topics_per_class.Topic.map(topic_names) data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :] diff --git a/docs/api/plotting/documents.md b/docs/api/plotting/documents.md new file mode 100644 index 00000000..5edc249d --- /dev/null +++ b/docs/api/plotting/documents.md @@ -0,0 +1,3 @@ +# `Documents` + +::: bertopic.plotting._documents.visualize_documents diff --git a/docs/api/plotting/hierarchical_documents.md b/docs/api/plotting/hierarchical_documents.md new file mode 100644 index 00000000..49216b77 --- /dev/null +++ b/docs/api/plotting/hierarchical_documents.md @@ -0,0 +1,3 @@ +# `Hierarchical Documents` + +::: bertopic.plotting._hierarchical_documents.visualize_hierarchical_documents diff --git a/docs/getting_started/embeddings/embeddings.md b/docs/getting_started/embeddings/embeddings.md index dd487772..62c489e3 100644 --- a/docs/getting_started/embeddings/embeddings.md +++ b/docs/getting_started/embeddings/embeddings.md @@ -24,6 +24,23 @@ topic_model = BERTopic(embedding_model=sentence_model) !!! tip "Tip!" This embedding back-end was put here first for a reason, sentence-transformers works amazing out-of-the-box! Playing around with different models can give you great results. Also, make sure to frequently visit [this](https://www.sbert.net/docs/pretrained_models.html) page as new models are often released. +### 🤗 Hugging Face Transformers +To use a Hugging Face transformers model, load in a pipeline and point +to any model found on their model hub (https://huggingface.co/models): + +```python +from bertopic.backend import HFTransformerBackend +from transformers.pipelines import pipeline + +hf_model = pipeline("feature-extraction", model="distilbert-base-cased") +embedding_model = HFTransformerBackend(hf_model) + +topic_model = BERTopic(embedding_model=document_glove_embeddings) +``` + +!!! tip "Tip!" + These transformers also work quite well using `sentence-transformers` which has a number of + optimizations tricks that make using it a bit faster. ### **Flair** [Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that diff --git a/docs/getting_started/hierarchicaltopics/hierarchical_topics.html b/docs/getting_started/hierarchicaltopics/hierarchical_topics.html new file mode 100644 index 00000000..51ba2364 --- /dev/null +++ b/docs/getting_started/hierarchicaltopics/hierarchical_topics.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md b/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md new file mode 100644 index 00000000..d9bca553 --- /dev/null +++ b/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md @@ -0,0 +1,348 @@ +When tweaking your topic model, the number of topics that are generated has a large effect on the quality of the topic representations. +Some topics could be merged together and having an understanding of the effect will help you understand which topics should and which +should not be merged. + +That is where hierarchical topic modeling comes in. It tries to model the possible hierarchical nature of the topics you have created +in order to understand which topics are similar to each other. Moreover, you will have more insight into sub-topics that might +exist in your data. + +In BERTopic, we can approximate this potential hierarchy by making use of our topic-term matrix (c-TF-IDF matrix). This matrix +contains information about the importance of every word in every topic and makes for a nice numerical representation of our topics. +The smaller the distance between two c-TF-IDF representations, the more similar we assume they are. In practice, this process of merging +topics is done through the hierarchical clustering capabilities of `scipy` (see [here](https://docs.scipy.org/doc/scipy/reference/cluster.hierarchy.html)). +It allows for several linkage methods through which we can approximate our topic hierarchy. As a default, we are using the [ward](https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ward.html#scipy.cluster.hierarchy.ward) but many others are availabe. + +Whenever we merge two topics, we can calculate the c-TF-IDF representation of these two merged by summing their bag-of-words representation. +We assume that two sets of topics are merged and that all others are kept the same, regardless of their location in the hierarchy. This helps +us isolate the potential effect of merging sets of topics. As a result, we can see the topic representation at each level in the tree. + +## **Example** +To demonstrate hierarchical topic modeling with BERTopic, we use the 20 Newsgroups dataset to see how the topics that we uncover are represented in the 20 categories of documents. + +First, we train a basic BERTopic model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))["data"] +topic_model = BERTopic(verbose=True) +topics, probs = topic_model.fit_transform(docs) +``` + +Next, we can use our fitted BERTopic model to extract possible hierarchies from our c-TF-IDF matrix: + +```python +hierarchical_topics = topic_model.hierarchical_topics(docs, topics) +``` + +The resulting `hierarchical_topics` is a dataframe in which merged topics are described. For example, if you would +merge two topics, what would the topic representation of the new topic be? + +### **Visualizations** +To visualize these results, we can start by running a familiar function, namely `topic_model.visualize_hierarchy`: + +```python +topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) +``` + + +If you **hover** over the black circles, you will see the topic representation at that level of the hierarchy. These representations +help you understand the effect of merging certain topics together. Some might be logical to merge whilst others might not. Moreover, +we can now see which sub-topics can be found within certain larger themes. + +Although this gives a nice overview of the potential hierarchy, hovering over all black circles can be tiresome. Instead, we can +use `topic_model.get_topic_tree` to create a text-based representation of this hierarchy. Although the general structure is more difficult +to view, we can see better which topics could be logically merged: + +```python +>>> tree = topic_model.get_topic_tree(hierarchical_topics) +>>> print(tree) +. +└─atheists_atheism_god_moral_atheist + ├─atheists_atheism_god_atheist_argument + │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ └─■──br_god_exist_genetic_existence ── Topic: 124 + └─■──moral_morality_objective_immoral_morals ── Topic: 29 +``` + +
+ Click here to view the full tree. + + ```bash + . + ├─people_armenian_said_god_armenians + │ ├─god_jesus_jehovah_lord_christ + │ │ ├─god_jesus_jehovah_lord_christ + │ │ │ ├─jehovah_lord_mormon_mcconkie_god + │ │ │ │ ├─■──ra_satan_thou_god_lucifer ── Topic: 94 + │ │ │ │ └─■──jehovah_lord_mormon_mcconkie_unto ── Topic: 78 + │ │ │ └─jesus_mary_god_hell_sin + │ │ │ ├─jesus_hell_god_eternal_heaven + │ │ │ │ ├─hell_jesus_eternal_god_heaven + │ │ │ │ │ ├─■──jesus_tomb_disciples_resurrection_john ── Topic: 69 + │ │ │ │ │ └─■──hell_eternal_god_jesus_heaven ── Topic: 53 + │ │ │ │ └─■──aaron_baptism_sin_law_god ── Topic: 89 + │ │ │ └─■──mary_sin_maria_priest_conception ── Topic: 56 + │ │ └─■──marriage_married_marry_ceremony_marriages ── Topic: 110 + │ └─people_armenian_armenians_said_mr + │ ├─people_armenian_armenians_said_israel + │ │ ├─god_homosexual_homosexuality_atheists_sex + │ │ │ ├─homosexual_homosexuality_sex_gay_homosexuals + │ │ │ │ ├─■──kinsey_sex_gay_men_sexual ── Topic: 44 + │ │ │ │ └─homosexuality_homosexual_sin_homosexuals_gay + │ │ │ │ ├─■──gay_homosexual_homosexuals_sexual_cramer ── Topic: 50 + │ │ │ │ └─■──homosexuality_homosexual_sin_paul_sex ── Topic: 27 + │ │ │ └─god_atheists_atheism_moral_atheist + │ │ │ ├─islam_quran_judas_islamic_book + │ │ │ │ ├─■──jim_context_challenges_articles_quote ── Topic: 36 + │ │ │ │ └─islam_quran_judas_islamic_book + │ │ │ │ ├─■──islam_quran_islamic_rushdie_muslims ── Topic: 31 + │ │ │ │ └─■──judas_scripture_bible_books_greek ── Topic: 33 + │ │ │ └─atheists_atheism_god_moral_atheist + │ │ │ ├─atheists_atheism_god_atheist_argument + │ │ │ │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ │ │ │ └─■──br_god_exist_genetic_existence ── Topic: 124 + │ │ │ └─■──moral_morality_objective_immoral_morals ── Topic: 29 + │ │ └─armenian_armenians_people_israel_said + │ │ ├─armenian_armenians_israel_people_jews + │ │ │ ├─tax_rights_government_income_taxes + │ │ │ │ ├─■──rights_right_slavery_slaves_residence ── Topic: 106 + │ │ │ │ └─tax_government_taxes_income_libertarians + │ │ │ │ ├─■──government_libertarians_libertarian_regulation_party ── Topic: 58 + │ │ │ │ └─■──tax_taxes_income_billion_deficit ── Topic: 41 + │ │ │ └─armenian_armenians_israel_people_jews + │ │ │ ├─gun_guns_militia_firearms_amendment + │ │ │ │ ├─■──blacks_penalty_death_cruel_punishment ── Topic: 55 + │ │ │ │ └─■──gun_guns_militia_firearms_amendment ── Topic: 7 + │ │ │ └─armenian_armenians_israel_jews_turkish + │ │ │ ├─■──israel_israeli_jews_arab_jewish ── Topic: 4 + │ │ │ └─■──armenian_armenians_turkish_armenia_azerbaijan ── Topic: 15 + │ │ └─stephanopoulos_president_mr_myers_ms + │ │ ├─■──serbs_muslims_stephanopoulos_mr_bosnia ── Topic: 35 + │ │ └─■──myers_stephanopoulos_president_ms_mr ── Topic: 87 + │ └─batf_fbi_koresh_compound_gas + │ ├─■──reno_workers_janet_clinton_waco ── Topic: 77 + │ └─batf_fbi_koresh_gas_compound + │ ├─batf_koresh_fbi_warrant_compound + │ │ ├─■──batf_warrant_raid_compound_fbi ── Topic: 42 + │ │ └─■──koresh_batf_fbi_children_compound ── Topic: 61 + │ └─■──fbi_gas_tear_bds_building ── Topic: 23 + └─use_like_just_dont_new + ├─game_team_year_games_like + │ ├─game_team_games_25_year + │ │ ├─game_team_games_25_season + │ │ │ ├─window_printer_use_problem_mhz + │ │ │ │ ├─mhz_wire_simms_wiring_battery + │ │ │ │ │ ├─simms_mhz_battery_cpu_heat + │ │ │ │ │ │ ├─simms_pds_simm_vram_lc + │ │ │ │ │ │ │ ├─■──pds_nubus_lc_slot_card ── Topic: 119 + │ │ │ │ │ │ │ └─■──simms_simm_vram_meg_dram ── Topic: 32 + │ │ │ │ │ │ └─mhz_battery_cpu_heat_speed + │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ │ ├─■──fan_cpu_heat_sink_fans ── Topic: 92 + │ │ │ │ │ │ │ │ └─■──mhz_speed_cpu_fpu_clock ── Topic: 22 + │ │ │ │ │ │ │ └─■──monitor_turn_power_computer_electricity ── Topic: 91 + │ │ │ │ │ │ └─battery_batteries_concrete_duo_discharge + │ │ │ │ │ │ ├─■──duo_battery_apple_230_problem ── Topic: 121 + │ │ │ │ │ │ └─■──battery_batteries_concrete_discharge_temperature ── Topic: 75 + │ │ │ │ │ └─wire_wiring_ground_neutral_outlets + │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ │ ├─■──leds_uv_blue_light_boards ── Topic: 66 + │ │ │ │ │ │ │ └─■──wire_wiring_ground_neutral_outlets ── Topic: 120 + │ │ │ │ │ │ └─scope_scopes_phone_dial_number + │ │ │ │ │ │ ├─■──dial_number_phone_line_output ── Topic: 93 + │ │ │ │ │ │ └─■──scope_scopes_motorola_generator_oscilloscope ── Topic: 113 + │ │ │ │ │ └─celp_dsp_sampling_antenna_digital + │ │ │ │ │ ├─■──antenna_antennas_receiver_cable_transmitter ── Topic: 70 + │ │ │ │ │ └─■──celp_dsp_sampling_speech_voice ── Topic: 52 + │ │ │ │ └─window_printer_xv_mouse_windows + │ │ │ │ ├─window_xv_error_widget_problem + │ │ │ │ │ ├─error_symbol_undefined_xterm_rx + │ │ │ │ │ │ ├─■──symbol_error_undefined_doug_parse ── Topic: 63 + │ │ │ │ │ │ └─■──rx_remote_server_xdm_xterm ── Topic: 45 + │ │ │ │ │ └─window_xv_widget_application_expose + │ │ │ │ │ ├─window_widget_expose_application_event + │ │ │ │ │ │ ├─■──gc_mydisplay_draw_gxxor_drawing ── Topic: 103 + │ │ │ │ │ │ └─■──window_widget_application_expose_event ── Topic: 25 + │ │ │ │ │ └─xv_den_polygon_points_algorithm + │ │ │ │ │ ├─■──den_polygon_points_algorithm_polygons ── Topic: 28 + │ │ │ │ │ └─■──xv_24bit_image_bit_images ── Topic: 57 + │ │ │ │ └─printer_fonts_print_mouse_postscript + │ │ │ │ ├─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──scanner_logitech_grayscale_ocr_scanman ── Topic: 108 + │ │ │ │ │ └─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──printer_print_deskjet_hp_ink ── Topic: 18 + │ │ │ │ │ └─■──fonts_font_truetype_tt_atm ── Topic: 49 + │ │ │ │ └─mouse_ghostscript_midi_driver_postscript + │ │ │ │ ├─ghostscript_midi_postscript_files_file + │ │ │ │ │ ├─■──ghostscript_postscript_pageview_ghostview_dsc ── Topic: 104 + │ │ │ │ │ └─midi_sound_file_windows_driver + │ │ │ │ │ ├─■──location_mar_file_host_rwrr ── Topic: 83 + │ │ │ │ │ └─■──midi_sound_driver_blaster_soundblaster ── Topic: 98 + │ │ │ │ └─■──mouse_driver_mice_ball_problem ── Topic: 68 + │ │ │ └─game_team_games_25_season + │ │ │ ├─1st_sale_condition_comics_hulk + │ │ │ │ ├─sale_condition_offer_asking_cd + │ │ │ │ │ ├─condition_stereo_amp_speakers_asking + │ │ │ │ │ │ ├─■──miles_car_amfm_toyota_cassette ── Topic: 62 + │ │ │ │ │ │ └─■──amp_speakers_condition_stereo_audio ── Topic: 24 + │ │ │ │ │ └─games_sale_pom_cds_shipping + │ │ │ │ │ ├─pom_cds_sale_shipping_cd + │ │ │ │ │ │ ├─■──size_shipping_sale_condition_mattress ── Topic: 100 + │ │ │ │ │ │ └─■──pom_cds_cd_sale_picture ── Topic: 37 + │ │ │ │ │ └─■──games_game_snes_sega_genesis ── Topic: 40 + │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ ├─1st_hulk_comics_art_appears + │ │ │ │ │ ├─lens_tape_camera_backup_lenses + │ │ │ │ │ │ ├─■──tape_backup_tapes_drive_4mm ── Topic: 107 + │ │ │ │ │ │ └─■──lens_camera_lenses_zoom_pouch ── Topic: 114 + │ │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ │ ├─■──1st_hulk_comics_art_appears ── Topic: 105 + │ │ │ │ │ └─■──books_book_cover_trek_chemistry ── Topic: 125 + │ │ │ │ └─tickets_hotel_ticket_voucher_package + │ │ │ │ ├─■──hotel_voucher_package_vacation_room ── Topic: 74 + │ │ │ │ └─■──tickets_ticket_june_airlines_july ── Topic: 84 + │ │ │ └─game_team_games_season_hockey + │ │ │ ├─game_hockey_team_25_550 + │ │ │ │ ├─■──espn_pt_pts_game_la ── Topic: 17 + │ │ │ │ └─■──team_25_game_hockey_550 ── Topic: 2 + │ │ │ └─■──year_game_hit_baseball_players ── Topic: 0 + │ │ └─bike_car_greek_insurance_msg + │ │ ├─car_bike_insurance_cars_engine + │ │ │ ├─car_insurance_cars_radar_engine + │ │ │ │ ├─insurance_health_private_care_canada + │ │ │ │ │ ├─■──insurance_health_private_care_canada ── Topic: 99 + │ │ │ │ │ └─■──insurance_car_accident_rates_sue ── Topic: 82 + │ │ │ │ └─car_cars_radar_engine_detector + │ │ │ │ ├─car_radar_cars_detector_engine + │ │ │ │ │ ├─■──radar_detector_detectors_ka_alarm ── Topic: 39 + │ │ │ │ │ └─car_cars_mustang_ford_engine + │ │ │ │ │ ├─■──clutch_shift_shifting_transmission_gear ── Topic: 88 + │ │ │ │ │ └─■──car_cars_mustang_ford_v8 ── Topic: 14 + │ │ │ │ └─oil_diesel_odometer_diesels_car + │ │ │ │ ├─odometer_oil_sensor_car_drain + │ │ │ │ │ ├─■──odometer_sensor_speedo_gauge_mileage ── Topic: 96 + │ │ │ │ │ └─■──oil_drain_car_leaks_taillights ── Topic: 102 + │ │ │ │ └─■──diesel_diesels_emissions_fuel_oil ── Topic: 79 + │ │ │ └─bike_riding_ride_bikes_motorcycle + │ │ │ ├─bike_ride_riding_bikes_lane + │ │ │ │ ├─■──bike_ride_riding_lane_car ── Topic: 11 + │ │ │ │ └─■──bike_bikes_miles_honda_motorcycle ── Topic: 19 + │ │ │ └─■──countersteering_bike_motorcycle_rear_shaft ── Topic: 46 + │ │ └─greek_msg_kuwait_greece_water + │ │ ├─greek_msg_kuwait_greece_water + │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ │ ├─greek_kuwait_greece_turkish_greeks + │ │ │ │ │ │ ├─■──greek_greece_turkish_greeks_cyprus ── Topic: 71 + │ │ │ │ │ │ └─■──kuwait_iraq_iran_gulf_arabia ── Topic: 76 + │ │ │ │ │ └─msg_dog_drugs_drug_food + │ │ │ │ │ ├─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──clinton_bush_quayle_reagan_panicking ── Topic: 101 + │ │ │ │ │ │ └─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──cooper_trial_weaver_spence_witnesses ── Topic: 90 + │ │ │ │ │ │ └─■──dog_dogs_bike_trained_springer ── Topic: 67 + │ │ │ │ │ └─msg_drugs_drug_food_chinese + │ │ │ │ │ ├─■──msg_food_chinese_foods_taste ── Topic: 30 + │ │ │ │ │ └─■──drugs_drug_marijuana_cocaine_alcohol ── Topic: 72 + │ │ │ │ └─water_theory_universe_science_larsons + │ │ │ │ ├─water_nuclear_cooling_steam_dept + │ │ │ │ │ ├─■──rocketry_rockets_engines_nuclear_plutonium ── Topic: 115 + │ │ │ │ │ └─water_cooling_steam_dept_plants + │ │ │ │ │ ├─■──water_dept_phd_environmental_atmospheric ── Topic: 97 + │ │ │ │ │ └─■──cooling_water_steam_towers_plants ── Topic: 109 + │ │ │ │ └─theory_universe_larsons_larson_science + │ │ │ │ ├─■──theory_universe_larsons_larson_science ── Topic: 54 + │ │ │ │ └─■──oort_cloud_grbs_gamma_burst ── Topic: 80 + │ │ │ └─helmet_kirlian_photography_lock_wax + │ │ │ ├─helmet_kirlian_photography_leaf_mask + │ │ │ │ ├─kirlian_photography_leaf_pictures_deleted + │ │ │ │ │ ├─deleted_joke_stuff_maddi_nickname + │ │ │ │ │ │ ├─■──joke_maddi_nickname_nicknames_frank ── Topic: 43 + │ │ │ │ │ │ └─■──deleted_stuff_bookstore_joke_motto ── Topic: 81 + │ │ │ │ │ └─■──kirlian_photography_leaf_pictures_aura ── Topic: 85 + │ │ │ │ └─helmet_mask_liner_foam_cb + │ │ │ │ ├─■──helmet_liner_foam_cb_helmets ── Topic: 112 + │ │ │ │ └─■──mask_goalies_77_santore_tl ── Topic: 123 + │ │ │ └─lock_wax_paint_plastic_ear + │ │ │ ├─■──lock_cable_locks_bike_600 ── Topic: 117 + │ │ │ └─wax_paint_ear_plastic_skin + │ │ │ ├─■──wax_paint_plastic_scratches_solvent ── Topic: 65 + │ │ │ └─■──ear_wax_skin_greasy_acne ── Topic: 116 + │ │ └─m4_mp_14_mw_mo + │ │ ├─m4_mp_14_mw_mo + │ │ │ ├─■──m4_mp_14_mw_mo ── Topic: 111 + │ │ │ └─■──test_ensign_nameless_deane_deanebinahccbrandeisedu ── Topic: 118 + │ │ └─■──ites_cheek_hello_hi_ken ── Topic: 3 + │ └─space_medical_health_disease_cancer + │ ├─medical_health_disease_cancer_patients + │ │ ├─■──cancer_centers_center_medical_research ── Topic: 122 + │ │ └─health_medical_disease_patients_hiv + │ │ ├─patients_medical_disease_candida_health + │ │ │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 + │ │ │ └─patients_disease_cancer_medical_doctor + │ │ │ ├─■──hiv_medical_cancer_patients_doctor ── Topic: 34 + │ │ │ └─■──pain_drug_patients_disease_diet ── Topic: 26 + │ │ └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + │ └─space_launch_nasa_shuttle_orbit + │ ├─space_moon_station_nasa_launch + │ │ ├─■──sky_advertising_billboard_billboards_space ── Topic: 59 + │ │ └─■──space_station_moon_redesign_nasa ── Topic: 16 + │ └─space_mission_hst_launch_orbit + │ ├─space_launch_nasa_orbit_propulsion + │ │ ├─■──space_launch_nasa_propulsion_astronaut ── Topic: 47 + │ │ └─■──orbit_km_jupiter_probe_earth ── Topic: 86 + │ └─■──hst_mission_shuttle_orbit_arrays ── Topic: 60 + └─drive_file_key_windows_use + ├─key_file_jpeg_encryption_image + │ ├─key_encryption_clipper_chip_keys + │ │ ├─■──key_clipper_encryption_chip_keys ── Topic: 1 + │ │ └─■──entry_file_ripem_entries_key ── Topic: 73 + │ └─jpeg_image_file_gif_images + │ ├─motif_graphics_ftp_available_3d + │ │ ├─motif_graphics_openwindows_ftp_available + │ │ │ ├─■──openwindows_motif_xview_windows_mouse ── Topic: 20 + │ │ │ └─■──graphics_widget_ray_3d_available ── Topic: 95 + │ │ └─■──3d_machines_version_comments_contact ── Topic: 38 + │ └─jpeg_image_gif_images_format + │ ├─■──gopher_ftp_files_stuffit_images ── Topic: 51 + │ └─■──jpeg_image_gif_format_images ── Topic: 13 + └─drive_db_card_scsi_windows + ├─db_windows_dos_mov_os2 + │ ├─■──copy_protection_program_software_disk ── Topic: 64 + │ └─■──db_windows_dos_mov_os2 ── Topic: 8 + └─drive_card_scsi_drives_ide + ├─drive_scsi_drives_ide_disk + │ ├─■──drive_scsi_drives_ide_disk ── Topic: 6 + │ └─■──meg_sale_ram_drive_shipping ── Topic: 12 + └─card_modem_monitor_video_drivers + ├─■──card_monitor_video_drivers_vga ── Topic: 5 + └─■──modem_port_serial_irq_com ── Topic: 10 + ``` +
+ + +### **Merge topics** + +After seeing the potential hierarchy of your topic, you might want to merge specific +topics. For example, if topic 1 is +`1_space_launch_moon_nasa` and topic 2 is `2_spacecraft_solar_space_orbit` it might +make sense to merge those two topics as they are quite similar in meaning. In BERTopic, +you can use `.merge_topics` to manually select and merge those topics. Doing so will +update their topic representation which in turn updates the entire model: + +```python +topics_to_merge = [1, 2] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` + +If you have several groups of topics you want to merge, create a list of lists instead: + +```python +topics_to_merge = [[1, 2] + [3, 4]] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` diff --git a/docs/getting_started/tips_and_tricks/skateboarders.jpg b/docs/getting_started/tips_and_tricks/skateboarders.jpg new file mode 100644 index 00000000..7144906a Binary files /dev/null and b/docs/getting_started/tips_and_tricks/skateboarders.jpg differ diff --git a/docs/getting_started/tips_and_tricks/tips_and_tricks.md b/docs/getting_started/tips_and_tricks/tips_and_tricks.md index b76fbbe6..b75f4409 100644 --- a/docs/getting_started/tips_and_tricks/tips_and_tricks.md +++ b/docs/getting_started/tips_and_tricks/tips_and_tricks.md @@ -150,4 +150,204 @@ force a cosine-related distance metric in UMAP: ```python from cuml.preprocessing import normalize embeddings = normalize(embeddings) -``` \ No newline at end of file +``` + +## **Finding similar topics between models** + +Whenever you have trained seperate BERTopic models on different datasets, it might +be worthful to find the similarities among these models. Is there overlap between +topics in model A and topic in model B? In other words, can we find topics in model A that are similar to those in model B? + +We can compare the topic representations of several models in two ways. First, by comparing the topic embeddings that are created when using the same embedding model across both fitted BERTopic instances. Second, we can compare the c-TF-IDF representations instead assuming we have fixed the vocabulary in both instances. + +This example will go into the former, using the same embedding model across two BERTopic instances. To do this comparison, let's first create an example where I trained two models, one on an English dataset and one on a Dutch dataset: + +```python +from datasets import load_dataset +from bertopic import BERTopic +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# The same embedding model needs to be used for both topic models +# and since we are dealing with multiple languages, the model needs to be multi-lingual +sentence_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2") + +# To make this example reproducible +umap_model = UMAP(n_neighbors=15, n_components=5, + min_dist=0.0, metric='cosine', random_state=42) + +# English +en_dataset = load_dataset("stsb_multi_mt", name="en", split="train").to_pandas().sentence1.tolist() +en_model = BERTopic(embedding_model=sentence_model, umap_model=umap_model) +en_model.fit(en_dataset) + +# Dutch +nl_dataset = load_dataset("stsb_multi_mt", name="nl", split="train").to_pandas().sentence1.tolist() +nl_model = BERTopic(embedding_model=sentence_model, umap_model=umap_model) +nl_model.fit(nl_dataset) +``` + +In the code above, there is one important thing to note and that is the `sentence_model`. This model needs to be exactly the same in all BERTopic models, otherwise, it is not possible to compare topic models. + +Next, we can calculate the similarity between topics in the English topic model `en_model` and the Dutch model `nl_model`. To do so, we can simply calculate the cosine similarity between the `topic_embedding` of both models: + +```python +from sklearn.metrics.pairwise import cosine_similarity +sim_matrix = cosine_similarity(en_model.topic_embeddings, nl_model.topic_embeddings) +``` + +Now that we know which topics are similar to each other, we can extract the most similar topics. Let's say that we have topic 10 in the `en_model` which represents a topic related to trains: + +```python +>>> topic = 10 +>>> en_model.get_topic(topic) +[('train', 0.2588080580844999), + ('tracks', 0.1392140438801078), + ('station', 0.12126454635946024), + ('passenger', 0.058057876475695866), + ('engine', 0.05123717127783682), + ('railroad', 0.048142847325312044), + ('waiting', 0.04098973702226946), + ('track', 0.03978248702913929), + ('subway', 0.03834661195748458), + ('steam', 0.03834661195748458)] +``` + +To find the matching topic, we extract the most similar topic in the `sim_matrix`: + +```python +>>> most_similar_topic = np.argmax(sim_matrix[topic + 1])-1 +>>> nl_model.get_topic(most_similar_topic) +[('trein', 0.24186603209316418), + ('spoor', 0.1338118418551581), + ('sporen', 0.07683661859111401), + ('station', 0.056990389779394225), + ('stoommachine', 0.04905829711711234), + ('zilveren', 0.04083879598477808), + ('treinen', 0.03534099197032758), + ('treinsporen', 0.03534099197032758), + ('staat', 0.03481332997324445), + ('zwarte', 0.03179591746822408)] +``` + +It seems to be working as, for example, `trein` is a translation of `train` and `sporen` a translation of `tracks`! You can do this for every single topic to find out which topic in the `en_model` might belong to a model in the `nl_model`. + +## **Multi-model data** +[Concept](https://github.com/MaartenGr/Concept) is a variation +of BERTopic for multi-modal data, such as images with captions. Although we can use that +package for multi-modal data, we can perform a small trick with BERTopic to have a similar feature. + +BERTopic is a relatively modular approach that attempts to isolate steps from one another. This means, +for example, that you can use k-Means instead of HDBSCAN or PCA instead of UMAP as it does not make +any assumptions with respect to the nature of the clustering. + +Similarly, you can pass pre-calculated embeddings to BERTopic that represent the documents that you have. +However, it does not make any assumption with respect to the relationship between those embeddings and +the documents. This means that we could pass any metadata to BERTopic to cluster on instead of document +embeddings. In this example, we can separate our embeddings from our documents so that the embeddings +are generated from images instead of their corresponding images. Thus, we will cluster image embeddings but +create the topic representation from the related captions. + +In this example, we first need to fetch our data, namely the Flickr 8k dataset that contains images +with captions: + +```python +import os +import glob +import zipfile +import numpy as np +import pandas as pd +from tqdm import tqdm +from PIL import Image +from sentence_transformers import SentenceTransformer, util + +# Flickr 8k images +img_folder = 'photos/' +caps_folder = 'captions/' +if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: + os.makedirs(img_folder, exist_ok=True) + + if not os.path.exists('Flickr8k_Dataset.zip'): #Download dataset if does not exist + util.http_get('https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip', 'Flickr8k_Dataset.zip') + util.http_get('https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip', 'Flickr8k_text.zip') + + for folder, file in [(img_folder, 'Flickr8k_Dataset.zip'), (caps_folder, 'Flickr8k_text.zip')]: + with zipfile.ZipFile(file, 'r') as zf: + for member in tqdm(zf.infolist(), desc='Extracting'): + zf.extract(member, folder) +images = list(glob.glob('photos/Flicker8k_Dataset/*.jpg')) + +# Prepare dataframe +captions = pd.read_csv("captions/Flickr8k.lemma.token.txt",sep='\t',names=["img_id","img_caption"]) +captions.img_id = captions.apply(lambda row: "photos/Flicker8k_Dataset/" + row.img_id.split(".jpg")[0] + ".jpg", 1) +captions = captions.groupby(["img_id"])["img_caption"].apply(','.join).reset_index() +captions = pd.merge(captions, pd.Series(images, name="img_id"), on="img_id") + +# Extract images together with their documents/captions +images = captions.img_id.to_list() +docs = captions.img_caption.to_list() +``` + +Now that we have our images and captions, we need to generate our image embeddings: + +```python +model = SentenceTransformer('clip-ViT-B-32') + +# Prepare images +batch_size = 32 +nr_iterations = int(np.ceil(len(images) / batch_size)) + +# Embed images per batch +embeddings = [] +for i in tqdm(range(nr_iterations)): + start_index = i * batch_size + end_index = (i * batch_size) + batch_size + + images_to_embed = [Image.open(filepath) for filepath in images[start_index:end_index]] + img_emb = model.encode(images_to_embed, show_progress_bar=False) + embeddings.extend(img_emb.tolist()) + + # Close images + for image in images_to_embed: + image.close() +embeddings = np.array(embeddings) +``` + +Finally, we can fit BERTopic the way we are used to, with documents and embeddings: + +```python +from bertopic import BERTopic +from sklearn.cluster import KMeans +from sklearn.feature_extraction.text import CountVectorizer + +vectorizer_model = CountVectorizer(stop_words="english") +topic_model = BERTopic(vectorizer_model=vectorizer_model) +topics, probs = topic_model.fit_transform(docs, embeddings) +captions["Topic"] = topics +``` + +After fitting our model, let's inspect a topic about skateboarders: + +```python +>>> topic_model.get_topic(2) +[('skateboard', 0.09592033177340711), + ('skateboarder', 0.07792520092546491), + ('trick', 0.07481578896400298), + ('ramp', 0.056952605147927216), + ('skate', 0.03745127816149923), + ('perform', 0.036546213623432654), + ('bicycle', 0.03453483070441857), + ('bike', 0.033233021253898994), + ('jump', 0.026709362981948037), + ('air', 0.025422798170830936)] +``` + +Based on the above output, we can take an image to see if the representation makes sense: + +```python +image = captions.loc[captions.Topic == 2, "img_id"].values.tolist()[0] +Image.open(image) +``` + +![](skateboarders.jpg) \ No newline at end of file diff --git a/docs/getting_started/topicreduction/topicreduction.md b/docs/getting_started/topicreduction/topicreduction.md index f3bb7c21..2c7262d0 100644 --- a/docs/getting_started/topicreduction/topicreduction.md +++ b/docs/getting_started/topicreduction/topicreduction.md @@ -6,14 +6,30 @@ so. ### **Manual Topic Reduction** Each resulting topic has its own feature vector constructed from c-TF-IDF. Using those feature vectors, we can find the most similar -topics and merge them. If we do this iteratively, starting from the least frequent topic, we can reduce the number -of topics quite easily. We do this until we reach the value of `nr_topics`: +topics and merge them. If we do this iteratively, starting from the least frequent topic, we can reduce the number of topics quite easily. We do this until we reach the value of `nr_topics`: ```python from bertopic import BERTopic topic_model = BERTopic(nr_topics=20) ``` +It is also possible to manually select certain topics that you believe should be merged. +For example, if topic 1 is `1_space_launch_moon_nasa` and topic 2 is `2_spacecraft_solar_space_orbit` +it might make sense to merge those two topics: + +```python +topics_to_merge = [1, 2] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` + +If you have several groups of topics you want to merge, create a list of lists instead: + +```python +topics_to_merge = [[1, 2] + [3, 4]] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` + ### **Automatic Topic Reduction** One issue with the approach above is that it will merge topics regardless of whether they are very similar. They are simply the most similar out of all options. This can be resolved by reducing the number of topics automatically. diff --git a/docs/getting_started/topicrepresentation/topicrepresentation.md b/docs/getting_started/topicrepresentation/topicrepresentation.md index 6f5d3c12..1c1e9a03 100644 --- a/docs/getting_started/topicrepresentation/topicrepresentation.md +++ b/docs/getting_started/topicrepresentation/topicrepresentation.md @@ -61,4 +61,66 @@ instead: from sklearn.feature_extraction.text import CountVectorizer vectorizer_model = CountVectorizer(stop_words="English", ngram_range=(1, 5)) topic_model.update_topics(docs, topics, vectorizer_model=vectorizer_model) -``` \ No newline at end of file +``` + +### **Custom labels** + +The topic labels are currently automatically generated by taking the top 3 words and combining them +using the `_` separator. Although this is an informative label, in practice, this is definitely not the prettiest nor necessarily the most accurate label. For example, although the topic label +`1_space_nasa_orbit` is informative, we would prefer to have a bit more intuitive label, such as +`space travel`. The difficulty with creating such topic labels is that much of the interpretation is left to the user. Would `space travel` be more accurate or perhaps `space explorations`? To truly understand which labels are most suited, going into some of the documents in topics is especially helpful. + +Although we can go through every single topic ourselves and try to label them, we can start by creating an overview of labels that have the length and number of words that we are looking for. To do so, we can generate our list of topic labels with `.get_topic_labels` and define the number of words, the separator, word length, etc: + +```python +topic_labels = topic_model.generate_topic_labels(topic_model, +nr_words=3, + topic_prefix=False, + word_length=10, + separator=", ") +``` + +In the above example, `1_space_nasa_orbit` would turn into `space, nasa, orbit` since we selected 3 words, no topic prefix, and the `, ` separator. We can then either change our `topic_labels` to whatever we want or directly pass them to `.set_topic_labels` so that they can be used across most visualization functions: + +```python +topic_model.set_topic_labels(topics_labels) +``` + +It is also possible to only change a few topic labels at a time by passing a dictionary +where the key represents the *topic ID* and the value is the *topic label*: + +```python +topic_model.set_topic_labels({1: "Space Travel", 7: "Religion"}) +``` + +Then, to make use of those custom topic labels across visualizations, such as `.visualize_hierarchy()`, +we can use the `custom_labels=True` parameter that is found in most visualizations. + +```python +fig = topic_model.visualize_barchart(custom_labels=True) +``` + +#### Optimize labels +The great advantage of passing custom labels to BERTopic is that when more accurate zero-shot are released, +we can simply use those on top of BERTopic to further fine-tune the labeling. For example, let's say you +have a set of potential topic labels that you want to use instead of the ones generated by BERTopic. You could +use the [bart-large-mnli](https://huggingface.co/facebook/bart-large-mnli) model to find which user-defined +labels best represent the BERTopic-generated labels: + + +```python +from transformers import pipeline +classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + +# A selected topic representation +# sequence_to_classify = 'god jesus atheists atheism belief atheist believe exist beliefs existence' +sequence_to_classify = " ".join([word for word, _ in topic_model.get_topic(1)]) + +# Our set of potential topic labels +candidate_labels = ['cooking', 'dancing', 'religion'] +classifier(sequence_to_classify, candidate_labels) + +#{'labels': ['cooking', 'dancing', 'religion'], +# 'scores': [0.086, 0.063, 0.850], +# 'sequence': 'god jesus atheists atheism belief atheist believe exist beliefs existence'} +``` diff --git a/docs/getting_started/visualization/documents.html b/docs/getting_started/visualization/documents.html new file mode 100644 index 00000000..094bf7e6 --- /dev/null +++ b/docs/getting_started/visualization/documents.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/visualization/hierarchical_documents.html b/docs/getting_started/visualization/hierarchical_documents.html new file mode 100644 index 00000000..9638fc0f --- /dev/null +++ b/docs/getting_started/visualization/hierarchical_documents.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/visualization/hierarchical_topics.html b/docs/getting_started/visualization/hierarchical_topics.html new file mode 100644 index 00000000..51ba2364 --- /dev/null +++ b/docs/getting_started/visualization/hierarchical_topics.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/visualization/visualization.md b/docs/getting_started/visualization/visualization.md index 00518034..1a67dc79 100644 --- a/docs/getting_started/visualization/visualization.md +++ b/docs/getting_started/visualization/visualization.md @@ -19,7 +19,7 @@ from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, probs = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` Then, we simply call `topic_model.visualize_topics()` in order to visualize our topics. The resulting graph is a @@ -32,6 +32,45 @@ Thus, you can play around with the results below: You can use the slider to select the topic which then lights up red. If you hover over a topic, then general information is given about the topic, including the size of the topic and its corresponding words. +## **Visualize Documents** +Using the previous method, we can visualize the topics and get insight into their relationships. However, +you might want a more fine-grained approach where we can visualize the documents inside the topics to see +if they were assigned correctly or whether they make sense. To do so, we can use the `topic_model.visualize_documents()` +function. This function recalculates the document embeddings and reduces them to 2-dimensional space for easier visualization +purposes. This process can be quite expensive, so it is advised to adhere to the following pipeline: + +```python +from sklearn.datasets import fetch_20newsgroups +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# Prepare embeddings +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +sentence_model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = sentence_model.encode(docs, show_progress_bar=False) + +# Train BERTopic +topic_model = BERTopic().fit(docs, embeddings) + +# Reduce dimensionality of embeddings, this step is optional +# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + +# Run the visualization with the original embeddings +topic_model.visualize_documents(docs, embeddings=embeddings) + +# Or, if you have reduced the original embeddings already: +topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) +``` + + + + +!!! note + The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the + option to hover over the individual points and see the content of the documents. This was done for demonstration purposes + as saving all those documents in the visualization can be quite expensive and result in large files. However, + it might be interesting to set `hide_document_hover=False` in order to hover over the points and see the content of the documents. ## **Visualize Topic Hierarchy** The topics that were created can be hierarchically reduced. In order to understand the potential hierarchical @@ -46,6 +85,347 @@ of topics that you have created. To visualize this hierarchy, simply call `topic auto since HDBSCAN is used to automatically extract topics. The visualization above closely resembles the actual procedure of `.reduce_topics()` when any number of `nr_topics` is selected. +Although visualizing this hierarchy gives us information about the structure, it would be helpful to see what happens +to the topic representations when merging topics. To do so, we first need to calculate the representations of the +hierarchical topics: + + +First, we train a basic BERTopic model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))["data"] +topic_model = BERTopic(verbose=True) +topics, probs = topic_model.fit_transform(docs) +hierarchical_topics = topic_model.hierarchical_topics(docs, topics) +``` + +To visualize these results, we simply need to pass the resulting `hierarchical_topics` to our `topic_model.visualize_hierarchy()` function: + +```python +topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) +``` + + + +If you **hover** over the black circles, you will see the topic representation at that level of the hierarchy. These representations +help you understand the effect of merging certain topics together. Some might be logical to merge whilst others might not. Moreover, +we can now see which sub-topics can be found within certain larger themes. + +Although this gives a nice overview of the potential hierarchy, hovering over all black circles can be tiresome. Instead, we can +use `topic_model.get_topic_tree` to create a text-based representation of this hierarchy. Although the general structure is more difficult +to view, we can see better which topics could be logically merged: + +```python +>>> tree = topic_model.get_topic_tree(hierarchical_topics) +>>> print(tree) +. +└─atheists_atheism_god_moral_atheist + ├─atheists_atheism_god_atheist_argument + │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ └─■──br_god_exist_genetic_existence ── Topic: 124 + └─■──moral_morality_objective_immoral_morals ── Topic: 29 +``` + +
+ Click here to view the full tree. + + ```bash + . + ├─people_armenian_said_god_armenians + │ ├─god_jesus_jehovah_lord_christ + │ │ ├─god_jesus_jehovah_lord_christ + │ │ │ ├─jehovah_lord_mormon_mcconkie_god + │ │ │ │ ├─■──ra_satan_thou_god_lucifer ── Topic: 94 + │ │ │ │ └─■──jehovah_lord_mormon_mcconkie_unto ── Topic: 78 + │ │ │ └─jesus_mary_god_hell_sin + │ │ │ ├─jesus_hell_god_eternal_heaven + │ │ │ │ ├─hell_jesus_eternal_god_heaven + │ │ │ │ │ ├─■──jesus_tomb_disciples_resurrection_john ── Topic: 69 + │ │ │ │ │ └─■──hell_eternal_god_jesus_heaven ── Topic: 53 + │ │ │ │ └─■──aaron_baptism_sin_law_god ── Topic: 89 + │ │ │ └─■──mary_sin_maria_priest_conception ── Topic: 56 + │ │ └─■──marriage_married_marry_ceremony_marriages ── Topic: 110 + │ └─people_armenian_armenians_said_mr + │ ├─people_armenian_armenians_said_israel + │ │ ├─god_homosexual_homosexuality_atheists_sex + │ │ │ ├─homosexual_homosexuality_sex_gay_homosexuals + │ │ │ │ ├─■──kinsey_sex_gay_men_sexual ── Topic: 44 + │ │ │ │ └─homosexuality_homosexual_sin_homosexuals_gay + │ │ │ │ ├─■──gay_homosexual_homosexuals_sexual_cramer ── Topic: 50 + │ │ │ │ └─■──homosexuality_homosexual_sin_paul_sex ── Topic: 27 + │ │ │ └─god_atheists_atheism_moral_atheist + │ │ │ ├─islam_quran_judas_islamic_book + │ │ │ │ ├─■──jim_context_challenges_articles_quote ── Topic: 36 + │ │ │ │ └─islam_quran_judas_islamic_book + │ │ │ │ ├─■──islam_quran_islamic_rushdie_muslims ── Topic: 31 + │ │ │ │ └─■──judas_scripture_bible_books_greek ── Topic: 33 + │ │ │ └─atheists_atheism_god_moral_atheist + │ │ │ ├─atheists_atheism_god_atheist_argument + │ │ │ │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ │ │ │ └─■──br_god_exist_genetic_existence ── Topic: 124 + │ │ │ └─■──moral_morality_objective_immoral_morals ── Topic: 29 + │ │ └─armenian_armenians_people_israel_said + │ │ ├─armenian_armenians_israel_people_jews + │ │ │ ├─tax_rights_government_income_taxes + │ │ │ │ ├─■──rights_right_slavery_slaves_residence ── Topic: 106 + │ │ │ │ └─tax_government_taxes_income_libertarians + │ │ │ │ ├─■──government_libertarians_libertarian_regulation_party ── Topic: 58 + │ │ │ │ └─■──tax_taxes_income_billion_deficit ── Topic: 41 + │ │ │ └─armenian_armenians_israel_people_jews + │ │ │ ├─gun_guns_militia_firearms_amendment + │ │ │ │ ├─■──blacks_penalty_death_cruel_punishment ── Topic: 55 + │ │ │ │ └─■──gun_guns_militia_firearms_amendment ── Topic: 7 + │ │ │ └─armenian_armenians_israel_jews_turkish + │ │ │ ├─■──israel_israeli_jews_arab_jewish ── Topic: 4 + │ │ │ └─■──armenian_armenians_turkish_armenia_azerbaijan ── Topic: 15 + │ │ └─stephanopoulos_president_mr_myers_ms + │ │ ├─■──serbs_muslims_stephanopoulos_mr_bosnia ── Topic: 35 + │ │ └─■──myers_stephanopoulos_president_ms_mr ── Topic: 87 + │ └─batf_fbi_koresh_compound_gas + │ ├─■──reno_workers_janet_clinton_waco ── Topic: 77 + │ └─batf_fbi_koresh_gas_compound + │ ├─batf_koresh_fbi_warrant_compound + │ │ ├─■──batf_warrant_raid_compound_fbi ── Topic: 42 + │ │ └─■──koresh_batf_fbi_children_compound ── Topic: 61 + │ └─■──fbi_gas_tear_bds_building ── Topic: 23 + └─use_like_just_dont_new + ├─game_team_year_games_like + │ ├─game_team_games_25_year + │ │ ├─game_team_games_25_season + │ │ │ ├─window_printer_use_problem_mhz + │ │ │ │ ├─mhz_wire_simms_wiring_battery + │ │ │ │ │ ├─simms_mhz_battery_cpu_heat + │ │ │ │ │ │ ├─simms_pds_simm_vram_lc + │ │ │ │ │ │ │ ├─■──pds_nubus_lc_slot_card ── Topic: 119 + │ │ │ │ │ │ │ └─■──simms_simm_vram_meg_dram ── Topic: 32 + │ │ │ │ │ │ └─mhz_battery_cpu_heat_speed + │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ │ ├─■──fan_cpu_heat_sink_fans ── Topic: 92 + │ │ │ │ │ │ │ │ └─■──mhz_speed_cpu_fpu_clock ── Topic: 22 + │ │ │ │ │ │ │ └─■──monitor_turn_power_computer_electricity ── Topic: 91 + │ │ │ │ │ │ └─battery_batteries_concrete_duo_discharge + │ │ │ │ │ │ ├─■──duo_battery_apple_230_problem ── Topic: 121 + │ │ │ │ │ │ └─■──battery_batteries_concrete_discharge_temperature ── Topic: 75 + │ │ │ │ │ └─wire_wiring_ground_neutral_outlets + │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ │ ├─■──leds_uv_blue_light_boards ── Topic: 66 + │ │ │ │ │ │ │ └─■──wire_wiring_ground_neutral_outlets ── Topic: 120 + │ │ │ │ │ │ └─scope_scopes_phone_dial_number + │ │ │ │ │ │ ├─■──dial_number_phone_line_output ── Topic: 93 + │ │ │ │ │ │ └─■──scope_scopes_motorola_generator_oscilloscope ── Topic: 113 + │ │ │ │ │ └─celp_dsp_sampling_antenna_digital + │ │ │ │ │ ├─■──antenna_antennas_receiver_cable_transmitter ── Topic: 70 + │ │ │ │ │ └─■──celp_dsp_sampling_speech_voice ── Topic: 52 + │ │ │ │ └─window_printer_xv_mouse_windows + │ │ │ │ ├─window_xv_error_widget_problem + │ │ │ │ │ ├─error_symbol_undefined_xterm_rx + │ │ │ │ │ │ ├─■──symbol_error_undefined_doug_parse ── Topic: 63 + │ │ │ │ │ │ └─■──rx_remote_server_xdm_xterm ── Topic: 45 + │ │ │ │ │ └─window_xv_widget_application_expose + │ │ │ │ │ ├─window_widget_expose_application_event + │ │ │ │ │ │ ├─■──gc_mydisplay_draw_gxxor_drawing ── Topic: 103 + │ │ │ │ │ │ └─■──window_widget_application_expose_event ── Topic: 25 + │ │ │ │ │ └─xv_den_polygon_points_algorithm + │ │ │ │ │ ├─■──den_polygon_points_algorithm_polygons ── Topic: 28 + │ │ │ │ │ └─■──xv_24bit_image_bit_images ── Topic: 57 + │ │ │ │ └─printer_fonts_print_mouse_postscript + │ │ │ │ ├─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──scanner_logitech_grayscale_ocr_scanman ── Topic: 108 + │ │ │ │ │ └─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──printer_print_deskjet_hp_ink ── Topic: 18 + │ │ │ │ │ └─■──fonts_font_truetype_tt_atm ── Topic: 49 + │ │ │ │ └─mouse_ghostscript_midi_driver_postscript + │ │ │ │ ├─ghostscript_midi_postscript_files_file + │ │ │ │ │ ├─■──ghostscript_postscript_pageview_ghostview_dsc ── Topic: 104 + │ │ │ │ │ └─midi_sound_file_windows_driver + │ │ │ │ │ ├─■──location_mar_file_host_rwrr ── Topic: 83 + │ │ │ │ │ └─■──midi_sound_driver_blaster_soundblaster ── Topic: 98 + │ │ │ │ └─■──mouse_driver_mice_ball_problem ── Topic: 68 + │ │ │ └─game_team_games_25_season + │ │ │ ├─1st_sale_condition_comics_hulk + │ │ │ │ ├─sale_condition_offer_asking_cd + │ │ │ │ │ ├─condition_stereo_amp_speakers_asking + │ │ │ │ │ │ ├─■──miles_car_amfm_toyota_cassette ── Topic: 62 + │ │ │ │ │ │ └─■──amp_speakers_condition_stereo_audio ── Topic: 24 + │ │ │ │ │ └─games_sale_pom_cds_shipping + │ │ │ │ │ ├─pom_cds_sale_shipping_cd + │ │ │ │ │ │ ├─■──size_shipping_sale_condition_mattress ── Topic: 100 + │ │ │ │ │ │ └─■──pom_cds_cd_sale_picture ── Topic: 37 + │ │ │ │ │ └─■──games_game_snes_sega_genesis ── Topic: 40 + │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ ├─1st_hulk_comics_art_appears + │ │ │ │ │ ├─lens_tape_camera_backup_lenses + │ │ │ │ │ │ ├─■──tape_backup_tapes_drive_4mm ── Topic: 107 + │ │ │ │ │ │ └─■──lens_camera_lenses_zoom_pouch ── Topic: 114 + │ │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ │ ├─■──1st_hulk_comics_art_appears ── Topic: 105 + │ │ │ │ │ └─■──books_book_cover_trek_chemistry ── Topic: 125 + │ │ │ │ └─tickets_hotel_ticket_voucher_package + │ │ │ │ ├─■──hotel_voucher_package_vacation_room ── Topic: 74 + │ │ │ │ └─■──tickets_ticket_june_airlines_july ── Topic: 84 + │ │ │ └─game_team_games_season_hockey + │ │ │ ├─game_hockey_team_25_550 + │ │ │ │ ├─■──espn_pt_pts_game_la ── Topic: 17 + │ │ │ │ └─■──team_25_game_hockey_550 ── Topic: 2 + │ │ │ └─■──year_game_hit_baseball_players ── Topic: 0 + │ │ └─bike_car_greek_insurance_msg + │ │ ├─car_bike_insurance_cars_engine + │ │ │ ├─car_insurance_cars_radar_engine + │ │ │ │ ├─insurance_health_private_care_canada + │ │ │ │ │ ├─■──insurance_health_private_care_canada ── Topic: 99 + │ │ │ │ │ └─■──insurance_car_accident_rates_sue ── Topic: 82 + │ │ │ │ └─car_cars_radar_engine_detector + │ │ │ │ ├─car_radar_cars_detector_engine + │ │ │ │ │ ├─■──radar_detector_detectors_ka_alarm ── Topic: 39 + │ │ │ │ │ └─car_cars_mustang_ford_engine + │ │ │ │ │ ├─■──clutch_shift_shifting_transmission_gear ── Topic: 88 + │ │ │ │ │ └─■──car_cars_mustang_ford_v8 ── Topic: 14 + │ │ │ │ └─oil_diesel_odometer_diesels_car + │ │ │ │ ├─odometer_oil_sensor_car_drain + │ │ │ │ │ ├─■──odometer_sensor_speedo_gauge_mileage ── Topic: 96 + │ │ │ │ │ └─■──oil_drain_car_leaks_taillights ── Topic: 102 + │ │ │ │ └─■──diesel_diesels_emissions_fuel_oil ── Topic: 79 + │ │ │ └─bike_riding_ride_bikes_motorcycle + │ │ │ ├─bike_ride_riding_bikes_lane + │ │ │ │ ├─■──bike_ride_riding_lane_car ── Topic: 11 + │ │ │ │ └─■──bike_bikes_miles_honda_motorcycle ── Topic: 19 + │ │ │ └─■──countersteering_bike_motorcycle_rear_shaft ── Topic: 46 + │ │ └─greek_msg_kuwait_greece_water + │ │ ├─greek_msg_kuwait_greece_water + │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ │ ├─greek_kuwait_greece_turkish_greeks + │ │ │ │ │ │ ├─■──greek_greece_turkish_greeks_cyprus ── Topic: 71 + │ │ │ │ │ │ └─■──kuwait_iraq_iran_gulf_arabia ── Topic: 76 + │ │ │ │ │ └─msg_dog_drugs_drug_food + │ │ │ │ │ ├─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──clinton_bush_quayle_reagan_panicking ── Topic: 101 + │ │ │ │ │ │ └─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──cooper_trial_weaver_spence_witnesses ── Topic: 90 + │ │ │ │ │ │ └─■──dog_dogs_bike_trained_springer ── Topic: 67 + │ │ │ │ │ └─msg_drugs_drug_food_chinese + │ │ │ │ │ ├─■──msg_food_chinese_foods_taste ── Topic: 30 + │ │ │ │ │ └─■──drugs_drug_marijuana_cocaine_alcohol ── Topic: 72 + │ │ │ │ └─water_theory_universe_science_larsons + │ │ │ │ ├─water_nuclear_cooling_steam_dept + │ │ │ │ │ ├─■──rocketry_rockets_engines_nuclear_plutonium ── Topic: 115 + │ │ │ │ │ └─water_cooling_steam_dept_plants + │ │ │ │ │ ├─■──water_dept_phd_environmental_atmospheric ── Topic: 97 + │ │ │ │ │ └─■──cooling_water_steam_towers_plants ── Topic: 109 + │ │ │ │ └─theory_universe_larsons_larson_science + │ │ │ │ ├─■──theory_universe_larsons_larson_science ── Topic: 54 + │ │ │ │ └─■──oort_cloud_grbs_gamma_burst ── Topic: 80 + │ │ │ └─helmet_kirlian_photography_lock_wax + │ │ │ ├─helmet_kirlian_photography_leaf_mask + │ │ │ │ ├─kirlian_photography_leaf_pictures_deleted + │ │ │ │ │ ├─deleted_joke_stuff_maddi_nickname + │ │ │ │ │ │ ├─■──joke_maddi_nickname_nicknames_frank ── Topic: 43 + │ │ │ │ │ │ └─■──deleted_stuff_bookstore_joke_motto ── Topic: 81 + │ │ │ │ │ └─■──kirlian_photography_leaf_pictures_aura ── Topic: 85 + │ │ │ │ └─helmet_mask_liner_foam_cb + │ │ │ │ ├─■──helmet_liner_foam_cb_helmets ── Topic: 112 + │ │ │ │ └─■──mask_goalies_77_santore_tl ── Topic: 123 + │ │ │ └─lock_wax_paint_plastic_ear + │ │ │ ├─■──lock_cable_locks_bike_600 ── Topic: 117 + │ │ │ └─wax_paint_ear_plastic_skin + │ │ │ ├─■──wax_paint_plastic_scratches_solvent ── Topic: 65 + │ │ │ └─■──ear_wax_skin_greasy_acne ── Topic: 116 + │ │ └─m4_mp_14_mw_mo + │ │ ├─m4_mp_14_mw_mo + │ │ │ ├─■──m4_mp_14_mw_mo ── Topic: 111 + │ │ │ └─■──test_ensign_nameless_deane_deanebinahccbrandeisedu ── Topic: 118 + │ │ └─■──ites_cheek_hello_hi_ken ── Topic: 3 + │ └─space_medical_health_disease_cancer + │ ├─medical_health_disease_cancer_patients + │ │ ├─■──cancer_centers_center_medical_research ── Topic: 122 + │ │ └─health_medical_disease_patients_hiv + │ │ ├─patients_medical_disease_candida_health + │ │ │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 + │ │ │ └─patients_disease_cancer_medical_doctor + │ │ │ ├─■──hiv_medical_cancer_patients_doctor ── Topic: 34 + │ │ │ └─■──pain_drug_patients_disease_diet ── Topic: 26 + │ │ └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + │ └─space_launch_nasa_shuttle_orbit + │ ├─space_moon_station_nasa_launch + │ │ ├─■──sky_advertising_billboard_billboards_space ── Topic: 59 + │ │ └─■──space_station_moon_redesign_nasa ── Topic: 16 + │ └─space_mission_hst_launch_orbit + │ ├─space_launch_nasa_orbit_propulsion + │ │ ├─■──space_launch_nasa_propulsion_astronaut ── Topic: 47 + │ │ └─■──orbit_km_jupiter_probe_earth ── Topic: 86 + │ └─■──hst_mission_shuttle_orbit_arrays ── Topic: 60 + └─drive_file_key_windows_use + ├─key_file_jpeg_encryption_image + │ ├─key_encryption_clipper_chip_keys + │ │ ├─■──key_clipper_encryption_chip_keys ── Topic: 1 + │ │ └─■──entry_file_ripem_entries_key ── Topic: 73 + │ └─jpeg_image_file_gif_images + │ ├─motif_graphics_ftp_available_3d + │ │ ├─motif_graphics_openwindows_ftp_available + │ │ │ ├─■──openwindows_motif_xview_windows_mouse ── Topic: 20 + │ │ │ └─■──graphics_widget_ray_3d_available ── Topic: 95 + │ │ └─■──3d_machines_version_comments_contact ── Topic: 38 + │ └─jpeg_image_gif_images_format + │ ├─■──gopher_ftp_files_stuffit_images ── Topic: 51 + │ └─■──jpeg_image_gif_format_images ── Topic: 13 + └─drive_db_card_scsi_windows + ├─db_windows_dos_mov_os2 + │ ├─■──copy_protection_program_software_disk ── Topic: 64 + │ └─■──db_windows_dos_mov_os2 ── Topic: 8 + └─drive_card_scsi_drives_ide + ├─drive_scsi_drives_ide_disk + │ ├─■──drive_scsi_drives_ide_disk ── Topic: 6 + │ └─■──meg_sale_ram_drive_shipping ── Topic: 12 + └─card_modem_monitor_video_drivers + ├─■──card_monitor_video_drivers_vga ── Topic: 5 + └─■──modem_port_serial_irq_com ── Topic: 10 + ``` +
+ +## **Visualize Hierarchical Document** +We can extend the previous method by calculating the topic representation at different levels of the hierarchy and +plotting them on a 2D-plane. To do so, we first need to calculate the hierarchical topics: + +```python +from sklearn.datasets import fetch_20newsgroups +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# Prepare embeddings +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +sentence_model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = sentence_model.encode(docs, show_progress_bar=False) + +# Train BERTopic and extract hierarchical topics +topic_model = BERTopic().fit(docs, embeddings) +hierarchical_topics = topic_model.hierarchical_topics(docs, topics) +``` +Then, we can visualize the hierarchical documents by either supplying it with our embeddings or by +reducing their dimensionality ourselves: + +```python +# Run the visualization with the original embeddings +topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + +# Reduce dimensionality of embeddings, this step is optional but much faster to perform iteratively: +reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) +topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) +``` + + + +!!! note + The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the + option to hover over the individual points and see the content of the documents. This makes the resulting visualization + smaller and fit into your RAM. However, it might be interesting to set `hide_document_hover=False` in order to hover + over the points and see the content of the documents. + ## **Visualize Terms** We can visualize the selected terms for a few topics by creating bar charts out of the c-TF-IDF scores for each topic representation. Insights can be gained from the relative c-TF-IDF scores between and within diff --git a/docs/index.md b/docs/index.md index 4ec98ddb..776161bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -79,6 +79,10 @@ frequent topic that was generated, topic 0: ## **Overview** +BERTopic has quite a number of functions that quickly can become overwhelming. To alleviate this issue, you will find an overview +of all methods and a short description of its purpose. + +### Common For quick access to common functions, here is an overview of BERTopic's main methods: | Method | Code | @@ -91,21 +95,40 @@ For quick access to common functions, here is an overview of BERTopic's main met | Get topic freq | `.get_topic_freq()` | | Get all topic information| `.get_topic_info()` | | Get representative docs per topic | `.get_representative_docs()` | -| Get topics per class | `.topics_per_class(docs, topics, classes)` | -| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | +| Generate topic labels | `.generate_topic_labels()` | +| Set topic labels | `.set_topic_labels(my_custom_labels)` | +| Merge topics | `.merge_topics(docs, topics, topics_to_merge)` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | | Save model | `.save("my_model")` | | Load model | `BERTopic.load("my_model")` | | Get parameters | `.get_params()` | -For an overview of BERTopic's visualization methods: +### Variations +There are many different use cases in which topic modeling can be used. As such, a number of +variations of BERTopic have been developed such that one package can be used across across many use cases: + +| Method | Code | +|-----------------------|---| +| (semi-) Supervised Topic Modeling | `.fit(docs, y=y)` | +| Topic Modeling per Class | `.topics_per_class(docs, topics, classes)` | +| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | +| Hierarchical Topic Modeling | `.hierarchical_topics(docs, topics)` | +| Guided Topic Modeling | `BERTopic(seed_topic_list=seed_topic_list)` | + +### Visualizations +Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation. +Visualizing different aspects of the topic model helps in understanding the model and makes it easier +to tweak the model to your liking. | Method | Code | |-----------------------|---| | Visualize Topics | `.visualize_topics()` | +| Visualize Documents | `.visualize_documents()` | +| Visualize Document Hierarchy | `.visualize_hierarchical_documents()` | | Visualize Topic Hierarchy | `.visualize_hierarchy()` | +| Visualize Topic Tree | `.get_topic_tree(hierarchical_topics)` | | Visualize Topic Terms | `.visualize_barchart()` | | Visualize Topic Similarity | `.visualize_heatmap()` | | Visualize Term Score Decline | `.visualize_term_rank()` | diff --git a/mkdocs.yml b/mkdocs.yml index 882b6e01..4dbec27f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ nav: - (semi)-Supervised Topic Modeling: getting_started/supervised/supervised.md - Dynamic Topic Modeling: getting_started/topicsovertime/topicsovertime.md - Guided Topic Modeling: getting_started/guided/guided.md + - Hierarchical Topic Modeling: getting_started/hierarchicaltopics/hierarchicaltopics.md - FAQ: faq.md - API: - BERTopic: api/bertopic.md @@ -38,12 +39,14 @@ nav: - Word Doc: api/backends/word_doc.md - Plotting: - Barchart: api/plotting/barchart.md + - Documents: api/plotting/documents.md + - DTM: api/plotting/dtm.md + - Hierarchical documents: api/plotting/hierarchical_documents.md + - Hierarchical topics: api/plotting/hierarchy.md - Distribution: api/plotting/distribution.md - Heatmap: api/plotting/heatmap.md - - Hierarchy: api/plotting/hierarchy.md - Term Scores: api/plotting/term.md - Topics: api/plotting/topics.md - - DTM: api/plotting/dtm.md - Topics per Class: api/plotting/topics_per_class.md - Changelog: changelog.md diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index b51c9f60..8546bda8 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -14,7 +14,7 @@ newsgroup_docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'][:2000] base_model = BERTopic(language="english", verbose=True, min_topic_size=5) -kmeans_model = BERTopic(language="english", verbose=True, min_topic_size=5, hdbscan_model=KMeans(n_clusters=10, random_state=42)) +kmeans_model = BERTopic(language="english", verbose=True, hdbscan_model=KMeans(n_clusters=10, random_state=42)) @pytest.mark.parametrize("topic_model", [base_model, kmeans_model]) @@ -52,6 +52,17 @@ def test_full_model(topic_model): assert topics_over_time.Frequency.sum() == 2000 assert len(topics_over_time.Topic.unique()) == len(set(topics)) + # Test hierarchical topics + hier_topics = topic_model.hierarchical_topics(newsgroup_docs, topics) + + assert len(hier_topics) > 0 + assert hier_topics.Parent_ID.astype(int).min() > max(topics) + + # Test creation of topic tree + tree = topic_model.get_topic_tree(hier_topics, tight_layout=False) + assert isinstance(tree, str) + assert len(tree) > 10 + # Test find topic similar_topics, similarity = topic_model.find_topics("query", top_n=2) assert len(similar_topics) == 2 @@ -78,3 +89,17 @@ def test_full_model(topic_model): assert topic != updated_topic assert topic == original_topic + + # Test updating topic labels + topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ") + assert len(topic_labels) == len(set(new_topics)) + + # Test setting topic labels + topic_model.set_topic_labels(topic_labels) + assert topic_model.custom_labels == topic_labels + + # Test merging topics + freq = topic_model.get_topic_freq(0) + topics_to_merge = [0, 1] + topic_model.merge_topics(newsgroup_docs, new_topics, topics_to_merge) + assert freq < topic_model.get_topic_freq(0)