In [None]:
!pip install octis

from octis.evaluation_metrics.coherence_metrics import Coherence
from octis.evaluation_metrics.diversity_metrics import TopicDiversity
from bertopic import BERTopic
from gensim.utils import simple_preprocess
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

In [None]:
from datasets import load_dataset
dataset = load_dataset("knkarthick/dialogsum")
print(f"Features: {dataset['train'].column_names}")



  0%|          | 0/3 [00:00<?, ?it/s]

Features: ['id', 'dialogue', 'summary', 'topic']


In [None]:
from typing import List

def preprocess(text: List[str]) -> List[str]:
    return [normalize(sentence) for sentence in text]

def normalize(sentence: str) -> str:
    eng_stopwords = stopwords.words("english")
    tokens = []
    for token in simple_preprocess(sentence, min_len=1):
        if token not in eng_stopwords:
            tokens.append(token)
    return " ".join(tokens)

In [None]:
data = dataset["test"]
corpus = preprocess(data['dialogue'])

In [None]:
emb_model = "sentence-transformers/all-MiniLM-L6-v2" #'distilbert-base-nli-mean-tokens' # 
topic_model = BERTopic(embedding_model=emb_model, language='english', verbose=True)
topics, _ = topic_model.fit_transform(corpus)
print(len(set(topics)))

Batches:   0%|          | 0/47 [00:00<?, ?it/s]

2022-12-12 06:15:08,670 - BERTopic - Transformed documents to Embeddings
2022-12-12 06:15:15,404 - BERTopic - Reduced dimensionality
2022-12-12 06:15:15,480 - BERTopic - Clustered reduced embeddings


58


In [None]:
from tqdm import tqdm
from collections import defaultdict 

def calc_topic_acc(topic_model, topics: List[int], true_topics: List[str], max_k = 5):

    num_true = {k: 0 for k in range(1, max_k + 1)}

    pred_topics = defaultdict(set)
    for i, pred_topic in enumerate(topics):
        if pred_topic != -1:
            pred_topics[pred_topic].add(i)

    for idd, topic in enumerate(tqdm(true_topics)):
        relevant_topics, _ = topic_model.find_topics(topic, top_n=max_k + 1)

        if -1 in relevant_topics:
            relevant_topics.remove(-1)
        else:
            relevant_topics = relevant_topics[:max_k]

        for k, relevant_topic in enumerate(relevant_topics):
            for i in pred_topics[relevant_topic]:
                if true_topics[i] == topic:
                    for cur_k in range(k + 1, max_k + 1):
                        num_true[cur_k] += 1
                    break           

    for k in range(1, max_k + 1):
        num_true[k] /= len(true_topics)

    return num_true
    
    
def prettify_output(model, topics):
    bertopic_topics = [
        [vals[0] for vals in model.get_topic(i)[:10]]
        for i in range(len(set(topics)) - 1)
    ]

    return {"topics": bertopic_topics}            


In [None]:
tacc = calc_topic_acc(topic_model, topics, data['topic'])
print()
print("TopicAcc@k:", tacc)

100%|██████████| 1500/1500 [00:14<00:00, 105.57it/s]


TopicAcc@k: {1: 0.3446666666666667, 2: 0.46266666666666667, 3: 0.5573333333333333, 4: 0.6306666666666667, 5: 0.6693333333333333}





In [None]:
output_tm = prettify_output(topic_model, topics)
topk = 10

In [None]:
npmi = Coherence(texts=list(map(lambda x: x.split(), corpus)), topk=topk, measure="c_npmi")
tc = npmi.score(output_tm)
print("TC =", tc)

TC = -0.20888024771852007


In [None]:
topic_diversity = TopicDiversity(topk=topk)
td = topic_diversity.score(output_tm)
print("TD =", td)

TD = 0.8736842105263158
