In [None]:
!pip install bertopic
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN

# Prepare embeddings
docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data']

#The model is a Hugging Face transformer model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
corpus_embeddings = embedding_model.encode(docs, show_progress_bar = True)


In [None]:
# We can take a look at the actual embeddings
corpus_embeddings.view()

In [None]:
# Stopwords should not be removed from the documents before preparing the embeddings because the
# transformers have been trained on normal text, including stopwords. However, it can be useful to
# remove stopwords later. This can be done by including a CountVectorizer
from sklearn.feature_extraction.text import CountVectorizer
vectorizer_model = CountVectorizer(stop_words = "english", max_df = .95, min_df = .01)
# vectorizer_model = CountVectorizer(stop_words = "english")


In [None]:
# setting parameters for HDBSCAN (clustering) and UMAP (dimensionality reduction)
hdbscan_model = HDBSCAN(min_cluster_size = 30, metric = 'euclidean', prediction_data = True)
umap_model = UMAP(n_neighbors = 15, n_components = 10, metric = 'cosine', low_memory = False)

In [None]:
# Train BERTopic
model = BERTopic(
    n_gram_range=(1, 3),
    vectorizer_model = vectorizer_model,
    nr_topics = 40,
    top_n_words = 10,
    umap_model = umap_model,
    hdbscan_model = hdbscan_model,
    min_topic_size = 30,
    calculate_probabilities = True).fit(docs, corpus_embeddings)

In [None]:
# get interesting properties of the model
topics, probabilities = model.transform(docs, corpus_embeddings)
df_topic_freq = model.get_topic_freq()
print(df_topic_freq)
topics_count = len(df_topic_freq) - 1
corpus_embeddings.view()
model.visualize_barchart(top_n_topics=topics_count)

In [None]:
model.get_topic_info()

In [None]:
# look at some visualizations
fig = model.visualize_documents(docs, embeddings = corpus_embeddings, sample = .6, topics = [0,1,2,3,4,5,6],
                          hide_annotations = False, hide_document_hover = True)

In [None]:
fig.write_image("./clusters.svg")

In [None]:
fig2 = model.visualize_barchart()
fig2.write_image("./barchart.svg")

In [None]:
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
new_docs = ["I'm looking for a new graphics card","when is the next nasa launch"]
embeddings = sentence_model.encode(new_docs)
topics, probs = model.transform(new_docs,embeddings)
print(topics)

In [None]:
print(topics,probs)

In [None]:
all_topics = model.get_topics()
topic_info = model.get_topic_info()
topic_info