From 17c68ce3ef349caaa975404baf4909b8a2588446 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Wed, 22 Jun 2022 15:15:06 +0200 Subject: [PATCH 01/15] Add hierarchical topic modeling --- bertopic/_bertopic.py | 289 +++++++++++++++- bertopic/plotting/_hierarchy.py | 168 +++++++-- .../hierarchical_topics.html | 7 + .../hierarchicaltopics/hierarchicaltopics.md | 325 ++++++++++++++++++ .../visualization/visualization.md | 2 +- mkdocs.yml | 1 + tests/test_bertopic.py | 13 +- 7 files changed, 770 insertions(+), 35 deletions(-) create mode 100644 docs/getting_started/hierarchicaltopics/hierarchical_topics.html create mode 100644 docs/getting_started/hierarchicaltopics/hierarchicaltopics.md diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index bdcc9ace..edf57b61 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -15,7 +15,8 @@ import pandas as pd from tqdm import tqdm from scipy.sparse.csr import csr_matrix -from typing import List, Tuple, Union, Mapping, Any +from scipy.cluster import hierarchy as sch +from typing import List, Tuple, Union, Mapping, Any, Callable # Models import hdbscan @@ -616,6 +617,138 @@ def topics_per_class(self, return topics_per_class + def hierarchical_topics(self, + docs: List[int], + topics: List[int], + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None) -> pd.DataFrame: + """ Create a hierarchy of topics + + To create this hierarchy, BERTopic needs to be already fitted once. + Then, a hierarchy is calculated on the distance matrix of the c-TF-IDF + representation using `scipy.cluster.hierarchy.linkage`. + + Based on that hierarchy, we calculate the topic representation at each + merged step. This is a local representation, as we only assume that the + chosen step is merged and not all others which typically improves the + topic representation. + + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + topics: The topics that were returned when calling either `fit` or `fit_transform` + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + + Returns: + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children + + Usage: + + ```python + from bertopic import BERTopic + topic_model = BERTopic() + topics, probs = topic_model.fit_transform(docs) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + ``` + + A custom linkage function can be used as follows: + + ```python + from scipy.cluster import hierarchy as sch + from bertopic import BERTopic + topic_model = BERTopic() + topics, probs = topic_model.fit_transform(docs) + + # Hierarchical topics + linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics, linkage_function=linkage_function) + ``` + """ + if distance_function is None: + distance_function = lambda x: 1 - cosine_similarity(x) + + if linkage_function is None: + linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) + + # Calculate linkage + embeddings = self.c_tf_idf[self._outliers:] + X = distance_function(embeddings) + Z = linkage_function(X) + + # Calculate basic bag-of-words to be iteratively merged later + documents = pd.DataFrame({"Document": docs, + "ID": range(len(docs)), + "Topic": topics}) + documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) + documents_per_topic = documents_per_topic.loc[documents_per_topic.Topic != -1, :] + documents = self._preprocess_text(documents_per_topic.Document.values) + words = self.vectorizer_model.get_feature_names() + bow = self.vectorizer_model.transform(documents) + + # Extract clusters + hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics", + "Child_Left_ID", "Child_Left_Name", + "Child_Right_ID", "Child_Right_Name"]) + for index in tqdm(range(len(Z))): + + # Find clustered documents + clusters = sch.fcluster(Z, t=Z[index][2], criterion='distance') - self._outliers + cluster_df = pd.DataFrame({"Topic": range(len(clusters)), "Cluster": clusters}) + cluster_df = cluster_df.groupby("Cluster").agg({'Topic':lambda x: list(x)}).reset_index() + nr_clusters = len(clusters) + + # Extract first topic we find to get the set of topics in a merged topic + topic = None + val = Z[index][0] + while topic is None: + if val - len(clusters) < 0: + topic = int(val) + else: + val = Z[int(val - len(clusters))][0] + clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]] + + # Group bow per cluster, calculate c-TF-IDF and extract words + grouped = csr_matrix(bow[clustered_topics].sum(axis=0)) + c_tf_idf = self.transformer.transform(grouped) + words_per_topic = self._extract_words_per_topic(words, c_tf_idf, labels=[0]) + + # Extract parent's name and ID + parent_id = index + len(clusters) + parent_name = "_".join([x[0] for x in words_per_topic[0]][:5]) + + # Extract child's name and ID + Z_id = Z[index][0] + child_left_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters + + if Z_id - nr_clusters < 0: + child_left_name = "_".join([x[0] for x in self.get_topic(Z_id)][:5]) + else: + child_left_name = hier_topics.iloc[int(child_left_id)].Parent_Name + + # Extract child's name and ID + Z_id = Z[index][1] + child_right_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters + + if Z_id - nr_clusters < 0: + child_right_name = "_".join([x[0] for x in self.get_topic(Z_id)][:5]) + else: + child_right_name = hier_topics.iloc[int(child_right_id)].Parent_Name + + # Save results + hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name, + clustered_topics, + int(Z[index][0]), child_left_name, + int(Z[index][1]), child_right_name] + + hier_topics["Distance"] = Z[:, 2] + hier_topics = hier_topics.sort_values("Parent_ID", ascending=False) + hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]] = hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]].astype(str) + + return hier_topics + def find_topics(self, search_term: str, top_n: int = 5) -> Tuple[List[int], List[float]]: @@ -835,6 +968,108 @@ def get_representative_docs(self, topic: int = None) -> List[str]: else: return self.representative_docs + @staticmethod + def get_topic_tree(hier_topics: pd.DataFrame, + max_distance: float = None, + tight_layout: bool = False) -> str: + """ Extract the topic tree such that it can be printed + + Arguments: + hier_topics: A dataframe containing the structure of the topic tree. + This is the output of `topic_model.hierachical_topics()` + max_distance: The maximum distance between two topics. This value is + based on the Distance column in `hier_topics`. + tight_layout: Whether to use a tight layout (narrow width) for + easier readability if you have hundreds of topics. + + Returns: + A tree that has the following structure when printed: + . + . + └─health_medical_disease_patients_hiv + ├─patients_medical_disease_candida_health + │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 + │ └─patients_disease_cancer_medical_doctor + │ ├─■──hiv_medical_cancer_patients_doctor ── Topic: 34 + │ └─■──pain_drug_patients_disease_diet ── Topic: 26 + └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + + The blocks (■) indicate that the topic is one you can directly access + from `topic_model.get_topic`. In other words, they are the original un-grouped topics. + + Usage: + + ```python + # Train model + from bertopic import BERTopic + topic_model = BERTopic() + topics, probs = topic_model.fit_transform(docs) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Print topic tree + tree = topic_model.get_topic_tree(hierarchical_topics) + print(tree) + ``` + """ + width = 1 if tight_layout else 4 + if max_distance is None: + max_distance = hier_topics.Distance.max() + 1 + + max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 + + # Extract mapping from ID to name + topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) + topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) + topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} + + # Create tree + tree = {str(row[1].Parent_ID): [str(row[1].Child_Left_ID), str(row[1].Child_Right_ID)] + for row in hier_topics.iterrows()} + + def get_tree(start, tree): + """ Based on: https://stackoverflow.com/a/51920869/10532563 """ + + def _tree(to_print, start, parent, tree, grandpa=None, indent=""): + + # Get distance between merged topics + distance = hier_topics.loc[(hier_topics.Child_Left_ID == parent) | + (hier_topics.Child_Right_ID == parent), "Distance"] + distance = distance.values[0] if len(distance) > 0 else 10 + + if parent != start: + if grandpa is None: + to_print += topic_to_name[parent] + else: + if int(parent) <= max_original_topic: + + # Do not append topic ID if they are not merged + if distance < max_distance: + to_print += "■──" + topic_to_name[parent] + f" ── Topic: {parent}" + "\n" + else: + to_print += "O \n" + else: + to_print += topic_to_name[parent] + "\n" + + if parent not in tree: + return to_print + + for child in tree[parent][:-1]: + to_print += indent + "├" + "─" + to_print = _tree(to_print, start, child, tree, parent, indent + "│" + " " * width) + + child = tree[parent][-1] + to_print += indent + "└" + "─" + to_print = _tree(to_print, start, child, tree, parent, indent + " " * (width+1)) + + return to_print + + to_print = "." + "\n" + to_print = _tree(to_print, start, start, tree) + return to_print + + start = str(hier_topics.Parent_ID.astype(int).max()) + return get_tree(start, tree) + def reduce_topics(self, docs: List[str], topics: List[int], @@ -1112,13 +1347,16 @@ def visualize_distribution(self, width=width, height=height) - def visualize_hierarchy(self, + def visualize_hierarchy(self, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, width: int = 1000, height: int = 600, - optimal_ordering: bool = False) -> go.Figure: + hierarchical_topics: pd.DataFrame = None, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + color_threshold: int = 1) -> go.Figure: """ Visualize a hierarchical structure of the topics A ward linkage function is used to perform the @@ -1126,17 +1364,29 @@ def visualize_hierarchy(self, matrix between topic embeddings. Arguments: + topic_model: A fitted BERTopic instance. orientation: The orientation of the figure. - Either 'left' or 'bottom' + Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' - optimal_ordering: If True, the linkage matrix will be reordered so that the distance - between successive leaves is minimal. This results in a more intuitive - tree structure when the data are visualized. defaults to False, because - this algorithm can be slow, particularly on large datasets. See - also the `linkage` function fun `scipy`. + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children. + NOTE: The hierarchical topic names are only visualized + if both `topics` and `top_n_topics` are not set. + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + NOTE: Make sure to use the same `linkage_function` as used + in `topic_model.hierarchical_topics`. + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + NOTE: Make sure to use the same `distance_function` as used + in `topic_model.hierarchical_topics`. + color_threshold: Value at which the separation of clusters will be made which + will result in different colors for different clusters. + A higher value will typically lead in less colored clusters. + Returns: fig: A plotly figure @@ -1149,12 +1399,25 @@ def visualize_hierarchy(self, topic_model.visualize_hierarchy() ``` - Or if you want to save the resulting figure: + If you also want the labels visualized of hierarchical topics, + run the following: + + ```python + # Extract hierarchical topics and their representations + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Visualize these representations + topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) + ``` + + If you want to save the resulting figure: ```python fig = topic_model.visualize_hierarchy() fig.write_html("path/to/file.html") ``` + """ check_is_fitted(self) return plotting.visualize_hierarchy(self, @@ -1163,7 +1426,11 @@ def visualize_hierarchy(self, top_n_topics=top_n_topics, width=width, height=height, - optimal_ordering=optimal_ordering) + hierarchical_topics=hierarchical_topics, + linkage_function=linkage_function, + distance_function=distance_function, + color_threshold=color_threshold + ) def visualize_heatmap(self, topics: List[int] = None, diff --git a/bertopic/plotting/_hierarchy.py b/bertopic/plotting/_hierarchy.py index 0d28b150..2a2ccebc 100644 --- a/bertopic/plotting/_hierarchy.py +++ b/bertopic/plotting/_hierarchy.py @@ -1,6 +1,8 @@ import numpy as np -from scipy.cluster.hierarchy import linkage -from typing import List +import pandas as pd +from typing import Callable, List +from scipy.sparse.csr import csr_matrix +from scipy.cluster import hierarchy as sch from sklearn.metrics.pairwise import cosine_similarity import plotly.graph_objects as go @@ -13,7 +15,10 @@ def visualize_hierarchy(topic_model, top_n_topics: int = None, width: int = 1000, height: int = 600, - optimal_ordering: bool = False) -> go.Figure: + hierarchical_topics: pd.DataFrame = None, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + color_threshold: int = 1) -> go.Figure: """ Visualize a hierarchical structure of the topics A ward linkage function is used to perform the @@ -28,11 +33,23 @@ def visualize_hierarchy(topic_model, top_n_topics: Only select the top n most frequent topics width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' - optimal_ordering: If True, the linkage matrix will be reordered so that the distance - between successive leaves is minimal. This results in a more intuitive - tree structure when the data are visualized. defaults to False, because - this algorithm can be slow, particularly on large datasets. See - also the `linkage` function fun `scipy`. + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children. + NOTE: The hierarchical topic names are only visualized + if both `topics` and `top_n_topics` are not set. + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + NOTE: Make sure to use the same `linkage_function` as used + in `topic_model.hierarchical_topics`. + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + NOTE: Make sure to use the same `distance_function` as used + in `topic_model.hierarchical_topics`. + color_threshold: Value at which the separation of clusters will be made which + will result in different colors for different clusters. + A higher value will typically lead in less colored clusters. + + Returns: fig: A plotly figure @@ -45,7 +62,18 @@ def visualize_hierarchy(topic_model, topic_model.visualize_hierarchy() ``` - Or if you want to save the resulting figure: + If you also want the labels visualized of hierarchical topics, + run the following: + + ```python + # Extract hierarchical topics and their representations + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Visualize these representations + topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) + ``` + + If you want to save the resulting figure: ```python fig = topic_model.visualize_hierarchy() @@ -54,12 +82,11 @@ def visualize_hierarchy(topic_model, """ - - # Select topic embeddings - if topic_model.topic_embeddings is not None: - embeddings = np.array(topic_model.topic_embeddings) - else: - embeddings = topic_model.c_tf_idf + if distance_function is None: + distance_function = lambda x: 1 - cosine_similarity(x) + + if linkage_function is None: + linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -74,16 +101,27 @@ def visualize_hierarchy(topic_model, # Select embeddings all_topics = sorted(list(topic_model.get_topics().keys())) indices = np.array([all_topics.index(topic) for topic in topics]) - embeddings = embeddings[indices] + embeddings = topic_model.c_tf_idf[indices] + + # Annotations + if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): + annotations = _get_annotations(topic_model=topic_model, + hierarchical_topics=hierarchical_topics, + embeddings=embeddings, + distance_function=distance_function, + linkage_function=linkage_function, + orientation=orientation) + else: + annotations = None # Create dendogram - distance_matrix = 1 - cosine_similarity(embeddings) - fig = ff.create_dendrogram(distance_matrix, + fig = ff.create_dendrogram(embeddings, orientation=orientation, - linkagefun=lambda x: linkage(x, "ward", - optimal_ordering=optimal_ordering), - color_threshold=1) - + distfun=distance_function, + linkagefun=linkage_function, + hovertext=annotations, + color_threshold=color_threshold) + # Create nicer labels axis = "yaxis" if orientation == "left" else "xaxis" new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) @@ -128,4 +166,90 @@ def visualize_hierarchy(topic_model, height=height, xaxis=dict(tickmode="array", ticktext=new_labels)) + + if hierarchical_topics is not None: + for index in [0, 3]: + axis = "x" if orientation == "left" else "y" + xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] + ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] + hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] + + fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', + hovertext=hovertext, hoverinfo="text", + mode='markers', showlegend=False)) return fig + + +def _get_annotations(topic_model, + hierarchical_topics: pd.DataFrame, + embeddings: csr_matrix, + linkage_function: Callable[[csr_matrix], np.ndarray], + distance_function: Callable[[csr_matrix], csr_matrix], + orientation: str) -> List[List[str]]: + + """ Get annotations by replicating linkage function calculation in scipy + + Arguments + topic_model: A fitted BERTopic instance. + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children. + NOTE: The hierarchical topic names are only visualized + if both `topics` and `top_n_topics` are not set. + embeddings: The c-TF-IDF matrix on which to model the hierarchy + linkage_function: The linkage function to use. Default is: + `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` + NOTE: Make sure to use the same `linkage_function` as used + in `topic_model.hierarchical_topics`. + distance_function: The distance function to use on the c-TF-IDF matrix. Default is: + `lambda x: 1 - cosine_similarity(x)` + NOTE: Make sure to use the same `distance_function` as used + in `topic_model.hierarchical_topics`. + orientation: The orientation of the figure. + Either 'left' or 'bottom' + + Returns: + text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` + """ + df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :] + + # Calculate linkage + X = distance_function(embeddings) + Z = linkage_function(X) + P = sch.dendrogram(Z, orientation=orientation, no_plot=True) + + # store topic no.(leaves) corresponding to the x-ticks in dendrogram + x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) + x_topic = dict(zip(P['leaves'], x_ticks)) + + topic_vals = dict() + for key, val in x_topic.items(): + topic_vals[val] = [key] + + parent_topic = dict(zip(df.Parent_ID, df.Topics)) + + # loop through every trace (scatter plot) in dendrogram + text_annotations = [] + for index, trace in enumerate(P['icoord']): + fst_topic = topic_vals[trace[0]] + scnd_topic = topic_vals[trace[2]] + + if len(fst_topic) == 1: + fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) + else: + for key, value in parent_topic.items(): + if set(value) == set(fst_topic): + fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] + + if len(scnd_topic) == 1: + scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) + else: + for key, value in parent_topic.items(): + if set(value) == set(scnd_topic): + scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] + + text_annotations.append([fst_name, "", "", scnd_name]) + + center = (trace[0] + trace[2]) / 2 + topic_vals[center] = fst_topic + scnd_topic + + return text_annotations diff --git a/docs/getting_started/hierarchicaltopics/hierarchical_topics.html b/docs/getting_started/hierarchicaltopics/hierarchical_topics.html new file mode 100644 index 00000000..51ba2364 --- /dev/null +++ b/docs/getting_started/hierarchicaltopics/hierarchical_topics.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md b/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md new file mode 100644 index 00000000..9cef8803 --- /dev/null +++ b/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md @@ -0,0 +1,325 @@ +When tweaking your topic model, the number of topics that are generated has a large effect on the quality of the topic representations. +Some topics could be merged together and having an understanding of the effect will help you understand which topics should and which +should not be merged. + +That is where hierarchical topic modeling comes in. It tries to model the possible hierarchical nature of the topics you have created +in order to understand which topics are similar to each other. Moreover, you will have more insight into sub-topics that might +exist in your data. + +In BERTopic, we can approximate this potential hierarchy by making use of our topic-term matrix (c-TF-IDF matrix). This matrix +contains information about the importance of every word in every topic and makes for a nice numerical representation of our topics. +The smaller the distance between two c-TF-IDF representations, the more similar we assume they are. In practice, this process of merging +topics is done through the hierarchical clustering capabilities of `scipy` (see [here](https://docs.scipy.org/doc/scipy/reference/cluster.hierarchy.html)). +It allows for several linkage methods through which we can approximate our topic hierarchy. As a default, we are using the [ward](https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ward.html#scipy.cluster.hierarchy.ward) but many others are availabe. + +Whenever we merge two topics, we can calculate the c-TF-IDF representation of these two merged by summing their bag-of-words representation. +We assume that two sets of topics are merged and that all others are kept the same, regardless of their location in the hierarchy. This helps +us isolate the potential effect of merging sets of topics. As a result, we can see the topic representation at each level in the tree. + +## **Example** +To demonstrate hierarchical topic modeling with BERTopic, we use the 20 Newsgroups dataset to see how the topics that we uncover are represented in the 20 categories of documents. + +First, we train a basic BERTopic model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))["data"] +topic_model = BERTopic(verbose=True) +topics, probs = topic_model.fit_transform(docs) +``` + +Next, we can use our fitted BERTopic model to extract possible hierarchies from our c-TF-IDF matrix: + +```python +hierarchical_topics = topic_model.hierarchical_topics(docs, topics) +``` + +The resulting `hierarchical_topics` is a dataframe in which merged topics are described. For example, if you would +merge two topics, what would the topic representation of the new topic be? + +### **Visualizations** +To visualize these results, we can start by running a familiar function, namely `topic_model.visualize_hierarchy`: + +```python +topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) +``` + + +If you **hover** over the black circles, you will see the topic representation at that level of the hierarchy. These representations +help you understand the effect of merging certain topics together. Some might be logical to merge whilst others might not. Moreover, +we can now see which sub-topics can be found within certain larger themes. + +Although this gives a nice overview of the potential hierarchy, hovering over all black circles can be tiresome. Instead, we can +use `topic_model.get_topic_tree` to create a text-based representation of this hierarchy. Although the general structure is more difficult +to view, we can see better which topics could be logically merged: + +```python +>>> tree = topic_model.get_topic_tree(hierarchical_topics) +>>> print(tree) +. +└─atheists_atheism_god_moral_atheist + ├─atheists_atheism_god_atheist_argument + │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ └─■──br_god_exist_genetic_existence ── Topic: 124 + └─■──moral_morality_objective_immoral_morals ── Topic: 29 +``` + +
+ Click here to view the full tree. + + ```bash + . + ├─people_armenian_said_god_armenians + │ ├─god_jesus_jehovah_lord_christ + │ │ ├─god_jesus_jehovah_lord_christ + │ │ │ ├─jehovah_lord_mormon_mcconkie_god + │ │ │ │ ├─■──ra_satan_thou_god_lucifer ── Topic: 94 + │ │ │ │ └─■──jehovah_lord_mormon_mcconkie_unto ── Topic: 78 + │ │ │ └─jesus_mary_god_hell_sin + │ │ │ ├─jesus_hell_god_eternal_heaven + │ │ │ │ ├─hell_jesus_eternal_god_heaven + │ │ │ │ │ ├─■──jesus_tomb_disciples_resurrection_john ── Topic: 69 + │ │ │ │ │ └─■──hell_eternal_god_jesus_heaven ── Topic: 53 + │ │ │ │ └─■──aaron_baptism_sin_law_god ── Topic: 89 + │ │ │ └─■──mary_sin_maria_priest_conception ── Topic: 56 + │ │ └─■──marriage_married_marry_ceremony_marriages ── Topic: 110 + │ └─people_armenian_armenians_said_mr + │ ├─people_armenian_armenians_said_israel + │ │ ├─god_homosexual_homosexuality_atheists_sex + │ │ │ ├─homosexual_homosexuality_sex_gay_homosexuals + │ │ │ │ ├─■──kinsey_sex_gay_men_sexual ── Topic: 44 + │ │ │ │ └─homosexuality_homosexual_sin_homosexuals_gay + │ │ │ │ ├─■──gay_homosexual_homosexuals_sexual_cramer ── Topic: 50 + │ │ │ │ └─■──homosexuality_homosexual_sin_paul_sex ── Topic: 27 + │ │ │ └─god_atheists_atheism_moral_atheist + │ │ │ ├─islam_quran_judas_islamic_book + │ │ │ │ ├─■──jim_context_challenges_articles_quote ── Topic: 36 + │ │ │ │ └─islam_quran_judas_islamic_book + │ │ │ │ ├─■──islam_quran_islamic_rushdie_muslims ── Topic: 31 + │ │ │ │ └─■──judas_scripture_bible_books_greek ── Topic: 33 + │ │ │ └─atheists_atheism_god_moral_atheist + │ │ │ ├─atheists_atheism_god_atheist_argument + │ │ │ │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ │ │ │ └─■──br_god_exist_genetic_existence ── Topic: 124 + │ │ │ └─■──moral_morality_objective_immoral_morals ── Topic: 29 + │ │ └─armenian_armenians_people_israel_said + │ │ ├─armenian_armenians_israel_people_jews + │ │ │ ├─tax_rights_government_income_taxes + │ │ │ │ ├─■──rights_right_slavery_slaves_residence ── Topic: 106 + │ │ │ │ └─tax_government_taxes_income_libertarians + │ │ │ │ ├─■──government_libertarians_libertarian_regulation_party ── Topic: 58 + │ │ │ │ └─■──tax_taxes_income_billion_deficit ── Topic: 41 + │ │ │ └─armenian_armenians_israel_people_jews + │ │ │ ├─gun_guns_militia_firearms_amendment + │ │ │ │ ├─■──blacks_penalty_death_cruel_punishment ── Topic: 55 + │ │ │ │ └─■──gun_guns_militia_firearms_amendment ── Topic: 7 + │ │ │ └─armenian_armenians_israel_jews_turkish + │ │ │ ├─■──israel_israeli_jews_arab_jewish ── Topic: 4 + │ │ │ └─■──armenian_armenians_turkish_armenia_azerbaijan ── Topic: 15 + │ │ └─stephanopoulos_president_mr_myers_ms + │ │ ├─■──serbs_muslims_stephanopoulos_mr_bosnia ── Topic: 35 + │ │ └─■──myers_stephanopoulos_president_ms_mr ── Topic: 87 + │ └─batf_fbi_koresh_compound_gas + │ ├─■──reno_workers_janet_clinton_waco ── Topic: 77 + │ └─batf_fbi_koresh_gas_compound + │ ├─batf_koresh_fbi_warrant_compound + │ │ ├─■──batf_warrant_raid_compound_fbi ── Topic: 42 + │ │ └─■──koresh_batf_fbi_children_compound ── Topic: 61 + │ └─■──fbi_gas_tear_bds_building ── Topic: 23 + └─use_like_just_dont_new + ├─game_team_year_games_like + │ ├─game_team_games_25_year + │ │ ├─game_team_games_25_season + │ │ │ ├─window_printer_use_problem_mhz + │ │ │ │ ├─mhz_wire_simms_wiring_battery + │ │ │ │ │ ├─simms_mhz_battery_cpu_heat + │ │ │ │ │ │ ├─simms_pds_simm_vram_lc + │ │ │ │ │ │ │ ├─■──pds_nubus_lc_slot_card ── Topic: 119 + │ │ │ │ │ │ │ └─■──simms_simm_vram_meg_dram ── Topic: 32 + │ │ │ │ │ │ └─mhz_battery_cpu_heat_speed + │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ │ ├─■──fan_cpu_heat_sink_fans ── Topic: 92 + │ │ │ │ │ │ │ │ └─■──mhz_speed_cpu_fpu_clock ── Topic: 22 + │ │ │ │ │ │ │ └─■──monitor_turn_power_computer_electricity ── Topic: 91 + │ │ │ │ │ │ └─battery_batteries_concrete_duo_discharge + │ │ │ │ │ │ ├─■──duo_battery_apple_230_problem ── Topic: 121 + │ │ │ │ │ │ └─■──battery_batteries_concrete_discharge_temperature ── Topic: 75 + │ │ │ │ │ └─wire_wiring_ground_neutral_outlets + │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ │ ├─■──leds_uv_blue_light_boards ── Topic: 66 + │ │ │ │ │ │ │ └─■──wire_wiring_ground_neutral_outlets ── Topic: 120 + │ │ │ │ │ │ └─scope_scopes_phone_dial_number + │ │ │ │ │ │ ├─■──dial_number_phone_line_output ── Topic: 93 + │ │ │ │ │ │ └─■──scope_scopes_motorola_generator_oscilloscope ── Topic: 113 + │ │ │ │ │ └─celp_dsp_sampling_antenna_digital + │ │ │ │ │ ├─■──antenna_antennas_receiver_cable_transmitter ── Topic: 70 + │ │ │ │ │ └─■──celp_dsp_sampling_speech_voice ── Topic: 52 + │ │ │ │ └─window_printer_xv_mouse_windows + │ │ │ │ ├─window_xv_error_widget_problem + │ │ │ │ │ ├─error_symbol_undefined_xterm_rx + │ │ │ │ │ │ ├─■──symbol_error_undefined_doug_parse ── Topic: 63 + │ │ │ │ │ │ └─■──rx_remote_server_xdm_xterm ── Topic: 45 + │ │ │ │ │ └─window_xv_widget_application_expose + │ │ │ │ │ ├─window_widget_expose_application_event + │ │ │ │ │ │ ├─■──gc_mydisplay_draw_gxxor_drawing ── Topic: 103 + │ │ │ │ │ │ └─■──window_widget_application_expose_event ── Topic: 25 + │ │ │ │ │ └─xv_den_polygon_points_algorithm + │ │ │ │ │ ├─■──den_polygon_points_algorithm_polygons ── Topic: 28 + │ │ │ │ │ └─■──xv_24bit_image_bit_images ── Topic: 57 + │ │ │ │ └─printer_fonts_print_mouse_postscript + │ │ │ │ ├─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──scanner_logitech_grayscale_ocr_scanman ── Topic: 108 + │ │ │ │ │ └─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──printer_print_deskjet_hp_ink ── Topic: 18 + │ │ │ │ │ └─■──fonts_font_truetype_tt_atm ── Topic: 49 + │ │ │ │ └─mouse_ghostscript_midi_driver_postscript + │ │ │ │ ├─ghostscript_midi_postscript_files_file + │ │ │ │ │ ├─■──ghostscript_postscript_pageview_ghostview_dsc ── Topic: 104 + │ │ │ │ │ └─midi_sound_file_windows_driver + │ │ │ │ │ ├─■──location_mar_file_host_rwrr ── Topic: 83 + │ │ │ │ │ └─■──midi_sound_driver_blaster_soundblaster ── Topic: 98 + │ │ │ │ └─■──mouse_driver_mice_ball_problem ── Topic: 68 + │ │ │ └─game_team_games_25_season + │ │ │ ├─1st_sale_condition_comics_hulk + │ │ │ │ ├─sale_condition_offer_asking_cd + │ │ │ │ │ ├─condition_stereo_amp_speakers_asking + │ │ │ │ │ │ ├─■──miles_car_amfm_toyota_cassette ── Topic: 62 + │ │ │ │ │ │ └─■──amp_speakers_condition_stereo_audio ── Topic: 24 + │ │ │ │ │ └─games_sale_pom_cds_shipping + │ │ │ │ │ ├─pom_cds_sale_shipping_cd + │ │ │ │ │ │ ├─■──size_shipping_sale_condition_mattress ── Topic: 100 + │ │ │ │ │ │ └─■──pom_cds_cd_sale_picture ── Topic: 37 + │ │ │ │ │ └─■──games_game_snes_sega_genesis ── Topic: 40 + │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ ├─1st_hulk_comics_art_appears + │ │ │ │ │ ├─lens_tape_camera_backup_lenses + │ │ │ │ │ │ ├─■──tape_backup_tapes_drive_4mm ── Topic: 107 + │ │ │ │ │ │ └─■──lens_camera_lenses_zoom_pouch ── Topic: 114 + │ │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ │ ├─■──1st_hulk_comics_art_appears ── Topic: 105 + │ │ │ │ │ └─■──books_book_cover_trek_chemistry ── Topic: 125 + │ │ │ │ └─tickets_hotel_ticket_voucher_package + │ │ │ │ ├─■──hotel_voucher_package_vacation_room ── Topic: 74 + │ │ │ │ └─■──tickets_ticket_june_airlines_july ── Topic: 84 + │ │ │ └─game_team_games_season_hockey + │ │ │ ├─game_hockey_team_25_550 + │ │ │ │ ├─■──espn_pt_pts_game_la ── Topic: 17 + │ │ │ │ └─■──team_25_game_hockey_550 ── Topic: 2 + │ │ │ └─■──year_game_hit_baseball_players ── Topic: 0 + │ │ └─bike_car_greek_insurance_msg + │ │ ├─car_bike_insurance_cars_engine + │ │ │ ├─car_insurance_cars_radar_engine + │ │ │ │ ├─insurance_health_private_care_canada + │ │ │ │ │ ├─■──insurance_health_private_care_canada ── Topic: 99 + │ │ │ │ │ └─■──insurance_car_accident_rates_sue ── Topic: 82 + │ │ │ │ └─car_cars_radar_engine_detector + │ │ │ │ ├─car_radar_cars_detector_engine + │ │ │ │ │ ├─■──radar_detector_detectors_ka_alarm ── Topic: 39 + │ │ │ │ │ └─car_cars_mustang_ford_engine + │ │ │ │ │ ├─■──clutch_shift_shifting_transmission_gear ── Topic: 88 + │ │ │ │ │ └─■──car_cars_mustang_ford_v8 ── Topic: 14 + │ │ │ │ └─oil_diesel_odometer_diesels_car + │ │ │ │ ├─odometer_oil_sensor_car_drain + │ │ │ │ │ ├─■──odometer_sensor_speedo_gauge_mileage ── Topic: 96 + │ │ │ │ │ └─■──oil_drain_car_leaks_taillights ── Topic: 102 + │ │ │ │ └─■──diesel_diesels_emissions_fuel_oil ── Topic: 79 + │ │ │ └─bike_riding_ride_bikes_motorcycle + │ │ │ ├─bike_ride_riding_bikes_lane + │ │ │ │ ├─■──bike_ride_riding_lane_car ── Topic: 11 + │ │ │ │ └─■──bike_bikes_miles_honda_motorcycle ── Topic: 19 + │ │ │ └─■──countersteering_bike_motorcycle_rear_shaft ── Topic: 46 + │ │ └─greek_msg_kuwait_greece_water + │ │ ├─greek_msg_kuwait_greece_water + │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ │ ├─greek_kuwait_greece_turkish_greeks + │ │ │ │ │ │ ├─■──greek_greece_turkish_greeks_cyprus ── Topic: 71 + │ │ │ │ │ │ └─■──kuwait_iraq_iran_gulf_arabia ── Topic: 76 + │ │ │ │ │ └─msg_dog_drugs_drug_food + │ │ │ │ │ ├─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──clinton_bush_quayle_reagan_panicking ── Topic: 101 + │ │ │ │ │ │ └─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──cooper_trial_weaver_spence_witnesses ── Topic: 90 + │ │ │ │ │ │ └─■──dog_dogs_bike_trained_springer ── Topic: 67 + │ │ │ │ │ └─msg_drugs_drug_food_chinese + │ │ │ │ │ ├─■──msg_food_chinese_foods_taste ── Topic: 30 + │ │ │ │ │ └─■──drugs_drug_marijuana_cocaine_alcohol ── Topic: 72 + │ │ │ │ └─water_theory_universe_science_larsons + │ │ │ │ ├─water_nuclear_cooling_steam_dept + │ │ │ │ │ ├─■──rocketry_rockets_engines_nuclear_plutonium ── Topic: 115 + │ │ │ │ │ └─water_cooling_steam_dept_plants + │ │ │ │ │ ├─■──water_dept_phd_environmental_atmospheric ── Topic: 97 + │ │ │ │ │ └─■──cooling_water_steam_towers_plants ── Topic: 109 + │ │ │ │ └─theory_universe_larsons_larson_science + │ │ │ │ ├─■──theory_universe_larsons_larson_science ── Topic: 54 + │ │ │ │ └─■──oort_cloud_grbs_gamma_burst ── Topic: 80 + │ │ │ └─helmet_kirlian_photography_lock_wax + │ │ │ ├─helmet_kirlian_photography_leaf_mask + │ │ │ │ ├─kirlian_photography_leaf_pictures_deleted + │ │ │ │ │ ├─deleted_joke_stuff_maddi_nickname + │ │ │ │ │ │ ├─■──joke_maddi_nickname_nicknames_frank ── Topic: 43 + │ │ │ │ │ │ └─■──deleted_stuff_bookstore_joke_motto ── Topic: 81 + │ │ │ │ │ └─■──kirlian_photography_leaf_pictures_aura ── Topic: 85 + │ │ │ │ └─helmet_mask_liner_foam_cb + │ │ │ │ ├─■──helmet_liner_foam_cb_helmets ── Topic: 112 + │ │ │ │ └─■──mask_goalies_77_santore_tl ── Topic: 123 + │ │ │ └─lock_wax_paint_plastic_ear + │ │ │ ├─■──lock_cable_locks_bike_600 ── Topic: 117 + │ │ │ └─wax_paint_ear_plastic_skin + │ │ │ ├─■──wax_paint_plastic_scratches_solvent ── Topic: 65 + │ │ │ └─■──ear_wax_skin_greasy_acne ── Topic: 116 + │ │ └─m4_mp_14_mw_mo + │ │ ├─m4_mp_14_mw_mo + │ │ │ ├─■──m4_mp_14_mw_mo ── Topic: 111 + │ │ │ └─■──test_ensign_nameless_deane_deanebinahccbrandeisedu ── Topic: 118 + │ │ └─■──ites_cheek_hello_hi_ken ── Topic: 3 + │ └─space_medical_health_disease_cancer + │ ├─medical_health_disease_cancer_patients + │ │ ├─■──cancer_centers_center_medical_research ── Topic: 122 + │ │ └─health_medical_disease_patients_hiv + │ │ ├─patients_medical_disease_candida_health + │ │ │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 + │ │ │ └─patients_disease_cancer_medical_doctor + │ │ │ ├─■──hiv_medical_cancer_patients_doctor ── Topic: 34 + │ │ │ └─■──pain_drug_patients_disease_diet ── Topic: 26 + │ │ └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + │ └─space_launch_nasa_shuttle_orbit + │ ├─space_moon_station_nasa_launch + │ │ ├─■──sky_advertising_billboard_billboards_space ── Topic: 59 + │ │ └─■──space_station_moon_redesign_nasa ── Topic: 16 + │ └─space_mission_hst_launch_orbit + │ ├─space_launch_nasa_orbit_propulsion + │ │ ├─■──space_launch_nasa_propulsion_astronaut ── Topic: 47 + │ │ └─■──orbit_km_jupiter_probe_earth ── Topic: 86 + │ └─■──hst_mission_shuttle_orbit_arrays ── Topic: 60 + └─drive_file_key_windows_use + ├─key_file_jpeg_encryption_image + │ ├─key_encryption_clipper_chip_keys + │ │ ├─■──key_clipper_encryption_chip_keys ── Topic: 1 + │ │ └─■──entry_file_ripem_entries_key ── Topic: 73 + │ └─jpeg_image_file_gif_images + │ ├─motif_graphics_ftp_available_3d + │ │ ├─motif_graphics_openwindows_ftp_available + │ │ │ ├─■──openwindows_motif_xview_windows_mouse ── Topic: 20 + │ │ │ └─■──graphics_widget_ray_3d_available ── Topic: 95 + │ │ └─■──3d_machines_version_comments_contact ── Topic: 38 + │ └─jpeg_image_gif_images_format + │ ├─■──gopher_ftp_files_stuffit_images ── Topic: 51 + │ └─■──jpeg_image_gif_format_images ── Topic: 13 + └─drive_db_card_scsi_windows + ├─db_windows_dos_mov_os2 + │ ├─■──copy_protection_program_software_disk ── Topic: 64 + │ └─■──db_windows_dos_mov_os2 ── Topic: 8 + └─drive_card_scsi_drives_ide + ├─drive_scsi_drives_ide_disk + │ ├─■──drive_scsi_drives_ide_disk ── Topic: 6 + │ └─■──meg_sale_ram_drive_shipping ── Topic: 12 + └─card_modem_monitor_video_drivers + ├─■──card_monitor_video_drivers_vga ── Topic: 5 + └─■──modem_port_serial_irq_com ── Topic: 10 + ``` +
diff --git a/docs/getting_started/visualization/visualization.md b/docs/getting_started/visualization/visualization.md index 00518034..c2334c97 100644 --- a/docs/getting_started/visualization/visualization.md +++ b/docs/getting_started/visualization/visualization.md @@ -19,7 +19,7 @@ from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, probs = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` Then, we simply call `topic_model.visualize_topics()` in order to visualize our topics. The resulting graph is a diff --git a/mkdocs.yml b/mkdocs.yml index 882b6e01..d3413342 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ nav: - (semi)-Supervised Topic Modeling: getting_started/supervised/supervised.md - Dynamic Topic Modeling: getting_started/topicsovertime/topicsovertime.md - Guided Topic Modeling: getting_started/guided/guided.md + - Hierarchical Topic Modeling: getting_started/hierarchicaltopics/hierarchicaltopics.md - FAQ: faq.md - API: - BERTopic: api/bertopic.md diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index b51c9f60..e8d722a1 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -14,7 +14,7 @@ newsgroup_docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'][:2000] base_model = BERTopic(language="english", verbose=True, min_topic_size=5) -kmeans_model = BERTopic(language="english", verbose=True, min_topic_size=5, hdbscan_model=KMeans(n_clusters=10, random_state=42)) +kmeans_model = BERTopic(language="english", verbose=True, hdbscan_model=KMeans(n_clusters=10, random_state=42)) @pytest.mark.parametrize("topic_model", [base_model, kmeans_model]) @@ -52,6 +52,17 @@ def test_full_model(topic_model): assert topics_over_time.Frequency.sum() == 2000 assert len(topics_over_time.Topic.unique()) == len(set(topics)) + # Test hierarchical topics + hier_topics = topic_model.hierarchical_topics(newsgroup_docs, topics) + + assert len(hier_topics) > 0 + assert hier_topics.Parent_ID.astype(int).min() > max(topics) + + # Test creation of topic tree + tree = topic_model.get_topic_tree(hier_topics, tight_layout=False) + assert isinstance(tree, str) + assert len(tree) > 10 + # Test find topic similar_topics, similarity = topic_model.find_topics("query", top_n=2) assert len(similar_topics) == 2 From 6934bbb4eeee24fb1ea266776cae14401b43077a Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Tue, 28 Jun 2022 08:39:57 +0200 Subject: [PATCH 02/15] Add .visualize_documents() and documentation updates --- bertopic/_bertopic.py | 104 +++++- bertopic/plotting/_documents.py | 205 +++++++++++ bertopic/plotting/_hierarchy.py | 61 ++-- docs/api/plotting/documents.md | 3 + .../visualization/documents.html | 7 + .../visualization/hierarchical_topics.html | 7 + .../visualization/visualization.md | 343 ++++++++++++++++++ mkdocs.yml | 1 + 8 files changed, 692 insertions(+), 39 deletions(-) create mode 100644 bertopic/plotting/_documents.py create mode 100644 docs/api/plotting/documents.md create mode 100644 docs/getting_started/visualization/documents.html create mode 100644 docs/getting_started/visualization/hierarchical_topics.html diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index edf57b61..002982cf 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1165,6 +1165,94 @@ def visualize_topics(self, width=width, height=height) + def visualize_documents(self, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: float = None, + hide_annotations: bool = False, + hide_document_hover: bool = False, + width: int = 1200, + height: int = 750) -> go.Figure: + """ Visualize documents and their topics in 2D + + Arguments: + topic_model: A fitted BERTopic instance. + docs: The documents you used when calling either `fit` or `fit_transform` + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualizatin. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_documents(docs) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic + topic_model = BERTopic().fit(docs, embeddings) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_documents(docs, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + check_is_fitted(self) + return plotting.visualize_documents(self, + docs=docs, + topics=topics, + embeddings=embeddings, + reduced_embeddings=reduced_embeddings, + sample=sample, + hide_annotations=hide_annotations, + hide_document_hover=hide_document_hover, + width=width, + height=height) + def visualize_term_rank(self, topics: List[int] = None, log_scale: bool = False, @@ -1347,7 +1435,7 @@ def visualize_distribution(self, width=width, height=height) - def visualize_hierarchy(self, + def visualize_hierarchy(self, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, @@ -1374,19 +1462,19 @@ def visualize_hierarchy(self, hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children. NOTE: The hierarchical topic names are only visualized - if both `topics` and `top_n_topics` are not set. + if both `topics` and `top_n_topics` are not set. linkage_function: The linkage function to use. Default is: `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` - NOTE: Make sure to use the same `linkage_function` as used + NOTE: Make sure to use the same `linkage_function` as used in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)` - NOTE: Make sure to use the same `distance_function` as used + NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. - color_threshold: Value at which the separation of clusters will be made which - will result in different colors for different clusters. + color_threshold: Value at which the separation of clusters will be made which + will result in different colors for different clusters. A higher value will typically lead in less colored clusters. - + Returns: fig: A plotly figure @@ -1399,7 +1487,7 @@ def visualize_hierarchy(self, topic_model.visualize_hierarchy() ``` - If you also want the labels visualized of hierarchical topics, + If you also want the labels visualized of hierarchical topics, run the following: ```python diff --git a/bertopic/plotting/_documents.py b/bertopic/plotting/_documents.py new file mode 100644 index 00000000..d1a33a05 --- /dev/null +++ b/bertopic/plotting/_documents.py @@ -0,0 +1,205 @@ +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from umap import UMAP +from typing import List + + +def visualize_documents(topic_model, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: float = None, + hide_annotations: bool = False, + hide_document_hover: bool = False, + width: int = 1200, + height: int = 750): + """ Visualize documents and their topics in 2D + + Arguments: + topic_model: A fitted BERTopic instance. + docs: The documents you used when calling either `fit` or `fit_transform` + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualizatin. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_documents(docs) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic + topic_model = BERTopic().fit(docs, embeddings) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_documents(docs, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + topic_per_doc = topic_model._map_predictions(topic_model.hdbscan_model.labels_) + + # Sample the data to optimize for visualization and dimensionality reduction + if sample is None or sample > 1: + sample = 1 + + indices = [] + for topic in set(topic_per_doc): + s = np.where(np.array(topic_per_doc) == topic)[0] + size = len(s) if len(s) < 100 else int(len(s) * sample) + indices.extend(np.random.choice(s, size=size, replace=False)) + indices = np.array(indices) + + df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]}) + df["doc"] = [docs[index] for index in indices] + df["topic"] = [topic_per_doc[index] for index in indices] + + # Extract embeddings if not already done + if sample is None: + if embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + else: + embeddings_to_reduce = embeddings + else: + if embeddings is not None: + embeddings_to_reduce = embeddings[indices] + elif embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + + # Reduce input embeddings + if reduced_embeddings is None: + umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce) + embeddings_2d = umap_model.embedding_ + elif sample is not None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings[indices] + elif sample is None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings + + unique_topics = set(topic_per_doc) + if topics is None: + topics = unique_topics + + # Combine data + df["x"] = embeddings_2d[:, 0] + df["y"] = embeddings_2d[:, 1] + + # Prepare text and names + names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics] + + # Visualize + fig = go.Figure() + + # Outliers and non-selected topics + non_selected_topics = set(unique_topics).difference(topics) + if len(non_selected_topics) == 0: + non_selected_topics = [-1] + + selection = df.loc[df.topic.isin(non_selected_topics), :] + selection["text"] = "" + selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), "Other documents"] + + fig.add_trace( + go.Scattergl( + x=selection.x, + y=selection.y, + hovertext=selection.doc if not hide_document_hover else None, + hoverinfo="text", + mode='markers+text', + name="other", + showlegend=False, + marker=dict(color='#CFD8DC', size=5, opacity=0.5) + ) + ) + + # Selected topics + for name, topic in zip(names, unique_topics): + if topic in topics and topic != -1: + selection = df.loc[df.topic == topic, :] + selection["text"] = "" + + if not hide_annotations: + selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name] + + fig.add_trace( + go.Scattergl( + x=selection.x, + y=selection.y, + hovertext=selection.doc if not hide_document_hover else None, + hoverinfo="text", + text=selection.text, + mode='markers+text', + name=name, + textfont=dict( + size=12, + ), + marker=dict(size=5, opacity=0.5) + ) + ) + + # Add grid in a 'plus' shape + x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) + y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) + fig.add_shape(type="line", + x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], + line=dict(color="#CFD8DC", width=2)) + fig.add_shape(type="line", + x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2)) + fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) + fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + + # Stylize layout + fig.update_layout( + template="simple_white", + width=width, + height=height + ) + + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + return fig diff --git a/bertopic/plotting/_hierarchy.py b/bertopic/plotting/_hierarchy.py index 2a2ccebc..b7d0c50f 100644 --- a/bertopic/plotting/_hierarchy.py +++ b/bertopic/plotting/_hierarchy.py @@ -36,20 +36,19 @@ def visualize_hierarchy(topic_model, hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children. NOTE: The hierarchical topic names are only visualized - if both `topics` and `top_n_topics` are not set. + if both `topics` and `top_n_topics` are not set. linkage_function: The linkage function to use. Default is: `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` - NOTE: Make sure to use the same `linkage_function` as used + NOTE: Make sure to use the same `linkage_function` as used in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)` - NOTE: Make sure to use the same `distance_function` as used + NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. - color_threshold: Value at which the separation of clusters will be made which - will result in different colors for different clusters. + color_threshold: Value at which the separation of clusters will be made which + will result in different colors for different clusters. A higher value will typically lead in less colored clusters. - Returns: fig: A plotly figure @@ -62,7 +61,7 @@ def visualize_hierarchy(topic_model, topic_model.visualize_hierarchy() ``` - If you also want the labels visualized of hierarchical topics, + If you also want the labels visualized of hierarchical topics, run the following: ```python @@ -83,8 +82,8 @@ def visualize_hierarchy(topic_model, style="width:1000px; height: 680px; border: 0px;""> """ if distance_function is None: - distance_function = lambda x: 1 - cosine_similarity(x) - + distance_function = lambda x: 1 - cosine_similarity(x) + if linkage_function is None: linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) @@ -105,11 +104,11 @@ def visualize_hierarchy(topic_model, # Annotations if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): - annotations = _get_annotations(topic_model=topic_model, - hierarchical_topics=hierarchical_topics, - embeddings=embeddings, - distance_function=distance_function, - linkage_function=linkage_function, + annotations = _get_annotations(topic_model=topic_model, + hierarchical_topics=hierarchical_topics, + embeddings=embeddings, + distance_function=distance_function, + linkage_function=linkage_function, orientation=orientation) else: annotations = None @@ -121,7 +120,7 @@ def visualize_hierarchy(topic_model, linkagefun=linkage_function, hovertext=annotations, color_threshold=color_threshold) - + # Create nicer labels axis = "yaxis" if orientation == "left" else "xaxis" new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) @@ -173,18 +172,18 @@ def visualize_hierarchy(topic_model, xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] - + fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', - hovertext=hovertext, hoverinfo="text", + hovertext=hovertext, hoverinfo="text", mode='markers', showlegend=False)) return fig -def _get_annotations(topic_model, - hierarchical_topics: pd.DataFrame, - embeddings: csr_matrix, +def _get_annotations(topic_model, + hierarchical_topics: pd.DataFrame, + embeddings: csr_matrix, linkage_function: Callable[[csr_matrix], np.ndarray], - distance_function: Callable[[csr_matrix], csr_matrix], + distance_function: Callable[[csr_matrix], csr_matrix], orientation: str) -> List[List[str]]: """ Get annotations by replicating linkage function calculation in scipy @@ -194,19 +193,19 @@ def _get_annotations(topic_model, hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children. NOTE: The hierarchical topic names are only visualized - if both `topics` and `top_n_topics` are not set. + if both `topics` and `top_n_topics` are not set. embeddings: The c-TF-IDF matrix on which to model the hierarchy linkage_function: The linkage function to use. Default is: `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` - NOTE: Make sure to use the same `linkage_function` as used + NOTE: Make sure to use the same `linkage_function` as used in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)` - NOTE: Make sure to use the same `distance_function` as used + NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. orientation: The orientation of the figure. Either 'left' or 'bottom' - + Returns: text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` """ @@ -216,7 +215,7 @@ def _get_annotations(topic_model, X = distance_function(embeddings) Z = linkage_function(X) P = sch.dendrogram(Z, orientation=orientation, no_plot=True) - + # store topic no.(leaves) corresponding to the x-ticks in dendrogram x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) x_topic = dict(zip(P['leaves'], x_ticks)) @@ -224,7 +223,7 @@ def _get_annotations(topic_model, topic_vals = dict() for key, val in x_topic.items(): topic_vals[val] = [key] - + parent_topic = dict(zip(df.Parent_ID, df.Topics)) # loop through every trace (scatter plot) in dendrogram @@ -232,24 +231,24 @@ def _get_annotations(topic_model, for index, trace in enumerate(P['icoord']): fst_topic = topic_vals[trace[0]] scnd_topic = topic_vals[trace[2]] - + if len(fst_topic) == 1: fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) else: for key, value in parent_topic.items(): if set(value) == set(fst_topic): fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] - + if len(scnd_topic) == 1: scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) else: for key, value in parent_topic.items(): if set(value) == set(scnd_topic): scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] - + text_annotations.append([fst_name, "", "", scnd_name]) center = (trace[0] + trace[2]) / 2 topic_vals[center] = fst_topic + scnd_topic - + return text_annotations diff --git a/docs/api/plotting/documents.md b/docs/api/plotting/documents.md new file mode 100644 index 00000000..5edc249d --- /dev/null +++ b/docs/api/plotting/documents.md @@ -0,0 +1,3 @@ +# `Documents` + +::: bertopic.plotting._documents.visualize_documents diff --git a/docs/getting_started/visualization/documents.html b/docs/getting_started/visualization/documents.html new file mode 100644 index 00000000..094bf7e6 --- /dev/null +++ b/docs/getting_started/visualization/documents.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/visualization/hierarchical_topics.html b/docs/getting_started/visualization/hierarchical_topics.html new file mode 100644 index 00000000..51ba2364 --- /dev/null +++ b/docs/getting_started/visualization/hierarchical_topics.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/visualization/visualization.md b/docs/getting_started/visualization/visualization.md index c2334c97..09055a56 100644 --- a/docs/getting_started/visualization/visualization.md +++ b/docs/getting_started/visualization/visualization.md @@ -32,6 +32,46 @@ Thus, you can play around with the results below: You can use the slider to select the topic which then lights up red. If you hover over a topic, then general information is given about the topic, including the size of the topic and its corresponding words. +## **Visualize Documents** +Using the previous method, we can visualize the overall topics and get insight into their relationships. However, +you might want a more fine-grained approach where we can visualize the documents inside the topics to see +if they were assigned correctly or whether they make sense. To do so, we can use the `topic_model.visualize_documents()` +function. This function recalculates the document embeddings and reduces them to 2-dimensional space for easier visualization +purposes. This process can be quite expensive, so it is advised to adhere to the following pipeline: + +```python +from sklearn.datasets import fetch_20newsgroups +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# Prepare embeddings +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +sentence_model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = sentence_model.encode(docs, show_progress_bar=False) + +# Train BERTopic +topic_model = BERTopic().fit(docs, embeddings) + +# Reduce dimensionality of embeddings, this step is optional +# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + +# Run the visualization with the original embeddings +topic_model.visualize_documents(docs, embeddings=embeddings) + +# Or, if you have reduced the original embeddings already: +topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) +``` + + + + +!!! note + The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the + option to hover over the individual points and see the content of the documents. This was done for demonstration purposes + as saving all those documents in the visualization can be quite expensive and result in large files. However, + it might be interesting to set `hide_document_hover=False` in order to hover over the points and see the content of the documents. + ## **Visualize Topic Hierarchy** The topics that were created can be hierarchically reduced. In order to understand the potential hierarchical @@ -46,6 +86,309 @@ of topics that you have created. To visualize this hierarchy, simply call `topic auto since HDBSCAN is used to automatically extract topics. The visualization above closely resembles the actual procedure of `.reduce_topics()` when any number of `nr_topics` is selected. +Although visualizing this hierarchy gives us information about the structure, it would be helpful to see what happens +to the topic representations when merging topics. To do so, we first need to calculate the representations of the +hierarchical topics: + + +First, we train a basic BERTopic model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))["data"] +topic_model = BERTopic(verbose=True) +topics, probs = topic_model.fit_transform(docs) +hierarchical_topics = topic_model.hierarchical_topics(docs, topics) +``` + +To visualize these results, we simply need to pass the resulting `hierarchical_topics` to our `topic_model.visualize_hierarchy()` function: + +```python +topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) +``` + + + +If you **hover** over the black circles, you will see the topic representation at that level of the hierarchy. These representations +help you understand the effect of merging certain topics together. Some might be logical to merge whilst others might not. Moreover, +we can now see which sub-topics can be found within certain larger themes. + +Although this gives a nice overview of the potential hierarchy, hovering over all black circles can be tiresome. Instead, we can +use `topic_model.get_topic_tree` to create a text-based representation of this hierarchy. Although the general structure is more difficult +to view, we can see better which topics could be logically merged: + +```python +>>> tree = topic_model.get_topic_tree(hierarchical_topics) +>>> print(tree) +. +└─atheists_atheism_god_moral_atheist + ├─atheists_atheism_god_atheist_argument + │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ └─■──br_god_exist_genetic_existence ── Topic: 124 + └─■──moral_morality_objective_immoral_morals ── Topic: 29 +``` + +
+ Click here to view the full tree. + + ```bash + . + ├─people_armenian_said_god_armenians + │ ├─god_jesus_jehovah_lord_christ + │ │ ├─god_jesus_jehovah_lord_christ + │ │ │ ├─jehovah_lord_mormon_mcconkie_god + │ │ │ │ ├─■──ra_satan_thou_god_lucifer ── Topic: 94 + │ │ │ │ └─■──jehovah_lord_mormon_mcconkie_unto ── Topic: 78 + │ │ │ └─jesus_mary_god_hell_sin + │ │ │ ├─jesus_hell_god_eternal_heaven + │ │ │ │ ├─hell_jesus_eternal_god_heaven + │ │ │ │ │ ├─■──jesus_tomb_disciples_resurrection_john ── Topic: 69 + │ │ │ │ │ └─■──hell_eternal_god_jesus_heaven ── Topic: 53 + │ │ │ │ └─■──aaron_baptism_sin_law_god ── Topic: 89 + │ │ │ └─■──mary_sin_maria_priest_conception ── Topic: 56 + │ │ └─■──marriage_married_marry_ceremony_marriages ── Topic: 110 + │ └─people_armenian_armenians_said_mr + │ ├─people_armenian_armenians_said_israel + │ │ ├─god_homosexual_homosexuality_atheists_sex + │ │ │ ├─homosexual_homosexuality_sex_gay_homosexuals + │ │ │ │ ├─■──kinsey_sex_gay_men_sexual ── Topic: 44 + │ │ │ │ └─homosexuality_homosexual_sin_homosexuals_gay + │ │ │ │ ├─■──gay_homosexual_homosexuals_sexual_cramer ── Topic: 50 + │ │ │ │ └─■──homosexuality_homosexual_sin_paul_sex ── Topic: 27 + │ │ │ └─god_atheists_atheism_moral_atheist + │ │ │ ├─islam_quran_judas_islamic_book + │ │ │ │ ├─■──jim_context_challenges_articles_quote ── Topic: 36 + │ │ │ │ └─islam_quran_judas_islamic_book + │ │ │ │ ├─■──islam_quran_islamic_rushdie_muslims ── Topic: 31 + │ │ │ │ └─■──judas_scripture_bible_books_greek ── Topic: 33 + │ │ │ └─atheists_atheism_god_moral_atheist + │ │ │ ├─atheists_atheism_god_atheist_argument + │ │ │ │ ├─■──atheists_atheism_god_atheist_argument ── Topic: 21 + │ │ │ │ └─■──br_god_exist_genetic_existence ── Topic: 124 + │ │ │ └─■──moral_morality_objective_immoral_morals ── Topic: 29 + │ │ └─armenian_armenians_people_israel_said + │ │ ├─armenian_armenians_israel_people_jews + │ │ │ ├─tax_rights_government_income_taxes + │ │ │ │ ├─■──rights_right_slavery_slaves_residence ── Topic: 106 + │ │ │ │ └─tax_government_taxes_income_libertarians + │ │ │ │ ├─■──government_libertarians_libertarian_regulation_party ── Topic: 58 + │ │ │ │ └─■──tax_taxes_income_billion_deficit ── Topic: 41 + │ │ │ └─armenian_armenians_israel_people_jews + │ │ │ ├─gun_guns_militia_firearms_amendment + │ │ │ │ ├─■──blacks_penalty_death_cruel_punishment ── Topic: 55 + │ │ │ │ └─■──gun_guns_militia_firearms_amendment ── Topic: 7 + │ │ │ └─armenian_armenians_israel_jews_turkish + │ │ │ ├─■──israel_israeli_jews_arab_jewish ── Topic: 4 + │ │ │ └─■──armenian_armenians_turkish_armenia_azerbaijan ── Topic: 15 + │ │ └─stephanopoulos_president_mr_myers_ms + │ │ ├─■──serbs_muslims_stephanopoulos_mr_bosnia ── Topic: 35 + │ │ └─■──myers_stephanopoulos_president_ms_mr ── Topic: 87 + │ └─batf_fbi_koresh_compound_gas + │ ├─■──reno_workers_janet_clinton_waco ── Topic: 77 + │ └─batf_fbi_koresh_gas_compound + │ ├─batf_koresh_fbi_warrant_compound + │ │ ├─■──batf_warrant_raid_compound_fbi ── Topic: 42 + │ │ └─■──koresh_batf_fbi_children_compound ── Topic: 61 + │ └─■──fbi_gas_tear_bds_building ── Topic: 23 + └─use_like_just_dont_new + ├─game_team_year_games_like + │ ├─game_team_games_25_year + │ │ ├─game_team_games_25_season + │ │ │ ├─window_printer_use_problem_mhz + │ │ │ │ ├─mhz_wire_simms_wiring_battery + │ │ │ │ │ ├─simms_mhz_battery_cpu_heat + │ │ │ │ │ │ ├─simms_pds_simm_vram_lc + │ │ │ │ │ │ │ ├─■──pds_nubus_lc_slot_card ── Topic: 119 + │ │ │ │ │ │ │ └─■──simms_simm_vram_meg_dram ── Topic: 32 + │ │ │ │ │ │ └─mhz_battery_cpu_heat_speed + │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ ├─mhz_cpu_speed_heat_fan + │ │ │ │ │ │ │ │ ├─■──fan_cpu_heat_sink_fans ── Topic: 92 + │ │ │ │ │ │ │ │ └─■──mhz_speed_cpu_fpu_clock ── Topic: 22 + │ │ │ │ │ │ │ └─■──monitor_turn_power_computer_electricity ── Topic: 91 + │ │ │ │ │ │ └─battery_batteries_concrete_duo_discharge + │ │ │ │ │ │ ├─■──duo_battery_apple_230_problem ── Topic: 121 + │ │ │ │ │ │ └─■──battery_batteries_concrete_discharge_temperature ── Topic: 75 + │ │ │ │ │ └─wire_wiring_ground_neutral_outlets + │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ ├─wire_wiring_ground_neutral_outlets + │ │ │ │ │ │ │ ├─■──leds_uv_blue_light_boards ── Topic: 66 + │ │ │ │ │ │ │ └─■──wire_wiring_ground_neutral_outlets ── Topic: 120 + │ │ │ │ │ │ └─scope_scopes_phone_dial_number + │ │ │ │ │ │ ├─■──dial_number_phone_line_output ── Topic: 93 + │ │ │ │ │ │ └─■──scope_scopes_motorola_generator_oscilloscope ── Topic: 113 + │ │ │ │ │ └─celp_dsp_sampling_antenna_digital + │ │ │ │ │ ├─■──antenna_antennas_receiver_cable_transmitter ── Topic: 70 + │ │ │ │ │ └─■──celp_dsp_sampling_speech_voice ── Topic: 52 + │ │ │ │ └─window_printer_xv_mouse_windows + │ │ │ │ ├─window_xv_error_widget_problem + │ │ │ │ │ ├─error_symbol_undefined_xterm_rx + │ │ │ │ │ │ ├─■──symbol_error_undefined_doug_parse ── Topic: 63 + │ │ │ │ │ │ └─■──rx_remote_server_xdm_xterm ── Topic: 45 + │ │ │ │ │ └─window_xv_widget_application_expose + │ │ │ │ │ ├─window_widget_expose_application_event + │ │ │ │ │ │ ├─■──gc_mydisplay_draw_gxxor_drawing ── Topic: 103 + │ │ │ │ │ │ └─■──window_widget_application_expose_event ── Topic: 25 + │ │ │ │ │ └─xv_den_polygon_points_algorithm + │ │ │ │ │ ├─■──den_polygon_points_algorithm_polygons ── Topic: 28 + │ │ │ │ │ └─■──xv_24bit_image_bit_images ── Topic: 57 + │ │ │ │ └─printer_fonts_print_mouse_postscript + │ │ │ │ ├─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──scanner_logitech_grayscale_ocr_scanman ── Topic: 108 + │ │ │ │ │ └─printer_fonts_print_font_deskjet + │ │ │ │ │ ├─■──printer_print_deskjet_hp_ink ── Topic: 18 + │ │ │ │ │ └─■──fonts_font_truetype_tt_atm ── Topic: 49 + │ │ │ │ └─mouse_ghostscript_midi_driver_postscript + │ │ │ │ ├─ghostscript_midi_postscript_files_file + │ │ │ │ │ ├─■──ghostscript_postscript_pageview_ghostview_dsc ── Topic: 104 + │ │ │ │ │ └─midi_sound_file_windows_driver + │ │ │ │ │ ├─■──location_mar_file_host_rwrr ── Topic: 83 + │ │ │ │ │ └─■──midi_sound_driver_blaster_soundblaster ── Topic: 98 + │ │ │ │ └─■──mouse_driver_mice_ball_problem ── Topic: 68 + │ │ │ └─game_team_games_25_season + │ │ │ ├─1st_sale_condition_comics_hulk + │ │ │ │ ├─sale_condition_offer_asking_cd + │ │ │ │ │ ├─condition_stereo_amp_speakers_asking + │ │ │ │ │ │ ├─■──miles_car_amfm_toyota_cassette ── Topic: 62 + │ │ │ │ │ │ └─■──amp_speakers_condition_stereo_audio ── Topic: 24 + │ │ │ │ │ └─games_sale_pom_cds_shipping + │ │ │ │ │ ├─pom_cds_sale_shipping_cd + │ │ │ │ │ │ ├─■──size_shipping_sale_condition_mattress ── Topic: 100 + │ │ │ │ │ │ └─■──pom_cds_cd_sale_picture ── Topic: 37 + │ │ │ │ │ └─■──games_game_snes_sega_genesis ── Topic: 40 + │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ ├─1st_hulk_comics_art_appears + │ │ │ │ │ ├─lens_tape_camera_backup_lenses + │ │ │ │ │ │ ├─■──tape_backup_tapes_drive_4mm ── Topic: 107 + │ │ │ │ │ │ └─■──lens_camera_lenses_zoom_pouch ── Topic: 114 + │ │ │ │ │ └─1st_hulk_comics_art_appears + │ │ │ │ │ ├─■──1st_hulk_comics_art_appears ── Topic: 105 + │ │ │ │ │ └─■──books_book_cover_trek_chemistry ── Topic: 125 + │ │ │ │ └─tickets_hotel_ticket_voucher_package + │ │ │ │ ├─■──hotel_voucher_package_vacation_room ── Topic: 74 + │ │ │ │ └─■──tickets_ticket_june_airlines_july ── Topic: 84 + │ │ │ └─game_team_games_season_hockey + │ │ │ ├─game_hockey_team_25_550 + │ │ │ │ ├─■──espn_pt_pts_game_la ── Topic: 17 + │ │ │ │ └─■──team_25_game_hockey_550 ── Topic: 2 + │ │ │ └─■──year_game_hit_baseball_players ── Topic: 0 + │ │ └─bike_car_greek_insurance_msg + │ │ ├─car_bike_insurance_cars_engine + │ │ │ ├─car_insurance_cars_radar_engine + │ │ │ │ ├─insurance_health_private_care_canada + │ │ │ │ │ ├─■──insurance_health_private_care_canada ── Topic: 99 + │ │ │ │ │ └─■──insurance_car_accident_rates_sue ── Topic: 82 + │ │ │ │ └─car_cars_radar_engine_detector + │ │ │ │ ├─car_radar_cars_detector_engine + │ │ │ │ │ ├─■──radar_detector_detectors_ka_alarm ── Topic: 39 + │ │ │ │ │ └─car_cars_mustang_ford_engine + │ │ │ │ │ ├─■──clutch_shift_shifting_transmission_gear ── Topic: 88 + │ │ │ │ │ └─■──car_cars_mustang_ford_v8 ── Topic: 14 + │ │ │ │ └─oil_diesel_odometer_diesels_car + │ │ │ │ ├─odometer_oil_sensor_car_drain + │ │ │ │ │ ├─■──odometer_sensor_speedo_gauge_mileage ── Topic: 96 + │ │ │ │ │ └─■──oil_drain_car_leaks_taillights ── Topic: 102 + │ │ │ │ └─■──diesel_diesels_emissions_fuel_oil ── Topic: 79 + │ │ │ └─bike_riding_ride_bikes_motorcycle + │ │ │ ├─bike_ride_riding_bikes_lane + │ │ │ │ ├─■──bike_ride_riding_lane_car ── Topic: 11 + │ │ │ │ └─■──bike_bikes_miles_honda_motorcycle ── Topic: 19 + │ │ │ └─■──countersteering_bike_motorcycle_rear_shaft ── Topic: 46 + │ │ └─greek_msg_kuwait_greece_water + │ │ ├─greek_msg_kuwait_greece_water + │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ ├─greek_msg_kuwait_greece_dog + │ │ │ │ │ ├─greek_kuwait_greece_turkish_greeks + │ │ │ │ │ │ ├─■──greek_greece_turkish_greeks_cyprus ── Topic: 71 + │ │ │ │ │ │ └─■──kuwait_iraq_iran_gulf_arabia ── Topic: 76 + │ │ │ │ │ └─msg_dog_drugs_drug_food + │ │ │ │ │ ├─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──clinton_bush_quayle_reagan_panicking ── Topic: 101 + │ │ │ │ │ │ └─dog_dogs_cooper_trial_weaver + │ │ │ │ │ │ ├─■──cooper_trial_weaver_spence_witnesses ── Topic: 90 + │ │ │ │ │ │ └─■──dog_dogs_bike_trained_springer ── Topic: 67 + │ │ │ │ │ └─msg_drugs_drug_food_chinese + │ │ │ │ │ ├─■──msg_food_chinese_foods_taste ── Topic: 30 + │ │ │ │ │ └─■──drugs_drug_marijuana_cocaine_alcohol ── Topic: 72 + │ │ │ │ └─water_theory_universe_science_larsons + │ │ │ │ ├─water_nuclear_cooling_steam_dept + │ │ │ │ │ ├─■──rocketry_rockets_engines_nuclear_plutonium ── Topic: 115 + │ │ │ │ │ └─water_cooling_steam_dept_plants + │ │ │ │ │ ├─■──water_dept_phd_environmental_atmospheric ── Topic: 97 + │ │ │ │ │ └─■──cooling_water_steam_towers_plants ── Topic: 109 + │ │ │ │ └─theory_universe_larsons_larson_science + │ │ │ │ ├─■──theory_universe_larsons_larson_science ── Topic: 54 + │ │ │ │ └─■──oort_cloud_grbs_gamma_burst ── Topic: 80 + │ │ │ └─helmet_kirlian_photography_lock_wax + │ │ │ ├─helmet_kirlian_photography_leaf_mask + │ │ │ │ ├─kirlian_photography_leaf_pictures_deleted + │ │ │ │ │ ├─deleted_joke_stuff_maddi_nickname + │ │ │ │ │ │ ├─■──joke_maddi_nickname_nicknames_frank ── Topic: 43 + │ │ │ │ │ │ └─■──deleted_stuff_bookstore_joke_motto ── Topic: 81 + │ │ │ │ │ └─■──kirlian_photography_leaf_pictures_aura ── Topic: 85 + │ │ │ │ └─helmet_mask_liner_foam_cb + │ │ │ │ ├─■──helmet_liner_foam_cb_helmets ── Topic: 112 + │ │ │ │ └─■──mask_goalies_77_santore_tl ── Topic: 123 + │ │ │ └─lock_wax_paint_plastic_ear + │ │ │ ├─■──lock_cable_locks_bike_600 ── Topic: 117 + │ │ │ └─wax_paint_ear_plastic_skin + │ │ │ ├─■──wax_paint_plastic_scratches_solvent ── Topic: 65 + │ │ │ └─■──ear_wax_skin_greasy_acne ── Topic: 116 + │ │ └─m4_mp_14_mw_mo + │ │ ├─m4_mp_14_mw_mo + │ │ │ ├─■──m4_mp_14_mw_mo ── Topic: 111 + │ │ │ └─■──test_ensign_nameless_deane_deanebinahccbrandeisedu ── Topic: 118 + │ │ └─■──ites_cheek_hello_hi_ken ── Topic: 3 + │ └─space_medical_health_disease_cancer + │ ├─medical_health_disease_cancer_patients + │ │ ├─■──cancer_centers_center_medical_research ── Topic: 122 + │ │ └─health_medical_disease_patients_hiv + │ │ ├─patients_medical_disease_candida_health + │ │ │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 + │ │ │ └─patients_disease_cancer_medical_doctor + │ │ │ ├─■──hiv_medical_cancer_patients_doctor ── Topic: 34 + │ │ │ └─■──pain_drug_patients_disease_diet ── Topic: 26 + │ │ └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + │ └─space_launch_nasa_shuttle_orbit + │ ├─space_moon_station_nasa_launch + │ │ ├─■──sky_advertising_billboard_billboards_space ── Topic: 59 + │ │ └─■──space_station_moon_redesign_nasa ── Topic: 16 + │ └─space_mission_hst_launch_orbit + │ ├─space_launch_nasa_orbit_propulsion + │ │ ├─■──space_launch_nasa_propulsion_astronaut ── Topic: 47 + │ │ └─■──orbit_km_jupiter_probe_earth ── Topic: 86 + │ └─■──hst_mission_shuttle_orbit_arrays ── Topic: 60 + └─drive_file_key_windows_use + ├─key_file_jpeg_encryption_image + │ ├─key_encryption_clipper_chip_keys + │ │ ├─■──key_clipper_encryption_chip_keys ── Topic: 1 + │ │ └─■──entry_file_ripem_entries_key ── Topic: 73 + │ └─jpeg_image_file_gif_images + │ ├─motif_graphics_ftp_available_3d + │ │ ├─motif_graphics_openwindows_ftp_available + │ │ │ ├─■──openwindows_motif_xview_windows_mouse ── Topic: 20 + │ │ │ └─■──graphics_widget_ray_3d_available ── Topic: 95 + │ │ └─■──3d_machines_version_comments_contact ── Topic: 38 + │ └─jpeg_image_gif_images_format + │ ├─■──gopher_ftp_files_stuffit_images ── Topic: 51 + │ └─■──jpeg_image_gif_format_images ── Topic: 13 + └─drive_db_card_scsi_windows + ├─db_windows_dos_mov_os2 + │ ├─■──copy_protection_program_software_disk ── Topic: 64 + │ └─■──db_windows_dos_mov_os2 ── Topic: 8 + └─drive_card_scsi_drives_ide + ├─drive_scsi_drives_ide_disk + │ ├─■──drive_scsi_drives_ide_disk ── Topic: 6 + │ └─■──meg_sale_ram_drive_shipping ── Topic: 12 + └─card_modem_monitor_video_drivers + ├─■──card_monitor_video_drivers_vga ── Topic: 5 + └─■──modem_port_serial_irq_com ── Topic: 10 + ``` +
+ + ## **Visualize Terms** We can visualize the selected terms for a few topics by creating bar charts out of the c-TF-IDF scores for each topic representation. Insights can be gained from the relative c-TF-IDF scores between and within diff --git a/mkdocs.yml b/mkdocs.yml index d3413342..5ae3a4bd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,6 +39,7 @@ nav: - Word Doc: api/backends/word_doc.md - Plotting: - Barchart: api/plotting/barchart.md + - Documents: api/plotting/documents.md - Distribution: api/plotting/distribution.md - Heatmap: api/plotting/heatmap.md - Hierarchy: api/plotting/hierarchy.md From ee67eb428d04a272a583db73f9cbb88897a5230c Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Tue, 28 Jun 2022 10:25:22 +0200 Subject: [PATCH 03/15] Add title, small documentation changes --- bertopic/plotting/__init__.py | 2 ++ bertopic/plotting/_documents.py | 9 +++++++++ docs/getting_started/visualization/visualization.md | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/bertopic/plotting/__init__.py b/bertopic/plotting/__init__.py index 8edccb73..c493504d 100644 --- a/bertopic/plotting/__init__.py +++ b/bertopic/plotting/__init__.py @@ -1,6 +1,7 @@ from ._topics import visualize_topics from ._heatmap import visualize_heatmap from ._barchart import visualize_barchart +from ._documents import visualize_documents from ._term_rank import visualize_term_rank from ._hierarchy import visualize_hierarchy from ._distribution import visualize_distribution @@ -12,6 +13,7 @@ "visualize_topics", "visualize_heatmap", "visualize_barchart", + "visualize_documents", "visualize_term_rank", "visualize_hierarchy", "visualize_distribution", diff --git a/bertopic/plotting/_documents.py b/bertopic/plotting/_documents.py index d1a33a05..99beb42c 100644 --- a/bertopic/plotting/_documents.py +++ b/bertopic/plotting/_documents.py @@ -196,6 +196,15 @@ def visualize_documents(topic_model, # Stylize layout fig.update_layout( template="simple_white", + title={ + 'text': "Documents and Topics", + 'x': 0.5, + 'xanchor': 'center', + 'yanchor': 'top', + 'font': dict( + size=22, + color="Black") + }, width=width, height=height ) diff --git a/docs/getting_started/visualization/visualization.md b/docs/getting_started/visualization/visualization.md index 09055a56..fce5274b 100644 --- a/docs/getting_started/visualization/visualization.md +++ b/docs/getting_started/visualization/visualization.md @@ -33,7 +33,7 @@ You can use the slider to select the topic which then lights up red. If you hove information is given about the topic, including the size of the topic and its corresponding words. ## **Visualize Documents** -Using the previous method, we can visualize the overall topics and get insight into their relationships. However, +Using the previous method, we can visualize the topics and get insight into their relationships. However, you might want a more fine-grained approach where we can visualize the documents inside the topics to see if they were assigned correctly or whether they make sense. To do so, we can use the `topic_model.visualize_documents()` function. This function recalculates the document embeddings and reduces them to 2-dimensional space for easier visualization From 78db7bc8c9c9f7a215fc1d71241b4981fa15d4ea Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Tue, 28 Jun 2022 14:17:36 +0200 Subject: [PATCH 04/15] Added .visualize_hierarchical_documents, documentation, and small changes --- bertopic/_bertopic.py | 102 +++++- bertopic/plotting/__init__.py | 4 +- bertopic/plotting/_documents.py | 2 +- bertopic/plotting/_hierarchical_documents.py | 306 ++++++++++++++++++ docs/api/plotting/hierarchical_documents.md | 3 + .../visualization/hierarchical_documents.html | 7 + .../visualization/visualization.md | 41 ++- mkdocs.yml | 5 +- 8 files changed, 463 insertions(+), 7 deletions(-) create mode 100644 bertopic/plotting/_hierarchical_documents.py create mode 100644 docs/api/plotting/hierarchical_documents.md create mode 100644 docs/getting_started/visualization/hierarchical_documents.html diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 002982cf..c19c101d 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1192,7 +1192,7 @@ def visualize_documents(self, millions of documents as a subset is chosen. hide_annotations: Hide the names of the traces on top of each cluster. hide_document_hover: Hide the content of the documents when hovering over - specific points. Helps to speed up generation of visualizatin. + specific points. Helps to speed up generation of visualization. width: The width of the figure. height: The height of the figure. @@ -1253,6 +1253,106 @@ def visualize_documents(self, width=width, height=height) + def visualize_hierarchical_documents(self, + docs: List[str], + hierarchical_topics: pd.DataFrame, + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: Union[float, int] = None, + hide_annotations: bool = False, + hide_document_hover: bool = True, + nr_levels: int = 10, + width: int = 1200, + height: int = 750) -> go.Figure: + """ Visualize documents and their topics in 2D at different levels of hierarchy + + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualizations. + nr_levels: The number of levels to be visualized in the hierarchy. First, the distances + in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances with + equal length. Then, for each list of distances, the merged topics are selected that + have a distance less or equal to the maximum distance of the selected list of distances. + NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to + the length of `hierarchical_topics`. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic and extract hierarchical topics + topic_model = BERTopic().fit(docs, embeddings) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + check_is_fitted(self) + return plotting.visualize_hierarchical_documents(self, + docs=docs, + hierarchical_topics=hierarchical_topics, + topics=topics, + embeddings=embeddings, + reduced_embeddings=reduced_embeddings, + sample=sample, + hide_annotations=hide_annotations, + hide_document_hover=hide_document_hover, + nr_levels=nr_levels, + width=width, + height=height) + def visualize_term_rank(self, topics: List[int] = None, log_scale: bool = False, diff --git a/bertopic/plotting/__init__.py b/bertopic/plotting/__init__.py index c493504d..3cc61cc5 100644 --- a/bertopic/plotting/__init__.py +++ b/bertopic/plotting/__init__.py @@ -7,6 +7,7 @@ from ._distribution import visualize_distribution from ._topics_over_time import visualize_topics_over_time from ._topics_per_class import visualize_topics_per_class +from ._hierarchical_documents import visualize_hierarchical_documents __all__ = [ @@ -18,5 +19,6 @@ "visualize_hierarchy", "visualize_distribution", "visualize_topics_over_time", - "visualize_topics_per_class" + "visualize_topics_per_class", + "visualize_hierarchical_documents" ] diff --git a/bertopic/plotting/_documents.py b/bertopic/plotting/_documents.py index 99beb42c..aa380079 100644 --- a/bertopic/plotting/_documents.py +++ b/bertopic/plotting/_documents.py @@ -33,7 +33,7 @@ def visualize_documents(topic_model, millions of documents as a subset is chosen. hide_annotations: Hide the names of the traces on top of each cluster. hide_document_hover: Hide the content of the documents when hovering over - specific points. Helps to speed up generation of visualizatin. + specific points. Helps to speed up generation of visualization. width: The width of the figure. height: The height of the figure. diff --git a/bertopic/plotting/_hierarchical_documents.py b/bertopic/plotting/_hierarchical_documents.py new file mode 100644 index 00000000..6c3d839f --- /dev/null +++ b/bertopic/plotting/_hierarchical_documents.py @@ -0,0 +1,306 @@ +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from umap import UMAP +from typing import List, Union + + +def visualize_hierarchical_documents(topic_model, + docs: List[str], + hierarchical_topics: pd.DataFrame, + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: Union[float, int] = None, + hide_annotations: bool = False, + hide_document_hover: bool = True, + nr_levels: int = 10, + width: int = 1200, + height: int = 750) -> go.Figure: + """ Visualize documents and their topics in 2D at different levels of hierarchy + + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + hierarchical_topics: A dataframe that contains a hierarchy of topics + represented by their parents and their children + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from `.fit_transform`. + For example, if you want to visualize only topics 1 through 5: + `topics = [1, 2, 3, 4, 5]`. + embeddings: The embeddings of all documents in `docs`. + reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. + sample: The percentage of documents in each topic that you would like to keep. + Value can be between 0 and 1. Setting this value to, for example, + 0.1 (10% of documents in each topic) makes it easier to visualize + millions of documents as a subset is chosen. + hide_annotations: Hide the names of the traces on top of each cluster. + hide_document_hover: Hide the content of the documents when hovering over + specific points. Helps to speed up generation of visualizations. + nr_levels: The number of levels to be visualized in the hierarchy. First, the distances + in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances with + equal length. Then, for each list of distances, the merged topics are selected that + have a distance less or equal to the maximum distance of the selected list of distances. + NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to + the length of `hierarchical_topics`. + width: The width of the figure. + height: The height of the figure. + + Usage: + + To visualize the topics simply run: + + ```python + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics) + ``` + + Do note that this re-calculates the embeddings and reduces them to 2D. + The advised and prefered pipeline for using this function is as follows: + + ```python + from sklearn.datasets import fetch_20newsgroups + from sentence_transformers import SentenceTransformer + from bertopic import BERTopic + from umap import UMAP + + # Prepare embeddings + docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = sentence_model.encode(docs, show_progress_bar=False) + + # Train BERTopic and extract hierarchical topics + topic_model = BERTopic().fit(docs, embeddings) + hierarchical_topics = topic_model.hierarchical_topics(docs, topics) + + # Reduce dimensionality of embeddings, this step is optional + # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) + + # Run the visualization with the original embeddings + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + + # Or, if you have reduced the original embeddings already: + topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + ``` + + Or if you want to save the resulting figure: + + ```python + fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) + fig.write_html("path/to/file.html") + ``` + + + """ + topic_per_doc = topic_model._map_predictions(topic_model.hdbscan_model.labels_) + + # Sample the data to optimize for visualization and dimensionality reduction + if sample is None or sample > 1: + sample = 1 + + indices = [] + for topic in set(topic_per_doc): + s = np.where(np.array(topic_per_doc) == topic)[0] + size = len(s) if len(s) < 100 else int(len(s)*sample) + indices.extend(np.random.choice(s, size=size, replace=False)) + indices = np.array(indices) + + df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]}) + df["doc"] = [docs[index] for index in indices] + df["topic"] = [topic_per_doc[index] for index in indices] + + # Extract embeddings if not already done + if sample is None: + if embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + else: + embeddings_to_reduce = embeddings + else: + if embeddings is not None: + embeddings_to_reduce = embeddings[indices] + elif embeddings is None and reduced_embeddings is None: + embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + + # Reduce input embeddings + if reduced_embeddings is None: + umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce) + embeddings_2d = umap_model.embedding_ + elif sample is not None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings[indices] + elif sample is None and reduced_embeddings is not None: + embeddings_2d = reduced_embeddings + + # Combine data + df["x"] = embeddings_2d[:, 0] + df["y"] = embeddings_2d[:, 1] + + # Create topic list for each level, levels are created by calculating the distance + distances = hierarchical_topics.Distance.to_list() + max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1] + + for index, max_distance in enumerate(max_distances): + + # Get topics below `max_distance` + mapping = {topic: topic for topic in df.topic.unique()} + selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :] + selection.Parent_ID = selection.Parent_ID.astype(int) + selection = selection.sort_values("Parent_ID") + + for row in selection.iterrows(): + for topic in row[1].Topics: + mapping[topic] = row[1].Parent_ID + + # Make sure the mappings are mapped 1:1 + mappings = [True for _ in mapping] + while any(mappings): + for i, (key, value) in enumerate(mapping.items()): + if value in mapping.keys() and key != value: + mapping[key] = mapping[value] + else: + mappings[i] = False + + # Create new column + df[f"level_{index+1}"] = df.topic.map(mapping) + df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int) + + # Prepare topic names of original and merged topics + trace_names = [] + topic_names = {} + for topic in range(hierarchical_topics.Parent_ID.astype(int).max()): + if topic < hierarchical_topics.Parent_ID.astype(int).min(): + if topic_model.get_topic(topic): + plot_text = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3]) + trace_name = f"{topic}_" + "_".join([word for word, _ in topic_model.get_topic(topic)][:3]) + topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]} + trace_names.append(trace_name) + else: + trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0] + plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]]) + topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]} + trace_names.append(trace_name) + + # Prepare traces + all_traces = [] + for level in range(len(max_distances)): + traces = [] + + # Outliers + if topic_model._outliers: + traces.append( + go.Scattergl( + x=df.loc[(df[f"level_{level+1}"] == -1), "x"], + y=df.loc[df[f"level_{level+1}"] == -1, "y"], + mode='markers+text', + name="other", + hoverinfo="text", + hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None, + showlegend=False, + marker=dict(color='#CFD8DC', size=5, opacity=0.5) + ) + ) + + # Selected topics + if topics: + selection = df.loc[(df.topic.isin(topics)), :] + unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()]) + else: + unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()]) + + for topic in unique_topics: + if topic != -1: + if topics: + selection = df.loc[(df[f"level_{level+1}"] == topic) & + (df.topic.isin(topics)), :] + else: + selection = df.loc[df[f"level_{level+1}"] == topic, :] + + if not hide_annotations: + selection.loc[len(selection), :] = None + selection["text"] = "" + selection.loc[len(selection) - 1, "x"] = selection.x.mean() + selection.loc[len(selection) - 1, "y"] = selection.y.mean() + selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"] + + traces.append( + go.Scattergl( + x=selection.x, + y=selection.y, + text=selection.text if not hide_annotations else None, + hovertext=selection.doc if not hide_document_hover else None, + hoverinfo="text", + name=topic_names[int(topic)]["trace_name"], + mode='markers+text', + marker=dict(size=5, opacity=0.5) + ) + ) + + all_traces.append(traces) + + # Track and count traces + nr_traces_per_set = [len(traces) for traces in all_traces] + trace_indices = [(0, nr_traces_per_set[0])] + for index, nr_traces in enumerate(nr_traces_per_set[1:]): + start = trace_indices[index][1] + end = nr_traces + start + trace_indices.append((start, end)) + + # Visualization + fig = go.Figure() + for traces in all_traces: + for trace in traces: + fig.add_trace(trace) + + for index in range(len(fig.data)): + if index >= nr_traces_per_set[0]: + fig.data[index].visible = False + + # Create and add slider + steps = [] + for index, indices in enumerate(trace_indices): + step = dict( + method="update", + label=str(index), + args=[{"visible": [False] * len(fig.data)}] + ) + for index in range(indices[1]-indices[0]): + step["args"][0]["visible"][index+indices[0]] = True + steps.append(step) + + sliders = [dict( + currentvalue={"prefix": "Level: "}, + pad={"t": 20}, + steps=steps + )] + + # Add grid in a 'plus' shape + x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) + y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) + fig.add_shape(type="line", + x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], + line=dict(color="#CFD8DC", width=2)) + fig.add_shape(type="line", + x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2)) + fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) + fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + + # Stylize layout + fig.update_layout( + sliders=sliders, + template="simple_white", + title={ + 'text': "Hierarchical Documents and Topics", + 'x': 0.5, + 'xanchor': 'center', + 'yanchor': 'top', + 'font': dict( + size=22, + color="Black") + }, + width=width, + height=height, + ) + + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + return fig diff --git a/docs/api/plotting/hierarchical_documents.md b/docs/api/plotting/hierarchical_documents.md new file mode 100644 index 00000000..49216b77 --- /dev/null +++ b/docs/api/plotting/hierarchical_documents.md @@ -0,0 +1,3 @@ +# `Hierarchical Documents` + +::: bertopic.plotting._hierarchical_documents.visualize_hierarchical_documents diff --git a/docs/getting_started/visualization/hierarchical_documents.html b/docs/getting_started/visualization/hierarchical_documents.html new file mode 100644 index 00000000..9638fc0f --- /dev/null +++ b/docs/getting_started/visualization/hierarchical_documents.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/getting_started/visualization/visualization.md b/docs/getting_started/visualization/visualization.md index fce5274b..1a67dc79 100644 --- a/docs/getting_started/visualization/visualization.md +++ b/docs/getting_started/visualization/visualization.md @@ -70,8 +70,7 @@ topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the option to hover over the individual points and see the content of the documents. This was done for demonstration purposes as saving all those documents in the visualization can be quite expensive and result in large files. However, - it might be interesting to set `hide_document_hover=False` in order to hover over the points and see the content of the documents. - + it might be interesting to set `hide_document_hover=False` in order to hover over the points and see the content of the documents. ## **Visualize Topic Hierarchy** The topics that were created can be hierarchically reduced. In order to understand the potential hierarchical @@ -388,6 +387,44 @@ to view, we can see better which topics could be logically merged: ``` +## **Visualize Hierarchical Document** +We can extend the previous method by calculating the topic representation at different levels of the hierarchy and +plotting them on a 2D-plane. To do so, we first need to calculate the hierarchical topics: + +```python +from sklearn.datasets import fetch_20newsgroups +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# Prepare embeddings +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +sentence_model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = sentence_model.encode(docs, show_progress_bar=False) + +# Train BERTopic and extract hierarchical topics +topic_model = BERTopic().fit(docs, embeddings) +hierarchical_topics = topic_model.hierarchical_topics(docs, topics) +``` +Then, we can visualize the hierarchical documents by either supplying it with our embeddings or by +reducing their dimensionality ourselves: + +```python +# Run the visualization with the original embeddings +topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + +# Reduce dimensionality of embeddings, this step is optional but much faster to perform iteratively: +reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) +topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) +``` + + + +!!! note + The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the + option to hover over the individual points and see the content of the documents. This makes the resulting visualization + smaller and fit into your RAM. However, it might be interesting to set `hide_document_hover=False` in order to hover + over the points and see the content of the documents. ## **Visualize Terms** We can visualize the selected terms for a few topics by creating bar charts out of the c-TF-IDF scores diff --git a/mkdocs.yml b/mkdocs.yml index 5ae3a4bd..4dbec27f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,12 +40,13 @@ nav: - Plotting: - Barchart: api/plotting/barchart.md - Documents: api/plotting/documents.md + - DTM: api/plotting/dtm.md + - Hierarchical documents: api/plotting/hierarchical_documents.md + - Hierarchical topics: api/plotting/hierarchy.md - Distribution: api/plotting/distribution.md - Heatmap: api/plotting/heatmap.md - - Hierarchy: api/plotting/hierarchy.md - Term Scores: api/plotting/term.md - Topics: api/plotting/topics.md - - DTM: api/plotting/dtm.md - Topics per Class: api/plotting/topics_per_class.md - Changelog: changelog.md From ec0b06198cced75287a67dc21242254db20fcc72 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Tue, 28 Jun 2022 14:28:05 +0200 Subject: [PATCH 05/15] Update tables with new functions --- README.md | 15 +++++++++++++-- docs/index.md | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7f327ed0..b4307b26 100644 --- a/README.md +++ b/README.md @@ -208,8 +208,6 @@ For quick access to common functions, here is an overview of BERTopic's main met | Get topic freq | `.get_topic_freq()` | | Get all topic information| `.get_topic_info()` | | Get representative docs per topic | `.get_representative_docs()` | -| Get topics per class | `.topics_per_class(docs, topics, classes)` | -| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | @@ -217,12 +215,25 @@ For quick access to common functions, here is an overview of BERTopic's main met | Load model | `BERTopic.load("my_model")` | | Get parameters | `.get_params()` | +For an overview of extensions to and variations of BERTopic, such as topics over time: + +| Method | Code | +|-----------------------|---| +| (semi-) Supervised Topic Modeling | `.fit(docs, y=y)` | +| Topic Modeling per Class | `.topics_per_class(docs, topics, classes)` | +| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | +| Hierarchical Topic Modeling | `.hierarchical_topics(docs, topics)` | +| Guided Topic Modeling | `BERTopic(seed_topic_list=seed_topic_list)` | + For an overview of BERTopic's visualization methods: | Method | Code | |-----------------------|---| | Visualize Topics | `.visualize_topics()` | +| Visualize Documents | `.visualize_documents()` | +| Visualize Document Hierarchy | `.visualize_hierarchical_documents()` | | Visualize Topic Hierarchy | `.visualize_hierarchy()` | +| Visualize Topic Tree | `.get_topic_tree(hierarchical_topics)` | | Visualize Topic Terms | `.visualize_barchart()` | | Visualize Topic Similarity | `.visualize_heatmap()` | | Visualize Term Score Decline | `.visualize_term_rank()` | diff --git a/docs/index.md b/docs/index.md index 4ec98ddb..3ac0792a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -91,8 +91,6 @@ For quick access to common functions, here is an overview of BERTopic's main met | Get topic freq | `.get_topic_freq()` | | Get all topic information| `.get_topic_info()` | | Get representative docs per topic | `.get_representative_docs()` | -| Get topics per class | `.topics_per_class(docs, topics, classes)` | -| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | @@ -100,12 +98,25 @@ For quick access to common functions, here is an overview of BERTopic's main met | Load model | `BERTopic.load("my_model")` | | Get parameters | `.get_params()` | +For an overview of extensions to and variations of BERTopic, such as topics over time: + +| Method | Code | +|-----------------------|---| +| (semi-) Supervised Topic Modeling | `.fit(docs, y=y)` | +| Topic Modeling per Class | `.topics_per_class(docs, topics, classes)` | +| Dynamic Topic Modeling | `.topics_over_time(docs, topics, timestamps)` | +| Hierarchical Topic Modeling | `.hierarchical_topics(docs, topics)` | +| Guided Topic Modeling | `BERTopic(seed_topic_list=seed_topic_list)` | + For an overview of BERTopic's visualization methods: | Method | Code | |-----------------------|---| | Visualize Topics | `.visualize_topics()` | +| Visualize Documents | `.visualize_documents()` | +| Visualize Document Hierarchy | `.visualize_hierarchical_documents()` | | Visualize Topic Hierarchy | `.visualize_hierarchy()` | +| Visualize Topic Tree | `.get_topic_tree(hierarchical_topics)` | | Visualize Topic Terms | `.visualize_barchart()` | | Visualize Topic Similarity | `.visualize_heatmap()` | | Visualize Term Score Decline | `.visualize_term_rank()` | From 2fdc4d634596e750823437d3f1b2c0b9d5ca54dd Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Tue, 28 Jun 2022 14:35:12 +0200 Subject: [PATCH 06/15] Update styling tables --- README.md | 13 +++++++++++-- docs/index.md | 13 +++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b4307b26..759a6d11 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,10 @@ topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=6) ## Overview +BERTopic has quite a number of functions that quickly can become overwhelming. To alleviate this issue, you will find an overview +of all methods and a short description of its purpose. + +### Common For quick access to common functions, here is an overview of BERTopic's main methods: | Method | Code | @@ -215,7 +219,9 @@ For quick access to common functions, here is an overview of BERTopic's main met | Load model | `BERTopic.load("my_model")` | | Get parameters | `.get_params()` | -For an overview of extensions to and variations of BERTopic, such as topics over time: +### Variations +There are many different use cases in which topic modeling can be used. As such, a number of +variations of BERTopic have been developed such that one package can be used across across many use cases: | Method | Code | |-----------------------|---| @@ -225,7 +231,10 @@ For an overview of extensions to and variations of BERTopic, such as topics over | Hierarchical Topic Modeling | `.hierarchical_topics(docs, topics)` | | Guided Topic Modeling | `BERTopic(seed_topic_list=seed_topic_list)` | -For an overview of BERTopic's visualization methods: +### Visualizations +Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation. +Visualizing different aspects of the topic model helps in understanding the model and makes it easier +to tweak the model to your liking. | Method | Code | |-----------------------|---| diff --git a/docs/index.md b/docs/index.md index 3ac0792a..fb99f872 100644 --- a/docs/index.md +++ b/docs/index.md @@ -79,6 +79,10 @@ frequent topic that was generated, topic 0: ## **Overview** +BERTopic has quite a number of functions that quickly can become overwhelming. To alleviate this issue, you will find an overview +of all methods and a short description of its purpose. + +### Common For quick access to common functions, here is an overview of BERTopic's main methods: | Method | Code | @@ -98,7 +102,9 @@ For quick access to common functions, here is an overview of BERTopic's main met | Load model | `BERTopic.load("my_model")` | | Get parameters | `.get_params()` | -For an overview of extensions to and variations of BERTopic, such as topics over time: +### Variations +There are many different use cases in which topic modeling can be used. As such, a number of +variations of BERTopic have been developed such that one package can be used across across many use cases: | Method | Code | |-----------------------|---| @@ -108,7 +114,10 @@ For an overview of extensions to and variations of BERTopic, such as topics over | Hierarchical Topic Modeling | `.hierarchical_topics(docs, topics)` | | Guided Topic Modeling | `BERTopic(seed_topic_list=seed_topic_list)` | -For an overview of BERTopic's visualization methods: +### Visualizations +Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation. +Visualizing different aspects of the topic model helps in understanding the model and makes it easier +to tweak the model to your liking. | Method | Code | |-----------------------|---| From 20328b1be8521a3fe3cc28e2c793787d4699f97c Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Wed, 29 Jun 2022 15:23:54 +0200 Subject: [PATCH 07/15] Added native Hugging Face support --- bertopic/backend/_hftransformers.py | 96 +++++++++++++++++++++++++++++ bertopic/backend/_utils.py | 6 ++ 2 files changed, 102 insertions(+) create mode 100644 bertopic/backend/_hftransformers.py diff --git a/bertopic/backend/_hftransformers.py b/bertopic/backend/_hftransformers.py new file mode 100644 index 00000000..b44ba84b --- /dev/null +++ b/bertopic/backend/_hftransformers.py @@ -0,0 +1,96 @@ +import numpy as np + +from tqdm import tqdm +from typing import List +from torch.utils.data import Dataset +from sklearn.preprocessing import normalize +from transformers.pipelines import Pipeline + +from bertopic.backend import BaseEmbedder + + +class HFTransformerBackend(BaseEmbedder): + """ Hugging Face transformers model + + This uses the `transformers.pipelines.pipeline` to define and create + a feature generation pipeline from which embeddings can be extracted. + + Arguments: + embedding_model: A Hugging Face feature extraction pipeline + + Usage: + + To use a Hugging Face transformers model, load in a pipeline and point + to any model found on their model hub (https://huggingface.co/models): + + ```python + from bertopic.backend import HFTransformerBackend + from transformers.pipelines import pipeline + + hf_model = pipeline("feature-extraction", model="distilbert-base-cased") + embedding_model = HFTransformerBackend(hf_model) + ``` + """ + def __init__(self, embedding_model: Pipeline): + super().__init__() + + if isinstance(embedding_model, Pipeline): + self.embedding_model = embedding_model + else: + raise ValueError("Please select a correct transformers pipeline. For example: " + "pipeline('feature-extraction', model='distilbert-base-cased', device=0)") + + def embed(self, + documents: List[str], + verbose: bool = False) -> np.ndarray: + """ Embed a list of n documents/words into an n-dimensional + matrix of embeddings + + Arguments: + documents: A list of documents or words to be embedded + verbose: Controls the verbosity of the process + + Returns: + Document/words embeddings with shape (n, m) with `n` documents/words + that each have an embeddings size of `m` + """ + dataset = MyDataset(documents) + + embeddings = [] + for document, features in tqdm(zip(documents, self.embedding_model(dataset, truncation=True, padding=True)), + total=len(dataset), disable=not verbose): + embeddings.append(self._embed(document, features)) + + return np.array(embeddings) + + def _embed(self, + document: str, + features: np.ndarray) -> np.ndarray: + """ Mean pooling + + Arguments: + document: The document for which to extract the attention mask + features: The embeddings for each token + + Adopted from: + https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#usage-huggingface-transformers + """ + token_embeddings = np.array(features) + attention_mask = self.embedding_model.tokenizer(document, truncation=True, padding=True, return_tensors="np")["attention_mask"] + input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), token_embeddings.shape) + sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=input_mask_expanded.sum(1).max()) + embedding = normalize(sum_embeddings / sum_mask)[0] + return embedding + + +class MyDataset(Dataset): + """ Dataset to pass to `transformers.pipelines.pipeline` """ + def __init__(self, docs): + self.docs = docs + + def __len__(self): + return len(self.docs) + + def __getitem__(self, idx): + return self.docs[idx] diff --git a/bertopic/backend/_utils.py b/bertopic/backend/_utils.py index a2b2eef9..a4db7860 100644 --- a/bertopic/backend/_utils.py +++ b/bertopic/backend/_utils.py @@ -1,5 +1,7 @@ from ._base import BaseEmbedder from ._sentencetransformers import SentenceTransformerBackend +from ._hftransformers import HFTransformerBackend +from transformers.pipelines import Pipeline languages = ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'assamese', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bengali romanize', @@ -59,6 +61,10 @@ def select_backend(embedding_model, if isinstance(embedding_model, str): return SentenceTransformerBackend(embedding_model) + # Hugging Face embeddings + if isinstance(embedding_model, Pipeline): + return HFTransformerBackend(embedding_model) + # Select embedding model based on language if language: if language.lower() in ["English", "english", "en"]: From c5296428a15f4a916095549fdc3160f4cb05b9b5 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Sat, 2 Jul 2022 16:38:50 +0200 Subject: [PATCH 08/15] Add custom labeling options --- bertopic/_bertopic.py | 160 ++++++++++++++++++- bertopic/plotting/_distribution.py | 26 +-- bertopic/plotting/_documents.py | 8 +- bertopic/plotting/_heatmap.py | 14 +- bertopic/plotting/_hierarchical_documents.py | 13 +- bertopic/plotting/_hierarchy.py | 38 +++-- bertopic/plotting/_term_rank.py | 11 +- bertopic/plotting/_topics_over_time.py | 10 +- bertopic/plotting/_topics_per_class.py | 10 +- tests/test_bertopic.py | 8 + 10 files changed, 263 insertions(+), 35 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index c19c101d..396725fb 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -14,7 +14,7 @@ import numpy as np import pandas as pd from tqdm import tqdm -from scipy.sparse.csr import csr_matrix +from scipy.sparse import csr_matrix from scipy.cluster import hierarchy as sch from typing import List, Tuple, Union, Mapping, Any, Callable @@ -177,10 +177,12 @@ def __init__(self, cluster_selection_method='eom', prediction_data=True) + # Attributes self.topics = None self.topic_mapper = None self.topic_sizes = None self.merged_topics = None + self.custom_labels = None self.topic_embeddings = None self.topic_sim_matrix = None self.representative_docs = None @@ -883,7 +885,7 @@ def get_topic(self, topic: int) -> Union[Mapping[str, Tuple[str, float]], bool]: return False def get_topic_info(self, topic: int = None) -> pd.DataFrame: - """ Get information about each topic including its id, frequency, and name + """ Get information about each topic including its ID, frequency, and name. Arguments: topic: A specific topic for which you want the frequency @@ -902,6 +904,11 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: info = pd.DataFrame(self.topic_sizes.items(), columns=['Topic', 'Count']).sort_values("Count", ascending=False) info["Name"] = info.Topic.map(self.topic_names) + if self.custom_labels is not None: + if len(self.custom_labels) == len(info): + labels = {topic - self._outliers: label for topic, label in enumerate(self.custom_labels)} + info["CustomName"] = info["Topic"].map(labels) + if topic: info = info.loc[info.Topic == topic, :] @@ -1070,6 +1077,119 @@ def _tree(to_print, start, parent, tree, grandpa=None, indent=""): start = str(hier_topics.Parent_ID.astype(int).max()) return get_tree(start, tree) + def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> None: + """ Set custom topic labels in your fitted BERTopic model + + Arguments: + topic_labels: If a list of topic labels, it should contain the same number + of labels as there are topics. This must be ordered + from the topic with the lowest ID to the highest ID, + including topic -1 if it exists. + If a dictionary of `topic ID`: `topic_label`, it can have + any number of topics as it will only map the topics found + in the dictionary. + + Usage: + + First, we define our topic labels with `.get_topic_labels` in which + we can customize our topic labels: + + ```python + topic_labels = topic_model.get_topic_labels(nr_words=2, + topic_prefix=True, + word_length=10, + separator=", ") + ``` + + Then, we pass these `topic_labels` to our topic model which + can be accessed at any time with `.custom_labels`: + + ```python + topic_model.set_topic_labels(topic_labels) + topic_model.custom_labels + ``` + + You might want to change only a few topic labels instead of all of them. + To do so, you can pass a dictionary where the keys are the topic IDs and + its keys the topic labels: + + ```python + topic_model.set_topic_labels({0: "Space", 1: "Sports", 2: "Medicine"}) + topic_model.custom_labels + ``` + """ + unique_topics = sorted(set(self._map_predictions(self.hdbscan_model.labels_))) + + if isinstance(topic_labels, dict): + if self.custom_labels is not None: + original_labels = {topic: label for topic, label in zip(unique_topics, self.custom_labels)} + else: + info = self.get_topic_info() + original_labels = dict(zip(info.Topic, info.Name)) + custom_labels = [topic_labels.get(topic) if topic_labels.get(topic) else original_labels[topic] for topic in unique_topics] + + elif isinstance(topic_labels, list): + if len(topic_labels) == len(unique_topics): + custom_labels = topic_labels + else: + raise ValueError("Make sure that `topic_labels` contains the same number " + "of labels as that there are topics.") + + self.custom_labels = custom_labels + + def generate_topic_labels(self, + nr_words: int = 3, + topic_prefix: bool = True, + word_length: int = None, + separator: str = "_") -> List[str]: + """ Get labels for each topic in a user-defined format + + Arguments: + original_labels: + nr_words: Top `n` words per topic to use + topic_prefix: Whether to use the topic ID as a prefix. + If set to True, the topic ID will be separated + using the `separator` + word_length: The maximum length of each word in the topic label. + Some words might be relatively long and setting this + value helps to make sure that all labels have relatively + similar lengths. + separator: The string with which the words and topic prefix will be + separated. Underscores are the default but a nice alternative + is `", "`. + + Returns: + topic_labels: A list of topic labels sorted from the lowest topic ID to the highest. + If the topic model was trained using HDBSCAN, the lowest topic ID is -1, + otherwise it is 0. + + Usage: + + To create our custom topic labels, usage is rather straightforward: + + ```python + topic_labels = topic_model.get_topic_labels(nr_words=2, separator=", ") + ``` + """ + unique_topics = sorted(set(self._map_predictions(self.hdbscan_model.labels_))) + topic_labels = [] + for topic in unique_topics: + words, _ = zip(*self.get_topic(topic)) + + if word_length: + words = [word[:word_length] for word in words][:nr_words] + else: + words = list(words)[:nr_words] + + if topic_prefix: + topic_label = f"{topic}{separator}" + separator.join(words) + else: + topic_label = separator.join(words) + + topic_labels.append(topic_label) + + return topic_labels + def reduce_topics(self, docs: List[str], topics: List[int], @@ -1173,6 +1293,7 @@ def visualize_documents(self, sample: float = None, hide_annotations: bool = False, hide_document_hover: bool = False, + custom_labels: bool = False, width: int = 1200, height: int = 750) -> go.Figure: """ Visualize documents and their topics in 2D @@ -1193,6 +1314,8 @@ def visualize_documents(self, hide_annotations: Hide the names of the traces on top of each cluster. hide_document_hover: Hide the content of the documents when hovering over specific points. Helps to speed up generation of visualization. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1250,6 +1373,7 @@ def visualize_documents(self, sample=sample, hide_annotations=hide_annotations, hide_document_hover=hide_document_hover, + custom_labels=custom_labels, width=width, height=height) @@ -1263,6 +1387,7 @@ def visualize_hierarchical_documents(self, hide_annotations: bool = False, hide_document_hover: bool = True, nr_levels: int = 10, + custom_labels: bool = False, width: int = 1200, height: int = 750) -> go.Figure: """ Visualize documents and their topics in 2D at different levels of hierarchy @@ -1290,6 +1415,10 @@ def visualize_hierarchical_documents(self, have a distance less or equal to the maximum distance of the selected list of distances. NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to the length of `hierarchical_topics`. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. width: The width of the figure. height: The height of the figure. @@ -1350,12 +1479,14 @@ def visualize_hierarchical_documents(self, hide_annotations=hide_annotations, hide_document_hover=hide_document_hover, nr_levels=nr_levels, + custom_labels=custom_labels, width=width, height=height) def visualize_term_rank(self, topics: List[int] = None, log_scale: bool = False, + custom_labels: bool = False, width: int = 800, height: int = 500) -> go.Figure: """ Visualize the ranks of all terms across all topics @@ -1369,6 +1500,8 @@ def visualize_term_rank(self, topics: A selection of topics to visualize. These will be colored red where all others will be colored black. log_scale: Whether to represent the ranking on a log scale + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1403,6 +1536,7 @@ def visualize_term_rank(self, return plotting.visualize_term_rank(self, topics=topics, log_scale=log_scale, + custom_labels=custom_labels, width=width, height=height) @@ -1411,6 +1545,7 @@ def visualize_topics_over_time(self, top_n_topics: int = None, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 450) -> go.Figure: """ Visualize topics over time @@ -1421,6 +1556,8 @@ def visualize_topics_over_time(self, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1449,6 +1586,7 @@ def visualize_topics_over_time(self, top_n_topics=top_n_topics, topics=topics, normalize_frequency=normalize_frequency, + custom_labels=custom_labels, width=width, height=height) @@ -1457,6 +1595,7 @@ def visualize_topics_per_class(self, top_n_topics: int = 10, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 900) -> go.Figure: """ Visualize topics per class @@ -1467,6 +1606,8 @@ def visualize_topics_per_class(self, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1495,12 +1636,14 @@ def visualize_topics_per_class(self, top_n_topics=top_n_topics, topics=topics, normalize_frequency=normalize_frequency, + custom_labels=custom_labels, width=width, height=height) def visualize_distribution(self, probabilities: np.ndarray, min_probability: float = 0.015, + custom_labels: bool = False, width: int = 800, height: int = 600) -> go.Figure: """ Visualize the distribution of topic probabilities @@ -1509,6 +1652,8 @@ def visualize_distribution(self, probabilities: An array of probability scores min_probability: The minimum probability score to visualize. All others are ignored. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1532,6 +1677,7 @@ def visualize_distribution(self, return plotting.visualize_distribution(self, probabilities=probabilities, min_probability=min_probability, + custom_labels=custom_labels, width=width, height=height) @@ -1539,6 +1685,7 @@ def visualize_hierarchy(self, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, + custom_labels: bool = False, width: int = 1000, height: int = 600, hierarchical_topics: pd.DataFrame = None, @@ -1557,6 +1704,10 @@ def visualize_hierarchy(self, Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' hierarchical_topics: A dataframe that contains a hierarchy of topics @@ -1612,6 +1763,7 @@ def visualize_hierarchy(self, orientation=orientation, topics=topics, top_n_topics=top_n_topics, + custom_labels=custom_labels, width=width, height=height, hierarchical_topics=hierarchical_topics, @@ -1624,6 +1776,7 @@ def visualize_heatmap(self, topics: List[int] = None, top_n_topics: int = None, n_clusters: int = None, + custom_labels: bool = False, width: int = 800, height: int = 800) -> go.Figure: """ Visualize a heatmap of the topic's similarity matrix @@ -1636,6 +1789,8 @@ def visualize_heatmap(self, top_n_topics: Only select the top n most frequent topics. n_clusters: Create n clusters and order the similarity matrix by those clusters. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1663,6 +1818,7 @@ def visualize_heatmap(self, topics=topics, top_n_topics=top_n_topics, n_clusters=n_clusters, + custom_labels=custom_labels, width=width, height=height) diff --git a/bertopic/plotting/_distribution.py b/bertopic/plotting/_distribution.py index ea2ea0a3..8fe4b67f 100644 --- a/bertopic/plotting/_distribution.py +++ b/bertopic/plotting/_distribution.py @@ -5,6 +5,7 @@ def visualize_distribution(topic_model, probabilities: np.ndarray, min_probability: float = 0.015, + custom_labels: bool = False, width: int = 800, height: int = 600) -> go.Figure: """ Visualize the distribution of topic probabilities @@ -14,6 +15,8 @@ def visualize_distribution(topic_model, probabilities: An array of probability scores min_probability: The minimum probability score to visualize. All others are ignored. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -50,16 +53,19 @@ def visualize_distribution(topic_model, vals = probabilities[labels_idx].tolist() # Create labels - labels = [] - for idx in labels_idx: - words = topic_model.get_topic(idx) - if words: - label = [word[0] for word in words[:5]] - label = f"Topic {idx}: {'_'.join(label)}" - label = label[:40] + "..." if len(label) > 40 else label - labels.append(label) - else: - vals.remove(probabilities[idx]) + if topic_model.custom_labels is not None and custom_labels: + labels = [topic_model.custom_labels[idx + topic_model._outliers] for idx in labels_idx] + else: + labels = [] + for idx in labels_idx: + words = topic_model.get_topic(idx) + if words: + label = [word[0] for word in words[:5]] + label = f"Topic {idx}: {'_'.join(label)}" + label = label[:40] + "..." if len(label) > 40 else label + labels.append(label) + else: + vals.remove(probabilities[idx]) # Create Figure fig = go.Figure(go.Bar( diff --git a/bertopic/plotting/_documents.py b/bertopic/plotting/_documents.py index aa380079..3d937ff1 100644 --- a/bertopic/plotting/_documents.py +++ b/bertopic/plotting/_documents.py @@ -14,6 +14,7 @@ def visualize_documents(topic_model, sample: float = None, hide_annotations: bool = False, hide_document_hover: bool = False, + custom_labels: bool = False, width: int = 1200, height: int = 750): """ Visualize documents and their topics in 2D @@ -34,6 +35,8 @@ def visualize_documents(topic_model, hide_annotations: Hide the names of the traces on top of each cluster. hide_document_hover: Hide the content of the documents when hovering over specific points. Helps to speed up generation of visualization. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -129,7 +132,10 @@ def visualize_documents(topic_model, df["y"] = embeddings_2d[:, 1] # Prepare text and names - names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics] + if topic_model.custom_labels is not None and custom_labels: + names = [topic_model.custom_labels[topic + topic_model._outliers] for topic in unique_topics] + else: + names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics] # Visualize fig = go.Figure() diff --git a/bertopic/plotting/_heatmap.py b/bertopic/plotting/_heatmap.py index 7f51c1e0..fe845550 100644 --- a/bertopic/plotting/_heatmap.py +++ b/bertopic/plotting/_heatmap.py @@ -11,6 +11,7 @@ def visualize_heatmap(topic_model, topics: List[int] = None, top_n_topics: int = None, n_clusters: int = None, + custom_labels: bool = False, width: int = 800, height: int = 800) -> go.Figure: """ Visualize a heatmap of the topic's similarity matrix @@ -24,6 +25,8 @@ def visualize_heatmap(topic_model, top_n_topics: Only select the top n most frequent topics. n_clusters: Create n clusters and order the similarity matrix by those clusters. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -90,10 +93,13 @@ def visualize_heatmap(topic_model, embeddings = embeddings[indices] distance_matrix = cosine_similarity(embeddings) - # Create nicer labels - new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + # Create labels + if topic_model.custom_labels is not None and custom_labels: + new_labels = [topic_model.custom_labels[topic + topic_model._outliers] for topic in sorted_topics] + else: + new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics] + new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] + new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] fig = px.imshow(distance_matrix, labels=dict(color="Similarity Score"), diff --git a/bertopic/plotting/_hierarchical_documents.py b/bertopic/plotting/_hierarchical_documents.py index 6c3d839f..c8162925 100644 --- a/bertopic/plotting/_hierarchical_documents.py +++ b/bertopic/plotting/_hierarchical_documents.py @@ -16,6 +16,7 @@ def visualize_hierarchical_documents(topic_model, hide_annotations: bool = False, hide_document_hover: bool = True, nr_levels: int = 10, + custom_labels: bool = False, width: int = 1200, height: int = 750) -> go.Figure: """ Visualize documents and their topics in 2D at different levels of hierarchy @@ -43,6 +44,10 @@ def visualize_hierarchical_documents(topic_model, have a distance less or equal to the maximum distance of the selected list of distances. NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to the length of `hierarchical_topics`. + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. width: The width of the figure. height: The height of the figure. @@ -169,9 +174,11 @@ def visualize_hierarchical_documents(topic_model, for topic in range(hierarchical_topics.Parent_ID.astype(int).max()): if topic < hierarchical_topics.Parent_ID.astype(int).min(): if topic_model.get_topic(topic): - plot_text = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3]) - trace_name = f"{topic}_" + "_".join([word for word, _ in topic_model.get_topic(topic)][:3]) - topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]} + if topic_model.custom_labels is not None and custom_labels: + trace_name = topic_model.custom_labels[topic + topic_model._outliers] + else: + trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3]) + topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]} trace_names.append(trace_name) else: trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0] diff --git a/bertopic/plotting/_hierarchy.py b/bertopic/plotting/_hierarchy.py index b7d0c50f..b1fef674 100644 --- a/bertopic/plotting/_hierarchy.py +++ b/bertopic/plotting/_hierarchy.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd from typing import Callable, List -from scipy.sparse.csr import csr_matrix +from scipy.sparse import csr_matrix from scipy.cluster import hierarchy as sch from sklearn.metrics.pairwise import cosine_similarity @@ -13,6 +13,7 @@ def visualize_hierarchy(topic_model, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, + custom_labels: bool = False, width: int = 1000, height: int = 600, hierarchical_topics: pd.DataFrame = None, @@ -31,6 +32,10 @@ def visualize_hierarchy(topic_model, Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' hierarchical_topics: A dataframe that contains a hierarchy of topics @@ -109,7 +114,8 @@ def visualize_hierarchy(topic_model, embeddings=embeddings, distance_function=distance_function, linkage_function=linkage_function, - orientation=orientation) + orientation=orientation, + custom_labels=custom_labels) else: annotations = None @@ -123,10 +129,13 @@ def visualize_hierarchy(topic_model, # Create nicer labels axis = "yaxis" if orientation == "left" else "xaxis" - new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) - for x in fig.layout[axis]["ticktext"]] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + if topic_model.custom_labels is not None and custom_labels: + new_labels = [topic_model.custom_labels[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] + else: + new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) + for x in fig.layout[axis]["ticktext"]] + new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] + new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] # Stylize layout fig.update_layout( @@ -184,7 +193,8 @@ def _get_annotations(topic_model, embeddings: csr_matrix, linkage_function: Callable[[csr_matrix], np.ndarray], distance_function: Callable[[csr_matrix], csr_matrix], - orientation: str) -> List[List[str]]: + orientation: str, + custom_labels: bool = False) -> List[List[str]]: """ Get annotations by replicating linkage function calculation in scipy @@ -205,6 +215,10 @@ def _get_annotations(topic_model, in `topic_model.hierarchical_topics`. orientation: The orientation of the figure. Either 'left' or 'bottom' + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + NOTE: Custom labels are only generated for the original + un-merged topics. Returns: text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` @@ -233,14 +247,20 @@ def _get_annotations(topic_model, scnd_topic = topic_vals[trace[2]] if len(fst_topic) == 1: - fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) + if topic_model.custom_labels is not None and custom_labels: + fst_name = topic_model.custom_labels[fst_topic[0] + topic_model._outliers] + else: + fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) else: for key, value in parent_topic.items(): if set(value) == set(fst_topic): fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] if len(scnd_topic) == 1: - scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) + if topic_model.custom_labels is not None and custom_labels: + scnd_name = topic_model.custom_labels[scnd_topic[0] + topic_model._outliers] + else: + scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) else: for key, value in parent_topic.items(): if set(value) == set(scnd_topic): diff --git a/bertopic/plotting/_term_rank.py b/bertopic/plotting/_term_rank.py index ced45ceb..a66aea39 100644 --- a/bertopic/plotting/_term_rank.py +++ b/bertopic/plotting/_term_rank.py @@ -6,6 +6,7 @@ def visualize_term_rank(topic_model, topics: List[int] = None, log_scale: bool = False, + custom_labels: bool = False, width: int = 800, height: int = 500) -> go.Figure: """ Visualize the ranks of all terms across all topics @@ -20,6 +21,8 @@ def visualize_term_rank(topic_model, topics: A selection of topics to visualize. These will be colored red where all others will be colored black. log_scale: Whether to represent the ranking on a log scale + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -69,9 +72,13 @@ def visualize_term_rank(topic_model, lines = [] for topic, x, y in zip(topic_ids, indices, values): if not any(y > 1.5): + # labels - label = f"Topic {topic}:" + "_".join([word[0] for word in topic_model.get_topic(topic)]) - label = label[:50] + if topic_model.custom_labels is not None and custom_labels: + label = topic_model.custom_labels[topic + topic_model._outliers] + else: + label = f"Topic {topic}:" + "_".join([word[0] for word in topic_model.get_topic(topic)]) + label = label[:50] # line parameters color = "red" if topic in topics else "black" diff --git a/bertopic/plotting/_topics_over_time.py b/bertopic/plotting/_topics_over_time.py index abaeeb84..a13282e4 100644 --- a/bertopic/plotting/_topics_over_time.py +++ b/bertopic/plotting/_topics_over_time.py @@ -9,6 +9,7 @@ def visualize_topics_over_time(topic_model, top_n_topics: int = None, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 450) -> go.Figure: """ Visualize topics over time @@ -20,6 +21,8 @@ def visualize_topics_over_time(topic_model, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -55,8 +58,11 @@ def visualize_topics_over_time(topic_model, selected_topics = topic_model.get_topic_freq().Topic.values # Prepare data - topic_names = {key: value[:40] + "..." if len(value) > 40 else value - for key, value in topic_model.topic_names.items()} + if topic_model.custom_labels is not None and custom_labels: + topic_names = {key: topic_model.custom_labels[key + topic_model._outliers] for key, _ in topic_model.topic_names.items()} + else: + topic_names = {key: value[:40] + "..." if len(value) > 40 else value + for key, value in topic_model.topic_names.items()} topics_over_time["Name"] = topics_over_time.Topic.map(topic_names) data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"]) diff --git a/bertopic/plotting/_topics_per_class.py b/bertopic/plotting/_topics_per_class.py index 3f064017..0894ea0c 100644 --- a/bertopic/plotting/_topics_per_class.py +++ b/bertopic/plotting/_topics_per_class.py @@ -9,6 +9,7 @@ def visualize_topics_per_class(topic_model, top_n_topics: int = 10, topics: List[int] = None, normalize_frequency: bool = False, + custom_labels: bool = False, width: int = 1250, height: int = 900) -> go.Figure: """ Visualize topics per class @@ -20,6 +21,8 @@ def visualize_topics_per_class(topic_model, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually + custom_labels: Whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -55,8 +58,11 @@ def visualize_topics_per_class(topic_model, selected_topics = topic_model.get_topic_freq().Topic.values # Prepare data - topic_names = {key: value[:40] + "..." if len(value) > 40 else value - for key, value in topic_model.topic_names.items()} + if topic_model.custom_labels is not None and custom_labels: + topic_names = {key: topic_model.custom_labels[key + topic_model._outliers] for key, _ in topic_model.topic_names.items()} + else: + topic_names = {key: value[:40] + "..." if len(value) > 40 else value + for key, value in topic_model.topic_names.items()} topics_per_class["Name"] = topics_per_class.Topic.map(topic_names) data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :] diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index e8d722a1..2e5f0b74 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -89,3 +89,11 @@ def test_full_model(topic_model): assert topic != updated_topic assert topic == original_topic + + # Test updating topic labels + topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ") + assert len(topic_labels) == len(set(new_topics)) + + # Test setting topic labels + topic_model.set_topic_labels(topic_labels) + assert topic_model.custom_labels == topic_labels \ No newline at end of file From 2107b808e2204df4bf3c4e64549929786993db3d Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 08:28:58 +0200 Subject: [PATCH 09/15] Fix #572 --- bertopic/_bertopic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 396725fb..ccd3f2c1 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -912,7 +912,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: if topic: info = info.loc[info.Topic == topic, :] - return info + return info.reset_index(drop=True) def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: """ Return the the size of topics (descending order) From aac064c03a679e9acc0bd4e5210c15552dfba2a5 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 08:42:03 +0200 Subject: [PATCH 10/15] Fix #581 --- bertopic/_bertopic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index ccd3f2c1..d20c48d3 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -901,7 +901,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: """ check_is_fitted(self) - info = pd.DataFrame(self.topic_sizes.items(), columns=['Topic', 'Count']).sort_values("Count", ascending=False) + info = pd.DataFrame(self.topic_sizes.items(), columns=["Topic", "Count"]).sort_values("Topic", ascending=False) info["Name"] = info.Topic.map(self.topic_names) if self.custom_labels is not None: From 90c73a6fe44fc2ab5d294a7376cdac25703c1cbe Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 10:36:32 +0200 Subject: [PATCH 11/15] Added example for finding similar topics between two models --- .../tips_and_tricks/tips_and_tricks.md | 83 ++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/docs/getting_started/tips_and_tricks/tips_and_tricks.md b/docs/getting_started/tips_and_tricks/tips_and_tricks.md index b76fbbe6..bc71d4c5 100644 --- a/docs/getting_started/tips_and_tricks/tips_and_tricks.md +++ b/docs/getting_started/tips_and_tricks/tips_and_tricks.md @@ -150,4 +150,85 @@ force a cosine-related distance metric in UMAP: ```python from cuml.preprocessing import normalize embeddings = normalize(embeddings) -``` \ No newline at end of file +``` + +## **Finding similar topics between models** + +Whenever you have trained seperate BERTopic models on different datasets, it might +be worthful to find the similarities among these models. Is there overlap between +topics in model A and topic in model B? In other words, can we find topics in model A that are similar to those in model B? + +We can compare the topic representations of several models in two ways. First, by comparing the topic embeddings that are created when using the same embedding model across both fitted BERTopic instances. Second, we can compare the c-TF-IDF representations instead assuming we have fixed the vocabulary in both instances. + +This example will go into the former, using the same embedding model across two BERTopic instances. To do this comparison, let's first create an example where I trained two models, one on an English dataset and one on a Dutch dataset: + +```python +from datasets import load_dataset +from bertopic import BERTopic +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# The same embedding model needs to be used for both topic models +# and since we are dealing with multiple languages, the model needs to be multi-lingual +sentence_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2") + +# To make this example reproducible +umap_model = UMAP(n_neighbors=15, n_components=5, + min_dist=0.0, metric='cosine', random_state=42) + +# English +en_dataset = load_dataset("stsb_multi_mt", name="en", split="train").to_pandas().sentence1.tolist() +en_model = BERTopic(embedding_model=sentence_model, umap_model=umap_model) +en_model.fit(en_dataset) + +# Dutch +nl_dataset = load_dataset("stsb_multi_mt", name="nl", split="train").to_pandas().sentence1.tolist() +nl_model = BERTopic(embedding_model=sentence_model, umap_model=umap_model) +nl_model.fit(nl_dataset) +``` + +In the code above, there is one important thing to note and that is the `sentence_model`. This model needs to be exactly the same in all BERTopic models, otherwise, it is not possible to compare topic models. + +Next, we can calculate the similarity between topics in the English topic model `en_model` and the Dutch model `nl_model`. To do so, we can simply calculate the cosine similarity between the `topic_embedding` of both models: + +```python +from sklearn.metrics.pairwise import cosine_similarity +sim_matrix = cosine_similarity(en_model.topic_embeddings, nl_model.topic_embeddings) +``` + +Now that we know which topics are similar to each other, we can extract the most similar topics. Let's say that we have topic 10 in the `en_model` which represents a topic related to trains: + +```python +>>> topic = 10 +>>> en_model.get_topic(topic) +[('train', 0.2588080580844999), + ('tracks', 0.1392140438801078), + ('station', 0.12126454635946024), + ('passenger', 0.058057876475695866), + ('engine', 0.05123717127783682), + ('railroad', 0.048142847325312044), + ('waiting', 0.04098973702226946), + ('track', 0.03978248702913929), + ('subway', 0.03834661195748458), + ('steam', 0.03834661195748458)] +``` + +To find the matching topic, we extract the most similar topic in the `sim_matrix`: + +```python +>>> most_similar_topic = np.argmax(sim_matrix[topic + 1])-1 +>>> nl_model.get_topic(most_similar_topic) +[('trein', 0.24186603209316418), + ('spoor', 0.1338118418551581), + ('sporen', 0.07683661859111401), + ('station', 0.056990389779394225), + ('stoommachine', 0.04905829711711234), + ('zilveren', 0.04083879598477808), + ('treinen', 0.03534099197032758), + ('treinsporen', 0.03534099197032758), + ('staat', 0.03481332997324445), + ('zwarte', 0.03179591746822408)] +``` + +It seems to be working as, for example, `trein` is a translation of `train` and `sporen` a translation of `tracks`! You can do this for every single topic to find out which topic in the `en_model` might belong to a model in the `nl_model`. From b11542e8b52c78077b00367c95a2c8db7cac5115 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 11:08:02 +0200 Subject: [PATCH 12/15] Add documentation for custom labels --- README.md | 2 + .../topicrepresentation.md | 64 ++++++++++++++++++- docs/index.md | 2 + 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 759a6d11..56c15a29 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,8 @@ For quick access to common functions, here is an overview of BERTopic's main met | Get all topic information| `.get_topic_info()` | | Get representative docs per topic | `.get_representative_docs()` | | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | +| Generate topic labels | `.generate_topic_labels()` | +| Set topic labels | `.set_topic_labels(my_custom_labels)` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | | Save model | `.save("my_model")` | diff --git a/docs/getting_started/topicrepresentation/topicrepresentation.md b/docs/getting_started/topicrepresentation/topicrepresentation.md index 6f5d3c12..1c1e9a03 100644 --- a/docs/getting_started/topicrepresentation/topicrepresentation.md +++ b/docs/getting_started/topicrepresentation/topicrepresentation.md @@ -61,4 +61,66 @@ instead: from sklearn.feature_extraction.text import CountVectorizer vectorizer_model = CountVectorizer(stop_words="English", ngram_range=(1, 5)) topic_model.update_topics(docs, topics, vectorizer_model=vectorizer_model) -``` \ No newline at end of file +``` + +### **Custom labels** + +The topic labels are currently automatically generated by taking the top 3 words and combining them +using the `_` separator. Although this is an informative label, in practice, this is definitely not the prettiest nor necessarily the most accurate label. For example, although the topic label +`1_space_nasa_orbit` is informative, we would prefer to have a bit more intuitive label, such as +`space travel`. The difficulty with creating such topic labels is that much of the interpretation is left to the user. Would `space travel` be more accurate or perhaps `space explorations`? To truly understand which labels are most suited, going into some of the documents in topics is especially helpful. + +Although we can go through every single topic ourselves and try to label them, we can start by creating an overview of labels that have the length and number of words that we are looking for. To do so, we can generate our list of topic labels with `.get_topic_labels` and define the number of words, the separator, word length, etc: + +```python +topic_labels = topic_model.generate_topic_labels(topic_model, +nr_words=3, + topic_prefix=False, + word_length=10, + separator=", ") +``` + +In the above example, `1_space_nasa_orbit` would turn into `space, nasa, orbit` since we selected 3 words, no topic prefix, and the `, ` separator. We can then either change our `topic_labels` to whatever we want or directly pass them to `.set_topic_labels` so that they can be used across most visualization functions: + +```python +topic_model.set_topic_labels(topics_labels) +``` + +It is also possible to only change a few topic labels at a time by passing a dictionary +where the key represents the *topic ID* and the value is the *topic label*: + +```python +topic_model.set_topic_labels({1: "Space Travel", 7: "Religion"}) +``` + +Then, to make use of those custom topic labels across visualizations, such as `.visualize_hierarchy()`, +we can use the `custom_labels=True` parameter that is found in most visualizations. + +```python +fig = topic_model.visualize_barchart(custom_labels=True) +``` + +#### Optimize labels +The great advantage of passing custom labels to BERTopic is that when more accurate zero-shot are released, +we can simply use those on top of BERTopic to further fine-tune the labeling. For example, let's say you +have a set of potential topic labels that you want to use instead of the ones generated by BERTopic. You could +use the [bart-large-mnli](https://huggingface.co/facebook/bart-large-mnli) model to find which user-defined +labels best represent the BERTopic-generated labels: + + +```python +from transformers import pipeline +classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + +# A selected topic representation +# sequence_to_classify = 'god jesus atheists atheism belief atheist believe exist beliefs existence' +sequence_to_classify = " ".join([word for word, _ in topic_model.get_topic(1)]) + +# Our set of potential topic labels +candidate_labels = ['cooking', 'dancing', 'religion'] +classifier(sequence_to_classify, candidate_labels) + +#{'labels': ['cooking', 'dancing', 'religion'], +# 'scores': [0.086, 0.063, 0.850], +# 'sequence': 'god jesus atheists atheism belief atheist believe exist beliefs existence'} +``` diff --git a/docs/index.md b/docs/index.md index fb99f872..9a30b2ff 100644 --- a/docs/index.md +++ b/docs/index.md @@ -96,6 +96,8 @@ For quick access to common functions, here is an overview of BERTopic's main met | Get all topic information| `.get_topic_info()` | | Get representative docs per topic | `.get_representative_docs()` | | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | +| Generate topic labels | `.generate_topic_labels()` | +| Set topic labels | `.set_topic_labels(my_custom_labels)` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | | Save model | `.save("my_model")` | From 5fca843e895c7fe99e42f985b171c6267f11dbef Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 11:43:41 +0200 Subject: [PATCH 13/15] Add multi-modal example, fix sorting, add hugging face example --- bertopic/_bertopic.py | 2 +- docs/getting_started/embeddings/embeddings.md | 17 +++ .../tips_and_tricks/skateboarders.jpg | Bin 0 -> 29510 bytes .../tips_and_tricks/tips_and_tricks.md | 119 ++++++++++++++++++ 4 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 docs/getting_started/tips_and_tricks/skateboarders.jpg diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index d20c48d3..2b0d598e 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -901,7 +901,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: """ check_is_fitted(self) - info = pd.DataFrame(self.topic_sizes.items(), columns=["Topic", "Count"]).sort_values("Topic", ascending=False) + info = pd.DataFrame(self.topic_sizes.items(), columns=["Topic", "Count"]).sort_values("Topic") info["Name"] = info.Topic.map(self.topic_names) if self.custom_labels is not None: diff --git a/docs/getting_started/embeddings/embeddings.md b/docs/getting_started/embeddings/embeddings.md index dd487772..62c489e3 100644 --- a/docs/getting_started/embeddings/embeddings.md +++ b/docs/getting_started/embeddings/embeddings.md @@ -24,6 +24,23 @@ topic_model = BERTopic(embedding_model=sentence_model) !!! tip "Tip!" This embedding back-end was put here first for a reason, sentence-transformers works amazing out-of-the-box! Playing around with different models can give you great results. Also, make sure to frequently visit [this](https://www.sbert.net/docs/pretrained_models.html) page as new models are often released. +### 🤗 Hugging Face Transformers +To use a Hugging Face transformers model, load in a pipeline and point +to any model found on their model hub (https://huggingface.co/models): + +```python +from bertopic.backend import HFTransformerBackend +from transformers.pipelines import pipeline + +hf_model = pipeline("feature-extraction", model="distilbert-base-cased") +embedding_model = HFTransformerBackend(hf_model) + +topic_model = BERTopic(embedding_model=document_glove_embeddings) +``` + +!!! tip "Tip!" + These transformers also work quite well using `sentence-transformers` which has a number of + optimizations tricks that make using it a bit faster. ### **Flair** [Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that diff --git a/docs/getting_started/tips_and_tricks/skateboarders.jpg b/docs/getting_started/tips_and_tricks/skateboarders.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7144906a96988d306be3aac0b8aaf8e9c83cfde9 GIT binary patch literal 29510 zcmbTdcQ~7W_&@qY5PKE1X{lAzh}ITWdyk^1Dq2FRS*u8>QH|J2YQ%_Ht2K+F_SRUf zO>9-w+Vhvs_jj(}xvulqIp@jsPM#}I#&f^l_jtYT`+i^jxLO1lbTxG}0T2iT*b^VX z)e_*M5deP<00sttC;$Ky04azAAS0fEh(iU$`G3weLBas|f8Hk{UX%ie&j2HF7!t>S zhdpup_nrTF^?VBV@Rjg*=F10_kd~0TOZ=x9KpmihKqw&OR1_2x)KpZ|bZiWCw6t{G zEUb)d{JggWZt~u|A$a@V9YGPO$c>wlYEn>nMHLm5TM`d+)RnaFDXS>{_ah)GYHB)K zIxYqVE=A#+!ixW&kE=F-k%B}TlnMsj07w`?U`EhYC%{X*dotp@{5LiK*8w5{lai4` zC@86@i65wC07yV!FbOG`jEt0&_~{Vhb%2zSjOm7y8u>Mor;wXo%+kojTnhgC)$J^% z-~R~6ICzIsQn6lVW9PUfC?qT*3YEQkPflJz{egz2mbQ+rnfXHtORGoLj!w^>JG;PL zeSH01`3D3BMZAfOijH|3o0OcAn)W_D1Cy6ufGaF2E-C%=xu&+RzM=8Uw~o%P?w;Ph ze*D<@#N^bEpVLdrgq78`^^MJ~y}$qV4-SuxPfq`n3j~1ww^;v2vj2x%j6}IeNJ+t@ zkpJWYk@yn_n30t1h7>uInhE5o*R`9{NDAiriMiG7l>9QLe^?y6zf-XaK$mXq{U_T0 zNcR6uu<-w1lKmgS{%^Ua0U9ufxOiYjKo!`}lgsrw=f@;Pr4Bf=2^>^POpwcdnAp!M zVU9&nLRhr_6;Gd4Uz&p^9__HH-Aq%~47Qm2QvAeFSUYJoacFwKdnxwkw1NWy$XBjX za6nTKxZ#La--7}+j|ME+?eIV>%8$K(#1Vp)hR3zM>rT_ZTid*j{!yOF&hFFMT6p{gZ&O9q z{Nu@o;!T^JKk3$YR2>^s-yc5|W%_-`jj@lwY00~{siKm9E^-C@NwYa%^>N&izYsp8 zZwS^#tB4?q2H zB`$EK4a?x8b#TFP26x?)!vngc)y8-lPbbOJr%yO{M#e*H6DS0{hWNy_1|87r7(k>3 z0WW~kxN8Z`zQpOV-I?Rn_2VNE#*(C^Q);;a`pfT?MnQV>NNPQOGORI34NJ*?mX{ek zk!QDT_T={{nK?T~|KC`wD3$oMx|=($c4e}9PJ6dXEMq!YD$#US{jzo;&(gZI z+$XUs>oGdvi+}WQb=l^fwJ05eN&k9bP;8U?(#f+L9b98L`G@rUZxPX2%2fs&l_TwIn&2>Pn>^`9lopbe!q-AA0lJm^twIJtZVeb zcVxfx?zEl}V`9b5^5ZVee|=pKT&+);m!;P_H~*0~7;l+ZImwgSFJ1-& zQxI@A{t*sO7-#Dgd1{1LuGu=PMJ@}>h$#W`FV;Rm=1)x>#nlJZ_bz=Ku=mXjZp6R~ z5ge}a@|vgS)v>LfA!Oow*Pu0kNh~0|qfdg~2O9HR0AXC=4Hqm)ERd?QBT{C=)Qu-9 zeJO$xr_zvz{j^$tn3m!3DbT)Ms-tteDRE2DJm}&L<7~T_YJ-d>SJ${m#rTx+_W@^j z&EvZH=M7QM-ePXNX!MjO^Ew&eEq{RVP5gLgumh5ya5`yG*eYA6y8h`3*wC}xnD_Mw6ibVbZ`564l(9HLMYg^iV4M)&7d3oNsRGdx(b<#s0T$9lD*Q_||Gy$W6Fd2+Z5 zo&ucG)FD3e`M$NkEKTz@hnB4Ba_uGFzVJbd?6t$66?6m~-uxT*qjDsv4E=+}e={_X zx02Nap8k!V@A4P#7w_Am?e(s8c`{j_(EEgQ$uRbr><$Ecs}1exmNNTokD7mgj*aVD zCP$~x+O7_&o^p^sdo9y|gb4$uQP+n8(dAk2>t6l0R_(+CndTp=Lb z0|gw?5X@B8q|9bO9zr1gw^`|AXSGN<9wAym^T~NSHoWt4rWyTrS=)ox32j3I8o?Pi zH)XF8cCIxkR&51amTIm0olVP6lS&bQ+s98$LrM2Bjdg7&gaB%mgK1YGgIi9~ z?2l9ZXmu7Ixy{_%;s-ItoVC&T925?H+`7)#Y)Ej5pvjdoFtf0P+*CBsJZ*n$@4#|$ z<30bjsr8osVSaebZHLzGZcLMFr;CXZz?gr7R+qhP?tzL2N5y#SAH9+Ait_ zP#Spgs%n#mqIngoilfnyhC%|ZvH=N#H4OtmozW?P$$bEa1-B@qhsOfex{IB_l2L3K+1-1n0C~0h0w+z#CPA-CKS8CtdD4qW;n=G2MSs zBculF<#x{6dp2!V@;%bNWKA?X{X(0TB`2drxeAV+bqT-NwY-)l$X0RjD*& zUUjIA#R>;Ma|lkh?ptaR84b%5ZeLFxQi1#Pr(UnQ0_fmjE`W99$m5S{;dcZd<<(Pd1vsq78k&ZvY4B?o^T&crE<~80 z9!IYQD67eTeBmrq5oF5ldgp+svI-qV7EX_21=BpkvZ zI2;j`hM*EA&M7e*MCK7sNzBhuifXNIMVD({2$R%=Z9m)TCC+JGFZL7anz{B+$9A&a z<8tstOV}0gc;qiLb8JAD*_SIIzRXY+6|1f=!*b#{8+gc4Psyq9yo70BI*@$DF%xXv z6|ItEO$tH~N)Co@*xN*bw_#M#@+&;>u=(vB7-fPS^BZ!Ip~5I&NUjFOI>qY(oqJSO zE{e{47OLv#@u)}dc&xg}5mf5hqqwvqoITu8}7u%c_)C~(-P z8YE08r%{qvKN2btgo!rgwcle+lQD@3&oSd5>uYBx(S@qzEMypUvngG|nd}v?+s<5G zvc1B!nG~WFp_g&Uou?^l#3X$M^hy}R2p3%cY~9ZWByiqGl&gbyrVY)znRHSHy1$?s zmOkjL5YfnOxwN?ilJ;R4ecC3Qa=ii5QV36YO>X0ajmtwAXrH;FE7#-pkoR@+yC@Z@ zQHvDQfa>oyeRGNJKOoGF>HGA@WzwOVS3pl#cV_i*C2D;$QDgOo^7q%3EDWB*AB{-A z#BUqg7%?R#LlM~p+9LD!M^$>3pO5U?hTLKv`E>ej7_XliDnz7GT94rhC}CGx8Hm7#%f?@IL=5~)ltw1-U!5HnQXjFF?xOg20h2J?1O$1JgJ z7_q!v*w#U1^QZTWj}?QYwBbO94y=1j;4E&x1jNU?u@^DS;CE_P3KK2_&A*2N8Ov#J zK)MQJY$o$aDiIe36BxA)j7c+8_f^p3I|r+*l6XAgcDIi}c6V&}@K7K)T%`uP=$zC4 zF%d$_)3|B(temJgL03Tazmdyu5oXjHo2>QfXP@V--#&2w+hQ{-wxFo;m_Xjij9WX3 zf&^t1Gi!ub`yv(aP@u>w2anD!T9`y!7vG*t*BK7z#X|1b9^#ZeRjmORdAlIzOmxla zit{8g{q#4-Bsii^{-ddnHs#V4V6IY_?E(^MDFMa|uun}I`l}9D#6X(k5BC96JV<^G zec=M~R2Eh;22B%fUU3WrSNzXJq{z58KrUEF(SLwiJEdn+ourxnDCxqbSnOBZ4Hqiu zIUSGrRHrBzbcd27HE8x&XO?xENO~yO}%P_#sIq{aAJ=e88vh{5(MRolE z5w*H`)$RmM;;BpWAH+Jw97#|+PJfwTbFtr&G1AyG#BTg$TXyCWo4rd2uaFMP|-?KS{}cN0fSm|u6E8j~== zHjD_mMV{--G<|#IO!IWwJkKEntH8sxglipxj|2`>hjU<*#0CEbuc6U}Z)Vqu0X)jW zL~2sb473Ou%5gsz`A{y9@;7lRb`U+v z#d4(nWdhWoncQ7F#|kI6`o+JmuYijbQx2J)Q19FPwpW1V^(Kzu-}kE1vy{JBI#>4M z1D-taclc$Qrfk&j)_86rubAM{$$5{EHL;;`z30bD4;;c9$(Dnp%jtBdktF-v@NWty zvZXUj>&Z=G&xk~aXb7{hfpLJWjv~S%i%#(jU)|2mergKz_O-2V>b9KOhdxgqYzbD; zSGocidseL;(+ZM~W_BH8`~ae#Bx(W_h{DW*O(jQot;zVnlf{54O6;0L+k>^r*4?g98sT?f*dAohvF7iW) zp&-d^NZjam2|u>78j?x4+=mfNP%)efR; zGQ{hTKIe>0`f1z#uaz2VI zppRZq?xdKw(uPLdB@lk@maRTO(tIi-Ao^1MG)j$hXzG%rCRT%G9p}8|NnIkI>p_Y( zANxV8$4hDTDQ@Qm?=GFH?s$Crv`G)s`ncp3K=s`Mv*olH$H`0eNTEi`)_~c0jUQP1 zV!&$NyZEukMtd5{DZzT3As9swpH@@>nyD0VQXGB%8$TAbhy09 z9^GV2V^h6w{YTG`roulCS_zyDew~-7J+i|4t#pRyF_t6f(Xah4?YTUs($%dJW5rdTX$& zbTA`eKFJHr*i*VWVK9FuyHxvuk_F$2GnDpP&t*jfOmfEnihvi6O4Y>0>XUMnONhdR ziDLX`rh-Hci=0Fv`c%f3@JK3QWq6D*7SeMR#^FWeNGzPAZ57dLp)D@O;@sEI1u*u6 zp=AEzA3oaig7|i@Z$Z$B?@5BURDKtShqTf!=d<*RYZh4&ejjGi|E&&4toH68RB@Dr ztkn+^r<5l2Cau3?nZ66_l~=i zKoI?2S$@~6xSpS@O1`tL1hc{De!e|hHd1SyVO-Kxb0aWyNI`k=@!{bOn4f)+G^(QufElN2IvyDcYxw9Wjp3 z80{D4T(iB4hyGNhU0Bk~l-HJ!Rg6=2NowCA=qIczUHaKTh&OMy@u2DoC|mT0F2fOK z|Jx~y8H6>=H57D9vvi8239WJ))oRNv7l4)UPU|>8A`sU*1Tfc}0x1nxk_0MpVdmfC zOhod7gCeZ{Qy#B!*3L+&4SP$Iw3kF6l?D4=QzH;}mvPI5Ln7Ts;RPQ$9BqC({<3RK zAF=3+@`pzBLKY{34uyp*|BOm#>bk3B?_YO$s zCp{DuuCmqA9(C@Iq3bMOd+)$gg*`14uA&D3>9<$F>)6oc22M?tlwSyVVw;?3!M@z4 z-db;q4R{mAcKT~XfHBI!b3I z1SQqZaJ;|bR}BkDOqc&tMCRuRDzpFbrk*(Le;`8F1?0FIcFj96=x@*s3m(EjLPDkmmw(0|YMU^T zh=~K_$Vmjqu#gx!sT7RCxg#kuQy4E6H-`s122Z;HOQYr8|KMQjeGO7VedvT!oCKXr zkMdC96%dnXSnGS{Iq$4`)LURG<)mcsqVq$8>Nd6j^6Ys*>AxIXjvwy;AM>|R-$8aZ zE&|u{n(BnRnW0d?cSOlpQ)?h`#2T=5%nVI&ObRw8k@~hiY78j*^d4-(XuD@u7!Pti zr4^}w!Kvz%2auyO$}f0F!FRtyp#&IXuT30IZ5Q))HF;>dECnpIFr9GYeclmJ=MTu5x(5&`yAZ(fkenSyAU0e@T}$Uvgha}axm0l;O9(Y6=xH3S%+;D?2P zG(}>pg@GNi|x$ofa3f|*?unbXM7$od!HRk`Jy z=Feo{i1~z=cSV@RJ3`@JMxiDbuP|&~um`R(-!H{9zc_B|`dFTG8R-{l&;O z`J+meJ@Njy=W7(xR^jzS)jjgyY&ER20%2&=&s#EZFe^+qUktE=<<7<}()K{f(9@kc zVlPjrCSt0lfo4*EBw7;(@Akn%*tu$*evJ?cK&OKxGXjNM83a?kU1*_hA`@`S#dYpR9@a(GZ-5`vGaLm<1Q z+OUmXu7FHCFyq6korSR${gb~7KYVjh$yKK&rXUSKTo$K&bnTp95JMs~>-2Wc|8-|6 z-49bOf(dn?-Qobh$Sck?0Xb8}@ZgI8EalEmeMK`wvm+a!R$=CymDYvKlT{cKhmkd4 z+78>zM)+e_-=_&HS7S&uS>n?wx*CPDi|_@erB8&o#z^EP5@A%{gYA}l1ff&2TY<}o zY6;T4RN>dzJjM~V6hC?9x)1^+L5VX^>|~6-ru?RXGA+q%jRd@NPEM>SwIb?n;684| z&}h{{8(b5VIh7WWI&;Znpw9bml>Xe8hBysA=w??L&k5n>{9bF8xb9+_9R+!}&AA((KAKiHWn6b5Nh^1qbpEl`Hd z9G@BkIy{*i*U?fZ#^jnUm${VNt)~`JTy*kS$NaO_IE4aojE85?g)rN;tyz z8*qPr3dBF<1pAK6L5g|u5VWZJ`wwlSZM9_5D%g9n7|-~83o zV(xpwY0l8p#~H8I1(;(9zxDzX-xiASNJbi3Wka4gS~`Zbk<(Pbma3ueW zOrFJ;B6f&`1xv0?-JfpygKxd6edyl6^DLjZYDs=oJVJ~F*6qTgeV<3;#x9IySst#O z)51c=-Nx3pE|h8ZxgObfP)>PGuDv_3QhbkI<<_{**w7X5Q{zTHG?m4|?91VgR}0#O zuj=Hdf&E?U#J>^Yw9*F3;)hKVa{RUMaIfSke0oHfK5|6DdEM@gJZ36*xjT(4VPz;~XQ90r^b$ql}0q!H`l>ns{tnAKYXqP6@x zVDdmk+&gyJylM*1jPE6jn6KcG+R~l*7Ch8z<0FGb3=8Gxt8x1Gf?w& zzwLe-N#`PH9KVd96O|W4-vyUU&-oS{!Ejfuo8p`Tt8ORzfS{!N1h2xGvmUIFuYSZBw>ZI!9ID*%;;k>%t9A8sfRc4M2c+D~>k z0nL}?C#wI(Ez^%Oh6DSBe{ptCQKB)M&1zw8Q_#=KN~xkwD^uJ0n>TuzKZ z^o~_`OO19yB+++rd&a-}-j)ih^~FO@U#KkHvsN$JYh%PR1zu?hWBy_>|U*h4TwawB48ITcLX)zG(` zKx_QEyO#|TVUUm%*TZSRC%_FEC8iCCV7q|tJwypKHlI%AsWwcZtX5!qb20lmY7~&3 z%_RFJTC1P%wiV33qb2sR@TDSRpYeUdWz_e>%ESqD0Ew3;uZ%u9}`>PH7c z4w88~N>*mt%V_nRiCna<$sf9SEoBRfbcLL%>KaYk>HlW@=^k9q!B(w(0`0D*jhM zi-z_vYD2-xvf5T0i?t67Yg&+{*IKPnTDg66#XM zig(PcpZtl3AFFQ1jOgQ6)dstGxW{xhrBtwcy04XsB2Wg~ljsLrgXG9U!ta5IS&Zjs z&3xo@`J?<)h|KjfYQy!F_q^8E)y81dpGnTN- z5Ioxa`ZI1z!_S{Jf&;j1zfym8a`>#Xvu&P=c>Ow}hxwNywu^{3rAzp|(`Z^!oc-&B z+1~MmHI4A5D%Rd$sOm^_MDd7l-mv*$%Fef;l2Q}(pQV`rL(W|>IzEQ z*(9>$Q43^VGeZ==C$M5x**&v;BeE?yR>N1}!;=Oq2z^k*YiMZXIGG`g$G4NU$g$)B z4~BIP3T`ag=cvPGMtRf*rr%V zw{ei`X3>feJ(whh916%F$PEn@#~#-$?xFs{C^J#${cmzYja(Rg$RV$;C`?@s_AM

oRsQ68=%+HC@(%h&=q8#fGv zNqSCk&mUNI%L{lL$P620hKJh<^UKddOnC^1o{N@V;m)N-omAp&4+R3)PBz&Ug zOCLM}x`QrLI(7I`2LK8(siP|8D}duc-Gk&X&ABFUk-jO3`#29AWrY9;LV{Vb1={pM zju;?T>Zo--ufW+`6T%%5BD{$lR(9>=vu}7#5W)oty`6fAlI1fKB1Av#LWE{xN4hwZfW(YPC(sw+NMvh1Uj8{^D`({l3OX$0gUouV$?tXnV4f7G zYYXTaVA)eOH=oGpecOiwz3I8G00p3Tk0Yu+CTvW%bJ(UF%++#P(@31(yT~}lNTE^x zQoQ^t#>}8ZZAjC}s~tg=t+XC|zLxY1rD}6y0cpo?;|+&n5_}}PahV9SNL4W zQdh-S<|ua)!NTGYXA-)FAP*5imq&RoNdL|MGzz*|)9ZbA{WogA5cfr7;8~*(>E0Oc zduEX6%?C)v`^6hTmxK$$^K^VE$)j-Oc(VSp;1JE>&3-j6E|em98CfxjHfUPM!P{Gz z2#iY~RIzaS<;KL((@yQ0n3~M-Nc}y_hAYxIJ&bus;AyPn-(!w_{j}u$I@1oAnE1&s z5vXgh8VfWgq087lC_SzH^?9hMjIf)GkeW5|{^&dD_)1#494dOfRhf(;!cmvsk)Xum zy3G`{-z7!YZZF@qpP!kt3Zsx)U-*j?3#|dZX1>C0_mCifq6PUO$SUb5xk~FlHAdBi z%hWdz`5jmv^PkX~T7C|=Fo{U(#9eX&&9dgCkLEnf31sb^boA_|qH#5=5!Pwb{A@xE zSX!QM3F8Pq903YY`2m7~g*wvR4}(7QK*}X=YX26P5vcr3n-#G?q(9KaE*v zL!&EnyTo?|#OV7gW)nhW)GQQw>3>J6wwTI1HECRnBGp0hc>OammT6Mj%hmwA+J#ih zU2bRwT)6cpX~iGw$r?7EX@wnGB#ZLz=%5*#L|3xYhea!}X9u_M99;p$1@D5n9V>NEi#?l!E8r`PHvk{={aB^yJc%>@sCUoPfDAYN1G#0%7^N9At=yH3}%mo_`wOv z#h@Q6e$1_QDF6++foMN6Q$qV#8u#Fv^V~m%FDIf}%fSL>ncqz%3yG!^c$e=Aa3O5! zw)z*!%rGO~B42p)H0(T@vM`kRI-13Pq@s8M2bG-jZxRzw*7UCAcW@HK1XB+mA#rpR zgb@aI42FOkFR0v9$VlYwaJY)!El$h%J3QyDRmRbderYO`{dJ!;@m@Luim}H!^D%u> zHa)MAu6@wm$dJQ1)@gtK;MWx)R$Lt4cg;~hQaL%;XyS8w%31goKS_-|(KidCxrQ7o zH;x1+qQjemJ0HEg?%EPPt#Eo?rjttBJ+oe1I`t;Ya10Cy=vfoFIp5>{g6h-JK{?b| zOSyC~F617M$!6AsCmlp-6d?Bw^3jwAVZq7`I0CVz-7y8m7R;)09bw6~{GCFfj?AFs zW!?kEu-0wcQ)6j!oGd>R#)Sk&csf<6H3wp5Odx zwW5D-j?CzItg6D+&e##3Tzu|I&XZDO>kw;e%vUf3rV@#vT+YZA21O!RsJ{jrVCkWg zg=dU^I^DtlP2#N=!DvT&4SMplyA3F3H%uh964R*|5rXM z%|NmebSYmd9gNnTr6&!67e;}v<(_^wmif}H6{IQT9chAInG7e}k{>A#IQ#+*XfhqS zeHhWRC~DW$25+Q@hPOeG@vG(;=QA@#Ph^UgMa@Kh`M=6&Td^nzk(sQKmaQ2Ke7TKj zY&(?0f)jrbX9Vp6fUDP%<86KB5j+6TMakU~m3`oi(|#|mO^6wkxXi`rQl0toqbyhu z8LOBRYem!=rMC3TV%kSTZfb5X zZ{-=R#`QNZ`j{Tdh6|+iQ}=)azk!nEy6jyOHCPMyqw|nWvttolsquP$GRdbp(yFst zWkO#VxK)>%3?Fe|X%ffB6cg1!#Q`eav58qfG)Ly19&;E8qhmUgvr`?I`JG-DRpJ6h z-a5IQ>y_**USQYbpPGrkL*A$T%Ve9x?BtB2L>5o47p1Fa-53Amf7pY7UEDU5K}Z&z zSVRfdrL)0ydW0q`n0&H2cz3Dn*?H&LM~)&TR_DD3a*iSa!9_ZMp^HP|-flf&)gMU( z1;mO-ztA6Z5EG)DW4znfF3)%gFq(f38-njD1I@#%O%%dLT*7vjiP!n04t_G0J4|nM z_WeX&K15TDw1AnN1>ZxXZruXz^Q*eT?|Ic1Dvo-_0^Bw@K88y99$DZqzF@0a+Ws^a zW@X0^A|gUVEajPw5*wUIg<|1K9QK-nUM7l*#$<|R#qN1BMxvb&Wl}v?Kt5Iorn}c`i@fTAyOPxPERaFs1JoJI*zlGbnpyA7I@c*dK8ym{GRpB09V2$G@A( z1fDAwvy2>O+S8rI$-cf<4F2cpA{{J+>+czi&Y_)q>Qf>ZTG-rbjp(RYxc`@uN?+~i z0#bK~g%ws1@;%es=T1h52g8dK#PS=nA6q-npW3{ekn9KuX3T)SI>LMXkHEtp!>cn- z=#^HMH>g`0$bIEa#PP89*!+f^LA@i}o5h+u@BC37 z9gB{9&$Wq%bSu1Qt(To7?RZPGksdeCl6vQ~ly zmqXwziFqAKd{|Kb6Ru9%oL_&nRf%;lS0;eeJq93Yr5L#^5dv6 z9ZenHkapVwSeUAA!Q005FHvP)#k0^UCRQQrul!LeT$86PHVQfHV{pu zA1=R=6pz?~`a&z=+wd*{xg?h0N9Kqp`38-uFc~2dmZfI(jd|E`SZ!MZ%PjNo5bkl>7Lw)ZBsM z#k{ZsGF%{-%$vz(cdy%Q`?W6l+)hvbsgB_rMb^^-IUYUE){=iNBv^ajDGlSkb8-b} z)+fQ}s6lP#u3eWArJMQ|hutng-o0KYM0Ty*?>NTw9ceB(Got&rWdroEF300w=9s>d zmh=kvF=cdCVwFy*>em#XyRH61D`>pqktBb7zH@v!(JGsNyYP?69UPRC0`_}$cb0GD zFmKuA`5;un^K7nzzUwDu$_i|D|M9$+-$|7?Y<Q$LFP=Y^;Ot$B61r!? z_D*31d%zMcD&qdTGjWtym*OjdW(auS)`NXLW5uFJ>f}}Zh0cn~>7QoT;k6jd_)8<& zB z75J97-inhO$w(Fc$&=0Zg#Jx<_P*sytnpi~+NC68r$^8D z7;~1Dy>=f}#!Q^w8fYbV!FNG^ijzwBes8)0aMvQ8fzt30O8E*$lkZ~gQ8jzbjh`AA z+myQ6NsoSqQEkZ$vHo$#ee0r8_}1=IfCi|tDH^pZvus(WfF=$*pXZqEc%EZKh_ZW_M`J3V18!aZ+`z413Y zEQY90U)yHA<#F5)u8A1KoE#QH(89oSoJQUH#SNrts-CF8wS4DQP8j!A>eqt__an~- z5zHSIH-cCyMuRdZKfW$*?<vMYm=T(Vc6`uc|PLW6>RUF`C}V$Ggh0W+SI_;g5~@@3Ox!f?`z$;48kVB+h^ zfNxP5{)>*lKYWqGT7J;C`K^Ol(2@DhlR^j!V{N~G)#$?jHs(pb#mOiqeeyu;3+@uc z%5%#(rkxjm+N;8JXYjRU$9~jLYHa(&cpOgN@_Uf^j3~opbM<`2*w0nDY%WI_&BlDL zJ$ir7?po6QG5D)ba%{6^nU+EdTDpa%!3$IpuUuZe369VaLOE8~oCzkn(Pj!2zXeu~ zPS5b`zpM-EU#tmCJo2FJ#J_HwT^)H+?lon-^4=$5PxCwFJ%F4}W=y%E49B_UZ1I>L zhL^oi`2DO2EpABZp=dsJNd$|zi_PeGXm9=2jTH|cz=^YmT#$T@%y|yR6 zpiSMa6-PeX+739v{bH)PW|OHLzxD{hkJu`qdv>|FE$gZeI*1PURb)-ESySuOUO_t> z6PiflgT% zGWk)h{bqBIRrVGU1{sG-v7>K?%cDFFy1g>P;z*ur;Ab_P(pWv#`F~H^iuITy(-~g3 z8N#LCe+%L+qK)saWdQeajR}m6Y4u-xai}Dg&OWG#dX2THa`_E`0V$L(cCSweTnn%h z##B=B(Z|H5m-X#94iy_Z*Tt1o{a;bPt}Z^< zwJ5JSE-xTuu=TxTyjBM_-m-Q(dCCiY;T`aBiF@L77t{S>|5UX;$GNxb%F4yz zl-H&|>rf=-BtLjhm1T9wQl0o{_k9vw7)Q+3>yTl%t(cB0CL70x9-dfnyyzO)eyozU znDsEWuh_1a}w0X`{rjmwzJ*M09hh?Dm?3Z|spYd$Gg>c9ZC>42r3noy-aL-&gsEg&>OV|GWz7u}mI`sD+$J^e z+z{dd7KwPZ1PU(4F(HyKAwHe%j=9PQU6HtPvspINy8>^C9ZHq#TdM1*ADeadl3`p-OGtp^&b&L)$@qQ~WoG|}LSKVtl_AtJfudi)BKWA`auJYJ5A zPQkLORjg+oholjrnu83w6ti!zuSJ|*C_Ic&vwKpuAP{tM{+#e!NQq^$ZCB$LYvX;r z1A(+w`5ss_X{etj+!1p*&JXGEhV^(|GC!95to8DF=iU^zn+es9D=0Z%~jXaEk!5voqnC|W}Nw<^e! z50$#)(VCq#5%H;2Qv1xO2@Ipk(wf2O@pLZL80hULo5WJR%S5_ApBh%vd5#auuHU`8 zHY~!bCa6gsNSLqkS8#l82EU1<>WTZY0>U@87@25OYo^147=OCva#owCCK?(9{dRud z=|T@HI|?bK4iIWd&0ucwd#+68i*w9Ag64nnjqyldSN=jZ)ExD|=j_`R5GDFiM^7-j zXfi=C=w3&Qt)AoElwT~%7b^h-?<1b$ipDY#4{N@JgVIlQwO~=*3>ICh4e;TFuA03#jwuB#~c4)1R~JVJV0}+})d< zjMA(~Z~P~(s{7Rl+qcsNP>()q3Nrolo6Mg@IR@)k-Ba&Mv0@Ak@q}y5_wbE^t$vY( zU+1sd3QZ}<3ZNiCaw0sv`YTnuVd8Ir;mvvOnTSI(Hb-#6(jJU zGqmHrJcL862lGB+*Glyz3$i<)ptX8Z<<5^S)Jdue?{e8?N7W>OO6%to9yQ%^7U;O5 zFqbD{FB6&0hqZ8mzhZTN{i=BCW&^)9<|Nq=Ps@Skax9vxZI(S${naoibCI`XOt~to zHXGQ(U+2|>dtNX-E?6#3Yj!&3Gek-ob6$SAmzOZ2HsTT1jW#k&b&4^h;L1F;L?0vT z19D8-_)d1FHbez%Y5D|4$UMnz>L;?Qk?{? zo0aYWipPa5KlgMj3PNlTcJSYj=E)3CZZ0pnUDX1f2L61o#{Rl@R)w2y_R}q*N zb7b2V_vvqFFEC|Q+D%>?;Yax{nr*>!)nSyjMQEsIb1eso!cJTBma$CA!MSQG6oTn3 z@XChhAC`!f56umFdQd#BZkRxh&a{)2o_M0d0M3_1UFrTPASy+*u6DU7!SBoVv zb*u{AFd{W3GgYeV_A!0xgks&Hn*DImNhnI%oSrxn7m`2^tzZgwrj;5jU@U#z+r(0^A0E z7;_V=wukAR!=7Eyf9Q(#inF^qyy`LIy`y^6Y-@Mc85HT4cvJm3G;>)?inRPVO5dLm zbM!VX0w3RgXcDWvr-T>K{<=www)?~sx~8AO%eoK8FQDb6+nJ7r-61yxbD+^qQndyF zT(rvaf&8t=e4y$na?3^C8UE6* zg!v5Vv!P&-;^R6fC#+It`O4YJwzg@?Jbg)2&ldz5*|zjQ{(#c_xq11~@dV$*gKX-Y zOEPSoeX>^sS~d;+6|uz}X05zfqvgJLgRLXp(DK~Hj##5_ zy*JL#7wsT+Q_5-ta-rd62ec%T_xeM$FYgy%Nr}7hXrYnSuo2ZgqwqDRfe>7$)7!7k zNb*NUujvSzw$GYs{3*n;ic~8YewHc0ZGCPIT9<<3<|!4gMxM*)GTV8 zzE}3`vqx@8dvU`v!NCuenyx?nCC59NX_4p5KL#Dc{7AC;SC>+~tN%PL0<06ZCL-Fa z`l??hT3N25aE(G@u&JWCef63`X93xI$;jNn+B@41KQjlLvL2w?-=Tib()wPKeQ`?L z8anQOd;OG{qvL6NMS8HS&sFDk<1LS8*gVu6wSOn4*xRUHO_x1-k1Ho=^gh#0TSDw! zCByoEeGfvan~kgX|RrBA4g7*SNse38rbL6^ve6 zTDXTxvxl^O&!y~YqH(S__Od0(;hm?L5Qz6g;gNQ3q}9;v zRgz5q<@y7+kRW{Tt^MyxOv1DR9&~C4H{8M+rU@jA%?z*o9^m{dq7z|l zcE`gbJ2Y}hSz}I=>ZU9Twy_+z~ByyRE`S;l^02CnZ!J})DB9e2#-p_9iF4p%y zlCJ-OrjPC&Eo>VSi(a$frN@9b31|Uz8(eG$|rlhjWn4!N*wpEHh%P_GT1TfKVTl z;MdMwLc-P9>CYCGE;OTwX(4O(0#3S?4*Z4y)9&x9=4F zF3R@(nQ?n~RuTW!ZF%wg{6mA+Y%(f*X%t1e_|Wh zm2xMvIks^Bx>+Iq_vH*}g}ix6X??kEk74jQ%lL~h-@ht*c?I<9=aOU!H4Itwb_WEz zoAP@SxmUmw+6z&kYUA0#WZW&paJl=iYJjMNnKIpk&g})BtZkj^BAeD3m@lRbl_R(( zKq9f%fKsdj&USfIPE{8oM9j`#@+pq~JVlQKvGV)NnStBHEqcCNaA|3CE}qMA(A(ue z<{5SN-Wx}usgdl%A#kar3hCK@dZr{Cx`4XLh?wUCCyYiM#KI8?<0@=g)(bAv`RJa% zz}0KA5-O(>aXOW4Yt`p2;rLS0Ib0$3`)#J5$`i9gTxKU-?~-_K>VTrdddhWfyisn_ zxh+DS2AM1YS-eofrPhnc8g#QEN1%v|Ypw*Y>{_U?Qblp>+3%WdDbCO(neEs)e!)T38qMc3e+l^f zyWXn(Y+R>fwNV*!$D#e*+}tlE3KD@t^aGUA*H=+(D$(T@6Z^rBO*aYE*NhJB1bqFd zDCPN_yaRFeE$KU|_gx^h4*&cp1-zMdtR3gY@5l{`ytJ>EMtW=XX!zgFpt zKEZ;gCKDWY9g|@lp2G?`x-gZHpF-0JW^Wmc3_AF16wsL?@)z_QlmkMlSs$p5tRDG* zT9+NsKz0dAKO0`V^xaAzR#ZV1{fSMja}UM5HsmXnCW~4p1mK0?A=vYh+VH$HhtiwR zg%05XUq>C)**A1l2cU_tCd0dxFXy_Y!#GBhojdFpUNYJ}7VVwPo0bx0qRHUoma;m; zxVd@HioK|YO1)S7vVPW@i~U~#1~K{0W&Y9r7XHZp?TZm&|3cSZ?Bh-4- zw_%+Xr^ApiZ#6Hf!k^UUx?L|q({xdrJ8&h&49<5y=iBhbX4ysJd8VF~AVl2MPALqn zFs!?k&S?RaWZPg>AaKkx_|bGjM8iF6t|lzNot6oC#xgjoNf#iIf!8%mNb7VxDI~Dd z(8S8uR_^84^gBpD<5XhX2{NpYYck65+|H-vM$R+0rtiX}W|p2B+TB`M*e$dGu$8&X ze8l6XdGxEA7U#5{;ma1D>PFn%NC#-o;I%T?7Uq%ZO&l7`v9bNr&+!^6MtUggk9tC& zbnDi%5S;P#uNd(W&xv&r{{T-7{VF6}4EYyt6t+BSw%Pvw0Q9L_l=eLe+niRMe2(7r zsfnwmAKkbky9<%_fjAtK(DPQM63&wdI&_CKA zod9C^o1H!{d2{sQx%kF8q{fL?sXnx|AR*?GfG4mKwc@Od7jPjM^jF8qPvb?GicXWM z9}ViZ?`|c}l|Ft^p2yUC)goEvJ|MDvFT{;!JYrdl%<&)qE2|%rU~~qcw|nWg#EWSP zu47{}gl~NDocq;?(bKFoF?24iZ|p8^VI|Tit~WpAK+oq~UZrn+r)pC9u)GZ;u~0Gw z3Fs)e(bZUZE*T)w*5*=@xqF!uf%7NJg&&tcTJBB;FaQ8x0oYb@Wg)qwCZJ*(L(+gE zW7mp_k6J@3OIh=ZKxA$5x2O9@^{Uqm@^M{JF=ig1Hz99&zUaeFa2Vu9NI!*B5<9EX zZ7rZ^;FTn|w_hq)+afcLg?oef*BulWGipzHA!S#TOz598R)%F^Prx%tpqTf z;QQB%Yy5bRQU3t1mLJNZMZl!vBi@^4GpXBuzx@qL+^4bV7NC%%X0@XLC>5!RtES`2 zf302|jCK?Nqu12YocmBQyK|h+KsE&l-4717*kejoB;I`+etW1%GaSE1>Rs5OE%spSi%B#{un1F-t_qCkr@ zgj5t1hou4|BZ`Ea=8($rilaKx10G@iki$Ldq{ATBR5mhZHSiu*J?lz1%$qUEL5@iK2QkhinpfU zOKET-I54^UtTzQdqKB}SW9)=~MKILeXqu4Q|j zGFFK_6-HuHN#cPFEyY4lT02OuaV6~L2P}xkV+8srQ(L?=Ojn{X3zZ!d;C>Yn8X4Y$jQuJGHavje?@>r+WB&jNCY#vJ{{VjJ z4~cZ%Hu>i#{&~^Pg3N1pn@t_&KScik>(`a)nY-6-?ju*)^Rq5TIO2<`Y&D>%Y6AZN zt1Ljt#Z`p`zM%|~22dE0$F&I*-iKdnt6K#C?;v|tq&mf}GnYT`8h0Ybg^!3d`*mR6 zZcaKU{{UXEf5Jy$jy&VXK!f~f=0RpJ{3JF!aWVR+{{R{Xh^%G`yG;<)c7yc3p z6OJVx@Dctrf5Jmy$1?u_-y{5J=0dX<{t_KLe-yssX#W5jL*hL&e-yv>#T>|XVt>L* zrsM$AXZ^wcG^^r0KLnPLaos59LCVZMS5PdhAz9EVz|W!Y^{3oH7&0=yS37{? zcRfJwR@jmjCP~$VsM4nFd5`jw^ryy=#7cp|C*09+XxFiYC%V(MM^+zYXIXHh?%FcH z&-AT36C+xzl8ij8(h=#n=M|RB+3OO?rqDoBGwVPNZ%U1T9@(G@@>Myy(gP0V_DLo> zlfbGH&M;K+O?5|O85#18(;&wj{{Skt_IA9x)9m76l3ci0$nM?0TvK92+c++u@>ll@ zJ@6h&IXMKj+!2nsJl1sf^Sp#d8thSn`@lv=2d+ArLo{^An(tGvo^iO#a0p;H#dn1D zKDC^RLwnL^=BXIe=zg?P8Q-;qhJWMAFUp(t8xnc3htFC z_PyrF#9&Ozv~(Ca{*(`Ls_L>tLaHe~=Al{M;{M%aMFVjFeP|m;UukO@RRuu$3fG3x zB^&|onn?=gjg6JOvN4h%`+5&r)B7cu>17}N^G?FJpJlQqCdwY3l<{k32mE|e6k;Qs)j#qILbO|6*0`y5&JT6RHpFWIc|{{SA5AF9#zOC!(N;`jMzvJ0>lc1QkP z;p6?9Mzyt7!dXL~b>fwfU75P@hAlaKyJc6AhFA<^70;4DqeUOpilz16~KZD6X&wj@`dgT=cA2%(A--WbkkWI|}zS^o@EO6pqJGww+e+ z`H&EVmIUqOb?UX(?c0UgS8o`{N`{I|ZhFz0k%`SlNy(rFF*Or2dQb!sXO7%eIgiYF zqy}WLN#?Vz)(5CPt3LKD##u%W@~)^mqZw{xfx1w=4?NY|X!CUhizrF$qcXtTfD0o0 z$_{-<#W@KYky+lsdu;MZw=zQWPZUK64$4W#9P)Zn&tr3I=S+$&3n`9$(*68#oF2gU zsVJ-An78&00j z9vA>JYsPh0^R3^5#x7s^(qlzWEi!q=1vVTD`)$47nnP!EY1KwKt%$hoR6{g$#yh0i zImc?zAd*J|oPB5k2m9Q1rvN%1%7Kf5MJ5QK2Z4`zLyY}+`p^R70+62lF+dpJB5l@j z@Kl*PA8O}_!OctDXq|h3*i!bS0#jSJ){qEmC&XornH*@P-b;i;QPr660Q%EFbDIAE zhjpzEz+T%DW?Qq3OJ@L|u9ct4|wGua*_QfydNQ0HnxvN@|R<`jgJ)*;@MEA{{ zw66JP8Q^nT8Z^2tudTywkV_jfGYFW1p~pfEHh~xCg2G8{^@&$A*{i1FeqZkbeezG| zSY8|PHSOPtUMcrXm!|5IO+jHyBbs0}x>QUlFd>i(RXKm1l*|TJ8iYGCM&qAMW185}oxioGjQK@Hqm)@d z{p7w|Y*C)4b~vphD3j{5Nu*j?toE(;hk*IPa*C1@p55_HlI^ES^;I%PN-2&rChYv# zJ$dQ*)F*QJbh)Xc!m)Uk<`f}~y}IvU0Lw7;>FHe;txMcUZ+cJSsTkDc9{kZrqtWsM zccsrJtdLXl53P9axY*aV3=Tff_~MfqDEVOH)YL5DH5*`0Mgs%uNeu3+$SSATrEA1O zpzcR=RERXR!VHh=iqjZB<3PjdoYE1K)_@FP9(g>{;OC!806K65JM^Ff)Y2XXc|Ozu zk>Vv&ZwcF$-++AubEly6se6qRDH(5CT0jS-F!rDXHJ_|r!KYi>UtBKax4ecy>_Y>Z z08bfe=KftnO^!gJmoYd$#hi?OHOSa#r%}4{uP-J!8NvVs5PAfWG10W!RwP>eqvja% zSnxicw9Pn0rRoWH40|O+g zq*`JYl0ziSM(G!h0nX9vPBZkWPRhZ%xzXo7)Q6fy2SH6IG{9jr#;F1as*`u13{6;k zi`(bopq)j21t?$pq=EtcYNkw0tyjc0 zb2>)L9o%T4aKj+uy+XSEn3v{viWVTr3~hkK3=gGe6GL^$INJ{8R;Ep9VK`d8pKF-eGoQliT?A9x*b`h$dGy=ZVTEWb97MLMRcz=yOgT&N=m<0Y_g-QU-cZ0tQk@<8bOd zDM{ms0LbwO2F5;lFU50?amQ1hDqiD6h&gQ2sQ^<5=8yrbFBlB?Zc;zztbh6t6u{)I zV@3*?e2g*AMaN8fR)D!kc4eLRoac9Xb_oK`FVWGrwypZ@?=Y8p(S0*XL1(@DhuF#YObphQnqB{U4j)RTV2AKuMC z(Fo&jn;V8PTf<{1Ze+@yMba?jIoN*_S@+DZbt@S9Hz(9pOK^#D(j-_nq!3(0PVC5} zAEi;hiXT66#F44vVmT*1wFOu?o*=>IAP~eH00&WBY%oJ>0zq#Wjx-om z6)$iXJ?T2Iqyu6kq5lAQ;);=;kUF2|K;!CbLPHrj_pdPV1;p|Aj{Y=}!1F;MF(1Mx z8B|Ynspfg^6!y+dWZqm!E#xKLkI&rD+Y(1#WsRVAt!O#H-N-bCYH8;eBk5YhAEhx1 zz!>XDao-dSL5yHFew6NWfrI$a0+8EIFi&jJ$i{h~03wDi6+<>Mzzmu~J-&i~GPV9} z)431DbM8K1aZ>jh9G&Uj^&lr|Z2%#xe-s8f1lxM6p&$3qpl2roB0$Km=E%e6E&PQn zdvm#!I6pQ?tFVU~rQoHwRa2e^O0O)$Wt9=xhXvH+eL19}h zfe4>;Vv3QTk^XT@)YgQJ=s+jEd6$WyZ}>=c7Z@reT>dI313XRpuP7-aBPSJ?d-B~! zk=9Yo3=!B`u?hzOEKgzDm6b`uC?np0 z4gkV{KAcmQfE)-E@8Qlm&;lBQ;zEc3j4g|=suI7>yIN+S*)Gw-;uoq;7huw@|iqQ#{ zZKau_Ik$F+daz-P{y3;EH4AgK1DvlwX}wsH*m!ndJ4KC=g8~vsGzHoL1aa#?5y`kO-^uz@a7KBcL@&mA4|=IC^PCk2-heQ5 zB#|dKZxCcv3%hsSouF>ETM=S3vWRpCeoDM5f z6`0ukWOS%o?_!1mj-4o}8R-Cfan_TePaWXn%G+eXCvpn8#sKQEtzR znDfXDTi~9U>+4Jdo(EsTmz?uJ2z@ruNNvU<^yTwv?R0Ka~~NT^Nwn+A{Df|jl9Q2*u0Ie zlbY1i7)5YU2Pr#}qi!~u8w7OPJ1my-NSPqa-F>*NRT&kTL(j^&H3$snfO+qVObGm& zfJkr@54v{&nkiT=dP#AfhJfJq?Zs*!&U5M7u(d0Vx6&bJ+q4{j00+}F(JZeUz0a5T z(E-bFGT*IJ@r~Rt=rP9|6D+q1yKp$*o_{KDU`9oxj8kfIK&_isCJ-CHy+SO|wc<$6 zTrjQXIOigP0AMkbN*PJUM?D2F6oHRkDF9wM=71Ea>zZo1=kE?E0Zg%m0OXU?ig9)B zaqU1E8r-V(F_q+Qu5&**Ip?P}Jw}NlIcN0a6(vgntZ_jj9cTh3bpZUsj@hjJNtSyX zn^G8=VPW#GILBHR;~|^CQfbTP3(wphUwmVFraNOit-x4J zB!?r8znCZSt7Idvu~q*7Nrq{^Wz*)7vM!{L{m?5b-4+7EeT7Bbha;;D zkxmwu3b+>yp4nPf0S%qu!2414`<;JUU$f2|{a`-WqQF<4+IL-|;34(JRf|i!k3Z?g zOq`O#tvdj}Vbp;?=@AwG0BJw2YB$sMGQWBj`U+MVfj5VBPdShFd6*NO;&4BmM0`7~ zAN6*!DDECaKi0HFvaG%o*PcNkw+unZ8^v1w$kjjPwx4+K-)UTa9MfVXhf38|zjtSP zp3JI8^R3HW7gTr<$sB6jHVcqJ{EZ>#Z)vuVX9F$Ykb#)CRzgVp2o+(P1Gov73yg*& zWAvq9xjb`PAi_4|&>F5a8JBqZar`H5N(FP8@h$AeSP7N0hUxfGK__?pwc>I= z?63IKPO;*EM%^M#bDxlZQ(4|-mLYv_#f$#{Fn0qV+y|e|wAMG8(I&cGvbuvJ;}|{b zN_W1b$qm6o`A>eHmCbneQwEn{6Yn(a-8$ia`t_qGJDkO?wWx!DfPcViMlTTRstzFp z9=KY|Q%e^uyhp7DQ60a`56klB6|%bDi!K7PY4Sb~UoRib)V^S{CH~g&i8=num!Ckt z<4U@}i#zFahW`MrWBh3M$7q4od|jS2>zonOYJZI&)qG$v{{RxpgYFeM{b_sQWJ&$4 z;`hh=O6E{Kk$?7T1^)nxoBm;LJ@TLYRX=_CX2{dX9R_4i@u__O02ttpvZ~{* zWJCREw8KK2@8T6?eOBB@3FB~d#^dY`YdGHTXI_{*(W=}<%NcyrDy}|KGtaeK*vG5m z*i#U`Qg9iV((QHv;Z6SpH3 z%|+l#V4zQTb8&%#k`y23N-}q4G>ex04AP?8y4_r_05L2DXD=eCp=ZJEDqgjQm|Y$3PBD}LMnh$5cH$TQ`pi0JcRz}{HZ*p zAHo2p0u)Y1$|(w<{_OxlsQhV?*wZv)s}3m);4tTC^<#lf!nnhF0deLi zWjG^~il~~6%z?KqFmZ#$A&iZT8)R}=xZKLenpoy(g5^LVLjF_%I%hp;=72eG7D=*e zaqQpHiUfOZUp#ak)SINwOo~<=#F3bB#&J~C9jDD0=B^@^#4IC?S&8G`xbGjzuQrzm zJr&&kJ88X%MqPvvJ`0dR^sb*q)9og^jyYr8r#yZY1>EZ*w~E<6aRM&Uz(pA~Hc2_? zao&p#<62U!jl=2Qm(6p}AC&+p-_(#jc%``C=pZQV z)PLJ${{R{^3y=FEo<-fM{{Xh@{xqv#%y2F(kJ4r5^`I_DG<$>hOh-PFDgLw_7Wkec zU;FG)V7TAxOaAeX`{-0prQN4k*a1LXh}y_KDXjx9QURa{-Yu)SrHvQW+%NN}WR23> z?FHdzj|##?ApJR}Cd|jY^9}}h;|jSoT1n+laTsj&bJCsAhQT5Ufw*uZAcK#mAFV;> zMA;}E#xs-NmBx^$?pQHVv}0=H0RI4uKW`g%ZoxRk0~OJWo&9oigT`t|hS8BgiIb@B zPAAfUIS&_{+}rcnzoit>C%T!S2mN56Pi{^;(MX_oiLw!twm2EADPLNfa>u8L0Bi*buCpQnUOAW^P=LO^mCX323y0L8dzexsSi=4M^Yx~|31JX83RL=YU7n4m zwY14K#EFsjQCmEpT7u2NG(+z?;161OnghsnKA1GbG|d3$=sWeK6GU4+ViJAm7#ZPF z{;>z+N#?*m<)VIc0INN>{Pd3_jAo=47Ubj1L-^1k!dOo`)MwPw8c0Fq)E`XH154Yn z=4im__o?DYIte5E=m0Gc{{WVW&rWH(q5x(@=cher0Su7!^B0eLUotU+<|y6MngCtO zbMvVDX?iIkz{gSzFc|n%2Y22aVCOW35{g?T100Zh{b>Q4f1~M>C!Ki|NPp+3`A`1< zUa)7~gGN^zhLN5yUamk;}=`cXkAweI@;DcK}{_2?*uKo9|1&tcH@7u}K<69dppLV15*U zgubTPA{d4nBTK6xaHH(_%sDszRw zKTc>79AMsA?(?2<09MGx77tn@ zwK&1Z?@u3abJ~$abDYBbWNkV3;=H-m=2cNAg42+RT0h1+$c+8pO&E|fMEE!yV-;^_?${3b{&gb-&Nr2sHv9XnKKl5x_2A&`x}^;jBc zL~qlcX@K%QbJlZ~=x6{%xghRn$<6}~gV&E*oiYPtmO1{EfQ-%a#x^)s1eb0AH7?~0PJZa;27oDF zL@q-kN52FgGrICRejc?(^G=ckE~9eLGLYrIVaIHGQ<5cc$~t=W4mqZ_NYAwZL^>SzrWT+I@*llb*Y&1^GDC`uPyU5W36c5!RM{Aj z_|OMAt&P*ipF>449@27qbfg5IZaY#L9G@(0{^|ZzzW2B>+^F>Y>|P*F%L}R z1k?Q~fZ$GgnqwY0KTatD5F4@&OjCdd1D-wT0UUrcf-rq)n@;V=gX!-;(EYek{VA){ zCyzlugKjw>bIHv%1lZ)Za!*s=(tsq&jQaJajO@y&$sKXpm=6h*JCqZSK9t4^2`3%C zf{+TFsA8mYNXhr0E=e3`Z2(5=_XgbYkEJ}p7zcxs*BoYmAh`_+B0#*z!9d3V`k!o6 zjI7!H<_0iEGnzvv-Dwd@cSj>5;0y!7{3|v~d94`ygCt`=;XoAh=$UTrBw@F5ugjn3 z?^l~0X#o=TJkwbGhJYXkAfIZFATd7l0St%leQL37y!WPrMJsYDH90x=rUi_E0nb{k zcYh7I+Fvc`4>T|^E+D+rB{H(?LHBTIr)DGEPXN<8d;KXfi1{(M9PK~KyoOP-cguep z53e*KkMd+F9kI0Z=hCY`{C~iDaCo2%>qa|&$stuYu%rW7M1hMXn;x*yolo z?RHB6%CAqtm+Y~IMr1;Ni#<(RLb&NKBZP%9udqVRkFEt2!1Z?gYAlicYc^69xk5d^ zrFk9C7gHP_ByZ+uLmV6TStsd(P~0gOR>Y|cNzE}Ec7q^Rv}b`*2V9(@pHWWzXc(uD zr7lmVDTo5#9=x7&NsJF_09DC7GfGJvF`r!10jq#B-2Sv4nf$2$(T`ku@k#)0K%fME zaDSCG2*K})09QHb*S2XQ+7B7`B=w*NtMb{#PfuY@4T*N`B>K<;xY%%b9d~x1Wsfa^ z>E3_@#@yg@$))Op1Fw1j3X*^iUO;i!8enF-+-etMGaRV_1nw$8KaF3F zRx}EM4i^iB&flr4O2QHDW`$44haBN@Gx}3Z;l^i?PJe}CYysPxb*3WBdv*=H4`d>W zV0wGF9C43YXeEysz%_|LcDrFoBKrSgc89hBIfMik(&7V)?(@gefERHfkTM6gG=Hr* zJOO|x0lx{gG$ zjhKIb7$2QQYiN;4x@&kq1MhC!e;T-#Ds>AUvGr-)&K*i`YnowKPfEYs)!S74_BzEsW00BS)9XO=H>zV*Dp8V2+ zFiAKa=mD#ScJ%Bh8Qs_dXaQWWPL$3Le~X+@0*r0+0-B`YU;sTR0o#cLsLvltYh!U3 z3~|WK0682ozsxg^qNLnH^v(qUARxzr-&$kxI3#z?06XOheYwe?9ocLTyTt%8h}f)g zg#_n4DT?19Cs5l?_Nm_Z;4dbr2_?OqWtgI(pY>aZ zKT%C?476GqBiODY=d4YNkVyz3eB@@wW*If1B>=aVQL<0vB9oje{{V$vX%U9kk9h<8 z+yV5Y1ygnA1`x-$uw&Eks1+nqP7Y5b(qkREZ$N&uPzSCuVnzV-PQVGt>C&KN5@hZR z&#o)YH7j2-`bOP@y}z9_G4VvI6dVi=!>v)&zDcdrWA7Z*GEF-tA1^!#+=%W)=N{D{ zblPbxWS?ABpcF4XMFSN5d89l7b3hF{Msb|ecK~@GS^zg^A1^cpB+vp5<#J6T2HE?@ z1fTAV3IK5HkU6J%{{V#mM#Phh{c}y+7dhxU&;z$+06qCM)jV<1fEVv_e-CPjq}ws! zNdkZvm`LPffOFQY6ND;yk<+aJIJgUp6HP333@;t10P`5W+!6?c(Jtie(bT5_Uo#vb (=jKsf_(#c5<0B}m9zd4z-2OBVfUjc&WLQCN=T zWd3x-21Qh33|sd90B}`lKHWNjBKb+zVBEd|ffuy+koqqR~(4Rp_+LR=!`D{ZDHm++hD(ae;w1K|Mr`+9K5-<@m zpnB%G`|Am#-88$G&2qzV$5WhBKtUaadKhtSlFRB0tax1sEnI>k28xpslro@f~b%!z69+|L|cNZi=Z`OOO>6&+6Q zlmRB$4ggYf=}YCy{lm{3nqWw_?nVpXk=Rls2j)I%s~)6|XcfwOpV};Kns^#HT=1yJ zZsYZ&)8~@n6XvvOA`hKI?gR0u_pjsxo@i!pvMC5M#CE1bHO0-axQaGJBjxixQhD^M z5o5@?INul<=vdPXQbn6$=P-ee2@k0S&3T&`FA@m5)Gig|()$3g2r%9odN+9{IiGYo3OE=L@n=Sdcz z%#E>RnX`kjI-ji|CyI8JM9pO=+Dd}ygG%>n7Ch9IGI%9(Ob&Nix{3`_StMwqv=+QHUUYImmkC{R6sOWjI3g2 z54YqBJJX=DkwGP2dFok-KaBtp11?%FR2*ebdPKLn`(6khPTmGHlRy%xxgkVlNBbjz z?N2EoObxjsgN%NZ0O&2tz!v2rxR;|bUZb}adecR^m6d+WYxw2;Q^qad<_iUx~+^4Z5>+Ow{impfRF z2Se^@2z`Z(rRJqEy0wOKnA_&Y8yt4cX%$o;B832S0~8Fy&?S;xe_cr=XIR=TftZlI zH&05{AJo$jWR2M3xv%w}D7Z)YmwIq{A98^kS9Z47Qbv~+@l3185+Gm?<5MgVK?+3_ zNgQeb2#Nq7g*C9z4cG4OkyGVU5jY$S`qLPB(WsaAgDdyBqy&j0o1MGc0Zs;e>N(xS z@PNbu-1FCp0LFjc?DM~Q{*;I3=lH9nazNvbC;}Z!2=v!z#>O#{O7L?}$=TPipa#1z zDgihro@$1$OKDRp9D)WxpbM4?qQWdF8BFjy51k(*1k|ioGG2v+L%Y$!Iyn8J=ZsR$1#M*)0lQBoYY&KJ-!p|Ji>> topic_model.get_topic(2) +[('skateboard', 0.09592033177340711), + ('skateboarder', 0.07792520092546491), + ('trick', 0.07481578896400298), + ('ramp', 0.056952605147927216), + ('skate', 0.03745127816149923), + ('perform', 0.036546213623432654), + ('bicycle', 0.03453483070441857), + ('bike', 0.033233021253898994), + ('jump', 0.026709362981948037), + ('air', 0.025422798170830936)] +``` + +Based on the above output, we can take an image to see if the representation makes sense: + +```python +image = captions.loc[captions.Topic == 2, "img_id"].values.tolist()[0] +Image.open(image) +``` + +![](skateboarders.jpg) \ No newline at end of file From 8b5084bf021f69e8495fa1ec19b1b76adb26de38 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 11:53:52 +0200 Subject: [PATCH 14/15] Remove self.topic_sim_matrix as it was not used anywhere --- bertopic/_bertopic.py | 3 --- bertopic/plotting/_hierarchical_documents.py | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 2b0d598e..bdc771d8 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -184,7 +184,6 @@ def __init__(self, self.merged_topics = None self.custom_labels = None self.topic_embeddings = None - self.topic_sim_matrix = None self.representative_docs = None self._outliers = 1 @@ -2244,8 +2243,6 @@ def _c_tf_idf(self, documents_per_topic: pd.DataFrame, fit: bool = True) -> Tupl c_tf_idf = self.transformer.transform(X) - self.topic_sim_matrix = cosine_similarity(c_tf_idf) - return c_tf_idf, words def _update_topic_size(self, documents: pd.DataFrame): diff --git a/bertopic/plotting/_hierarchical_documents.py b/bertopic/plotting/_hierarchical_documents.py index c8162925..3fc55eb8 100644 --- a/bertopic/plotting/_hierarchical_documents.py +++ b/bertopic/plotting/_hierarchical_documents.py @@ -94,6 +94,10 @@ def visualize_hierarchical_documents(topic_model, fig.write_html("path/to/file.html") ``` + NOTE: + This visualization was inspired by the scatter plot representation of Doc2Map: + https://github.com/louisgeisler/Doc2Map + """ From 08129c8e881eb3f7d1ec2bff814530fa6f852948 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Mon, 4 Jul 2022 13:34:37 +0200 Subject: [PATCH 15/15] Added .merge_topics() to manually merge topics --- README.md | 1 + bertopic/_bertopic.py | 226 +++++++++++------- .../hierarchicaltopics/hierarchicaltopics.md | 23 ++ .../topicreduction/topicreduction.md | 20 +- docs/index.md | 1 + tests/test_bertopic.py | 8 +- 6 files changed, 189 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index 56c15a29..c7686ea5 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,7 @@ For quick access to common functions, here is an overview of BERTopic's main met | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | | Generate topic labels | `.generate_topic_labels()` | | Set topic labels | `.set_topic_labels(my_custom_labels)` | +| Merge topics | `.merge_topics(docs, topics, topics_to_merge)` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | | Save model | `.save("my_model")` | diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index bdc771d8..7970d657 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -16,7 +16,7 @@ from tqdm import tqdm from scipy.sparse import csr_matrix from scipy.cluster import hierarchy as sch -from typing import List, Tuple, Union, Mapping, Any, Callable +from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable # Models import hdbscan @@ -618,20 +618,20 @@ def topics_per_class(self, return topics_per_class - def hierarchical_topics(self, - docs: List[int], + def hierarchical_topics(self, + docs: List[int], topics: List[int], linkage_function: Callable[[csr_matrix], np.ndarray] = None, distance_function: Callable[[csr_matrix], csr_matrix] = None) -> pd.DataFrame: """ Create a hierarchy of topics - + To create this hierarchy, BERTopic needs to be already fitted once. Then, a hierarchy is calculated on the distance matrix of the c-TF-IDF - representation using `scipy.cluster.hierarchy.linkage`. + representation using `scipy.cluster.hierarchy.linkage`. - Based on that hierarchy, we calculate the topic representation at each - merged step. This is a local representation, as we only assume that the - chosen step is merged and not all others which typically improves the + Based on that hierarchy, we calculate the topic representation at each + merged step. This is a local representation, as we only assume that the + chosen step is merged and not all others which typically improves the topic representation. Arguments: @@ -670,10 +670,10 @@ def hierarchical_topics(self, """ if distance_function is None: distance_function = lambda x: 1 - cosine_similarity(x) - + if linkage_function is None: linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) - + # Calculate linkage embeddings = self.c_tf_idf[self._outliers:] X = distance_function(embeddings) @@ -681,8 +681,8 @@ def hierarchical_topics(self, # Calculate basic bag-of-words to be iteratively merged later documents = pd.DataFrame({"Document": docs, - "ID": range(len(docs)), - "Topic": topics}) + "ID": range(len(docs)), + "Topic": topics}) documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) documents_per_topic = documents_per_topic.loc[documents_per_topic.Topic != -1, :] documents = self._preprocess_text(documents_per_topic.Document.values) @@ -690,17 +690,17 @@ def hierarchical_topics(self, bow = self.vectorizer_model.transform(documents) # Extract clusters - hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics", + hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics", "Child_Left_ID", "Child_Left_Name", "Child_Right_ID", "Child_Right_Name"]) for index in tqdm(range(len(Z))): - + # Find clustered documents clusters = sch.fcluster(Z, t=Z[index][2], criterion='distance') - self._outliers cluster_df = pd.DataFrame({"Topic": range(len(clusters)), "Cluster": clusters}) - cluster_df = cluster_df.groupby("Cluster").agg({'Topic':lambda x: list(x)}).reset_index() + cluster_df = cluster_df.groupby("Cluster").agg({'Topic': lambda x: list(x)}).reset_index() nr_clusters = len(clusters) - + # Extract first topic we find to get the set of topics in a merged topic topic = None val = Z[index][0] @@ -709,38 +709,38 @@ def hierarchical_topics(self, topic = int(val) else: val = Z[int(val - len(clusters))][0] - clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]] + clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]] # Group bow per cluster, calculate c-TF-IDF and extract words grouped = csr_matrix(bow[clustered_topics].sum(axis=0)) c_tf_idf = self.transformer.transform(grouped) words_per_topic = self._extract_words_per_topic(words, c_tf_idf, labels=[0]) - + # Extract parent's name and ID parent_id = index + len(clusters) parent_name = "_".join([x[0] for x in words_per_topic[0]][:5]) - + # Extract child's name and ID Z_id = Z[index][0] child_left_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters - + if Z_id - nr_clusters < 0: child_left_name = "_".join([x[0] for x in self.get_topic(Z_id)][:5]) else: child_left_name = hier_topics.iloc[int(child_left_id)].Parent_Name - + # Extract child's name and ID Z_id = Z[index][1] child_right_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters - + if Z_id - nr_clusters < 0: child_right_name = "_".join([x[0] for x in self.get_topic(Z_id)][:5]) else: child_right_name = hier_topics.iloc[int(child_right_id)].Parent_Name - + # Save results - hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name, - clustered_topics, + hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name, + clustered_topics, int(Z[index][0]), child_left_name, int(Z[index][1]), child_right_name] @@ -975,23 +975,23 @@ def get_representative_docs(self, topic: int = None) -> List[str]: return self.representative_docs @staticmethod - def get_topic_tree(hier_topics: pd.DataFrame, - max_distance: float = None, + def get_topic_tree(hier_topics: pd.DataFrame, + max_distance: float = None, tight_layout: bool = False) -> str: """ Extract the topic tree such that it can be printed Arguments: hier_topics: A dataframe containing the structure of the topic tree. This is the output of `topic_model.hierachical_topics()` - max_distance: The maximum distance between two topics. This value is - based on the Distance column in `hier_topics`. - tight_layout: Whether to use a tight layout (narrow width) for + max_distance: The maximum distance between two topics. This value is + based on the Distance column in `hier_topics`. + tight_layout: Whether to use a tight layout (narrow width) for easier readability if you have hundreds of topics. Returns: A tree that has the following structure when printed: . - . + . └─health_medical_disease_patients_hiv ├─patients_medical_disease_candida_health │ ├─■──candida_yeast_infection_gonorrhea_infections ── Topic: 48 @@ -1017,29 +1017,29 @@ def get_topic_tree(hier_topics: pd.DataFrame, print(tree) ``` """ - width = 1 if tight_layout else 4 + width = 1 if tight_layout else 4 if max_distance is None: max_distance = hier_topics.Distance.max() + 1 max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 - + # Extract mapping from ID to name topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} - + # Create tree tree = {str(row[1].Parent_ID): [str(row[1].Child_Left_ID), str(row[1].Child_Right_ID)] for row in hier_topics.iterrows()} def get_tree(start, tree): """ Based on: https://stackoverflow.com/a/51920869/10532563 """ - + def _tree(to_print, start, parent, tree, grandpa=None, indent=""): # Get distance between merged topics - distance = hier_topics.loc[(hier_topics.Child_Left_ID == parent) | - (hier_topics.Child_Right_ID == parent), "Distance"] + distance = hier_topics.loc[(hier_topics.Child_Left_ID == parent) | + (hier_topics.Child_Right_ID == parent), "Distance"] distance = distance.values[0] if len(distance) > 0 else 10 if parent != start: @@ -1047,71 +1047,71 @@ def _tree(to_print, start, parent, tree, grandpa=None, indent=""): to_print += topic_to_name[parent] else: if int(parent) <= max_original_topic: - + # Do not append topic ID if they are not merged if distance < max_distance: - to_print += "■──" + topic_to_name[parent] + f" ── Topic: {parent}" + "\n" + to_print += "■──" + topic_to_name[parent] + f" ── Topic: {parent}" + "\n" else: to_print += "O \n" else: - to_print += topic_to_name[parent] + "\n" - + to_print += topic_to_name[parent] + "\n" + if parent not in tree: return to_print for child in tree[parent][:-1]: - to_print += indent + "├" + "─" + to_print += indent + "├" + "─" to_print = _tree(to_print, start, child, tree, parent, indent + "│" + " " * width) - + child = tree[parent][-1] - to_print += indent + "└" + "─" + to_print += indent + "└" + "─" to_print = _tree(to_print, start, child, tree, parent, indent + " " * (width+1)) - + return to_print to_print = "." + "\n" to_print = _tree(to_print, start, start, tree) return to_print - + start = str(hier_topics.Parent_ID.astype(int).max()) return get_tree(start, tree) def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> None: """ Set custom topic labels in your fitted BERTopic model - + Arguments: topic_labels: If a list of topic labels, it should contain the same number - of labels as there are topics. This must be ordered - from the topic with the lowest ID to the highest ID, + of labels as there are topics. This must be ordered + from the topic with the lowest ID to the highest ID, including topic -1 if it exists. - If a dictionary of `topic ID`: `topic_label`, it can have - any number of topics as it will only map the topics found - in the dictionary. - + If a dictionary of `topic ID`: `topic_label`, it can have + any number of topics as it will only map the topics found + in the dictionary. + Usage: - + First, we define our topic labels with `.get_topic_labels` in which we can customize our topic labels: - + ```python topic_labels = topic_model.get_topic_labels(nr_words=2, topic_prefix=True, word_length=10, separator=", ") ``` - - Then, we pass these `topic_labels` to our topic model which + + Then, we pass these `topic_labels` to our topic model which can be accessed at any time with `.custom_labels`: - + ```python topic_model.set_topic_labels(topic_labels) topic_model.custom_labels ``` - + You might want to change only a few topic labels instead of all of them. - To do so, you can pass a dictionary where the keys are the topic IDs and + To do so, you can pass a dictionary where the keys are the topic IDs and its keys the topic labels: - + ```python topic_model.set_topic_labels({0: "Space", 1: "Sports", 2: "Medicine"}) topic_model.custom_labels @@ -1120,31 +1120,31 @@ def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> unique_topics = sorted(set(self._map_predictions(self.hdbscan_model.labels_))) if isinstance(topic_labels, dict): - if self.custom_labels is not None: - original_labels = {topic: label for topic, label in zip(unique_topics, self.custom_labels)} - else: - info = self.get_topic_info() - original_labels = dict(zip(info.Topic, info.Name)) - custom_labels = [topic_labels.get(topic) if topic_labels.get(topic) else original_labels[topic] for topic in unique_topics] + if self.custom_labels is not None: + original_labels = {topic: label for topic, label in zip(unique_topics, self.custom_labels)} + else: + info = self.get_topic_info() + original_labels = dict(zip(info.Topic, info.Name)) + custom_labels = [topic_labels.get(topic) if topic_labels.get(topic) else original_labels[topic] for topic in unique_topics] elif isinstance(topic_labels, list): if len(topic_labels) == len(unique_topics): custom_labels = topic_labels else: raise ValueError("Make sure that `topic_labels` contains the same number " - "of labels as that there are topics.") + "of labels as that there are topics.") self.custom_labels = custom_labels - def generate_topic_labels(self, - nr_words: int = 3, - topic_prefix: bool = True, - word_length: int = None, + def generate_topic_labels(self, + nr_words: int = 3, + topic_prefix: bool = True, + word_length: int = None, separator: str = "_") -> List[str]: """ Get labels for each topic in a user-defined format Arguments: - original_labels: + original_labels: nr_words: Top `n` words per topic to use topic_prefix: Whether to use the topic ID as a prefix. If set to True, the topic ID will be separated @@ -1161,11 +1161,11 @@ def generate_topic_labels(self, topic_labels: A list of topic labels sorted from the lowest topic ID to the highest. If the topic model was trained using HDBSCAN, the lowest topic ID is -1, otherwise it is 0. - + Usage: - + To create our custom topic labels, usage is rather straightforward: - + ```python topic_labels = topic_model.get_topic_labels(nr_words=2, separator=", ") ``` @@ -1186,9 +1186,61 @@ def generate_topic_labels(self, topic_label = separator.join(words) topic_labels.append(topic_label) - + return topic_labels + def merge_topics(self, + docs: List[str], + topics: List[int], + topics_to_merge: List[Union[Iterable[int], int]]) -> None: + """ + Arguments: + docs: The documents you used when calling either `fit` or `fit_transform` + topics: The topics that were returned when calling either `fit` or `fit_transform` + topics_to_merge: Either a list of topics or a list of list of topics + to merge. For example: + [1, 2, 3] will merge topics 1, 2 and 3 + [[1, 2], [3, 4]] will merge topics 1 and 2, and + separately merge topics 3 and 4. + + Usage: + + If you want to merge topics 1, 2, and 3: + + ```python + topics_to_merge = [1, 2, 3] + topic_model.merge_topics(docs, topics, topics_to_merge) + ``` + + or if you want to merge topics 1 and 2, and separately + merge topics 3 and 4: + + ```python + topics_to_merge = [[1, 2] + [3, 4]] + topic_model.merge_topics(docs, topics, topics_to_merge) + ``` + """ + check_is_fitted(self) + documents = pd.DataFrame({"Document": docs, "Topic": topics}) + + mapping = {topic: topic for topic in set(topics)} + if isinstance(topics_to_merge[0], int): + for topic in sorted(topics_to_merge): + mapping[topic] = topics_to_merge[0] + elif isinstance(topics_to_merge[0], Iterable): + for topic_group in sorted(topics_to_merge): + for topic in topic_group: + mapping[topic] = topic_group[0] + else: + raise ValueError("Make sure that `topics_to_merge` is either" + "a list of topics or a list of list of topics.") + + documents.Topic = documents.Topic.map(mapping) + documents = self._sort_mappings_by_frequency(documents) + self._extract_topics(documents) + self._update_topic_size(documents) + def reduce_topics(self, docs: List[str], topics: List[int], @@ -1313,7 +1365,7 @@ def visualize_documents(self, hide_annotations: Hide the names of the traces on top of each cluster. hide_document_hover: Hide the content of the documents when hovering over specific points. Helps to speed up generation of visualization. - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1414,9 +1466,9 @@ def visualize_hierarchical_documents(self, have a distance less or equal to the maximum distance of the selected list of distances. NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to the length of `hierarchical_topics`. - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. - NOTE: Custom labels are only generated for the original + NOTE: Custom labels are only generated for the original un-merged topics. width: The width of the figure. height: The height of the figure. @@ -1499,7 +1551,7 @@ def visualize_term_rank(self, topics: A selection of topics to visualize. These will be colored red where all others will be colored black. log_scale: Whether to represent the ranking on a log scale - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1555,7 +1607,7 @@ def visualize_topics_over_time(self, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1605,7 +1657,7 @@ def visualize_topics_per_class(self, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1651,7 +1703,7 @@ def visualize_distribution(self, probabilities: An array of probability scores min_probability: The minimum probability score to visualize. All others are ignored. - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. @@ -1703,9 +1755,9 @@ def visualize_hierarchy(self, Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. - NOTE: Custom labels are only generated for the original + NOTE: Custom labels are only generated for the original un-merged topics. width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' @@ -1788,7 +1840,7 @@ def visualize_heatmap(self, top_n_topics: Only select the top n most frequent topics. n_clusters: Create n clusters and order the similarity matrix by those clusters. - custom_labels: Whether to use custom topic labels that were defined using + custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. width: The width of the figure. height: The height of the figure. diff --git a/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md b/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md index 9cef8803..d9bca553 100644 --- a/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md +++ b/docs/getting_started/hierarchicaltopics/hierarchicaltopics.md @@ -323,3 +323,26 @@ to view, we can see better which topics could be logically merged: └─■──modem_port_serial_irq_com ── Topic: 10 ``` + + +### **Merge topics** + +After seeing the potential hierarchy of your topic, you might want to merge specific +topics. For example, if topic 1 is +`1_space_launch_moon_nasa` and topic 2 is `2_spacecraft_solar_space_orbit` it might +make sense to merge those two topics as they are quite similar in meaning. In BERTopic, +you can use `.merge_topics` to manually select and merge those topics. Doing so will +update their topic representation which in turn updates the entire model: + +```python +topics_to_merge = [1, 2] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` + +If you have several groups of topics you want to merge, create a list of lists instead: + +```python +topics_to_merge = [[1, 2] + [3, 4]] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` diff --git a/docs/getting_started/topicreduction/topicreduction.md b/docs/getting_started/topicreduction/topicreduction.md index f3bb7c21..2c7262d0 100644 --- a/docs/getting_started/topicreduction/topicreduction.md +++ b/docs/getting_started/topicreduction/topicreduction.md @@ -6,14 +6,30 @@ so. ### **Manual Topic Reduction** Each resulting topic has its own feature vector constructed from c-TF-IDF. Using those feature vectors, we can find the most similar -topics and merge them. If we do this iteratively, starting from the least frequent topic, we can reduce the number -of topics quite easily. We do this until we reach the value of `nr_topics`: +topics and merge them. If we do this iteratively, starting from the least frequent topic, we can reduce the number of topics quite easily. We do this until we reach the value of `nr_topics`: ```python from bertopic import BERTopic topic_model = BERTopic(nr_topics=20) ``` +It is also possible to manually select certain topics that you believe should be merged. +For example, if topic 1 is `1_space_launch_moon_nasa` and topic 2 is `2_spacecraft_solar_space_orbit` +it might make sense to merge those two topics: + +```python +topics_to_merge = [1, 2] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` + +If you have several groups of topics you want to merge, create a list of lists instead: + +```python +topics_to_merge = [[1, 2] + [3, 4]] +topic_model.merge_topics(docs, topics, topics_to_merge) +``` + ### **Automatic Topic Reduction** One issue with the approach above is that it will merge topics regardless of whether they are very similar. They are simply the most similar out of all options. This can be resolved by reducing the number of topics automatically. diff --git a/docs/index.md b/docs/index.md index 9a30b2ff..776161bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -98,6 +98,7 @@ For quick access to common functions, here is an overview of BERTopic's main met | Update topic representation | `.update_topics(docs, topics, n_gram_range=(1, 3))` | | Generate topic labels | `.generate_topic_labels()` | | Set topic labels | `.set_topic_labels(my_custom_labels)` | +| Merge topics | `.merge_topics(docs, topics, topics_to_merge)` | | Reduce nr of topics | `.reduce_topics(docs, topics, nr_topics=30)` | | Find topics | `.find_topics("vehicle")` | | Save model | `.save("my_model")` | diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 2e5f0b74..8546bda8 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -96,4 +96,10 @@ def test_full_model(topic_model): # Test setting topic labels topic_model.set_topic_labels(topic_labels) - assert topic_model.custom_labels == topic_labels \ No newline at end of file + assert topic_model.custom_labels == topic_labels + + # Test merging topics + freq = topic_model.get_topic_freq(0) + topics_to_merge = [0, 1] + topic_model.merge_topics(newsgroup_docs, new_topics, topics_to_merge) + assert freq < topic_model.get_topic_freq(0)