In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

from mypackage.elastic import Document
from sentence_transformers import SentenceTransformer
from mypackage.sentence import SentenceChain, doc_to_sentences, iterative_merge
from mypackage.clustering import chain_clustering, visualize_clustering, group_chains_by_label

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
doc = Document.from_json("../cached_docs/doc_0001.json", text_path="article")

2025-04-06 15:13:12.985317: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743941593.000276   19461 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743941593.004576   19461 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
sentences = doc_to_sentences(doc, model)
merged = iterative_merge(sentences, threshold=0.6, round_limit=None, pooling_method="average")

labels, clusters = chain_clustering(merged)

In [3]:
from bertopic import BERTopic
from bertopic.backend import BaseEmbedder
from bertopic.backend._utils import select_backend
from bertopic.cluster import BaseCluster
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.dimensionality import BaseDimensionalityReduction

In [4]:
empty_embedding_model = BaseEmbedder()
empty_dimensionality_model = BaseDimensionalityReduction()
empty_cluster_model = BaseCluster()
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

In [5]:
topic_model= BERTopic(
        embedding_model=empty_embedding_model,
        umap_model=empty_dimensionality_model,
        hdbscan_model=empty_cluster_model,
        ctfidf_model=ctfidf_model
)

In [6]:
from itertools import chain
import numpy as np

dista = 1

if dista == 0:
    docs = ["\n\n".join([chain.text for chain in cluster]) for label, cluster in clusters.items() if label >= 0]
    labels = list(filter(lambda x: x >= 0, clusters.keys()))

elif dista == 1:
    docs = list(chain.from_iterable([[chain.text, label] for chain in cluster] for label, cluster in clusters.items() if label >= 0))
    docs = np.array(docs)
    labels = docs[:, 1].astype(int)
    docs = docs[:, 0]
    print(labels)

topics, probs = topic_model.fit_transform(docs, y=labels)

[ 6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  1  1  1  1  1  1
  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  9  9  9  9  9  9  9  9
  9  9  9  9  9  9  7  7  7  7  7  7  7  7  7  7  7  7  3  3  3  3  3  5
  5  5  5  5  5  5  5  5  5  5  5  5  8  8  8  8  8  8  8  8  8  8 10 10
 10 10 10 10 10 10 10  4  4  4  4  4  4  0  0  0  0  0  0  0  0  0  0  0
  0  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2]


In [7]:
from itertools import starmap, groupby

info_df = topic_model.get_topic_info()

for row in info_df['Representation']:
    print(row)

newdic = {
            topic: f"cluster_{cluster:02}"
            for cluster, topic in sorted(topic_model.topic_mapper_.get_mappings().items(), key=lambda item: item[1])
}

print(newdic)

['qc', 'gr', 'no', '2013', 'et', 'al', 'arxiv', '0264', '9381', '1088']
['sgwb', 'gravitational', 'waves', 'ptas', 'xcite', 'also', 'has', 'sources', 'anisotropy', 'of']
['phys', 'allen', 'takahashi', 'soda', 'ottewill', 'zaldarriaga', 'kuroyanagi', 'jackiw', 'hellings', 'mingarelli']
['cg', 'mode', 'dominated', 'xmath218', 'xmath3', 'region', 'xmath2', 'by', 'when', 'applies']
['pulsars', 'pulsar', 'signals', 'optimal', 'to', 'time', 'filter', 'correlated', 'xmath1', 'two']
['parameters', 'b_', 'stokes', 'ast', 'times', 'with', 'parameter', 'harmonics', 'spin', 'electromagnetic']
['1103', 'physrevd', '10', 'ph', 'doi', 'astro', 'arxiv', 'physrevlett', '85', 'phys']
['gv', 'orfs', 'fig', 'curve', 'generalized', 'red', 'orf', 'for', 'dipole', 'find']
['xi', 'right', 'left', 'cos', 'gamma', 'frac', '34', 'gamma_', '11', '12']
['curve', 'shows', 'gi', 'dashed', 'blue', 'dark', 'green', 'xmath266', 'dash', 'xmath267']
['average', 'ensemble', 'xmath65', 'bracket', 'temporal', 'over', 'here'

In [8]:
#Let's see where each chain belongs

docs = list(chain.from_iterable([[chain.text, label] for chain in cluster] for label, cluster in clusters.items() if label >= 0))

In [10]:
topic_model.embedding_model = select_backend(model)
print(topic_model.find_topics(search_term="what are sunspots"))

([0], [1.0])
