In [9]:
import pickle
import os
import torch
import sys
from collections import Counter
from gensim.models import LdaModel
from bertopic import BERTopic
from topic_specificity import calculate_specificity_for_all_topics
from TopicSpecificityBerTopic.topic_specificity_bertopic import calculate_specificity_bertopic

In [10]:
collections = ['20ng', 'wsj', 'wiki']
lang_models = ['doc2vec','sbert','repllama']
topic_models = ['lda','bertopic']

# Single “source of truth” for optimal topic counts:
OPTIMAL_LDA_TOPICS = {
    'wsj': 50,
    'wiki': 80,
    '20ng': 70
}

# Placeholder for future Bertopic‐specific topic counts
OPTIMAL_BERTOPIC_TOPICS = {
    'wsj': 50,
    'wiki':80,
    '20ng':40
}

In [11]:
# -----------------------------------------------
# helper: fetch keywords  ------------------------
# -----------------------------------------------
def topic_keywords(tid, model, model_type, topn=10):
    if model_type == 'lda':
        return [w for w, _ in model.show_topic(tid, topn=topn)]
    return [w for w, _ in model.get_topic(tid)[:topn]]

In [13]:
# calculate LDA mapping
for i, coll in enumerate(collections):
    for topic_model in topic_models:
        # 1) Compute specificity
        if topic_model == 'lda':
            n_topics = OPTIMAL_LDA_TOPICS[coll]
            lda_path = os.path.join('Results', 'LDA', f"{coll}_lda{n_topics}.model")
            if not os.path.exists(lda_path):
                sys.exit(f"[ERROR] Cannot find LDA model: {lda_path}")
            lda = LdaModel.load(lda_path)

            corpus_path = os.path.join('Results', 'LDA', f"{coll}_corpus.pkl")
            if not os.path.exists(corpus_path):
                sys.exit(f"[ERROR] Corpus not found: {corpus_path}")
            corpus = pickle.load(open(corpus_path, 'rb'))

            scores = calculate_specificity_for_all_topics(
                model=lda,
                corpus=corpus,
                mode='lda',
                threshold_mode='gmm',
                specificity_mode='diff'
            )

        else:  # bertopic
            n_topics = OPTIMAL_BERTOPIC_TOPICS[coll]

            bt_path = os.path.join('Results', 'BERTOPIC', f"{coll}_bertopic_{n_topics}.model")
            if not os.path.exists(bt_path):
                sys.exit(f"[ERROR] BERTopic model not found: {bt_path}")
            bt = BERTopic.load(bt_path)

            scores = calculate_specificity_bertopic(
                bt,
                threshold_mode='gmm',
                specificity_mode='diff'
            )
        for lang_model in lang_models:
            # 2) Load mapping results for this (coll, lang_model, topic_model)
            map_file = os.path.join('Results', f"{coll}_{lang_model}_{topic_model}_{n_topics}_mapping.pkl")
            if lang_model == 'repllama' and coll in ['wsj','20ng']:
                map_file = os.path.join('Results', f"{coll}_{lang_model}_{topic_model}_mapping.pkl")
            if not os.path.exists(map_file):
                sys.exit(f"[ERROR] Mapping file not found: {map_file}")
            mappings = pickle.load(open(map_file, 'rb'))
            
            # 3) find the mapped topics and unmapped topics
            if topic_model == 'bertopic':
                real_n_topics = len(scores)
            else:
                real_n_topics = n_topics
            mapped_topics = [t for _, t in mappings]
            unmapped_topics = set(range(real_n_topics)) - set(mapped_topics)
            mapped_scores = [scores[t] for t in mapped_topics]
            unmapped_scores = [scores[t] for t in unmapped_topics]
            
            avg_mapped_score = sum(mapped_scores)/len(mapped_scores)
            avg_unmapped_score = sum(unmapped_scores)/len(unmapped_scores)
            
            # 4) print out the results
            # print(f"{coll}\t{topic_model}\t{lang_model}\t{avg_mapped_score:.4f}\t{avg_unmapped_score:.4f}")

            # -----------------------------------------------
            # 5) pick and print topics -----------------------
            # -----------------------------------------------
            # --- (a)  most-frequent mapped topics, tie-break by specificity
            freq = Counter(mapped_topics)                              # count hits
            top_mapped = sorted(freq,                                  # topic ids only
                                key=lambda t: (-freq[t], -scores[t]))  # primary ↓freq, secondary ↓score
            top_mapped = top_mapped[:3]

            # --- (b)  worst-specificity unmapped topics (unchanged)
            worst_unmapped = sorted(unmapped_topics, key=lambda t: scores[t])[:3]

            # --- (c) pretty print --------------------------------------
            print(f"—— {coll.upper()} | {topic_model.upper()} | {lang_model} ——")
            print(f"Avg specificity: mapped={avg_mapped_score:.4f}, "
                  f"unmapped={avg_unmapped_score:.4f}")

            print("\nTop-3 mapped topics (by frequency):")
            for rank, tid in enumerate(top_mapped, 1):
                kw = ", ".join(topic_keywords(tid,
                                              lda if topic_model == 'lda' else bt,
                                              topic_model))
                print(f"  {rank}. topic {tid:>3} | freq={freq[tid]:>3} | "
                      f"score={scores[tid]:.4f} | {kw}")

            print("\nWorst-3 unmapped topics (by specificity):")
            for rank, tid in enumerate(worst_unmapped, 1):
                kw = ", ".join(topic_keywords(tid,
                                              lda if topic_model == 'lda' else bt,
                                              topic_model))
                print(f"  {rank}. topic {tid:>3} | score={scores[tid]:.4f} | {kw}")

            print("-" * 60)

—— 20NG | LDA | doc2vec ——
Avg specificity: mapped=0.0974, unmapped=0.0513

Top-3 mapped topics (by frequency):
  1. topic   2 | freq=  9 | score=0.1707 | think, go, like, know, time, get, want, thing, come, way
  2. topic  22 | freq=  7 | score=0.1095 | god, jesus, church, christ, sin, love, christian, bible, lord, say
  3. topic   0 | freq=  6 | score=0.0763 | space, earth, planet, moon, launch, solar, orbit, spacecraft, system, mission

Worst-3 unmapped topics (by specificity):
  1. topic  24 | score=0.0240 | orthodox, son, slipper, till, candida, italy, beast, abstract, explore, presentation
  2. topic  68 | score=0.0265 | keyboard, sub, rgb, virtual, thanx, custom, interior, silicon, plastic, fish
  3. topic  54 | score=0.0278 | dry, push, playback, evolution, depth, clean, pop, educational, reverse, quit
------------------------------------------------------------
—— 20NG | LDA | sbert ——
Avg specificity: mapped=0.0954, unmapped=0.0448

Top-3 mapped topics (by frequency):
  1. to